In [11]:
# =========================================
# Script: advanced_spectral_analysis-JADES.py
# Purpose: JADES data download, processing, emission line analysis, and plotting.
# Author: Joseph Havens
# Research Supervisor: Dr. Bren Backhaus
# Date: 31-08-2025
# =========================================

import os
import shutil
import logging
from datetime import datetime
from collections import defaultdict
import requests
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.signal import savgol_filter, find_peaks, medfilt
from scipy.optimize import curve_fit
from itertools import cycle
import time

# Astropy imports for more robust astronomical data handling
import astropy.units as u
from astropy.visualization import quantity_support
import astropy.constants as const
from astropy.io import fits
from astropy.table import Table
from astropy.utils.data import download_file
from astropy.stats import sigma_clip
from astropy.wcs import FITSFixedWarning
import warnings
from astropy.utils.exceptions import AstropyDeprecationWarning
from astropy.modeling.fitting import FittingWithOutlierRemoval
from scipy.optimize import OptimizeWarning

try:
  from specutils import Spectrum1D, SpectralRegion
  from specutils.manipulation import extract_region
  from specutils.fitting import fit_generic_continuum, fit_lines
  from specutils.analysis import template_correlate
except:
  !pip install --upgrade specutils astropy
  from specutils import Spectrum1D, SpectralRegion
  from specutils.manipulation import extract_region
  from specutils.fitting import fit_generic_continuum, fit_lines
  from specutils.analysis import template_correlate

from astropy.modeling.models import Gaussian1D, Chebyshev1D
from astropy.nddata import StdDevUncertainty

warnings.simplefilter('ignore', category=AstropyDeprecationWarning)
warnings.simplefilter('ignore', category=UserWarning) # Catches the "Model is linear" warning
warnings.simplefilter('ignore', category=RuntimeWarning) # Catches the "overflow" warning
warnings.simplefilter('ignore', category=OptimizeWarning) # Catches "The fit may be unsuccessful"

In [12]:
# Suppress a common but benign FITS warning
warnings.filterwarnings('ignore', category=FITSFixedWarning)

# --- Configuration Section ---

# 1. Define paths
# Use os.path.join for cross-platform compatibility.
USER_HOME = "/content/drive/MyDrive/Bren_code/My_work/"
BASE_PATH = os.path.join(USER_HOME, "JADES_Analysis/")
SAVE_PATH = os.path.join(BASE_PATH, "JADES/1180")
TARGET_LIST_PATH = os.path.join(BASE_PATH, "jades.csv") # Assumes the CSV is in the base path
OUTPUT_CATALOG_PATH = os.path.join(BASE_PATH, "jades_line_flux_catalog.csv")
PLOT_PATH = os.path.join(BASE_PATH, "plots/")

# 2. Data source
JADES_PROJECTS = {
    '1180': {
        'jades-gds-wide-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds-wide-v3/',
        'jades-gds-wide2-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds-wide2-v3/',
        'jades-gds-wide3-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds-wide3-v3/'
    },

#    '1181': {
#        'jades-gdn-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn-v3/',
#        'jades-gdn09-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn09-v3/',
#        'jades-gdn10-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn10-v3/',
#        'jades-gdn11-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn11-v3/',
#        'jades-gdn2-blue-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn2-blue-v3/',
#        'jades-gdn2-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gdn2-v3/',
#    },
#
#    '1210': {
#        'gds-deep-v3':'https://s3.amazonaws.com/msaexp-nirspec/extractions/gds-deep-v3/'
#        },
#
#    '1286': {
#        'jades-gds02-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds02-v3/',
#        'jades-gds03-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds03-v3/',
#        'jades-gds04-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds04-v3/',
#        'jades-gds05-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds05-v3/',
#        'jades-gds06-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds06-v3/',
#        'jades-gds07-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds07-v3/',
#        'jades-gds08-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds08-v3/',
#        'jades-gds1-v3': 'https://s3.amazonaws.com/msaexp-nirspec/extractions/jades-gds1-v3/',
#    },
#
#    'ultra_deep': {
#        'gds-udeep-v3':'https://s3.amazonaws.com/msaexp-nirspec/extractions/gds-udeep-v3/'
#    }
}

