In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy import units as u
from astropy.constants import c
from specutils import Spectrum1D, SpectralRegion
from specutils.analysis import line_flux, equivalent_width
from specutils.manipulation import extract_region
from specutils.fitting import fit_lines
from astropy.modeling import models
import warnings
warnings.filterwarnings('ignore')  # Suppress warnings for cleaner output

# Define rest-frame wavelengths for Hα and Hβ in Angstroms
HALPHA_REST = 6562.8 * u.AA
HBETA_REST = 4861.3 * u.AA

def load_jwst_spectrum(fits_file):
    """Load spectrum from JWST FITS file"""
    print(f"Loading spectrum from {fits_file}...")
    with fits.open(fits_file) as hdul:
        # Print basic info about the file structure
        print(f"FITS file contains {len(hdul)} extensions")
        for i, hdu in enumerate(hdul):
            print(f"Extension {i}: {hdu.name} - {hdu.__class__.__name__}")

        # For JWST NIRSpec prism data, usually the spectrum is in the first extension
        # This is a simplified approach - actual extraction may vary by data product
        try:
            data = hdul[1].data  # Extract from SCI extension
            header = hdul[1].header

            # Extract wavelength and flux - actual column names depend on data product
            # This approach works for many JWST data products but might need adjustment
            wavelength = data['WAVELENGTH'] * u.um  # Most JWST data uses microns
            flux = data['FLUX'] * u.Jy  # Often in Jansky

            # Convert flux to more standard units if needed
            flux = flux.to(u.erg / u.s / u.cm**2 / u.AA)

            # Convert wavelength to Angstroms for easier comparison with rest-frame values
            wavelength = wavelength.to(u.AA)

            return Spectrum1D(spectral_axis=wavelength, flux=flux)

        except (KeyError, IndexError) as e:
            # If the expected structure isn't found, try a more generic approach
            print(f"Error with standard extraction: {e}")
            print("Attempting alternative extraction method...")

            # This is a simplified fallback - may need customization based on your data
            data = hdul[1].data
            header = hdul[1].header

            # Try to get wavelength solution from WCS or header info
            crval1 = header.get('CRVAL1', 0)
            cdelt1 = header.get('CDELT1', 1)
            naxis1 = header.get('NAXIS1', len(data))

            wavelength = (crval1 + cdelt1 * np.arange(naxis1)) * u.um
            wavelength = wavelength.to(u.AA)

            if isinstance(data, np.ndarray) and data.ndim == 1:
                flux = data * u.Jy
            else:
                # Try to extract the first row/column of data
                flux = data[0] * u.Jy

            flux = flux.to(u.erg / u.s / u.cm**2 / u.AA)

            return Spectrum1D(spectral_axis=wavelength, flux=flux)

def apply_redshift_correction(spectrum, redshift):
    """Correct the observed spectrum for redshift to get rest-frame wavelengths"""
    print(f"Applying redshift correction z={redshift}...")

    # Create a new spectrum with rest-frame wavelengths
    rest_wavelength = spectrum.spectral_axis / (1 + redshift)

    return Spectrum1D(spectral_axis=rest_wavelength, flux=spectrum.flux)

def extract_emission_lines(spectrum, line_center, line_width=50*u.AA, plot=True):
    """Extract and analyze an emission line from the spectrum"""
    # Define region around the emission line
    line_region = SpectralRegion(line_center - line_width, line_center + line_width)

    # Extract the region containing the line
    line_spectrum = extract_region(spectrum, line_region)

    if line_spectrum.flux.size == 0:
        print(f"Warning: No data found in the region around {line_center}")
        return None, None, None

    # Measure line properties
    try:
        # Calculate line flux
        flux_value = line_flux(line_spectrum)

        # Calculate equivalent width
        ew_value = equivalent_width(line_spectrum)

        # Fit a Gaussian model to the line
        g_init = models.Gaussian1D(
            amplitude=np.max(line_spectrum.flux.value) * u.erg / u.s / u.cm**2 / u.AA,
            mean=line_center.value,
            stddev=10.0
        )

        g_fit = fit_lines(line_spectrum, g_init)

        # Generate the fitted model
        x = line_spectrum.spectral_axis.value
        y_fit = g_fit(x)

        # Plot the result if requested
        if plot:
            plt.figure(figsize=(10, 6))
            plt.plot(line_spectrum.spectral_axis, line_spectrum.flux, 'b-', label='Observed')
            plt.plot(x, y_fit, 'r-', label='Gaussian fit')
            plt.axvline(x=line_center.value, color='g', linestyle='--', label=f'Rest frame: {line_center}')
            plt.xlabel('Wavelength (Å)')
            plt.ylabel('Flux (erg/s/cm²/Å)')
            if line_center.value > 6000:
                plt.title(f'Hα emission line')
            else:
                plt.title(f'Hβ emission line')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.tight_layout()
            plt.show()

        return flux_value, ew_value, g_fit

    except Exception as e:
        print(f"Error analyzing line at {line_center}: {e}")
        return None, None, None

