# Horne Extraction - Toy Tests

This notebook recreates some simple tests to run for `HorneExtract`.
Each section builds a synthetic 2‑D spectrum and checks how uncertainties are propagated under different conditions.

**Key ideas:**
- The input is a 2‑D long‑slit spectrum (spatial x spectral).
- We supply per‑pixel variance via `NDData.uncertainty` in a few different flavors.
- `HorneExtract` uses a spatial profile and inverse‑variance weighting:  
  $F(\lambda)=\frac{\sum P\,D/\sigma^2}{\sum P^2/\sigma^2}$ and $\mathrm{Var}[F]=\frac{1}{\sum P^2/\sigma^2}$.

You should be able to run each cell top‑to‑bottom.

In [None]:
import numpy as np
import astropy.units as u
from astropy.nddata import NDData, VarianceUncertainty, StdDevUncertainty

from specreduce.extract import HorneExtract
from specreduce.tracing import FlatTrace

In [None]:
# Global synthetic image settings
ny, nx = 20, 60                 # spatial, spectral
true_flux = 100.0               # DN per pixel
sigma_per_pixel = 5.0           # DN stddev per pixel
rng = np.random.default_rng(42) # reproducible noise

def make_base_data(ny=ny, nx=nx, mu=true_flux, sigma=sigma_per_pixel):
    return mu + rng.normal(0, sigma, size=(ny, nx))

def set_spectral_axis(image, ndisp=nx):
    # Specutils wants a spectral axis that matches a flux axis.
    image.spectral_axis = np.arange(ndisp) * u.pix
    return image

def run_case(title, image, trace_pos=None, profile="gaussian"):
    print(f"=== {title} ===")
    # Provide a spectral axis on the NDData
    set_spectral_axis(image, ndisp=image.data.shape[1])

    # Choose a flat, centered trace unless requested otherwise
    if trace_pos is None:
        trace_pos = image.data.shape[0] // 2
    trace = FlatTrace(image, trace_pos=trace_pos)

    # Build extractor and run
    extractor = HorneExtract(
        image=image,
        trace_object=trace,
        spatial_profile=profile,
        disp_axis=1,
        crossdisp_axis=0,
    )
    try:
        spec = extractor()
        arr = np.asarray(spec.uncertainty.array) if getattr(spec, 'uncertainty', None) is not None else None
        print("flux shape:", spec.flux.shape)
        if arr is None:
            print("No uncertainty returned")
        else:
            print("first 5 uncertainties:", arr[:5])
            print("finite, positive:", np.isfinite(arr).all() and np.all(arr > 0))
    except Exception as e:
        print("ERROR:", e)
    print()


## Case A - VarianceUncertainty (correct units)
We pass per‑pixel variance as `VarianceUncertainty` with units of `DN^2`.  
Expected: uncertainties are finite and roughly constant across wavelength for stationary noise.

In [None]:
dataA = make_base_data()
var_unc = VarianceUncertainty(np.full((ny, nx), sigma_per_pixel**2) * u.DN**2)
imgA = NDData(data=dataA * u.DN, uncertainty=var_unc)
run_case("A: VarianceUncertainty (correct units)", imgA)

## Case B - StdDevUncertainty (correct units)
We pass standard deviation with units (`DN`). Internally it will be squared to variance.  
Expected: same result as Case A.

In [None]:
dataB = make_base_data()
std_unc = StdDevUncertainty(np.full((ny, nx), sigma_per_pixel) * u.DN)
imgB = NDData(data=dataB * u.DN, uncertainty=std_unc)
run_case("B: StdDevUncertainty (correct units)", imgB)

## Case C - StdDevUncertainty (no units)
We pass a unitless stddev array; the parser will assume it's in the same units as the data.  
Expected: matches Case B.

In [None]:
dataC = make_base_data()
std_unc_nounits = StdDevUncertainty(np.full((ny, nx), sigma_per_pixel))
imgC = NDData(data=dataC * u.DN, uncertainty=std_unc_nounits)
run_case("C: StdDevUncertainty (no units)", imgC)

## Case D - No uncertainty
We omit uncertainties; `HorneExtract` requires them. This should raise an error.

In [None]:
dataD = make_base_data()
imgD = NDData(data=dataD * u.DN)
run_case("D: No uncertainty", imgD)

## Case E - Negative variance
Construct a variance array with a negative element, should be rejected.

In [None]:
dataE = make_base_data()
v = np.full((ny, nx), sigma_per_pixel**2)
v[0,0] = -1.0
var_unc_bad = VarianceUncertainty(v * u.DN**2)
imgE = NDData(data=dataE * u.DN, uncertainty=var_unc_bad)
run_case("E: Negative variance", imgE)

## Case F - Zero variance everywhere
All zeros are treated as an unweighted case inside the implementation (variances replaced so extraction can proceed).

In [None]:
dataF = make_base_data()
var_unc_zero = VarianceUncertainty(np.zeros((ny, nx)) * u.DN**2)
imgF = NDData(data=dataF * u.DN, uncertainty=var_unc_zero)
run_case("F: Zero variance everywhere", imgF)

## Case G - Some zero variances
Only the first spatial row has zero variance; those pixels are masked out for the variance calculation.
Expected: finite uncertainties; first elements may differ slightly.

In [None]:
dataG = make_base_data()
v = np.full((ny, nx), sigma_per_pixel**2)
v[0, :] = 0.0
var_unc_mixed = VarianceUncertainty(v * u.DN**2)
imgG = NDData(data=dataG * u.DN, uncertainty=var_unc_mixed)
run_case("G: Some zero variances", imgG)

## Case H/I - Trace near detector edges
We move the flat trace near the top/bottom rows to simulate partial apertures hitting the edge.
Expected: still finite uncertainties (the spatial profile normalization handles partial coverage).

In [None]:
dataH = make_base_data()
var_unc_H = VarianceUncertainty(np.full((ny, nx), sigma_per_pixel**2) * u.DN**2)
imgH = NDData(data=dataH * u.DN, uncertainty=var_unc_H)
run_case("H: Trace near top edge", imgH, trace_pos=1)

dataI = make_base_data()
var_unc_I = VarianceUncertainty(np.full((ny, nx), sigma_per_pixel**2) * u.DN**2)
imgI = NDData(data=dataI * u.DN, uncertainty=var_unc_I)
run_case("I: Trace near bottom edge", imgI, trace_pos=ny-2)

## Case J - Fully masked input
We feed an image where all pixels are NaN; the input parser should reject it as fully masked.

In [None]:
dataJ = np.full((ny, nx), np.nan)
var_unc_J = VarianceUncertainty(np.full((ny, nx), sigma_per_pixel**2) * u.DN**2)
imgJ = NDData(data=dataJ * u.DN, uncertainty=var_unc_J)
run_case("J: Fully masked input", imgJ)

## Case K - Zero profile sumP in a column
We force one spectral column to be entirely NaN, so the spatial profile normalization `sumP` would be zero there.
Implementation guards against division by zero by replacing `sumP==0` with 1.
Expected: uncertainties remain finite.

In [None]:
dataK = make_base_data()
dataK[:, nx//2] = np.nan  # one dead spectral column
var_unc_K = VarianceUncertainty(np.full((ny, nx), sigma_per_pixel**2) * u.DN**2)
imgK = NDData(data=dataK * u.DN, uncertainty=var_unc_K)
run_case("K: Zero profile sumP", imgK)