# 3. Analysis Parameters
PEAK_FINDING_HEIGHT = 3.0  # In units of local noise standard deviation
PEAK_PROMINENCE = 1.5      # How much a peak must 'stick out' from its surroundings, in units of local noise.
PEAK_FINDING_DISTANCE = 10 # Minimum separation between peaks in pixels
GAUSSIAN_FIT_WINDOW = 20   # Window size around a peak for fitting, in pixels
WAVELENGTH_MATCH_TOLERANCE = 10 * u.AA # Tolerance for matching fitted peaks to known lines
LINE_INTEGRATION_WIDTH_A = 25.0  # The width in Angstroms to integrate over for each line.
SN_CUTOFF = 5.0                  # The signal-to-noise ratio required to count a line as "detected".
flag_GIDS = []

# --- Emission Line Catalog ---
EMISSION_LINES = {
    # Name: Rest Wavelength (Angstroms)

    # --- Hydrogen: Balmer Series (Visible) ---
    r'H-alpha': 6562.819 * u.AA,
    r'H-beta': 4861.333 * u.AA,
    r'H-gamma': 4340.47 * u.AA,
    r'H-delta': 4101.74 * u.AA,
    r'H-epsilon': 3970.07 * u.AA,

    # --- Hydrogen: Paschen Series (NIR) ---
    r'Pa-alpha': 18750.976 * u.AA,
    r'Pa-beta': 12818.1 * u.AA,
    r'Pa-gamma': 10938.1 * u.AA,
    r'Pa-delta': 10049.0 * u.AA,
    r'Pa-epsilon': 9545.98 * u.AA,
    r'Pa-zeta': 9229.02 * u.AA,
    r'Pa-eta': 9014.91 * u.AA,

    # --- Hydrogen: Brackett Series (NIR) ---
    r'Br-beta': 26251.0 * u.AA,
    r'Br-gamma': 21655.302 * u.AA,
    r'Br-delta': 19445.582 * u.AA,
    r'Br-epsilon': 18174.141 * u.AA,

    # --- Hydrogen: Pfund Series (NIR) ---
    r'Pf-delta': 32960.0 * u.AA,
    r'Pf-epsilon': 30383.731 * u.AA,

    # --- Other notable Hydrogen lines ---
    r'H (15700 Å)': 15700.0 * u.AA,
    r'H (26119 Å)': 26119.351 * u.AA,
    r'H (28722 Å)': 28722.0 * u.AA,

    # --- Helium Lines ---
    r'He I 5875': 5875.0 * u.AA,
    r'He I 10830': 10830.0 * u.AA,
    r'He I 18697': 18697.216 * u.AA,
    r'He II 4686': 4685.68 * u.AA,

    # --- Common Forbidden Lines ("Nebular" Lines) ---
    r'[O I] 6300': 6300 * u.AA,
    r'[O II] 3727': 3727.3 * u.AA,    # Doublet avg.
    r'[O III] 4959': 4958.91 * u.AA,
    r'[O III] 5007': 5006.84 * u.AA,
    r'[N II] 6583': 6583.4 * u.AA,    # Often blended with H-alpha
    r'[S II] 6716': 6716.4 * u.AA,
    r'[S II] 6731': 6730.8 * u.AA,

    # --- High-Ionization / AGN Lines ---
    # These lines indicate a very hard radiation field, often from an AGN.
    # Some are UV lines, only visible in NIR spectra for high-redshift objects.
    r'[C IV] 1549': 1549.0 * u.AA,           # UV, z > ~4 for NIRSpec
    r'[C III] 1909': 1908.7 * u.AA,         # UV, z > ~3 for NIRSpec
    r'[Mg II] 2798': 2798.0 * u.AA,          # UV, z > ~2 for NIRSpec
    r'[Ne V] 3346': 3345.8 * u.AA,
    r'[Ne V] 3426': 3425.9 * u.AA,
    r'[Ne III] 3869': 3868.8 * u.AA,
    r'[Ar IV] 4740': 4740.2 * u.AA,
    r'[Fe VII] 6087': 6087.0 * u.AA,
}

# Define a set of line names that indicate high-ionization or AGN activity
FLAG_LINES = {
    r'[C IV] 1549',
    r'[C III] 1909',
    r'[Mg II] 2798',
    r'[Ne V] 3346',
    r'[Ne V] 3426',
    r'[Ne III] 3869',
    r'He II 4686',
    r'[Ar IV] 4740',
    r'[Fe VII] 6087',
}

