In [52]:
# =========================================
# Script: advanced_spectral_analysis.py
# Purpose: SMACS data download, processing, emission line analysis, and plotting.
# Author: Joseph Havens
# Research Supervisor: Dr. Bren Backhaus
# Date: 08-06-2025
# =========================================

import os
import shutil
import logging
from datetime import datetime
from collections import defaultdict
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.signal import savgol_filter, find_peaks, medfilt
from scipy.optimize import curve_fit

# Astropy imports for more robust astronomical data handling
import astropy.units as u
import astropy.constants as const
from astropy.io import fits
from astropy.table import Table
from astropy.utils.data import download_file
from astropy.stats import sigma_clip
from astropy.wcs import FITSFixedWarning
import warnings
from astropy.utils.exceptions import AstropyDeprecationWarning
from astropy.modeling.fitting import FittingWithOutlierRemoval
from scipy.optimize import OptimizeWarning

#!pip install --upgrade specutils astropy
from specutils import Spectrum1D, SpectralRegion
from specutils.manipulation import extract_region
from specutils.fitting import fit_generic_continuum, fit_lines
from specutils.analysis import template_correlate
from astropy.modeling.models import Gaussian1D, Chebyshev1D
from astropy.nddata import StdDevUncertainty

warnings.simplefilter('ignore', category=AstropyDeprecationWarning)
warnings.simplefilter('ignore', category=UserWarning) # Catches the "Model is linear" warning
warnings.simplefilter('ignore', category=RuntimeWarning) # Catches the "overflow" warning
warnings.simplefilter('ignore', category=OptimizeWarning) # Catches "The fit may be unsuccessful"



In [53]:
# Suppress a common but benign FITS warning
warnings.filterwarnings('ignore', category=FITSFixedWarning)

# --- Configuration Section ---

# 1. Define paths
# Use os.path.join for cross-platform compatibility.
USER_HOME = "/content/drive/MyDrive/Bren_code/My_work/"
BASE_PATH = os.path.join(USER_HOME, "SMACS_Analysis/")
SAVE_PATH = os.path.join(BASE_PATH, "SMACS/")
TARGET_LIST_PATH = os.path.join(BASE_PATH, "csv.csv") # Assumes the CSV is in the base path
OUTPUT_CATALOG_PATH = os.path.join(BASE_PATH, "smacs_line_flux_catalog.csv")
PLOT_PATH = os.path.join(BASE_PATH, "plots/")

# 2. Data source
DATA_URL = 'https://s3.amazonaws.com/msaexp-nirspec/extractions/smacs0723-ero-v3/'

# 3. Analysis Parameters
PEAK_FINDING_HEIGHT = 3.0  # In units of local noise standard deviation
PEAK_PROMINENCE = 1.5      # How much a peak must 'stick out' from its surroundings, in units of local noise.
PEAK_FINDING_DISTANCE = 10 # Minimum separation between peaks in pixels
GAUSSIAN_FIT_WINDOW = 20   # Window size around a peak for fitting, in pixels
WAVELENGTH_MATCH_TOLERANCE = 10 * u.AA # Tolerance for matching fitted peaks to known lines
LINE_INTEGRATION_WIDTH_A = 25.0  # The width in Angstroms to integrate over for each line.
SN_CUTOFF = 5.0                  # The signal-to-noise ratio required to count a line as "detected".
flag_GIDS = []

