# Geko Fitting Demo: Simplified Usage

This notebook demonstrates how to use `geko` to fit JWST grism spectroscopy data using a simple Python interface.

## Overview

`geko` is a Python package for analyzing JWST grism spectroscopy and morphology data. It uses:
- **JAX** for accelerated numerical computation
- **Numpyro** for Bayesian inference via MCMC
- **Kinematic models** to fit galaxy rotation curves

## Required Data Structure

All data files should be organized in a single base directory (specified by `save_runs_path`). The code expects the following structure:

### Directory Structure

```
<save_runs_path>/                     # Base directory (e.g., 'fitting_results/')
├── <output_name>/                    # Subfolder for your specific galaxy/run
│   ├── spec_2d_*_ID<ID>_comb.fits    # 2D grism spectrum (required)
│   └── <ID>_output                   # Fit results (generated after run)
├── morph_fits/                       # Morphology results directory
│   └── summary_<ID>_image_F150W_svi.cat  # PySersic Sersic profile fits
├── psfs/                             # PSF files directory
│   ├── mpsf_jw018950.gn.f444w.fits   # PSF for GOODS-N
│   ├── mpsf_jw035770.f356w.fits      # PSF for GOODS-N-CONGRESS
│   └── mpsf_jw018950.gs.f444w.fits   # PSF for GOODS-S-FRESCO
│   └── <your_custom_psf>.fits        # Your custom PSF (if using field='manual')
└── catalogs/                         # Optional: catalog directory
    └── <master_catalog>.cat          # Master catalog (can be anywhere)
```

**Example with predefined field**: If you set:
- `save_runs_path='/path/to/data/'`
- `output_name='my_galaxy'`
- `source_id=12345`
- `field='GOODS-N'`

The code will look for:
- Grism spectrum: `/path/to/data/my_galaxy/spec_2d_GDN_F444W_ID12345_comb.fits`
- Morphology: `/path/to/data/morph_fits/summary_12345_image_F150W_svi.cat`
- PSF: `/path/to/data/psfs/mpsf_jw018950.gn.f444w.fits`
- Results saved to: `/path/to/data/my_galaxy/12345_output`

**Example with manual field**: If you set:
- `save_runs_path='/path/to/data/'`
- `output_name='my_galaxy'`
- `source_id=12345`
- `field='manual'`
- `manual_grism_file='my_grism_spectrum.fits'`
- `manual_psf_name='my_psf.fits'`
- `manual_pysersic_file='my_morphology.cat'`
- `manual_theta_rot=45.0`

The code will look for:
- Grism spectrum: `/path/to/data/my_galaxy/my_grism_spectrum.fits`
- Morphology: `/path/to/data/morph_fits/my_morphology.cat`
- PSF: `/path/to/data/psfs/my_psf.fits`
- Rotation angle: 45.0 degrees

### Required Files

#### 1. Master Catalog File

An ASCII table containing source properties. Path specified as `master_cat` parameter (can be located anywhere).

Required columns:
- `ID`: Source identifier (must match your `source_id`)
- `zspec`: Spectroscopic redshift
- `<line>_lambda`: Observed wavelength of emission line (e.g., `H_alpha_lambda` for H-alpha at 6562.8 Å)
- `fit_flux_cgs`: Log of integrated emission line flux (log erg/s/cm²)
- `fit_flux_cgs_e`: Error on log flux

#### 2. Grism Spectrum FITS File

Located at: `<save_runs_path>/<output_name>/spec_2d_*_ID<source_id>_comb.fits` (or custom filename if `field='manual'`)

File naming convention depends on the field:
- **GOODS-N**: `spec_2d_GDN_F444W_ID<source_id>_comb.fits`
- **GOODS-N-CONGRESS**: `spec_2d_GDN_F356W_ID<source_id>_comb.fits`
- **GOODS-S-FRESCO**: `spec_2d_FRESCO_F444W_ID<source_id>_comb.fits`
- **manual**: Specify your own filename via `manual_grism_file` parameter

The FITS file should contain:
- Extension 0: 2D spectrum data (flux vs wavelength and spatial position)
- Extension 1: Error/uncertainty map
- WCS information for wavelength calibration

#### 3. PySersic Morphology File

Located at: `<save_runs_path>/morph_fits/summary_<source_id>_image_F150W_svi.cat` (or custom if `field='manual'`)