In [13]:
# --- Logger Setup ---
# This block replaces the need for a global 'debug' variable.

# 1. Create a filename for the detailed log file
# Example: 'log_2025-06-08_19-38.log'
log_filename = f"log_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.log"
# In Colab/Drive, you might want to specify the full path:
log_filepath = "/content/drive/MyDrive/Bren_code/My_work/JADES_Analysis/logs/" + log_filename

# 2. Get the root logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG) # Set the lowest level to capture ALL messages

# 3. Create a handler to write to the CONSOLE (for high-level info)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO) # Only shows INFO, WARNING, ERROR, CRITICAL
console_formatter = logging.Formatter('%(message)s') # Keep console output clean
console_handler.setFormatter(console_formatter)

# 4. Create a handler to write to the FILE (for all the details)
file_handler = logging.FileHandler(log_filepath)
file_handler.setLevel(logging.DEBUG) # Captures EVERY level, including DEBUG
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(file_formatter)

# 5. Add the handlers to the logger
# Important: Clear existing handlers to prevent duplicate logs in notebooks
if (logger.hasHandlers()):
    logger.handlers.clear()
logger.addHandler(console_handler)
logger.addHandler(file_handler)

# --- End Logger Setup ---
# Now, you can use logger.info(), logger.debug(), etc. throughout your code.
logger.info(f"Logger initialized. Detailed debug output will be saved to: {log_filepath}")

Logger initialized. Detailed debug output will be saved to: /content/drive/MyDrive/Bren_code/My_work/JADES_Analysis/logs/log_2025-09-09_17-03.log


In [14]:
# === Helper Functions (Final Stable Version) ===

def download_jades_data(projects, csv_parent_path, base_save_path):
    """
    Downloads JADES FITS files, using a separate CSV file for each
    individual pointing.
    """
    # This function is working well and does not need changes.
    logger.info("====================" * 6)
    logger.info("--- Beginning JADES Data Download (Pointing-Specific CSVs) ---")
    count = 0
    for program_id, pointings_dict in projects.items():
        logger.info("-------------------------------------------------")
        logger.info(f"Processing Program ID: {program_id}")
        program_save_path = os.path.join(base_save_path, program_id)
        os.makedirs(program_save_path, exist_ok=True)
        for pointing_name, data_url in pointings_dict.items():
            logger.info(f"  --> Processing pointing: {pointing_name}")
            csv_filename = f"{program_id}-{pointing_name}.csv"
            target_list_path = os.path.join(csv_parent_path, csv_filename)
            try:
                target_list = pd.read_csv(target_list_path)
                logger.info(f"    Loaded {len(target_list)} targets from '{csv_filename}'.")
            except FileNotFoundError:
                logger.warning(f"    Target list '{csv_filename}' not found. Skipping this pointing.")
                continue
            for i, row in target_list.iterrows():
                target_file = row['file']
                file_path = os.path.join(program_save_path, target_file)
                if os.path.exists(file_path):
                    logger.debug(f"    Skipping '{target_file}', already downloaded.")
                    continue
                file_url = data_url + target_file
                logger.info(f"    Downloading: {target_file}")
                try:
                    r = requests.get(file_url, allow_redirects=True, timeout=120)
                    r.raise_for_status()
                    with open(file_path, 'wb') as f:
                        f.write(r.content)
                        count += 1
                    with fits.open(file_path) as hdul:
                        logger.debug(f"      ✓ FITS verification successful.")
                except Exception as e:
                    logger.error(f"      ✗ An error occurred for '{target_file}': {e}")
    logger.info("====================" * 6)
    logger.info("!!!! JADES DOWNLOAD COMPLETE !!!!")
    logger.info(f"Total files downloaded: {count}")
    logger.info("====================" * 6)