# --- Emission Line Catalog ---
EMISSION_LINES = {
    # Name: Rest Wavelength (Angstroms)

    # --- Hydrogen: Balmer Series (Visible) ---
    r'H-alpha': 6562.819 * u.AA,
    r'H-beta': 4861.333 * u.AA,
    r'H-gamma': 4340.47 * u.AA,
    r'H-delta': 4101.74 * u.AA,
    r'H-epsilon': 3970.07 * u.AA,

    # --- Hydrogen: Paschen Series (NIR) ---
    r'Pa-alpha': 18750.976 * u.AA,
    r'Pa-beta': 12818.1 * u.AA,
    r'Pa-gamma': 10938.1 * u.AA,
    r'Pa-delta': 10049.0 * u.AA,
    r'Pa-epsilon': 9545.98 * u.AA,
    r'Pa-zeta': 9229.02 * u.AA,
    r'Pa-eta': 9014.91 * u.AA,

    # --- Hydrogen: Brackett Series (NIR) ---
    r'Br-beta': 26251.0 * u.AA,
    r'Br-gamma': 21655.302 * u.AA,
    r'Br-delta': 19445.582 * u.AA,
    r'Br-epsilon': 18174.141 * u.AA,

    # --- Hydrogen: Pfund Series (NIR) ---
    r'Pf-delta': 32960.0 * u.AA,
    r'Pf-epsilon': 30383.731 * u.AA,

    # --- Other notable Hydrogen lines ---
    r'H (15700 Å)': 15700.0 * u.AA,
    r'H (26119 Å)': 26119.351 * u.AA,
    r'H (28722 Å)': 28722.0 * u.AA,

    # --- Helium Lines ---
    r'He I 5875': 5875.0 * u.AA,
    r'He I 10830': 10830.0 * u.AA,
    r'He I 18697': 18697.216 * u.AA,
    r'He II 4686': 4685.68 * u.AA,

    # --- Common Forbidden Lines ("Nebular" Lines) ---
    r'[O I] 6300': 6300 * u.AA,
    r'[O II] 3727': 3727.3 * u.AA,    # Doublet avg.
    r'[O III] 4959': 4958.91 * u.AA,
    r'[O III] 5007': 5006.84 * u.AA,
    r'[N II] 6583': 6583.4 * u.AA,    # Often blended with H-alpha
    r'[S II] 6716': 6716.4 * u.AA,
    r'[S II] 6731': 6730.8 * u.AA,

    # --- High-Ionization / AGN Lines ---
    # These lines indicate a very hard radiation field, often from an AGN.
    # Some are UV lines, only visible in NIR spectra for high-redshift objects.
    r'[C IV] 1549': 1549.0 * u.AA,           # UV, z > ~4 for NIRSpec
    r'[C III] 1909': 1908.7 * u.AA,         # UV, z > ~3 for NIRSpec
    r'[Mg II] 2798': 2798.0 * u.AA,          # UV, z > ~2 for NIRSpec
    r'[Ne V] 3346': 3345.8 * u.AA,
    r'[Ne V] 3426': 3425.9 * u.AA,
    r'[Ne III] 3869': 3868.8 * u.AA,
    r'[Ar IV] 4740': 4740.2 * u.AA,
    r'[Fe VII] 6087': 6087.0 * u.AA,
}

# Define a set of line names that indicate high-ionization or AGN activity
FLAG_LINES = {
    r'[C IV] 1549',
    r'[C III] 1909',
    r'[Mg II] 2798',
    r'[Ne V] 3346',
    r'[Ne V] 3426',
    r'[Ne III] 3869',
    r'He II 4686',
    r'[Ar IV] 4740',
    r'[Fe VII] 6087',
}

bad_lines = [
    1679,
    2653,
    4798,
    7570,
    7677,
    8277,
    8498,
    8717,
    8883,
    102091,
    102993,
    102744,
    102730,
    102711,
    102423,
    101449,
    10444,
    10380,
    8730
]

In [54]:
# --- Logger Setup ---
# This block replaces the need for a global 'debug' variable.

# 1. Create a filename for the detailed log file
# Example: 'log_2025-06-08_19-38.log'
log_filename = f"log_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.log"
# In Colab/Drive, you might want to specify the full path:
log_filepath = "/content/drive/MyDrive/Bren_code/My_work/SMACS_Analysis/logs/" + log_filename

# 2. Get the root logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # Set the lowest level to capture ALL messages

