In [1]:
import wave
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt

import scipy
from scipy import integrate
import scipy.optimize
import astropy
import astropy.io as io
import astropy.nddata
import astropy.constants as const
import astropy.units as units
import specutils
from tqdm import tqdm

import inspect

import FunctionLib as FL

In [None]:
class SpectralLineFitter:
    """
    A class to fit Gaussian components to spectral lines in a spectrum.

    Attributes
    ----------
    spectrum: Spectrum_1d
        An instance of the Spectrum_1d class containing the observed wavelengths and fluxes.
    line_restframe_wavelengths: astropy.units.Quantity or list of astropy.units.Quantity
        The rest-frame wavelengths of the spectral lines to fit.
    max_components: int
        The maximum number of Gaussian components to fit to each spectral line.
    max_iterations: int
        The maximum number of iterations for the fitting process.
    fit_results: list of dict
        A list to store the fitting results for each spectral line.

    Methods
    -------
    Details of the methods will be attached after the method definitions.
    gaussian(x, amplitude, mean, stddev)

    """

    def __init__(self, spectrum, line_restframe_wavelengths, max_components=8, max_iterations=100000):
        """
        Initializes the SpectralLineFitter with a spectrum and spectral lines to fit.

        Parameters
        ----------
        spectrum: Spectrum_1d
            An instance of the Spectrum_1d class containing the observed wavelengths and fluxes.
        line_restframe_wavelengths: astropy.units.Quantity or list of astropy.units.Quantity
            The rest-frame wavelengths of the spectral lines to fit.
        max_components: int, optional
            The maximum number of Gaussian components to fit to each spectral line (default is 8).
        max_iterations: int, optional
            The maximum number of iterations for the fitting process (default is 100000).
        """

        if not isinstance(spectrum, Spectrum_1d):
            raise TypeError("spectrum must be an instance of Spectrum_1d.")

        self.spectrum = spectrum

        if isinstance(line_restframe_wavelengths, astropy.units.Quantity):
            self.line_restframe_wavelengths = [line_restframe_wavelengths]
        elif isinstance(line_restframe_wavelengths, list) and all(isinstance(w, astropy.units.Quantity) for w in line_restframe_wavelengths):
            self.line_restframe_wavelengths = line_restframe_wavelengths
        else:
            raise TypeError("line_restframe_wavelengths must be an astropy.units.Quantity or a list of astropy.units.Quantity objects.")

        self.max_components = max_components
        self.max_iterations = max_iterations
        self.fit_results = []

    def gaussian(self, x, amplitude, mean, stddev):
        """
        Gaussian function for fitting.

        Parameters
        ----------
        x : array-like
            The independent variable (wavelengths).
        amplitude : float
            The height of the Gaussian peak.
        mean : float
            The position of the center of the Gaussian.
        stddev : float
            The standard deviation (width) of the Gaussian.

        Returns
        -------
        array-like
            The values of the Gaussian function at x.
        """
        if isinstance(x, astropy.nddata.NDDataArray):
            x = x.data
            return astropy.nddata.NDDataArray(
                data=amplitude * np.exp(-0.5 * ((x - mean) / stddev) ** 2),
                unit=x.unit
            )
        elif isinstance(x, astropy.units.Quantity):
            return amplitude * np.exp(-0.5 * ((x.value - mean) / stddev) ** 2) * x.unit
        else:
            return amplitude * np.exp(-0.5 * ((x - mean) / stddev) ** 2)

    def fit_single_gaussian(self):
        """
        Fits a single Gaussian to the spectrum data, the initial guess will be generated based on the observed fluxes.

        Returns
        -------
        dict
            A dictionary containing the fit results, including the fitted parameters and the covariance matrix.
        """

        try:
            # Extract observed wavelengths and fluxes
            obs_wavelengths = self.spectrum.observed_wavelengths.convert_unit_to(astropy.units.AA).data
            obs_flux_lambda = self.spectrum.processing_flux.data

            # Initial guess for the Gaussian parameters
            mean_guess = np.argmax(obs_flux_lambda)  # Position of the peak
            amplitude_guess = obs_flux_lambda[mean_guess]  # Height of the peak
            stddev_guess = 10 * astropy.units.AA  # Initial guess for the width, can be adjusted

            initial_guess = [amplitude_guess, mean_guess, stddev_guess.value]

            #Fit the Gaussian using scipy.optimize.curve_fit

            popt, pcov = scipy.optimize.curve_fit(self.gaussian,
                                                    obs_wavelengths,
                                                    obs_flux_lambda,
                                                    p0=initial_guess,
                                                    maxfev=self.max_iterations)

            y_fit = self.gaussian(obs_wavelengths, *popt)

            integrated_flux, integration_error = scipy.integrate.quad(self.gaussian, obs_wavelengths.min(), obs_wavelengths.max(), args=tuple(popt))* self.spectrum.processing_flux.unit* self.spectrum.observed_wavelengths.unit

            return {
                    'success': True,
                    'parameters': {
                        'amplitude': popt[0] * self.spectrum.processing_flux.unit,
                        'mean': popt[1] * self.spectrum.observed_wavelengths.unit,
                        'stddev': popt[2] * astropy.units.AA,
                    },
                    'fitted_curve': astropy.nddata.NDDataArray(
                        data=y_fit,
                        unit=self.spectrum.processing_flux.unit
                    ),
                    'covariance': pcov,
                    'integrated_flux': integrated_flux,
                    'integration_error': integration_error
                }
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }

    def check_line(self, line_restframe_wavelength, mean_fit, tolerance=10* astropy.units.AA):
        """
        Checks if the line_restframe_wavelength is within the tolerance of the mean_fit.

        Parameters
        ----------
        line_restframe_wavelength : astropy.units.Quantity
            The rest-frame wavelength of the spectral line.
        mean_fit : astropy.units.Quantity
            The mean wavelength from the Gaussian fit.
        tolerance : astropy.units.Quantity, optional
            The tolerance range for checking (default is 10 * astropy.units.AA).

        Returns
        -------
        bool
            True if the line is within the tolerance, False otherwise.
        """
        if not isinstance(line_restframe_wavelength, astropy.units.Quantity):
            raise TypeError("line_restframe_wavelength must be an astropy.units.Quantity object.")
        if not isinstance(mean_fit, astropy.units.Quantity):
            raise TypeError("mean_fit must be an astropy.units.Quantity object.")
        if not isinstance(tolerance, astropy.units.Quantity):
            raise TypeError("tolerance must be an astropy.units.Quantity object.")

        return abs(line_restframe_wavelength - mean_fit) <= tolerance

    def plot_fit_result(self, fit_result, component_index=0, is_residual=False):
        """
        Plots the fit result with Gaussian components fit overlay.

        Parameters
        ----------
        fit_result : dict
            The fit result dictionary containing the fitted parameters and the fitted curve, should be generated by fit_single_gaussian.
        component_index : int, optional
            The index of the component to plot (default is 0).
        is_residual : bool, optional
            If True, plot the residuals instead of the fitted curve (default is False).
        """
        if not isinstance(fit_result, dict):
            raise TypeError("fit_result must be a dictionary containing the fit results.")

        if 'fitted_curve' not in fit_result or 'parameters' not in fit_result:
            raise ValueError("fit_result must contain 'fitted_curve' and 'parameters' keys.")

        obs_wavelengths = self.spectrum.observed_wavelengths.convert_unit_to(astropy.units.AA).data
        obs_flux_lambda = self.spectrum.processing_flux.data

        fig,ax= plt.subplots(figsize=(20, 10))

        if is_residual:
            ax.plot(obs_wavelengths, obs_flux_lambda, label=f"Residual Spectrum {component_index}", color='blue', alpha=0.5)
            title=f'Component {component_index} - Gaussian Fit Residuals'

        else:
            ax.plot(obs_wavelengths, obs_flux_lambda, label=f"Observed Spectrum {component_index}", color='blue', alpha=0.5)
            title=f'Component {component_index} - Gaussian Fit Result'

        if fit_result['success']:
            ax.plot(obs_wavelengths, fit_result['fitted_curve'].data, label=f"Fitted Curve {component_index}", color='red', alpha=0.7)

        ax.set_xlabel("Wavelength (Angstrom)", fontsize=14)
        ax.set_ylabel("Flux (erg/cm^2/s/Angstrom)", fontsize=14)
        ax.set_title(title, fontsize=16)
        ax.legend()
        ax.set_xlim(obs_wavelengths.min(), obs_wavelengths.max())

        if not is_residual:
            flux_margin= 0.1 * np.nanmax(obs_flux_lambda)
            ax.set_ylim(np.nanmin(obs_flux_lambda) - flux_margin, np.nanmax(obs_flux_lambda) + flux_margin)

        plt.grid()
        plt.show()

    def print_fit_summary(self, fit_result, component_index=0):
        """
        Prints a summary of the fit results for a specific component.

        Parameters
        ----------
        fit_result : dict
            The fit result dictionary containing the fitted parameters and the covariance matrix, should be generated by fit_single_gaussian.
        component_index : int, optional
            The index of the component to print the summary for (default is 0).
        """
        if not isinstance(fit_result, dict):
            raise TypeError("fit_result must be a dictionary containing the fit results.")

        if 'parameters' not in fit_result or 'covariance' not in fit_result:
            raise ValueError("fit_result must contain 'parameters' and 'covariance' keys.")

        if fit_result['success']:
            params = fit_result['parameters']
            print(f"\n{'='*60}")
            print(f"Component {component_index} Fit Parameters:")
            print(f"Amplitude: {params['amplitude']:.3e}")
            print(f"Mean (Rest-frame Wavelength): {params['mean']:.3f} {self.spectrum.restframe_wavelengths.unit}")
            print(f"Standard Deviation (Width): {params['stddev']:.3f} {self.spectrum.restframe_wavelengths.unit}")
            print(f"Integrated Flux: {fit_result['integrated_flux']:.3e} {self.spectrum.processing_flux.unit* self.spectrum.restframe_wavelengths.unit}")
            print(f"Integration Error: {fit_result['integration_error']:.3e} {self.spectrum.processing_flux.unit* self.spectrum.restframe_wavelengths.unit}")
            print(f"Covariance Matrix:\n{fit_result['covariance']}")
            print(f"{'='*60}\n")

        else:
            print(f"\n{'='*60}")
            print(f"Component {component_index} Fit Failed: {fit_result['error']}")
            print(f"{'='*60}\n")

    def iterative_gaussian_fitting(self, line_restframe_wavelength, tolerance=10 * astropy.units.AA, plot_results=True):
        """
        Iteratively fits Gaussian components to a spectral line until the fit is successful or the maximum number of components is reached.

        Parameters
        ----------
        line_restframe_wavelength : astropy.units.Quantity
            The rest-frame wavelength of the spectral line to fit.
        tolerance : astropy.units.Quantity, optional
            The tolerance range for checking the fit (default is 10 * astropy.units.AA).
        plot_results : bool, optional
            If True, plots the fit results (default is True).

        Returns
        -------
        tuple
            A tuple containing:
            - fit_result: list of dict
                The fit results for each component, including the fitted parameters and the covariance matrix.
            - line_integrated_flux: astropy.units.Quantity
                The integrated flux of the spectral line.
        """

        self.fit_results = []  # Reset fit results for each line
        line_integrated_flux = 0 * self.spectrum.processing_flux.unit * self.spectrum.restframe_wavelengths.unit

        for component_index in range(self.max_components):

            if not isinstance(line_restframe_wavelength, astropy.units.Quantity):
                raise TypeError("line_restframe_wavelength must be an astropy.units.Quantity object.")

            print(f"\n{'='*60}")
            print(f"Fitting Component {component_index + 1} for Line at {line_restframe_wavelength:.3f} {self.spectrum.restframe_wavelengths.unit}")
            print(f"{'='*60}")

            fit_result= self.fit_single_gaussian()

            if not fit_result['success']:
                print(f"Component {component_index + 1} fit failed: {fit_result['error']}")
                break

            fit_result['component_index'] = component_index
            fit_result['line_restframe_wavelength'] = line_restframe_wavelength
            fit_result['is_within_tolerance'] = self.check_line(
                line_restframe_wavelength,
                fit_result['parameters']['mean'],
                tolerance=tolerance
            )

            self.fit_results.append(fit_result)

            self.print_fit_summary(fit_result, component_index)

            if plot_results:
                is_residual=component_index > 0
                self.plot_fit_result(fit_result, component_index, is_residual=is_residual)

            self.spectrum.processing_flux = astropy.nddata.NDDataArray(
                data=self.spectrum.processing_flux.data - fit_result['fitted_curve'].data,
                unit=self.spectrum.processing_flux.unit
            )

            if plot_results and component_index< self.max_components - 1:
                plt.figure(figsize=(20, 10))
                plt.plot(self.spectrum.observed_wavelengths.data, self.spectrum.processing_flux.data, label=f"Residual Spectrum after Component {component_index + 1}", color='blue', alpha=0.5)
                plt.xlabel("Wavelength (Angstrom)", fontsize=14)
                plt.ylabel("Flux (erg/cm^2/s/Angstrom)", fontsize=14)
                plt.title(f"Residual Spectrum after Component {component_index + 1} Fitting", fontsize=16)
                plt.legend()
                plt.xlim(self.spectrum.observed_wavelengths.data.min(), self.spectrum.observed_wavelengths.data.max())
                flux_margin = 0.1 * np.nanmax(self.spectrum.processing_flux.data)
                plt.ylim(np.nanmin(self.spectrum.processing_flux.data) - flux_margin, np.nanmax(self.spectrum.processing_flux.data) + flux_margin)
                plt.grid()
                plt.show()

        if len(self.fit_results) ==0:
            print("No successful fits were made.")
            return [], line_integrated_flux

    def get_line_components(self):
        """
        Returns the fitted components from the fit results.

        Returns
        -------
        Dict or None
            A dictionary containing the fitted components, or None if no fits were made.
        """

        for fit_result in self.fit_results:
            if fit_result.get('is_within_tolerance', True):
                return {
                    'component_index': fit_result['component_index'],
                    'line_restframe_wavelength': fit_result['line_restframe_wavelength'],
                    'parameters': fit_result['parameters'],
                    'integrated_flux': fit_result['integrated_flux'],
                    'integration_error': fit_result['integration_error']
                }
        return None

    def plot_final_decomposition(self, line_restframe_wavelength):
        """
        Plots the final decomposition of the spectral line with all fitted components.

        Parameters
        ----------
        line_restframe_wavelength : astropy.units.Quantity
            The rest-frame wavelength of the spectral line to plot.
        tolerance : astropy.units.Quantity, optional
            The tolerance range for checking the fit (default is 10 * astropy.units.AA).
        """
        plt.figure(figsize=(20, 10))
        plt.plot(self.spectrum.restframe_wavelengths.data, self.spectrum.observed_flux_lambda.data, label="Observed Spectrum", color='blue', alpha=0.5)

        colors = ['red', 'blue', 'green', 'orange', 'purple']

        for i, fit_result in enumerate(self.fit_results):
            if fit_result['success']:

                params= fit_result['parameters']

                component_flux= self.gaussian(
                    self.spectrum.restframe_wavelengths.data,** params
                )

                label= f"Component {fit_result['component_index']}"

                if fit_result['is_within_tolerance']:
                    label+=f'(Line at {line_restframe_wavelength:.3f} {self.spectrum.restframe_wavelengths.unit})'

                plt.plot(
                    self.spectrum.restframe_wavelengths.data,
                    component_flux.data,
                    label=label,
                    color=colors[i % len(colors)],
                    alpha=0.7
                )

        plt.xlabel("Wavelength (Angstrom)", fontsize=14)
        plt.ylabel("Flux (erg/cm^2/s/Angstrom)", fontsize=14)
        plt.title(f"Final Decomposition for Line at {line_restframe_wavelength:.3f} {self.spectrum.restframe_wavelengths.unit}", fontsize=16)
        plt.legend()
        plt.xlim(self.spectrum.restframe_wavelengths.data.min(), self.spectrum.restframe_wavelengths.data.max())
        flux_margin = 0.1 * np.nanmax(self.spectrum.observed_flux_lambda.data)
        plt.ylim(np.nanmin(self.spectrum.observed_flux_lambda.data) - flux_margin, np.nanmax(self.spectrum.observed_flux_lambda.data) + flux_margin)
        plt.grid()
        plt.show()




In [None]:

fitter.iterative_gaussian_fitting(4863.0 * units.Angstrom, tolerance=10 * units.AA, plot_results=True)

In [None]:
fitter.fit_results