def combine_gratings_specutils(base_data_folder):
    """
    Recursively finds all .spec.fits files, groups by GID, and combines into Spectrum1D objects.
    """
    # This function is working well and does not need changes.
    logger.info("====================" * 6)
    logger.info("--- Beginning grating combination for JADES ---")
    all_fits_files = [os.path.join(root, file) for root, dirs, files in os.walk(base_data_folder) for file in files if file.endswith('.spec.fits')]
    logger.info(f"Found {len(all_fits_files)} total '.spec.fits' files across all subdirectories.")
    grouped_spectra = defaultdict(list)
    for fits_path in all_fits_files:
        filename = os.path.basename(fits_path)
        try:
            galaxy_id = filename.split('_')[-1].replace('.spec.fits', '')
            grouped_spectra[galaxy_id].append(fits_path)
        except IndexError:
            logger.warning(f"Could not parse galaxy ID from {filename}. Skipping.")
            continue
    logger.info(f"Total unique galaxy IDs found: {len(grouped_spectra)}")
    combined_spectra = {}
    for galaxy_id, file_paths in sorted(grouped_spectra.items()):
        all_wave, all_flux, all_err = [], [], []
        source_filenames = [os.path.basename(p) for p in file_paths]
        for f_path in file_paths:
            try:
                with fits.open(f_path) as hdul:
                    data = hdul[1].data
                    all_wave.append(data['wave'])
                    all_flux.append(data['flux'])
                    all_err.append(data['err'])
            except Exception as e:
                logger.error(f"GID {galaxy_id}: Failed to process file '{os.path.basename(f_path)}'. Reason: {e}")
                continue
        if not all_wave: continue
        wave, flux, err = np.concatenate(all_wave), np.concatenate(all_flux), np.concatenate(all_err)
        sort_idx = np.argsort(wave)
        spectrum_object = Spectrum1D(
            flux=flux[sort_idx] * u.uJy,
            spectral_axis=wave[sort_idx] * u.um,
            uncertainty=StdDevUncertainty(err[sort_idx] * u.uJy))
        combined_spectra[galaxy_id] = {'spectrum': spectrum_object, 'source_files': source_filenames}
    logger.info("--- Combination Complete ---")
    logger.info(f"Successfully processed and stored data for {len(combined_spectra)} galaxies.")
    return combined_spectra


def create_specialized_template(template_type):
    """
    Creates a specialized emission line template (UV, Optical, or IR).
    """
    # This function is working well and does not need changes.
    lines = []
    if template_type == 'optical':
        wave_range = np.arange(3700, 6800, 1.0)
        lines = [(3727.0, 2.0), (4861.3, 1.0), (4958.9, 1.5), (5006.8, 4.5), (6548.1, 0.5), (6562.8, 3.5), (6583.4, 1.5), (6716.4, 0.6), (6730.8, 0.5)]
    elif template_type == 'ir':
        wave_range = np.arange(8500, 19000, 10.0)
        lines = [(10049.0, 0.8), (10938.1, 1.0), (12818.1, 1.5), (18750.9, 2.0), (10830.0, 1.2)]
    elif template_type == 'uv':
        wave_range = np.arange(1200, 3000, 1.0)
        lines = [(1215.7, 5.0), (1549.0, 2.0), (1908.7, 2.5), (2798.0, 1.5)]
    flux_axis = np.zeros(len(wave_range))
    for wave, strength in lines:
        flux_axis += Gaussian1D.evaluate(wave_range, amplitude=strength, mean=wave, stddev=2.0)
    return wave_range, flux_axis


def correct_redshift_xcorr(spectrum, initial_z, galaxy_id):
    """
    Uses a manual numpy cross-correlation to find the redshift correction.
    """
    # This function is working well and does not need changes.
    best_z = initial_z
    max_corr_peak = -np.inf
    best_template_type = 'none'
    for template_type in ['optical', 'ir', 'uv']:
        try:
            template_wave, template_flux = create_specialized_template(template_type)
            rest_wave_val = (spectrum.spectral_axis / (1 + initial_z)).to_value(u.AA)
            rest_flux_val = spectrum.flux.value
            resampled_flux = np.interp(template_wave, rest_wave_val, rest_flux_val)
            if np.all(resampled_flux == 0.0): continue
            correlation = np.correlate(resampled_flux, template_flux, mode='same')
            current_peak = np.max(correlation)
            if current_peak > max_corr_peak:
                max_corr_peak = current_peak
                best_template_type = template_type
                pixel_lags = np.arange(len(correlation)) - len(correlation) // 2
                dw_per_pixel = np.median(np.diff(template_wave))
                center_wave = np.median(template_wave)
                dv_per_pixel = (const.c.to('km/s').value * dw_per_pixel / center_wave)
                velocity_lags_kms = pixel_lags * dv_per_pixel
                velocity_offset = velocity_lags_kms[np.argmax(correlation)] * u.km / u.s
                delta_z = velocity_offset / const.c
                best_z = (initial_z + delta_z * (1 + initial_z)).value
        except Exception as e:
            logger.debug(f"  Template '{template_type}' failed for GID {galaxy_id}: {e}")
            continue
    if best_template_type != 'none':
        logger.info(f"  Redshift Correction (Best Match: '{best_template_type}' template): Initial z={initial_z:.5f} -> Corrected z={best_z:.5f}")
        return best_z
    else:
        logger.warning(f"  Cross-correlation failed for all templates. Using initial redshift.")
        return initial_z