# 3. Create a handler to write to the CONSOLE (for high-level info)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) # Only shows INFO, WARNING, ERROR, CRITICAL
console_formatter = logging.Formatter('%(message)s') # Keep console output clean
console_handler.setFormatter(console_formatter)

# 4. Create a handler to write to the FILE (for all the details)
file_handler = logging.FileHandler(log_filepath)
file_handler.setLevel(logging.DEBUG) # Captures EVERY level, including DEBUG
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(file_formatter)

# 5. Add the handlers to the logger
# Important: Clear existing handlers to prevent duplicate logs in notebooks
if (logger.hasHandlers()):
    logger.handlers.clear()
logger.addHandler(console_handler)
logger.addHandler(file_handler)

# --- End Logger Setup ---
# Now, you can use logger.info(), logger.debug(), etc. throughout your code.
logger.info(f"Logger initialized. Detailed debug output will be saved to: {log_filepath}")

Logger initialized. Detailed debug output will be saved to: /content/drive/MyDrive/Bren_code/My_work/SMACS_Analysis/logs/log_2025-08-29_23-52.log


In [55]:
# === Helper Functions (Final Stable Version) ===

def download_data(data_url, target_list_path, save_path):
    """
    Downloads FITS files using the logging system for output.
    """
    logger.info("====================" * 6)
    logger.info("--- Beginning data download ---")
    os.makedirs(save_path, exist_ok=True)
    try:
        target_list = pd.read_csv(target_list_path)
    except FileNotFoundError:
        logger.critical(f"Target list CSV not found at '{target_list_path}'. Halting download.")
        return
    for i, j in target_list.iterrows():
        target = j['file']
        file_url = data_url + target
        file_path = os.path.join(save_path, target)
        try:
            r = requests.get(file_url, allow_redirects=True, timeout=120)
            r.raise_for_status()
            with open(file_path, 'wb') as f: f.write(r.content)
            with fits.open(file_path) as hdul: pass
        except Exception as e:
            logger.error(f"An error occurred for target '{target}': {e}")
    logger.info("--- Download process complete ---")


def combine_gratings_specutils(fits_folder):
    """
    Groups, processes, and combines FITS files into specutils.Spectrum1D objects.
    """
    logger.info("====================" * 6)
    logger.info("--- Beginning grating combination (specutils) ---")
    try:
        fits_files = [f for f in os.listdir(fits_folder) if f.endswith('.spec.fits')]
    except FileNotFoundError:
        logger.critical(f"The directory '{fits_folder}' was not found. Halting.")
        return {}
    grouped_spectra = defaultdict(list)
    for fits_file in fits_files:
        galaxy_id = fits_file.split('_')[-1].replace('.spec.fits', '')
        if ('-' not in galaxy_id) and (int(galaxy_id) not in bad_lines):
            grouped_spectra[galaxy_id].append(fits_file)
    logger.info(f"Total unique galaxy IDs found: {len(grouped_spectra)}")
    combined_spectra = {}
    for galaxy_id, files in sorted(grouped_spectra.items()):
        file_paths = [os.path.join(fits_folder, f) for f in files]
        all_wave, all_flux, all_err = [], [], []
        for f_path in file_paths:
            try:
                with fits.open(f_path) as hdul:
                    data = hdul[1].data
                    all_wave.append(data['wave'])
                    all_flux.append(data['flux'])
                    all_err.append(data['err'])
            except Exception as e:
                logger.error(f"GID {galaxy_id}: Failed to process file '{os.path.basename(f_path)}'. Reason: {e}")
                continue
        if not all_wave: continue
        wave, flux, err = np.concatenate(all_wave), np.concatenate(all_flux), np.concatenate(all_err)
        sort_idx = np.argsort(wave)
        spectrum_object = Spectrum1D(
            flux=flux[sort_idx] * u.uJy,
            spectral_axis=wave[sort_idx] * u.um,
            uncertainty=StdDevUncertainty(err[sort_idx] * u.uJy))
        combined_spectra[galaxy_id] = {'spectrum': spectrum_object, 'source_files': files}
    logger.info("--- Combination Complete ---")
    logger.info(f"Successfully processed and stored data for {len(combined_spectra)} galaxies.")
    return combined_spectra


