In [6]:
# Import Statements
from fits_loader import FitsLoader
from fits_visualization_utils import interactive_plot
from fits_plotter import FitsPlotter

# To load FITS data
fits_loader = FitsLoader()

# Ensure FITS file is loaded correctly: put your FITS File here!
header, data, wcs = fits_loader.loading_FITS('M51_HST (2).fits')

INFO: 
                Inconsistent SIP distortion information is present in the FITS header and the WCS object:
                SIP coefficients were detected, but CTYPE is missing a "-SIP" suffix.
                astropy.wcs is using the SIP distortion coefficients,
                therefore the coordinates calculated here might be incorrect.

                If you do not want to apply the SIP distortion coefficients,
                please remove the SIP coefficients from the FITS header or the
                WCS object.  As an example, if the image is already distortion-corrected
                (e.g., drizzled) then distortion components should not apply and the SIP
                coefficients should be removed.

                While the SIP distortion coefficients are being applied here, if that was indeed the intent,
                for consistency please append "-SIP" to the CTYPE in the FITS header or the WCS object.

                 [astropy.wcs.wcs]


In [8]:
# Below is the driver statement you can modify. 
# I suggest subsample of 10 to load FITS faster. 
# You can view the statistics and recommended scalings of the image
# You can change the title, color, and save the image as png. If there is wcs, it should appear in axes of the image

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import os
from astropy.io import fits
from astropy.wcs import WCS
from ipywidgets import interactive, FloatSlider, widgets, HBox, VBox
import warnings
from tqdm import tqdm
import time 

vmin_slider = FloatSlider(min=np.min(data)-20, max=np.max(data), step=0.1, description='vmin', value=np.min(data), layout={'width': '80%'})
vmax_slider = FloatSlider(min=np.min(data), max=np.max(data)+20, step=0.1, description='vmax', value=np.max(data), layout={'width': '80%'})
subsample_slider = FloatSlider(min=1, max=20, step=1, description='Subsample Factor', value=1, layout={'width': '80%'})
scaling_dropdown = widgets.Dropdown(
    options=['linear', 'log', 'sqrt', 'sinh', 'power:2', 'power:3'], 
    value='linear',
    description='Scaling Method:'
)
x_axis_input = widgets.Text(value='Default X Coordinates', description='X Axis Label:')
y_axis_input = widgets.Text(value='Default Y Coordinates', description='Y Axis Label:')
title_input = widgets.Text(value='Title', description='Plot Title:')
annotation_text = widgets.Text(value='', description='Annotation Text:')
save_filename = widgets.Text(value='plot.png', description='Filename:')



def interactive_plot(vmin, vmax, subsample_factor, scaling, x_label, y_label, plot_title, cmap):
    global fits_plotter  # Declare fits_plotter as a global variable
    
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.plot_image(vmin, vmax, scaling, subsample_factor=int(subsample_factor), cmap=cmap, x_axis=x_label, y_axis=y_label, title=plot_title)

def display_stats(_):
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.display_statistics()


    
save_button = widgets.Button(description="Save Plot")

def save_plot(_):
    try:
        print("Saving plot...")
        if fits_plotter.fig is not None:
            filename = save_filename.value
            if not filename:
                filename = "plot.png"
            print("Filename:", filename)
            fits_plotter.save_plot(filename)
            print("Plot saved successfully.")
        else:
            print("No plot to save.")
    except Exception as e:
        print("Error saving plot:", e)

save_button.on_click(save_plot)

def display_suggested_scalings(button):
    global fits_plotter
    # Display suggested scaling methods and vmin/vmax values
    suggested_scalings = fits_plotter.evaluate_scaling_methods()
    if suggested_scalings:
        print("Top three suggested scaling methods:")
        for i, (method, suggested_vmin, suggested_vmax) in enumerate(suggested_scalings):
            print(f"{i+1}. Scaling Method: {method}, Suggested vmin: {suggested_vmin}, Suggested vmax: {suggested_vmax}")

# Define a button for displaying suggested scalings
display_scalings_button = widgets.Button(description="Display Scalings")
display_scalings_button.on_click(display_suggested_scalings)

interactive_plot = interactive(interactive_plot, vmin=vmin_slider, vmax=vmax_slider, subsample_factor=subsample_slider, scaling=scaling_dropdown, x_label=x_axis_input, y_label=y_axis_input, plot_title=title_input, cmap=widgets.Dropdown(options=plt.colormaps(), value='gray_r', description='Color Map:'))
stats_button = widgets.Button(description="Display Statistics")
save_button = widgets.Button(description="Save Plot")
stats_button.on_click(display_stats)
save_button.on_click(save_plot)

VBox([interactive_plot, stats_button, display_scalings_button, save_filename, save_button])

VBox(children=(interactive(children=(FloatSlider(value=-0.6049584746360779, description='vmin', layout=Layout(…