def fit_continuum_stable(spectrum):
    """
    Fits and subtracts the continuum, robustly handling non-finite values.
    """
    # This function is working well and does not need changes.
    from astropy.modeling.models import Chebyshev1D
    from astropy.modeling.fitting import LevMarLSQFitter, FittingWithOutlierRemoval
    flux_clean, uncert_clean = spectrum.flux.value.copy(), spectrum.uncertainty.array.copy()
    bad_mask = ~np.isfinite(flux_clean) | ~np.isfinite(uncert_clean)
    if np.any(bad_mask):
        good_indices = np.where(~bad_mask)[0]
        if good_indices.size < 2: raise ValueError("Not enough finite data points for interpolation.")
        bad_indices = np.where(bad_mask)[0]
        flux_clean[bad_mask] = np.interp(bad_indices, good_indices, flux_clean[good_indices])
        uncert_clean[bad_mask] = np.interp(bad_indices, good_indices, uncert_clean[good_indices])
    fitter = LevMarLSQFitter()
    robust_fitter = FittingWithOutlierRemoval(fitter, sigma_clip, niter=3, sigma=3.0)
    continuum_model_init = Chebyshev1D(degree=5)
    fitted_model, _ = robust_fitter(continuum_model_init, spectrum.spectral_axis.value, flux_clean)
    continuum_flux = fitted_model(spectrum.spectral_axis.value) * spectrum.flux.unit
    subtracted_spectrum = spectrum - continuum_flux
    return subtracted_spectrum, continuum_flux


def gaussian(x, amplitude, mean, stddev):
    """A standard Gaussian function for fitting."""
    return amplitude * np.exp(-((x - mean)**2) / (2 * stddev**2))


def find_and_fit_lines(spectrum_sub):
    """
    NEW "Hunt-then-Identify" line fitting function. This is a major upgrade.
    1. Finds all significant peaks in the spectrum.
    2. Fits a Gaussian to each peak to get precise parameters and errors.
    3. Matches the fitted peak to the closest known emission line.
    """
    line_fits = {}

    # --- 1. Hunt for Peaks ---
    # Use the smoothed uncertainty as the basis for our detection threshold
    local_noise = savgol_filter(spectrum_sub.uncertainty.array, 51, 3)
    peaks, properties = find_peaks(
        spectrum_sub.flux.value,
        height=local_noise * PEAK_FINDING_HEIGHT,
        prominence=local_noise * PEAK_PROMINENCE,
        distance=PEAK_FINDING_DISTANCE
    )

    if peaks.size == 0:
        return {} # Return empty if no peaks are found

    logger.info(f"  Found {len(peaks)} significant emission line candidates. Now fitting...")

    # --- 2. Fit and Identify Each Peak ---
    for i, peak_idx in enumerate(peaks):
        peak_wave = spectrum_sub.spectral_axis[peak_idx].to_value(u.AA)

        # Define a small window around the peak for fitting
        min_wave = peak_wave - GAUSSIAN_FIT_WINDOW
        max_wave = peak_wave + GAUSSIAN_FIT_WINDOW

        mask = (spectrum_sub.spectral_axis.to_value(u.AA) > min_wave) & (spectrum_sub.spectral_axis.to_value(u.AA) < max_wave)

        wave_subset = spectrum_sub.spectral_axis.value[mask]
        flux_subset = spectrum_sub.flux.value[mask]
        err_subset = spectrum_sub.uncertainty.array[mask]

        if len(wave_subset) < 4: continue # Not enough points to fit

        try:
            # Use curve_fit to get both params (popt) and covariance (pcov)
            p0 = [properties['peak_heights'][i], peak_wave, 2.0] # Initial guess: amp, mean, stddev
            popt, pcov = curve_fit(gaussian, wave_subset, flux_subset, p0=p0, sigma=err_subset, absolute_sigma=True)

            # --- 3. Extract parameters and errors from the fit ---
            amplitude, center_A, stddev_A = popt
            # Errors are the sqrt of the diagonal of the covariance matrix
            amp_err, center_err, stddev_err = np.sqrt(np.diag(pcov))

            integrated_flux = amplitude * stddev_A * np.sqrt(2 * np.pi)
            # Propagate error for the integrated flux
            flux_err = integrated_flux * np.sqrt((amp_err/amplitude)**2 + (stddev_err/stddev_A)**2)

            snr = integrated_flux / flux_err if flux_err > 0 else 0.0

            if snr >= SN_CUTOFF:
                # --- 4. Match the fitted line to the closest line in our catalog ---
                all_line_waves = np.array([line.value for line in EMISSION_LINES.values()])
                all_line_names = list(EMISSION_LINES.keys())

                wave_diff = np.abs(all_line_waves - center_A)
                closest_idx = np.argmin(wave_diff)

                if wave_diff[closest_idx] < WAVELENGTH_MATCH_TOLERANCE.value:
                    line_name = all_line_names[closest_idx]
                    # Avoid overwriting a stronger detection of the same line
                    if line_name not in line_fits or snr > line_fits[line_name]['snr']:
                        line_fits[line_name] = {
                            'flux': integrated_flux, 'flux_err': flux_err,
                            'snr': snr, 'fit_center_A': center_A
                        }
                        logger.info(f"    ✓ {line_name}: DETECTED with S/N = {snr:.2f}")

        except Exception as e:
            logger.debug(f"    - Fit failed for peak near {peak_wave:.1f} Å: {e}")
            continue

    return line_fits