def create_specialized_template(template_type, flux_unit):
    """
    Creates a specialized emission line template (UV, Optical, or IR).
    """
    lines = []
    if template_type == 'optical':
        # Best for z ~ 1-4, where optical lines are in NIRSpec range
        wave_range = np.arange(3700, 6800, 1.0) * u.AA
        lines = [
            (r'[O II]',  3727.0, 2.0), (r'H-$\beta$', 4861.3, 1.0),
            (r'[O III]', 4958.9, 1.5), (r'[O III]', 5006.8, 4.5),
            (r'[N II]',  6548.1, 0.5), (r'H-$\alpha$', 6562.8, 3.5),
            (r'[N II]',  6583.4, 1.5), (r'[S II]',  6716.4, 0.6),
            (r'[S II]',  6730.8, 0.5)]
    elif template_type == 'ir':
        # Best for z < 1, where Paschen/Brackett lines are prominent
        wave_range = np.arange(8500, 19000, 10.0) * u.AA
        lines = [
            (r'Pa-$\delta$', 10049.0, 0.8), (r'Pa-$\gamma$', 10938.1, 1.0),
            (r'Pa-$\beta$',  12818.1, 1.5), (r'Pa-$\alpha$',  18750.9, 2.0),
            (r'He I', 10830.0, 1.2)]
    elif template_type == 'uv':
        # Best for z > 4, where UV lines are redshifted into NIRSpec range
        wave_range = np.arange(1200, 3000, 1.0) * u.AA
        lines = [
            (r'Ly-$\alpha$', 1215.7, 5.0), (r'C IV', 1549.0, 2.0),
            (r'C III]', 1908.7, 2.5), (r'Mg II', 2798.0, 1.5)]

    flux_axis = np.zeros(len(wave_range)) * flux_unit
    for name, wave, strength in lines:
        line_model = Gaussian1D(amplitude=strength*flux_unit, mean=wave*u.AA, stddev=2.0*u.AA)
        flux_axis += line_model(wave_range)

    return Spectrum1D(flux=flux_axis, spectral_axis=wave_range)


def correct_redshift_xcorr(spectrum, initial_z, galaxy_id):
    """
    Uses a manual numpy cross-correlation. This version manually resamples the
    spectrum onto the template grid to ensure compatibility and avoid all errors.
    """
    best_z = initial_z
    max_corr_peak = -np.inf
    best_template_type = 'none'

    for template_type in ['optical', 'ir', 'uv']:
        try:
            # 1. Create the appropriate template for this iteration.
            template = create_specialized_template(template_type, spectrum.flux.unit)
            template_wave_val = template.spectral_axis.value
            template_flux_val = template.flux.value

            # 2. Manually shift the observed spectrum to the initial rest frame.
            rest_wave_val = (spectrum.spectral_axis / (1 + initial_z)).to_value(u.AA)
            rest_flux_val = spectrum.flux.value

            # 3. Manually resample the spectrum's flux onto the template's wavelength grid
            #    using numpy's linear interpolation. This is the key fix.
            resampled_flux = np.interp(template_wave_val, rest_wave_val, rest_flux_val)

            # 4. Check that the resampled spectrum has a valid range.
            if np.all(resampled_flux == 0.0):
                continue

            # 5. Perform the cross-correlation using numpy on the raw arrays.
            correlation = np.correlate(resampled_flux, template_flux_val, mode='same')
            current_peak = np.max(correlation)

            # 6. If this template provides a better match, store its result.
            if current_peak > max_corr_peak:
                max_corr_peak = current_peak
                best_template_type = template_type

                pixel_lags = np.arange(len(correlation)) - len(correlation) // 2
                dw_per_pixel = np.median(np.diff(template_wave_val))
                center_wave = np.median(template_wave_val)
                dv_per_pixel = (const.c.to('km/s').value * dw_per_pixel / center_wave)
                velocity_lags_kms = pixel_lags * dv_per_pixel

                velocity_offset = velocity_lags_kms[np.argmax(correlation)] * u.km / u.s
                delta_z = velocity_offset / const.c
                best_z = (initial_z + delta_z * (1 + initial_z)).value

        except Exception as e:
            logger.debug(f"  Template '{template_type}' failed for GID {galaxy_id}: {e}")
            continue

    if best_template_type != 'none':
        logger.info(f"  Redshift Correction (Best Match: '{best_template_type}' template): Initial z={initial_z:.5f} -> Corrected z={best_z:.5f}")
        return best_z
    else:
        logger.warning(f"  Cross-correlation failed for all templates. Using initial redshift.")
        return initial_z


