# Geko Fitting Demo: Simplified Usage

This notebook demonstrates how to use `geko` to fit JWST grism spectroscopy data using a simple Python interface.

## Overview

`geko` is a Python package for analyzing JWST grism spectroscopy and morphology data. It uses:
- **JAX** for accelerated numerical computation
- **Numpyro** for Bayesian inference via MCMC
- **Kinematic models** to fit galaxy rotation curves

## Required Data Structure

All data files should be organized in a single base directory (specified by `save_runs_path`). The code expects the following structure:

### Directory Structure

```
<save_runs_path>/                     # Base directory (e.g., 'fitting_results/')
├── <output_name>/                    # Subfolder for your specific galaxy/run
│   ├── spec_2d_*_ID<ID>_comb.fits    # 2D grism spectrum (required)
│   └── <ID>_output                   # Fit results (generated after run)
├── morph_fits/                       # Morphology results directory
│   └── summary_<ID>_image_F150W_svi.cat  # PySersic Sersic profile fits
├── psfs/                             # PSF files directory
│   ├── mpsf_jw018950.gn.f444w.fits   # PSF for GOODS-N
│   ├── mpsf_jw035770.f356w.fits      # PSF for GOODS-N-CONGRESS
│   └── mpsf_jw018950.gs.f444w.fits   # PSF for GOODS-S-FRESCO
└── catalogs/                         # Optional: catalog directory
    └── <master_catalog>.cat          # Master catalog (can be anywhere)
```

**Example**: If you set:
- `save_runs_path='/path/to/data/'`
- `output_name='my_galaxy'`
- `source_id=12345`
- `field='GOODS-N'`

The code will look for:
- Grism spectrum: `/path/to/data/my_galaxy/spec_2d_GDN_F444W_ID12345_comb.fits`
- Morphology: `/path/to/data/morph_fits/summary_12345_image_F150W_svi.cat`
- PSF: `/path/to/data/psfs/mpsf_jw018950.gn.f444w.fits`
- Results saved to: `/path/to/data/my_galaxy/12345_output`

### Required Files

#### 1. Master Catalog File

An ASCII table containing source properties. Path specified as `master_cat` parameter (can be located anywhere).

Required columns:
- `ID`: Source identifier (must match your `source_id`)
- `zspec`: Spectroscopic redshift
- `<line>_lambda`: Observed wavelength of emission line (e.g., `H_alpha_lambda` for H-alpha at 6562.8 Å)
- `fit_flux_cgs`: Log of integrated emission line flux (log erg/s/cm²)
- `fit_flux_cgs_e`: Error on log flux

#### 2. Grism Spectrum FITS File

Located at: `<save_runs_path>/<output_name>/spec_2d_*_ID<source_id>_comb.fits`

File naming convention depends on the field:
- **GOODS-N**: `spec_2d_GDN_F444W_ID<source_id>_comb.fits`
- **GOODS-N-CONGRESS**: `spec_2d_GDN_F356W_ID<source_id>_comb.fits`
- **GOODS-S-FRESCO**: `spec_2d_FRESCO_F444W_ID<source_id>_comb.fits`

The FITS file should contain:
- Extension 0: 2D spectrum data (flux vs wavelength and spatial position)
- Extension 1: Error/uncertainty map
- WCS information for wavelength calibration

#### 3. PySersic Morphology File

Located at: `<save_runs_path>/morph_fits/summary_<source_id>_image_F150W_svi.cat` (or `F182M`)

