In [4]:
import matplotlib.pyplot as plt
import numpy as np
import os
from astropy.io import fits
from astropy.wcs import WCS
from ipywidgets import interactive, FloatSlider, widgets, HBox, VBox
import warnings
from tqdm import tqdm
import time 

# Filter out runtime warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

class FitsLoader:
    def __init__(self):
        pass
    
    @staticmethod
    def loading_FITS(string_filepath, extension=0):
        try:
            if not os.path.exists(string_filepath):
                raise FileNotFoundError("File does not exist")
            assert isinstance(string_filepath, str), 'Filepath must be a string'      
            endings = ['.FIT', '.FITS', '.fit', '.fits']
            assert any(string_filepath.endswith(i) for i in endings), "Must be of type .FIT, .FITS, .fit, .fits"
            with fits.open(string_filepath) as hdu:
                extension_total = len(hdu)
                assert extension < extension_total, 'Extension not in range, the extension value must be less than {}'.format(extension_total)
                header = hdu[extension].header
                data = hdu[extension].data
                wcs = WCS(header)  # Extract WCS information
            return header, data, wcs
        except (AssertionError, ValueError, FileNotFoundError) as msg:
            print(msg)
            return None, None, None

class FitsPlotter:
    def __init__(self, data, wcs=None):  # Add WCS as an argument
        self.data = data
        self.wcs = wcs  # Store WCS information
        self.fig = None
        self.ax = None
    
    def plot_image(self, vmin, vmax, scaling='linear', subsample_factor=1, figsize=(8, 8), cmap='gray_r', x_axis='Default X Coordinates', y_axis='Default Y Coordinates', title='Title'):
        if self.data is not None:
            # Subsample the data
            data_subsampled = self.data[::subsample_factor, ::subsample_factor]

            # Apply scaling to the data
            if scaling == 'log':
                data_scaled = np.where(data_subsampled > 0, np.log(data_subsampled), 0)
            elif scaling == 'sqrt':
                data_scaled = np.sqrt(data_subsampled)
            elif scaling == 'sinh':
                data_scaled = np.sinh(data_subsampled)
            elif scaling == 'histogram':
                data_scaled, _ = np.histogram(data_subsampled, bins=256, range=(np.min(data_subsampled), np.max(data_subsampled)))
            elif scaling == 'linear':
                data_scaled = data_subsampled
            elif scaling.startswith('power'):
                power = float(scaling.split(':')[1])
                data_scaled = np.power(data_subsampled, power)
            else:
                raise ValueError('Invalid scaling method')

            if self.fig is None or self.ax is None:
                self.fig, self.ax = plt.subplots(figsize=(5, 5))
                if self.wcs:  # If WCS information is provided
                    self.ax = plt.subplot(projection=self.wcs)  # Create a WCS subplot

            self.ax.clear()
            data_rotated = np.rot90(data_scaled, 2) 
            data_rotated[data_rotated == 0] = np.nan

            im = self.ax.imshow(data_rotated, cmap=cmap, vmin=vmin, vmax=vmax)
            self.ax.set_title(title)
            self.ax.set_xlabel(x_axis)
            self.ax.set_ylabel(y_axis)
            plt.colorbar(im, ax=self.ax)
            plt.show()

    def evaluate_scaling_methods(self):
        if self.data is not None:
            # Compute data statistics
            data_min = np.nanmin(self.data)
            data_max = np.nanmax(self.data)
            data_mean = np.nanmean(self.data)
            data_std = np.nanstd(self.data)

            # Evaluate each scaling method
            scaling_methods = ['linear', 'log', 'sqrt', 'sinh']
            scaling_results = {}
            for scaling_method in scaling_methods:
                if scaling_method == 'linear':
                    score = 1  # Linear scaling is always considered
                    suggested_vmin = data_min
                    suggested_vmax = data_max
                else:
                    # Apply the scaling method to the data
                    if scaling_method == 'log':
                        scaled_data = np.where(self.data > 0, np.log(self.data), 0)
                    elif scaling_method == 'sqrt':
                        scaled_data = np.sqrt(self.data)
                    elif scaling_method == 'sinh':
                        scaled_data = np.sinh(self.data)

                    # Compute the score based on data distribution
                    scaled_min = np.nanmin(scaled_data)
                    scaled_max = np.nanmax(scaled_data)
                    scaled_mean = np.nanmean(scaled_data)
                    scaled_std = np.nanstd(scaled_data)
                    score = data_std / scaled_std  # Higher score indicates better scaling

                    # Compute suggested vmin and vmax values
                    suggested_vmin = scaled_min - 0.1 * (scaled_max - scaled_min)  # Adjust vmin to provide some padding
                    suggested_vmax = scaled_max + 0.1 * (scaled_max - scaled_min)  # Adjust vmax to provide some padding

                scaling_results[scaling_method] = {'score': score, 'suggested_vmin': suggested_vmin, 'suggested_vmax': suggested_vmax}

            # Sort scaling methods based on their scores
            sorted_methods = sorted(scaling_results, key=lambda x: scaling_results[x]['score'], reverse=True)

            # Return top three scaling methods along with suggested vmin and vmax values
            top_three_results = [(method, scaling_results[method]['suggested_vmin'], scaling_results[method]['suggested_vmax']) for method in sorted_methods[:3]]
            return top_three_results
        else:
            return None



    def display_statistics(self):
        if self.data is not None:
            print("Mean:", np.nanmean(self.data))
            print("Median:", np.nanmedian(self.data))
            print("Standard Deviation:", np.nanstd(self.data))
        else:
            print("No data available to compute statistics.")


    def add_annotation(self, x, y, text):
        if self.ax is not None:
            self.ax.annotate(text, xy=(x, y), xytext=(x + 50, y + 50),
                             arrowprops=dict(facecolor='red', shrink=0.05))

    def zoom(self, x1, x2, y1, y2):
        if self.ax is not None:
            self.ax.set_xlim(x1, x2)
            self.ax.set_ylim(y1, y2)
            self.fig.canvas.draw_idle()
    def equalize_histogram(self):
        if self.data is not None:
            if np.all(np.isfinite(self.data)):  # Check if all values are finite
                # Replace NaN values with zeros
                data_no_nan = np.nan_to_num(self.data)
                hist, bins = np.histogram(data_no_nan.flatten(), bins=256, range=(np.nanmin(data_no_nan), np.nanmax(data_no_nan)))
                cdf = hist.cumsum()
                cdf_normalized = cdf / cdf.max()
                equalized_data = np.interp(data_no_nan.flatten(), bins[:-1], cdf_normalized)
                equalized_data = equalized_data.reshape(data_no_nan.shape)
                # Restore NaN values
                equalized_data[np.isnan(self.data)] = np.nan
                self.data = equalized_data
                self.plot_image(np.nanmin(self.data), np.nanmax(self.data))
            else:
                print("Histogram equalization cannot be performed: Data contains non-finite values.")



    def save_plot(self, filename):
        if self.fig is not None:
            self.fig.savefig(filename)