def fit_continuum_stable(spectrum):
    """
    Fits and subtracts the continuum, robustly handling non-finite values.
    """
    from astropy.modeling.models import Chebyshev1D
    from astropy.modeling.fitting import LevMarLSQFitter, FittingWithOutlierRemoval

    # Create a clean copy of the flux and uncertainty arrays for interpolation
    flux_clean = spectrum.flux.value.copy()
    uncert_clean = spectrum.uncertainty.array.copy()

    # Create a mask of all non-finite (NaN or inf) values
    bad_mask = ~np.isfinite(flux_clean) | ~np.isfinite(uncert_clean)

    # Replace non-finite values with interpolated values from their neighbors
    if np.any(bad_mask):
        good_indices = np.where(~bad_mask)[0]
        # If there are no good points to interpolate from, we can't proceed.
        if good_indices.size < 2:
            raise ValueError("Not enough finite data points to perform interpolation.")
        bad_indices = np.where(bad_mask)[0]
        flux_clean[bad_mask] = np.interp(bad_indices, good_indices, flux_clean[good_indices])
        uncert_clean[bad_mask] = np.interp(bad_indices, good_indices, uncert_clean[good_indices])

    fitter = LevMarLSQFitter()
    # Use a fitter that also removes statistical outliers for more robustness
    robust_fitter = FittingWithOutlierRemoval(fitter, sigma_clip, niter=3, sigma=3.0)

    continuum_model_init = Chebyshev1D(degree=5)

    # Fit the model to the cleaned data
    fitted_model, _ = robust_fitter(continuum_model_init, spectrum.spectral_axis.value, flux_clean)

    # Evaluate the fitted model over the entire wavelength range
    continuum_flux = fitted_model(spectrum.spectral_axis.value) * spectrum.flux.unit

    # Subtract the continuum model from the original spectrum
    subtracted_spectrum = spectrum - continuum_flux

    # CORRECTED RETURN STATEMENT
    return subtracted_spectrum, continuum_flux


def fit_line_stable(spectrum_sub, line_wave_A):
    """
    Fits a single Gaussian to a target line using astropy.modeling directly.
    """
    from astropy.modeling.models import Gaussian1D
    from astropy.modeling.fitting import LevMarLSQFitter

    try:
        fit_region = SpectralRegion(line_wave_A - (GAUSSIAN_FIT_WINDOW/2)*u.AA,
                                    line_wave_A + (GAUSSIAN_FIT_WINDOW/2)*u.AA)
        sub_spectrum = extract_region(spectrum_sub, fit_region)

        if not np.all(np.isfinite(sub_spectrum.flux)): return None

        fitter = LevMarLSQFitter()
        model_init = Gaussian1D(
            amplitude=np.nanmax(sub_spectrum.flux),
            mean=line_wave_A,
            stddev=2.0 * u.AA
        )

        fitted_model = fitter(model_init, sub_spectrum.spectral_axis, sub_spectrum.flux)

        amplitude = fitted_model.amplitude.value
        stddev_A = fitted_model.stddev.value
        center_A = fitted_model.mean.value
        # change this bs
        integrated_flux = amplitude * stddev_A * np.sqrt(2 * np.pi)
        local_noise = np.nanmedian(sub_spectrum.uncertainty.array)
        snr = amplitude / local_noise if local_noise > 0 else 0.0

        if snr >= SN_CUTOFF:
            return {'flux': integrated_flux, 'flux_err': integrated_flux / snr,
                    'snr': snr, 'fit_center_A': center_A}
        else:
            return None

    except Exception as e:
        logger.debug(f"    - Fit failed for line at {line_wave_A:.1f}: {e}")
        return None