ASCII catalog from [PySersic](https://github.com/astropath/pysersic) fits containing:
- Sersic index (n)
- Effective radius (r_eff)
- Position angle (PA)
- Axis ratio (q)
- Centroid positions (x0, y0)

These morphological parameters are used to set priors for the kinematic fitting.

#### 4. PSF Files

Located at: `<save_runs_path>/psfs/mpsf_*.fits`

Field-specific point spread function FITS files:
- `mpsf_jw018950.gn.f444w.fits` for GOODS-N
- `mpsf_jw035770.f356w.fits` for GOODS-N-CONGRESS  
- `mpsf_jw018950.gs.f444w.fits` for GOODS-S-FRESCO

The code automatically selects the appropriate PSF based on the `field` parameter.

## Running the Fit

Once you have prepared all required files, running the fit is straightforward:

In [1]:
# Import required modules
from geko.fitting import run_geko_fit
from geko.config import FitConfiguration

# JAX configuration
import jax
jax.config.update('jax_enable_x64', True)

print("Imports successful!")

Imports successful!


### Basic Usage

In [2]:
# Define parameters
source_id = 191250                      # Source ID in your catalog
field = 'GOODS-S-FRESCO'                      # Field name
output_name = 'my_galaxy'              # Name of output folder
master_catalog = '/Users/lola/ASTRO/JWST/grism_project/testing_geko_demo/catalogs/my_galaxies_cat'  # Path to master catalog
emission_line = 'H_alpha'                   # Emission line wavelength (Angstroms, rest frame)
parametric = True                      # Use parametric Sersic morphology
save_runs_path = '/Users/lola/ASTRO/JWST/grism_project/testing_geko_demo/'    # Where to save results

# Optional parameters (with defaults)
grism_filter = 'F444W'                 # Grism filter
delta_wave_cutoff = 0.005              # Wavelength bin size (microns)
factor = 5                             # Spatial oversampling factor
wave_factor = 10                       # Wavelength oversampling factor
model_name = 'Disk'                    # Kinematic model type

# MCMC parameters
num_chains = 1                       # Number of MCMC chains
num_warmup = 5                       # Warmup iterations
num_samples = 20                     # Sampling iterations

# Run the fit (no config file needed!)
inference_data = run_geko_fit(
    output=output_name,
    master_cat=master_catalog,
    line=emission_line,
    parametric=parametric,
    save_runs_path=save_runs_path,
    num_chains=num_chains,
    num_warmup=num_warmup,
    num_samples=num_samples,
    source_id=source_id,               # NEW: Direct parameter
    field=field,                        # NEW: Direct parameter
    grism_filter=grism_filter,         # Optional
    delta_wave_cutoff=delta_wave_cutoff,  # Optional
    factor=factor,                      # Optional
    wave_factor=wave_factor,            # Optional
    model_name=model_name,              # Optional
    config=None                         # Optional: custom configuration
)

Cont sub done
Flipping map! (Mod B)
(15, 15)
Disk model created
Disk object created
Rotating the prior by  0.0  radians, from  2.964  radians to  2.964  radians


  kin_model.disk.set_parametric_priors(pysersic_summary, [int_flux, int_flux_err], z_spec, wavelength, delta_wave, theta_rot = theta_rot, shape = obs_map.shape[0])


Setting parametric priors:  10.175309523224001 59.206811076802936 3.342 1.297 0.18970094750436658 15.165 15.0425
MCMC settings: 1 chains, 20 samples, 5 warmup, max_tree_depth=10, target_accept=0.8
step size:  0.001
warmup:  5
samples:  20


  0%|          | 0/25 [00:00<?, ?it/s]Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x110a88bd0>>
Traceback (most recent call last):
  File "/Users/lola/opt_arm/anaconda3/envs/geko_env/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
warmup:  16%|█▌        | 4/25 [00:08<00:41,  1.98s/it, 7 steps of size 2.08e-02. acc. prob=0.50]

### Using Custom Configuration (Optional)

You can override default priors using the `FitConfiguration` class:

In [None]:
# Create custom configuration
config = FitConfiguration()

# Set custom priors for specific parameters
config.set_prior('Va', prior_type='uniform', bounds=[50, 300])
config.set_prior('r_t', prior_type='uniform', bounds=[0.5, 5.0])
config.set_prior('sigma0', prior_type='uniform', bounds=[10, 150])

# Print configuration summary
config.print_summary()

# Run fit with custom config
inference_data = run_geko_fit(
    output=output_name,
    master_cat=master_catalog,
    line=emission_line,
    parametric=parametric,
    save_runs_path=save_runs_path,
    num_chains=num_chains,
    num_warmup=num_warmup,
    num_samples=num_samples,
    config=config  # Use custom configuration
)

## What Happens During the Fit

The `run_geko_fit` function automatically:

1. **Preprocessing** (`run_full_preprocessing`):
   - Loads 2D grism spectrum and error maps
   - Loads PSF for the appropriate field
   - Creates wavelength grid
   - Initializes `Grism` object for dispersion modeling
   - Initializes kinematic model (e.g., `DiskModel`)

2. **Prior Setup** (if `parametric=True`):
   - Loads PySersic morphology results from `morph_fits/` directory
   - Extracts integrated emission line flux from master catalog
   - Sets morphological priors (position angle, inclination, etc.) based on imaging data
   - Applies field-specific rotation corrections to align coordinate systems

3. **MCMC Sampling**:
   - Creates `Fit_Numpyro` instance with observation data
   - Runs NUTS (No-U-Turn Sampler) with specified chains, warmup, and samples
   - Applies custom configuration if provided via `config` parameter
   - Creates source mask to focus on high S/N regions

4. **Postprocessing** (`process_results`):
   - Computes best-fit kinematic model
   - Calculates velocity at effective radius (v_re)
   - Generates diagnostic plots and summary statistics
   - Saves fit results and plots to output directory

5. **Output**:
   - Returns `arviz.InferenceData` object with posterior samples
   - Saves MCMC results to `<save_runs_path>/<output_name>/<source_id>_output`
   - Saves plots and summary tables to the same directory

## Analyzing Results

After the fit completes, you can analyze the results:

In [None]:
import arviz as az
import matplotlib.pyplot as plt

# Summary statistics
print(az.summary(inference_data))

# Trace plots
az.plot_trace(inference_data)
plt.tight_layout()
plt.show()

# Corner plot
import corner
samples = az.extract(inference_data, num_samples=1000)
# Convert to numpy and create corner plot
# corner.corner(samples)

## Common Emission Lines

For the `line` parameter, use rest-frame wavelengths in Angstroms:

- **H-alpha**: 6562.8 Å
- **H-beta**: 4861.3 Å  
- **[OIII]**: 5006.8 Å
- **[OII]**: 3727.0 Å

The code will automatically calculate the observed wavelength based on the redshift in your master catalog.

## Tips for Success

1. **Check your data quality**: Ensure grism spectra have good S/N
2. **Set appropriate priors**: Use photometric measurements to constrain morphology
3. **Run sufficient samples**: Typically 1000-5000 samples after warmup
4. **Monitor convergence**: Check R-hat values < 1.01
5. **Validate results**: Inspect model residuals and posterior distributions