def run_analysis_on_galaxy(spectrum_object, initial_z, galaxy_id, make_plots=True):
    """
    Top-level analysis function for a single galaxy.
    """
    z_corr = correct_redshift_xcorr(spectrum_object, initial_z, galaxy_id)

    wave_obs, flux_obs_fnu, err_obs_fnu = spectrum_object.spectral_axis, spectrum_object.flux, spectrum_object.uncertainty.quantity
    wave_rest, flux_rest_fnu, err_rest_fnu = wave_obs/(1+z_corr), flux_obs_fnu*(1+z_corr), err_obs_fnu*(1+z_corr)
    wave_rest_aa = wave_rest.to(u.AA)
    fnu_to_flambda_conv = (const.c / wave_rest_aa**2).to(u.erg/u.s/u.cm**2/u.AA/u.uJy)
    flux_flambda, err_flambda = flux_rest_fnu*fnu_to_flambda_conv, err_rest_fnu*fnu_to_flambda_conv
    spec_flam = Spectrum1D(
        flux=flux_flambda,
        spectral_axis=wave_rest_aa,
        uncertainty=StdDevUncertainty(err_flambda))

    try:
        spec_sub, continuum_flux = fit_continuum_stable(spec_flam)
        if make_plots:
            create_continuum_diagnostic_plot(galaxy_id, spec_flam, continuum_flux, spec_sub, PLOT_PATH)
    except Exception as e:
        logger.error(f"  Continuum fitting failed for GID {galaxy_id}: {e}. Skipping galaxy.")
        return None

    # Call the new "hunt-then-identify" line fitting function
    line_fits = find_and_fit_lines(spec_sub)

    if make_plots and line_fits:
        plot_and_save_results(galaxy_id, z_corr, spec_flam, line_fits, PLOT_PATH)

    final_results = {'galaxy_id': galaxy_id, 'redshift_initial': initial_z, 'redshift_corrected': z_corr}
    for line_name in EMISSION_LINES.keys():
        safe_name = line_name.replace(' ', '').replace('[','').replace(']','').replace(r'\\','').replace(r'$','').replace('-','').replace('Å','')
        if line_name in line_fits:
            final_results[f'{safe_name}_flux'] = line_fits[line_name]['flux']
            final_results[f'{safe_name}_flux_err'] = line_fits[line_name]['flux_err']
            final_results[f'{safe_name}_snr'] = line_fits[line_name]['snr']
            final_results[f'{safe_name}_fit_center_A'] = line_fits[line_name]['fit_center_A']
        else:
            for suffix in ['_flux', '_flux_err', '_snr', '_fit_center_A']:
                final_results[f'{safe_name}{suffix}'] = np.nan

    found_flag_lines = FLAG_LINES.intersection(line_fits.keys())
    if found_flag_lines:
        logger.warning(f"  >>> AGN/HIGH-IONIZATION SIGNATURE DETECTED! Flagged lines: {sorted(list(found_flag_lines))}")
    final_results['agn_flag'] = len(found_flag_lines) > 0

    return final_results