def run_analysis_on_galaxy(spectrum_object, initial_z, galaxy_id, make_plots=True):
    """
    Top-level analysis function for a single galaxy, now with plotting.
    """
    z_corr = correct_redshift_xcorr(spectrum_object, initial_z, galaxy_id)

    wave_obs, flux_obs_fnu, err_obs_fnu = spectrum_object.spectral_axis, spectrum_object.flux, spectrum_object.uncertainty.quantity
    wave_rest, flux_rest_fnu, err_rest_fnu = wave_obs/(1+z_corr), flux_obs_fnu*(1+z_corr), err_obs_fnu*(1+z_corr)
    wave_rest_aa = wave_rest.to(u.AA)
    fnu_to_flambda_conv = (const.c / wave_rest_aa**2).to(u.erg/u.s/u.cm**2/u.AA/u.uJy)
    flux_flambda, err_flambda = flux_rest_fnu*fnu_to_flambda_conv, err_rest_fnu*fnu_to_flambda_conv
    spec_flam = Spectrum1D(
        flux=flux_flambda,
        spectral_axis=wave_rest_aa,
        uncertainty=StdDevUncertainty(err_flambda))

    try:
        spec_sub, continuum_flux = fit_continuum_stable(spec_flam)
        # Call the diagnostic plotter if requested
        if make_plots:
            create_continuum_diagnostic_plot(galaxy_id, spec_flam, continuum_flux, spec_sub, PLOT_PATH)
    except Exception as e:
        logger.error(f"  Continuum fitting failed for GID {galaxy_id}: {e}. Skipping galaxy.")
        return None

    logger.info(f"  GID {galaxy_id}: Measuring target lines...")
    line_fits = {}
    for line_name, rest_wave in EMISSION_LINES.items():
        measurement = fit_line_stable(spec_sub, rest_wave)
        if measurement:
            line_fits[line_name] = measurement
            logger.info(f"    ✓ {line_name}: DETECTED with S/N = {measurement['snr']:.2f}")

    # Call the final results plotter if requested and lines were found
    if make_plots and line_fits:
        plot_and_save_results(galaxy_id, z_corr, spec_flam, line_fits, PLOT_PATH)
    if z_corr != initial_z:
      final_results = {'galaxy_id': galaxy_id, 'redshift': initial_z, 'redshift_corrected': z_corr}
    else:
      final_results = {'galaxy_id': galaxy_id, 'redshift': initial_z, 'redshift_corrected': np.nan}
    for line_name in EMISSION_LINES.keys():
        safe_name = line_name.replace(' ', '').replace('[','').replace(']','').replace(r'\\','').replace(r'$','').replace('-','').replace('Å','')
        if line_name in line_fits:
            final_results[f'{safe_name}_flux'] = line_fits[line_name]['flux']
            final_results[f'{safe_name}_flux_err'] = line_fits[line_name]['flux_err']
            final_results[f'{safe_name}_snr'] = line_fits[line_name]['snr']
            final_results[f'{safe_name}_fit_center_A'] = line_fits[line_name]['fit_center_A']
        else:
            for suffix in ['_flux', '_flux_err', '_snr', '_fit_center_A']:
                final_results[f'{safe_name}{suffix}'] = np.nan

    found_flag_lines = FLAG_LINES.intersection(line_fits.keys())
    if found_flag_lines:
        logger.warning(f"  >>> AGN/HIGH-IONIZATION SIGNATURE DETECTED! Flagged lines: {sorted(list(found_flag_lines))}")
    final_results['agn_flag'] = len(found_flag_lines) > 0

    return final_results