def analyze_jwst_spectrum(fits_file, redshift):
    """Complete analysis of a JWST spectrum to extract Hα and Hβ lines"""
    # Load the spectrum
    observed_spectrum = load_jwst_spectrum(fits_file)

    if observed_spectrum is None:
        print("Failed to load spectrum. Check the FITS file format.")
        return

    # Plot the full observed spectrum
    plt.figure(figsize=(12, 6))
    plt.plot(observed_spectrum.spectral_axis, observed_spectrum.flux)
    plt.xlabel('Wavelength (Å)')
    plt.ylabel('Flux (erg/s/cm²/Å)')
    plt.title('JWST Observed Spectrum')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    # Apply redshift correction
    rest_spectrum = apply_redshift_correction(observed_spectrum, redshift)

    # Plot the rest-frame spectrum
    plt.figure(figsize=(12, 6))
    plt.plot(rest_spectrum.spectral_axis, rest_spectrum.flux)
    plt.xlabel('Rest Wavelength (Å)')
    plt.ylabel('Flux (erg/s/cm²/Å)')
    plt.title(f'JWST Rest-Frame Spectrum (z={redshift})')
    plt.axvline(x=HALPHA_REST.value, color='r', linestyle='--', label='Hα')
    plt.axvline(x=HBETA_REST.value, color='g', linestyle='--', label='Hβ')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # Extract and analyze Hα line
    print("\nAnalyzing Hα emission line...")
    halpha_flux, halpha_ew, halpha_fit = extract_emission_lines(rest_spectrum, HALPHA_REST)

    # Extract and analyze Hβ line
    print("\nAnalyzing Hβ emission line...")
    hbeta_flux, hbeta_ew, hbeta_fit = extract_emission_lines(rest_spectrum, HBETA_REST)

    # Calculate Balmer decrement (Hα/Hβ ratio)
    if halpha_flux is not None and hbeta_flux is not None:
        balmer_decrement = halpha_flux / hbeta_flux
        print(f"\nBalmer Decrement (Hα/Hβ): {balmer_decrement.value:.2f}")

        # Estimate extinction using Balmer decrement
        # Assuming Case B recombination with T=10,000K, ne=100 cm^-3
        intrinsic_ratio = 2.86
        k_halpha = 2.63  # extinction coefficient at Hα
        k_hbeta = 3.71   # extinction coefficient at Hβ

        E_BV = 2.5 * np.log10(balmer_decrement.value / intrinsic_ratio) / (k_beta - k_alpha)
        print(f"Estimated E(B-V): {E_BV:.3f} mag")

    # Summary of results
    print("\n=== RESULTS SUMMARY ===")
    if halpha_flux is not None:
        print(f"Hα Flux: {halpha_flux.value:.3e} {halpha_flux.unit}")
        print(f"Hα Equivalent Width: {halpha_ew.value:.2f} {halpha_ew.unit}")
        if halpha_fit is not None:
            print(f"Hα Line Center: {halpha_fit.mean.value:.2f} Å")
            print(f"Hα Line FWHM: {2.355 * halpha_fit.stddev.value:.2f} Å")

    if hbeta_flux is not None:
        print(f"Hβ Flux: {hbeta_flux.value:.3e} {hbeta_flux.unit}")
        print(f"Hβ Equivalent Width: {hbeta_ew.value:.2f} {hbeta_ew.unit}")
        if hbeta_fit is not None:
            print(f"Hβ Line Center: {hbeta_fit.mean.value:.2f} Å")
            print(f"Hβ Line FWHM: {2.355 * hbeta_fit.stddev.value:.2f} Å")

# Example usage
if __name__ == "__main__":
    # Replace with your actual file path and redshift value
    fits_file = "your_jwst_spectrum.fits"
    redshift = 0.0  # Replace with your actual redshift value

    # Run the analysis
    analyze_jwst_spectrum(fits_file, redshift)

# To use this script:
# 1. Replace 'your_jwst_spectrum.fits' with your actual file path
# 2. Set the redshift variable to your known redshift value
# 3. Run the script