def create_continuum_diagnostic_plot(galaxy_id, spectrum, continuum, spec_sub, save_dir):
    """
    Creates and saves a plot to diagnose the continuum subtraction process.
    """
    # This function is working well and does not need changes.
    from astropy.visualization import quantity_support
    quantity_support()
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(18, 10), sharex=True,
                                   gridspec_kw={'height_ratios': [2, 1]})
    ax1.plot(spectrum.spectral_axis, spectrum.flux, color='grey', lw=1, label='Original Spectrum')
    ax1.plot(spectrum.spectral_axis, continuum, color='red', linestyle='--', lw=2, label='Continuum Model')
    ax1.set_title(f'Continuum Subtraction Diagnostic for Galaxy {galaxy_id}', fontsize=16)
    ax1.legend()
    ax2.plot(spec_sub.spectral_axis, spec_sub.flux, color='black', lw=1, label='Continuum-Subtracted Flux')
    ax2.fill_between(spec_sub.spectral_axis, -spec_sub.uncertainty.quantity, spec_sub.uncertainty.quantity,
                     color='cyan', alpha=0.5, label='1-sigma Noise')
    ax2.axhline(0, color='red', linestyle=':', lw=1)
    ax2.set_ylabel('Subtracted Flux')
    ax2.legend()
    plot_std = np.nanmedian(spec_sub.uncertainty.array) * 10
    if np.isfinite(plot_std) and plot_std > 0:
        ax2.set_ylim(-plot_std, plot_std)
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.05)
    plot_filename = os.path.join(save_dir, f'diagnostic_continuum_{galaxy_id}.png')
    plt.savefig(plot_filename, dpi=150)
    plt.close(fig)
    logger.info(f"  → Continuum diagnostic plot saved to {os.path.basename(plot_filename)}")


def plot_and_save_results(galaxy_id, redshift, spectrum, line_fits, save_dir):
    """
    Generates and saves a final plot showing the spectrum and detected lines.
    """
    # This function is working well and does not need changes.
    from astropy.visualization import quantity_support
    quantity_support()

    max_flux_value = np.max(spectrum.flux[np.isfinite(spectrum.flux.value)])
    if max_flux_value == 0: # Handle edge case of all-zero flux
        exponent = 0
    else:
        exponent = int(np.floor(np.log10(abs(max_flux_value.value))))

    scaling_factor = 10**exponent
    scale_flux = spectrum.flux / scaling_factor
    scale_err = spectrum.uncertainty.quantity / scaling_factor
    spectrum = Spectrum1D(
        flux=scale_flux,
        spectral_axis=spectrum.spectral_axis,
        uncertainty=StdDevUncertainty(scale_err))

    fig, ax = plt.subplots(figsize=(18, 7))
    ax.plot(spectrum.spectral_axis, spectrum.flux, label='Spectrum (Rest-frame)', color='black', lw=0.8)

    for line_name, fit_info in line_fits.items():
        center_wave = fit_info['fit_center_A'] * u.AA
        ax.axvline(center_wave.value, color='red', linestyle='--', alpha=0.7,
                    label=f"{line_name} (S/N: {fit_info['snr']:.1f})")

    ax.set_xlabel(r'$\lambda$ [$\AA$]', fontsize=14)
    y_label_string = rf'$F_{{\lambda}}$ [$10^{{{exponent}}}$ erg s$^{{-1}}$ cm$^{{-2}}$ $\AA^{{-1}}$]'

    ax.set_ylabel(y_label_string, fontsize=14)
    ax.set_title(f'Spectrum and Detected Lines for Galaxy {galaxy_id} (z={redshift:.4f})', fontsize=14)

    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right')

    ax.set_xlim(np.nanmin(spectrum.spectral_axis.value), np.nanmax(spectrum.spectral_axis.value))
    med_flux, std_flux = np.nanmedian(spectrum.flux.value), np.nanstd(spectrum.flux.value)
    if np.isfinite(med_flux) and np.isfinite(std_flux):
        ax.set_ylim(med_flux - 2 * std_flux, med_flux + 7 * std_flux)

    plt.tight_layout()
    plot_filename = os.path.join(save_dir, f'spectrum_{galaxy_id}.png')
    plt.savefig(plot_filename, dpi=150)
    plt.close(fig)
    logger.info(f"  → Final results plot saved to {os.path.basename(plot_filename)}")