def create_continuum_diagnostic_plot(galaxy_id, spectrum, continuum, spec_sub, save_dir):
    """
    Creates and saves a plot to diagnose the continuum subtraction process.
    """
    from astropy.visualization import quantity_support
    quantity_support() # Enables astropy units on plot axes

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10), sharex=True,
                                   gridspec_kw={'height_ratios': [2, 1]})

    # Top Panel: Original spectrum and continuum model
    ax1.plot(spectrum.spectral_axis, spectrum.flux, color='grey', lw=1, label='Original Spectrum')
    ax1.plot(spectrum.spectral_axis, continuum, color='red', linestyle='--', lw=2, label='Continuum Model')
    ax1.set_title(f'Continuum Subtraction Diagnostic for Galaxy {galaxy_id}', fontsize=16)
    ax1.legend()

    # Bottom Panel: Continuum-subtracted spectrum
    ax2.plot(spec_sub.spectral_axis, spec_sub.flux, color='black', lw=1, label='Continuum-Subtracted Flux')
    # Use the uncertainty from the subtracted spectrum for the noise level
    ax2.fill_between(spec_sub.spectral_axis, -spec_sub.uncertainty.quantity, spec_sub.uncertainty.quantity,
                     color='cyan', alpha=0.5, label='1-sigma Noise')
    ax2.axhline(0, color='red', linestyle=':', lw=1)
    ax2.set_ylabel('Subtracted Flux')
    ax2.legend()

    # Set smart y-limits for the bottom panel
    plot_std = np.nanmedian(spec_sub.uncertainty.array) * 10
    if np.isfinite(plot_std) and plot_std > 0:
        ax2.set_ylim(-plot_std, plot_std)

    plt.tight_layout()
    plt.subplots_adjust(hspace=0.05)
    plot_filename = os.path.join(save_dir, f'diagnostic_continuum_{galaxy_id}.png')
    plt.savefig(plot_filename, dpi=150)
    plt.close(fig)
    logger.info(f"  → Continuum diagnostic plot saved to {os.path.basename(plot_filename)}")


def plot_and_save_results(galaxy_id, redshift, spectrum, line_fits, save_dir):
    """
    Generates and saves a final plot showing the spectrum and detected lines.
    """
    from astropy.visualization import quantity_support
    quantity_support()

    fig, ax = plt.subplots(figsize=(18, 7))
    ax.plot(spectrum.spectral_axis, spectrum.flux, label='Spectrum (Rest-frame)', color='black', lw=0.8)

    # Mark the locations of detected lines
    for line_name, fit_info in line_fits.items():
        center_wave = fit_info['fit_center_A'] * u.AA
        ax.axvline(center_wave.value, color='red', linestyle='--', alpha=0.7,
                    label=f"{line_name} (S/N: {fit_info['snr']:.1f})")

    ax.set_xlabel(f'Rest-frame Wavelength ({spectrum.spectral_axis.unit})', fontsize=12)
    ax.set_ylabel(f'Flux Density ({spectrum.flux.unit})', fontsize=12)
    ax.set_title(f'Spectrum and Detected Lines for Galaxy {galaxy_id} (z={redshift:.4f})', fontsize=14)

    # Create a clean legend without duplicate labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right')

    ax.set_xlim(np.nanmin(spectrum.spectral_axis.value), np.nanmax(spectrum.spectral_axis.value))
    med_flux, std_flux = np.nanmedian(spectrum.flux.value), np.nanstd(spectrum.flux.value)
    if np.isfinite(med_flux) and np.isfinite(std_flux):
        ax.set_ylim(med_flux - 2 * std_flux, med_flux + 7 * std_flux)

    plt.tight_layout()
    plot_filename = os.path.join(save_dir, f'spectrum_{galaxy_id}.png')
    plt.savefig(plot_filename, dpi=150)
    plt.close(fig)
    logger.info(f"  → Final results plot saved to {os.path.basename(plot_filename)}")