fits_loader = FitsLoader()
header, data, wcs = fits_loader.loading_FITS('M51_HST (2).fits')  # Load FITS data and WCS information

vmin_slider = FloatSlider(min=np.min(data)-20, max=np.max(data), step=0.1, description='vmin', value=np.min(data), layout={'width': '80%'})
vmax_slider = FloatSlider(min=np.min(data), max=np.max(data)+20, step=0.1, description='vmax', value=np.max(data), layout={'width': '80%'})
subsample_slider = FloatSlider(min=1, max=20, step=1, description='Subsample Factor', value=1, layout={'width': '80%'})

scaling_dropdown = widgets.Dropdown(
    options=['linear', 'log', 'sqrt', 'sinh', 'histogram', 'power:2', 'power:3'], 
    value='linear',
    description='Scaling Method:'
)
x_axis_input = widgets.Text(value='Default X Coordinates', description='X Axis Label:')
y_axis_input = widgets.Text(value='Default Y Coordinates', description='Y Axis Label:')
title_input = widgets.Text(value='Title', description='Plot Title:')
# annotation_x = FloatSlider(min=0, max=data.shape[1], step=1, description='X:', value=data.shape[1]//2)
# annotation_y = FloatSlider(min=0, max=data.shape[0], step=1, description='Y:', value=data.shape[0]//2)
annotation_text = widgets.Text(value='', description='Annotation Text:')
# zoom_x1 = FloatSlider(min=0, max=data.shape[1], step=1, description='X1:', value=0)
# zoom_x2 = FloatSlider(min=0, max=data.shape[1], step=1, description='X2:', value=data.shape[1])
# zoom_y1 = FloatSlider(min=0, max=data.shape[0], step=1, description='Y1:', value=0)
# zoom_y2 = FloatSlider(min=0, max=data.shape[0], step=1, description='Y2:', value=data.shape[0])
save_filename = widgets.Text(value='plot.png', description='Filename:')



def interactive_plot(vmin, vmax, subsample_factor, scaling, x_label, y_label, plot_title, cmap):
    global fits_plotter  # Declare fits_plotter as a global variable
    
    fits_plotter = FitsPlotter(data, wcs=wcs)
    