In [15]:
def main():
    """
    Main function to run the entire analysis pipeline for JADES data.
    """
    logger.info("====================" * 6)
    logger.info("--- Starting JADES Main Analysis Pipeline ---")
    os.makedirs(PLOT_PATH, exist_ok=True)

    try:
        df_redshifts = pd.read_csv(TARGET_LIST_PATH)
        df_redshifts['galaxy_id'] = df_redshifts['galaxy_id'].astype(int)
        logger.info(f"Master redshift catalog loaded with {len(df_redshifts)} entries.")
    except Exception as e:
        logger.critical(f"Failed to load or parse target list '{TARGET_LIST_PATH}': {e}")
        return

    processed_spectra = combine_gratings_specutils(SAVE_PATH)
    if not processed_spectra:
        logger.critical("No spectra were processed. Halting.")
        return

    all_galaxy_results = []
    for galaxy_id_str, spec_info in processed_spectra.items():
        logger.info("-------------------------------------------------")

        try:
            galaxy_id = int(galaxy_id_str)
        except ValueError:
            logger.warning(f"Could not convert GID '{galaxy_id_str}' to integer. Skipping.")
            continue

        logger.info(f"Processing Galaxy ID: {galaxy_id}")

        z_rows = df_redshifts[df_redshifts['galaxy_id'] == galaxy_id]

        if z_rows.empty:
            logger.warning(f"  Could not find redshift for GID {galaxy_id} in the master catalog. Skipping.")
            continue

        redshift = z_rows['z'].iloc[0]

        galaxy_data_row = run_analysis_on_galaxy(spec_info['spectrum'], redshift, galaxy_id)
        if galaxy_data_row:
            all_galaxy_results.append(galaxy_data_row)

    if all_galaxy_results:
        logger.info("====================" * 6)
        logger.info("--- Saving Final Flux Catalog ---")
        results_df = pd.DataFrame(all_galaxy_results)
        cols = ['galaxy_id', 'redshift_initial', 'redshift_corrected', 'agn_flag'] + sorted(
            [c for c in results_df.columns if c not in ['galaxy_id', 'redshift_initial', 'redshift_corrected', 'agn_flag']])
        results_df = results_df[cols]
        try:
            results_df.to_csv(OUTPUT_CATALOG_PATH, index=False, float_format='%.5e')
            logger.info(f"SUCCESS: Final catalog saved to: {OUTPUT_CATALOG_PATH}")
        except Exception as e:
            logger.critical(f"Failed to save the final catalog. Reason: {e}")
    else:
        logger.warning("No galaxies were processed successfully. Final catalog was not created.")

    logger.info("====================" * 6)
    logger.info("!!!! ANALYSIS COMPLETE !!!!")

In [None]:
# === Driver ===
if __name__ == "__main__":
    start = time.time()
    main()
    logger.info(f"Total time taken: {time.time() - start:.2f} seconds")

--- Starting JADES Main Analysis Pipeline ---
Master redshift catalog loaded with 4424 entries.
--- Beginning grating combination for JADES ---
Found 4579 total '.spec.fits' files across all subdirectories.
Total unique galaxy IDs found: 1276
--- Combination Complete ---
Successfully processed and stored data for 1276 galaxies.
-------------------------------------------------
Processing Galaxy ID: -1
  Could not find redshift for GID -1 in the master catalog. Skipping.
-------------------------------------------------
Processing Galaxy ID: -2
  Could not find redshift for GID -2 in the master catalog. Skipping.
-------------------------------------------------
Processing Galaxy ID: -3
  Could not find redshift for GID -3 in the master catalog. Skipping.
-------------------------------------------------
Processing Galaxy ID: 10004141
  Cross-correlation failed for all templates. Using initial redshift.
  → Continuum diagnostic plot saved to diagnostic_continuum_10004141.png
  Found 58 