In [56]:
def main():
    """
    Main function to run the entire analysis pipeline.
    """
    logger.info("====================" * 6)
    logger.info("--- Starting Main Analysis Pipeline (Stable Version) ---")
    os.makedirs(PLOT_PATH, exist_ok=True)
    try:
        df_redshifts = pd.read_csv(TARGET_LIST_PATH)
        if 'file' not in df_redshifts.columns or 'z' not in df_redshifts.columns:
             raise ValueError("CSV must contain 'file' and 'z' columns.")
    except Exception as e:
        logger.critical(f"Failed to load or parse target list '{TARGET_LIST_PATH}': {e}"); return

    processed_spectra = combine_gratings_specutils(SAVE_PATH)
    if not processed_spectra:
        logger.critical("No spectra were processed. Halting.")
        return

    all_galaxy_results = []
    for galaxy_id, spec_info in processed_spectra.items():
        logger.info("-------------------------------------------------")
        logger.info(f"Processing Galaxy ID: {galaxy_id}")
        z_rows = df_redshifts[df_redshifts['file'].isin(spec_info['source_files'])]
        if z_rows.empty:
            logger.warning(f"Could not find an entry for GID {galaxy_id} in redshift file. Skipping.")
            continue
        redshift = z_rows['z'].iloc[0]
        galaxy_data_row = run_analysis_on_galaxy(spec_info['spectrum'], redshift, galaxy_id)
        if galaxy_data_row:
            all_galaxy_results.append(galaxy_data_row)

    if all_galaxy_results:
        logger.info("====================" * 6)
        logger.info("--- Saving Final Flux Catalog ---")
        results_df = pd.DataFrame(all_galaxy_results)
        cols = ['galaxy_id', 'redshift', 'redshift_corrected', 'agn_flag'] + sorted([c for c in results_df.columns if c not in ['galaxy_id', 'redshift_initial', 'redshift_corrected', 'agn_flag']])
        results_df = results_df[cols]
        try:
            results_df.to_csv(OUTPUT_CATALOG_PATH, index=False, float_format='%.5e')
            logger.info(f"SUCCESS: Final catalog saved to: {OUTPUT_CATALOG_PATH}")
        except Exception as e:
            logger.critical(f"Failed to save the final catalog. Reason: {e}")
    else:
        logger.warning("No galaxies were processed successfully. Final catalog was not created.")

    logger.info("====================" * 6)
    logger.info("!!!! ANALYSIS COMPLETE !!!!")

In [None]:
# === Driver ===
if __name__ == "__main__":
    main()

--- Starting Main Analysis Pipeline (Stable Version) ---
--- Beginning grating combination (specutils) ---
Total unique galaxy IDs found: 26
--- Combination Complete ---
Successfully processed and stored data for 26 galaxies.
-------------------------------------------------
Processing Galaxy ID: 102539
  Cross-correlation failed for all templates. Using initial redshift.
  → Continuum diagnostic plot saved to diagnostic_continuum_102539.png
  GID 102539: Measuring target lines...
    ✓ Pa-alpha: DETECTED with S/N = 36.01
    ✓ Pa-beta: DETECTED with S/N = 18.20
    ✓ Pa-gamma: DETECTED with S/N = 13.40
    ✓ Pa-delta: DETECTED with S/N = 6.16
    ✓ Pa-epsilon: DETECTED with S/N = 10649.64
    ✓ Pa-zeta: DETECTED with S/N = 19.96
    ✓ Br-gamma: DETECTED with S/N = 5.23
    ✓ Br-epsilon: DETECTED with S/N = 3521.60
    ✓ He I 10830: DETECTED with S/N = 27.24
  → Final results plot saved to spectrum_102539.png
-------------------------------------------------
Processing Galaxy ID: 10267