ASCII catalog from [PySersic](https://github.com/astropath/pysersic) fits containing:
- Sersic index (n)
- Effective radius (r_eff)
- Position angle (PA)
- Axis ratio (q)
- Centroid positions (x0, y0)

These morphological parameters are used to set priors for the kinematic fitting.

#### 4. PSF Files

Located at: `<save_runs_path>/psfs/mpsf_*.fits` (or custom if `field='manual'`)

Field-specific point spread function FITS files:
- `mpsf_jw018950.gn.f444w.fits` for GOODS-N
- `mpsf_jw035770.f356w.fits` for GOODS-N-CONGRESS  
- `mpsf_jw018950.gs.f444w.fits` for GOODS-S-FRESCO
- Or specify your own via `manual_psf_name` parameter if `field='manual'`

The code automatically selects the appropriate PSF based on the `field` parameter, or uses your specified PSF if `field='manual'`.

#### 5. Rotation Angle (theta)

The rotation angle aligns the morphological model (from imaging) with the grism orientation:
- **GOODS-N**: 230.5098 degrees
- **GOODS-N-CONGRESS**: 228.22379 degrees
- **GOODS-S-FRESCO**: 0.0 degrees
- **manual**: Specify via `manual_theta_rot` parameter

## Field Options

You can use either predefined fields or manual mode:

### Predefined Fields
- `field='GOODS-N'`
- `field='GOODS-N-CONGRESS'`
- `field='GOODS-S-FRESCO'`

These automatically select the appropriate PSF, file naming convention, and rotation angle.

### Manual Field
- `field='manual'`

When using manual mode, you must provide:
- `manual_psf_name`: PSF filename (in `psfs/` directory)
- `manual_theta_rot`: Rotation angle in degrees
- `manual_pysersic_file`: PySersic results filename (in `morph_fits/` directory)
- `manual_grism_file`: Grism spectrum filename (in `output/` directory)

## Running the Fit

Once you have prepared all required files, running the fit is straightforward:

In [None]:
# Import required modules
from geko.fitting import run_geko_fit
from geko.config import FitConfiguration

# JAX configuration
import jax
jax.config.update('jax_enable_x64', True)

print("Imports successful!")

### Basic Usage

In [None]:
# Define parameters
source_id = 191250                      # Source ID in your catalog
field = 'manual'                        # Field name: use 'manual' to specify custom files
output_name = 'my_galaxy'              # Name of output folder
master_catalog = '/Users/lola/ASTRO/JWST/grism_project/testing_geko_demo/catalogs/my_galaxies_cat'  # Path to master catalog
emission_line = 'H_alpha'                   # Emission line wavelength (Angstroms, rest frame)
parametric = True                      # Use parametric Sersic morphology
save_runs_path = '/Users/lola/ASTRO/JWST/grism_project/testing_geko_demo/'    # Where to save results

# Manual field parameters (required when field='manual')
manual_psf_name = 'mpsf_jw018950.gs.f444w.fits'           # PSF filename in psfs/ directory
manual_theta_rot = 0.0                                     # Rotation angle in degrees (GOODS-S value)
manual_pysersic_file = 'summary_191250_image_F150W_svi.cat'  # PySersic file in morph_fits/
manual_grism_file = 'spec_2d_FRESCO_F444W_ID191250_comb.fits'  # Grism spectrum in output/ directory

# Optional parameters (with defaults)
grism_filter = 'F444W'                 # Grism filter
delta_wave_cutoff = 0.02             # Wavelength bin size (microns)
factor = 3                             # Spatial oversampling factor
wave_factor = 4                       # Wavelength oversampling factor
model_name = 'Disk'                    # Kinematic model type

# MCMC parameters
num_chains = 1                       # Number of MCMC chains
num_warmup = 5                       # Warmup iterations
num_samples = 20                     # Sampling iterations

In [None]:
# Run the fit with manual field option
inference_data = run_geko_fit(
    output=output_name,
    master_cat=master_catalog,
    line=emission_line,
    parametric=parametric,
    save_runs_path=save_runs_path,
    num_chains=num_chains,
    num_warmup=num_warmup,
    num_samples=num_samples,
    source_id=source_id,
    field=field,                        # 'manual' mode
    grism_filter=grism_filter,
    delta_wave_cutoff=delta_wave_cutoff,
    factor=factor,
    wave_factor=wave_factor,
    model_name=model_name,
    config=None,                        # Optional: custom configuration
    # Manual field parameters (required when field='manual')
    manual_psf_name=manual_psf_name,
    manual_theta_rot=manual_theta_rot,
    manual_pysersic_file=manual_pysersic_file,
    manual_grism_file=manual_grism_file
)

## Understanding the Output

After the fit completes, `geko` saves several output files in the directory:
`<save_runs_path>/<output_name>/`

**All output files are named using the source ID**, not the folder name. This allows you to use
any folder name (like 'my_galaxy') while keeping files organized by source ID.

### Output Files

1. **`<source_id>_output`** (NetCDF file)
   - Contains the full MCMC posterior samples and prior samples
   - Can be loaded with `arviz.InferenceData.from_netcdf()`
   - Includes all fitted parameters: Va, sigma0, PA, inc, r_eff, n, xc, yc, amplitude

2. **`<source_id>_results`** (ASCII table)
   - Summary statistics for all fitted parameters
   - Contains median values and 16th/84th percentiles (1-sigma uncertainties)
   - Includes derived quantities like v_re (velocity at effective radius)

3. **`<source_id>_summary.png`** (Diagnostic plot)
   - Multi-panel figure showing:
     - Observed 2D spectrum
     - Best-fit model spectrum
     - Residuals (observed - model)
     - 1D velocity and dispersion profiles
     - Rotation curve and fit parameters

4. **`<source_id>_v_sigma_corner.png`** (Corner plot)
   - Posterior distributions for v/σ ratio
   - Shows correlations between velocity and dispersion

Let's load and examine these outputs:

In [None]:
import arviz as az
import matplotlib.pyplot as plt
from IPython.display import Image, display
import os

# Construct output directory path
output_dir = os.path.join(save_runs_path, output_name)

# Load the inference data (already returned from run_geko_fit)
print("=" * 60)
print("MCMC Summary Statistics")
print("=" * 60)
print(az.summary(inference_data, hdi_prob=0.68))  # 1-sigma (68%) credible intervals

In [None]:
# Load and display the fit results table
import numpy as np
from astropy.table import Table

fit_results_file = os.path.join(output_dir, f'{source_id}_results')
if os.path.exists(fit_results_file):
    fit_results = Table.read(fit_results_file, format='ascii')
    print("\n" + "=" * 60)
    print("Fit Results Summary")
    print("=" * 60)
    print(fit_results)
else:
    print(f"Fit results file not found: {fit_results_file}")

In [None]:
# Display the summary diagnostic plot
summary_plot = os.path.join(output_dir, f'{source_id}_summary.png')
if os.path.exists(summary_plot):
    print("\n" + "=" * 60)
    print("Model Fit Summary Plot")
    print("=" * 60)
    display(Image(filename=summary_plot, width=800))
else:
    print(f"Summary plot not found: {summary_plot}")

In [None]:
# Display the v/sigma corner plot
corner_plot = os.path.join(output_dir, f'{source_id}_v_sigma_corner.png')
if os.path.exists(corner_plot):
    print("\n" + "=" * 60)
    print("v/σ Posterior Distribution")
    print("=" * 60)
    display(Image(filename=corner_plot, width=600))
else:
    print(f"Corner plot not found: {corner_plot}")

In [None]:
# Create additional diagnostic plots with arviz
print("\n" + "=" * 60)
print("MCMC Trace Plots")
print("=" * 60)

# Trace plots for key parameters
az.plot_trace(inference_data, var_names=['Va', 'sigma0', 'PA', 'inc'])
plt.tight_layout()
plt.show()

In [None]:
# Posterior distributions
print("\n" + "=" * 60)
print("Posterior Distributions")
print("=" * 60)

az.plot_posterior(inference_data, var_names=['Va', 'sigma0', 'v_re'], 
                  hdi_prob=0.68, point_estimate='median')
plt.tight_layout()
plt.show()

## Key Output Parameters

The most important fitted parameters are:

**Kinematic Parameters:**
- `Va`: Asymptotic rotation velocity (km/s)
- `sigma0`: Central velocity dispersion (km/s)
- `v_re`: Rotation velocity at the effective radius (km/s) - derived quantity
- `r_t`: Turnover radius (pixels)

**Morphological Parameters:**
- `PA`: Position angle (degrees)
- `inc`: Inclination angle (degrees)
- `r_eff`: Effective radius (pixels)
- `n`: Sersic index
- `xc`, `yc`: Centroid coordinates (pixels)
- `amplitude`: Flux normalization

**Quality Metrics:**
- `r_hat`: Gelman-Rubin convergence diagnostic (should be < 1.01)
- `ess_bulk`: Effective sample size for bulk of distribution
- `ess_tail`: Effective sample size for tails of distribution

## Advanced: Using Custom Configuration

You can set custom priors using the `FitConfiguration` class. 

**Config contains:**
- **Morphology priors**: No defaults - must come from PySersic or manual specification
- **Kinematic priors**: Have defaults but can be overridden
- **MCMC settings**: Have defaults but can be overridden

**Scenario 1: You have PySersic fits (typical)**
- PySersic priors are loaded automatically for morphology
- You can optionally provide a config to override kinematic priors (Va, sigma0 ranges)
- Morphology stays from PySersic unless you explicitly set it in config

**Scenario 2: You don't have PySersic fits (rare)**
- You **must** provide a config with all morphology priors explicitly set
- Error will be raised if morphology priors are missing
- You can still override kinematic priors if desired

In [None]:
# Example 1: Override kinematic priors (you have PySersic fits)
# Only override the kinematic priors, keep PySersic morphology
from geko.config import FitConfiguration, KinematicPriors

config = FitConfiguration(
    kinematics=KinematicPriors(
        Va_min=50.0,        # Minimum asymptotic velocity (km/s)
        Va_max=300.0,       # Maximum asymptotic velocity (km/s)
        sigma0_min=10.0,    # Minimum velocity dispersion (km/s)
        sigma0_max=150.0    # Maximum velocity dispersion (km/s)
    )
)

# Morphology is None - will use PySersic values automatically
print("Kinematic priors will be overridden, morphology will come from PySersic")
config.print_summary()

# Run fit - PySersic morphology + custom kinematic priors
# inference_data_custom = run_geko_fit(
#     output=output_name,
#     master_cat=master_catalog,
#     line=emission_line,
#     parametric=parametric,
#     save_runs_path=save_runs_path,
#     num_chains=num_chains,
#     num_warmup=num_warmup,
#     num_samples=num_samples,
#     source_id=source_id,
#     field=field,
#     config=config,  # Kinematic override only
#     manual_psf_name=manual_psf_name,
#     manual_theta_rot=manual_theta_rot,
#     manual_pysersic_file=manual_pysersic_file,
#     manual_grism_file=manual_grism_file
# )

In [None]:
# Example 2: Complete config with morphology (no PySersic available)
# Set all morphology manually but keep the default kinematic priors
from geko.config import MorphologyPriors

config_full = FitConfiguration(
    morphology=MorphologyPriors(
        # Position angle (degrees) - normal prior
        PA_mean=90.0,
        PA_std=30.0,
        # Inclination (degrees) - truncated normal prior
        inc_mean=55.0,
        inc_std=15.0,
        # Effective radius (pixels) - truncated normal
        r_eff_mean=3.0,
        r_eff_std=1.0,
        r_eff_min=0.5,
        r_eff_max=10.0,
        # Sersic index - truncated normal
        n_mean=1.0,
        n_std=0.5,
        n_min=0.5,
        n_max=4.0,
        # Central coordinates (pixels) - normal
        xc_mean=0.0,
        xc_std=2.0,
        yc_mean=0.0,
        yc_std=2.0,
        # Amplitude - truncated normal
        amplitude_mean=100.0,
        amplitude_std=50.0,
        amplitude_min=1.0,
        amplitude_max=1000.0
    )
)

print("\nComplete config set - can run without PySersic file")
config_full.print_summary()

# This would work even without a PySersic file
# Run fit with complete morphology config
# inference_data_full = run_geko_fit(
#     output=output_name,
#     master_cat=master_catalog,
#     line=emission_line,
#     parametric=parametric,
#     save_runs_path=save_runs_path,
#     num_chains=num_chains,
#     num_warmup=num_warmup,
#     num_samples=num_samples,
#     source_id=source_id,
#     field=field,
#     config=config_full,  # Complete morphology config
#     manual_psf_name=manual_psf_name,
#     manual_theta_rot=manual_theta_rot,
#     manual_pysersic_file=manual_pysersic_file,
#     manual_grism_file=manual_grism_file
# )