# Simple Geko Demo: Fitting One Object

This notebook demonstrates how to use `geko` to fit a single galaxy observation using grism spectroscopy data.

## Overview

`geko` is a Python package for analyzing JWST grism spectroscopy and morphology data. It uses:
- **JAX** for accelerated numerical computation
- **Numpyro** for Bayesian inference via MCMC
- **Kinematic models** to fit galaxy rotation curves

## 1. Setup and Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
from astropy.io import fits

# Import geko modules
from geko.grism import Grism
from geko.models import KinModels
from geko.fitting import Fit_Numpyro

# JAX configuration (enable 64-bit precision)
import jax
jax.config.update('jax_enable_x64', True)

print("Imports successful!")

## 2. Create Mock Data

For this demo, we'll create simple mock observations. In a real scenario, you would load your JWST grism data from FITS files.

In [None]:
# Define observation parameters
im_size = 45  # Image size in pixels (model space)
im_scale = 0.0629 / 5  # Pixel scale (arcsec/pixel), oversampled 5x from detector
center_pixel = im_size // 2  # Center of the galaxy

# Wavelength setup
central_wavelength = 4.0  # microns
wave_space = np.arange(central_wavelength - 0.05, 
                       central_wavelength + 0.05 + 0.0001, 
                       0.0001)  # High-resolution wavelength grid

print(f"Image size: {im_size}x{im_size} pixels")
print(f"Pixel scale: {im_scale:.5f} arcsec/pixel")
print(f"Central wavelength: {central_wavelength} μm")
print(f"Wavelength range: {wave_space.min():.3f} - {wave_space.max():.3f} μm")
print(f"Number of wavelength bins: {len(wave_space)}")

In [None]:
# Create a simple PSF (Point Spread Function)
# In practice, you would use a measured JWST PSF
PSF = np.zeros((5, 5))
PSF[2, 2] = 1.0  # Delta function PSF for simplicity

# Create mock observed grism spectrum (2D)
# This would come from your JWST data reduction pipeline
obs_grism = np.random.randn(im_size, im_size) * 0.01  # Small noise

# Add a simple "signal" - a Gaussian source at the center
y, x = np.ogrid[:im_size, :im_size]
gaussian_source = np.exp(-((x - center_pixel)**2 + (y - center_pixel)**2) / (2 * 3**2))
obs_grism += gaussian_source * 0.5

# Create error map (uncertainty in each pixel)
obs_error = np.ones_like(obs_grism) * 0.05  # Constant 5% error for simplicity

# Visualize the mock observation
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

im0 = axes[0].imshow(obs_grism, origin='lower', cmap='viridis')
axes[0].set_title('Mock Grism Observation')
axes[0].set_xlabel('Pixel (dispersion direction)')
axes[0].set_ylabel('Pixel (spatial direction)')
plt.colorbar(im0, ax=axes[0], label='Flux')

im1 = axes[1].imshow(obs_error, origin='lower', cmap='magma')
axes[1].set_title('Error Map')
axes[1].set_xlabel('Pixel (dispersion direction)')
axes[1].set_ylabel('Pixel (spatial direction)')
plt.colorbar(im1, ax=axes[1], label='Uncertainty')

plt.tight_layout()
plt.show()

## 3. Initialize Grism Object

The `Grism` class handles the dispersion model and wavelength calibration for JWST NIRCam grism observations.

In [None]:
# Initialize the Grism object
grism = Grism(
    im_shape=im_size,
    im_scale=im_scale,
    icenter=center_pixel,
    jcenter=center_pixel,
    wavelength=central_wavelength,
    wave_space=wave_space,
    index_min=0,
    index_max=len(wave_space),
    grism_filter='F444W',  # NIRCam filter
    grism_module='A',      # NIRCam module A
    grism_pupil='R',       # Grism R (row dispersion)
    PSF=PSF
)

print(f"Grism object initialized!")
print(f"Detector scale: {grism.detector_scale:.4f} arcsec/pixel")
print(f"Oversampling factor: {grism.factor}x")

## 4. Test Grism Dispersion

Let's test the grism dispersion model by dispersing a simple source.

In [None]:
# Create a simple test: point source with no velocity
test_flux = np.zeros((im_size, im_size))
test_flux[center_pixel, center_pixel] = 1.0

test_vel = np.zeros((im_size, im_size))  # No velocity
test_disp = np.zeros((im_size, im_size))  # No velocity dispersion

# Disperse the source
dispersed = grism.disperse(test_flux, test_vel, test_disp)

# Visualize
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(test_flux, origin='lower', cmap='viridis')
plt.title('Input: Point Source')
plt.colorbar(label='Flux')

plt.subplot(122)
plt.imshow(dispersed, origin='lower', cmap='viridis')
plt.title('Output: Dispersed Spectrum')
plt.colorbar(label='Flux')

plt.tight_layout()
plt.show()

print(f"Flux conservation check: Input = {test_flux.sum():.6f}, Output = {dispersed.sum():.6f}")

## 5. Initialize Kinematic Model

The `KinModels` class defines the kinematic model for the galaxy (e.g., rotation curve, morphology).

In [None]:
# Create kinematic model
redshift = 2.0  # Example redshift

kin_model = KinModels(
    im_shape=im_size,
    im_scale=im_scale,
    grism_object=grism,
    z=redshift,
    line=6562.8,  # H-alpha line in Angstroms
    par_fit=True,  # Use parametric morphology
    PA_prior=[90, 30],  # Position angle prior [mean, std] in degrees
    i_prior=[45, 20],   # Inclination prior [mean, std] in degrees
    z_prior=[redshift, 0.01]  # Redshift prior [mean, std]
)

print(f"Kinematic model initialized for z={redshift}")
print(f"Emission line: H-alpha at {6562.8} Å (rest frame)")
print(f"Observed wavelength: {6562.8 * (1 + redshift) / 1e4:.3f} μm")

## 6. Set Up the Fitter

The `Fit_Numpyro` class performs Bayesian inference using MCMC to fit the model to the data.

In [None]:
# Initialize the fitter
fitter = Fit_Numpyro(
    obs_map=obs_grism,
    obs_error=obs_error,
    grism_object=grism,
    kin_model=kin_model,
    inference_data=None,  # No previous results
    parametric=True  # Use parametric Sersic profile
)

print("Fitter initialized!")
print(f"Observation shape: {fitter.obs_map.shape}")
print(f"Number of valid pixels (S/N > 5): {fitter.mask.sum()}")

## 7. Run the Fit (Optional)

**Note:** Running a full MCMC fit can take significant time. This section is commented out but shows how you would run the fit.

```python
# Run MCMC sampling
fitter.run_inference(
    num_warmup=500,
    num_samples=1000,
    num_chains=2
)

# Get results
samples = fitter.get_samples()
print("Parameter posterior samples:", samples.keys())

# Plot corner plot
fitter.plot_corner()
```

## Summary

This notebook demonstrated:
1. Setting up mock JWST grism observations
2. Initializing the `Grism` dispersion model
3. Testing grism dispersion on a point source
4. Setting up kinematic models for galaxy fitting
5. Preparing data for Bayesian MCMC fitting

### Next Steps

For real science applications:
- Load actual JWST grism data from FITS files
- Use measured PSF models for your filter
- Set appropriate priors based on photometric measurements
- Run MCMC with sufficient samples (typically 1000-5000)
- Analyze posterior distributions and model residuals
- Extract physical quantities (rotation velocity, disk scale length, etc.)