#     # Display suggested scaling methods and vmin/vmax values
#     suggested_scalings = fits_plotter.evaluate_scaling_methods()
#     if suggested_scalings:
#         print("Top three suggested scaling methods:")
#         for i, (method, suggested_vmin, suggested_vmax) in enumerate(suggested_scalings):
#             print(f"{i+1}. Scaling Method: {method}, Suggested vmin: {suggested_vmin}, Suggested vmax: {suggested_vmax}")
    
    # Plot the image
    fits_plotter.plot_image(vmin, vmax, scaling, subsample_factor=int(subsample_factor), cmap=cmap, x_axis=x_label, y_axis=y_label, title=plot_title)

def display_stats(_):
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.display_statistics()

def add_annotation(_):
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.add_annotation(annotation_x.value, annotation_y.value, annotation_text.value)

def zoom_plot(_):
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.zoom(zoom_x1.value, zoom_x2.value, zoom_y1.value, zoom_y2.value)

def equalize_hist(_):
    fits_plotter = FitsPlotter(data, wcs=wcs)
    fits_plotter.equalize_histogram()
# Define save_button before save_plot function
save_button = widgets.Button(description="Save Plot")

def save_plot(_):
    try:
        print("Saving plot...")
        if fits_plotter.fig is not None:
            filename = save_filename.value
            if not filename:
                filename = "plot.png"
            print("Filename:", filename)
            fits_plotter.save_plot(filename)
            print("Plot saved successfully.")
        else:
            print("No plot to save.")
    except Exception as e:
        print("Error saving plot:", e)

save_button.on_click(save_plot)




save_button.on_click(save_plot)


def display_suggested_scalings(button):
    global fits_plotter
    
    # Display suggested scaling methods and vmin/vmax values
    suggested_scalings = fits_plotter.evaluate_scaling_methods()
    if suggested_scalings:
        print("Top three suggested scaling methods:")
        for i, (method, suggested_vmin, suggested_vmax) in enumerate(suggested_scalings):
            print(f"{i+1}. Scaling Method: {method}, Suggested vmin: {suggested_vmin}, Suggested vmax: {suggested_vmax}")

# Define a button for displaying suggested scalings
display_scalings_button = widgets.Button(description="Display Scalings")
display_scalings_button.on_click(display_suggested_scalings)


interactive_plot = interactive(interactive_plot, vmin=vmin_slider, vmax=vmax_slider, subsample_factor=subsample_slider, scaling=scaling_dropdown, x_label=x_axis_input, y_label=y_axis_input, plot_title=title_input, cmap=widgets.Dropdown(options=plt.colormaps(), value='gray_r', description='Color Map:'))
stats_button = widgets.Button(description="Display Statistics")
annotation_button = widgets.Button(description="Add Annotation")
# zoom_button = widgets.Button(description="Zoom")
# equalize_hist_button = widgets.Button(description="Equalize Histogram")
save_button = widgets.Button(description="Save Plot")
stats_button.on_click(display_stats)
annotation_button.on_click(add_annotation)
zoom_button.on_click(zoom_plot)
equalize_hist_button.on_click(equalize_hist)
save_button.on_click(save_plot)

# VBox([interactive_plot, stats_button, annotation_button, HBox([annotation_x, annotation_y, annotation_text]), zoom_button, HBox([zoom_x1, zoom_x2, zoom_y1, zoom_y2]), equalize_hist_button, save_filename, save_button])
VBox([interactive_plot, stats_button, display_scalings_button, save_filename, save_button])


INFO: 
                Inconsistent SIP distortion information is present in the FITS header and the WCS object:
                SIP coefficients were detected, but CTYPE is missing a "-SIP" suffix.
                astropy.wcs is using the SIP distortion coefficients,
                therefore the coordinates calculated here might be incorrect.

                If you do not want to apply the SIP distortion coefficients,
                please remove the SIP coefficients from the FITS header or the
                WCS object.  As an example, if the image is already distortion-corrected
                (e.g., drizzled) then distortion components should not apply and the SIP
                coefficients should be removed.

                While the SIP distortion coefficients are being applied here, if that was indeed the intent,
                for consistency please append "-SIP" to the CTYPE in the FITS header or the WCS object.

                 [astropy.wcs.wcs]




VBox(children=(interactive(children=(FloatSlider(value=-0.6049584746360779, description='vmin', layout=Layout(…

Mean: 0.13744055
Median: 0.08879564
Standard Deviation: 0.3868615
Top three suggested scaling methods:
1. Scaling Method: sqrt, Suggested vmin: -1.655286095151678, Suggested vmax: 18.222461891174316
2. Scaling Method: linear, Suggested vmin: -0.6049584746360779, Suggested vmax: 274.4317932128906
3. Scaling Method: log, Suggested vmin: -15.370337295532227, Suggested vmax: 7.522433567047119
Saving plot...
Filename: plot.png
Plot saved successfully.
Mean: 0.13744055


In [7]:
# Automatically runs with this fits file: 

# header, data, wcs = fits_loader.loading_FITS('M51_HST (2).fits')  