# NeurIPS 2024 Ariel Data Challenge — Preprocessing Pipeline

**Goal**: Demonstrate the full preprocessing pipeline that converts raw telescope photometry into a per-wavelength transit-depth spectrum.

**Pipeline steps**:
0. Detector calibration — dark subtraction, flat fielding, dead pixel masking, spatial summation (raw parquet → calibrated numpy)
1. Out-of-transit (OOT) mask — identify which time frames are baseline (no planet)
2. Baseline normalisation — divide each channel by its OOT median to make flux dimensionless
3. Common-mode correction — remove correlated systematics shared across all wavelengths
4. Temporal binning — co-add frames to boost SNR by √bin_size
5. Transit depth extraction — compute per-channel depth = 1 − ⟨in-transit flux⟩
6. Full one-liner pipeline — `preprocess_planet()` wraps all steps

**Source module**: `src/preprocessing.py` — all functions are pure numpy, no side effects.

> **Note**: This notebook is Kaggle-ready. Run it as a Kaggle notebook kernel with the `ariel-data-challenge-2024` dataset attached.

## Setup

In [None]:
import subprocess, sys
from pathlib import Path

# ── Kaggle: clone repo and add to sys.path ─────────────────────────────────
repo_dir = "/kaggle/working/ariel-exoplanet-ml"
project_dir = repo_dir + "/Kaggle competition/ARIEL neurIPS"

if not Path(repo_dir).exists():
    subprocess.run(
        ["git", "clone", "https://github.com/Smooth-Cactus0/ariel-exoplanet-ml.git", repo_dir],
        check=True,
    )
    print(f"Cloned repo to {repo_dir}")
else:
    print(f"Repo already exists at {repo_dir}")

sys.path.insert(0, project_dir)

DATA_ROOT = "/kaggle/input/ariel-data-challenge-2024"

print(f"project_dir: {project_dir}")
print(f"DATA_ROOT  : {DATA_ROOT}")
print(f"sys.path[0]: {sys.path[0]}")

In [None]:
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

from src.preprocessing import (
    out_of_transit_mask,
    baseline_normalize,
    common_mode_correction,
    bin_time,
    extract_transit_depth,
    preprocess_planet,
)

# ── Plot style ─────────────────────────────────────────────────────────────
plt.rcParams.update({
    "figure.dpi": 110,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "savefig.dpi": 150,
    "savefig.facecolor": "white",
})

# ── Figure output directory ────────────────────────────────────────────────
FIG_DIR = Path("/kaggle/working/figures_preprocessing")
FIG_DIR.mkdir(parents=True, exist_ok=True)

print(f"NumPy      : {np.__version__}")
print(f"Pandas     : {pd.__version__}")
print(f"Matplotlib : {matplotlib.__version__}")
print(f"FIG_DIR    : {FIG_DIR}")
print("[Done] All imports successful.")

## Step 0: Load and Calibrate One Planet

The raw data lives in per-planet directories as parquet files:

```
{data_root}/train/{planet_id}/
    AIRS-CH0_signal.parquet          (11250, 32*356) uint16
    FGS1_signal.parquet              (135000, 32*32) uint16
    AIRS-CH0_calibration/
        dark.parquet  flat.parquet  dead.parquet  ...
    FGS1_calibration/
        dark.parquet  flat.parquet  dead.parquet  ...
```

**Calibration** converts raw detector counts to science-ready flux:
1. **Dark subtraction** — remove thermal electron current: `cal = raw - dark`
2. **Flat-field division** — correct per-pixel sensitivity: `cal /= flat`
3. **Dead pixel zeroing** — mask bad pixels to zero

After calibration, we **sum over the spatial rows** (32 rows for AIRS, 32x32 pixels for FGS1) to collapse the 2-D detector image into 1-D light curves:
- AIRS: `(n_time, 32, 356)` → `(n_time, 356)` — one flux value per spectral channel per time step
- FGS1: `(n_time_fgs, 32, 32)` → `(n_time_fgs,)` → downsample 12:1 → `(n_time,)` — broadband flux at AIRS cadence

In [None]:
import os
from pathlib import Path

# ── Detector geometry (confirmed) ──────────────────────────────────────────
AIRS_N_ROWS = 32    # spatial rows on AIRS-CH0 detector
AIRS_N_COLS = 356   # spectral channels
FGS1_N_ROWS = 32
FGS1_N_COLS = 32
FGS1_RATIO  = 12    # FGS1 frames per AIRS frame (135000 / 11250)


def calibrate(raw, dark, flat, dead):
    """Apply dark subtraction, flat-field division, and dead pixel zeroing."""
    cal = raw - dark[None]
    flat_safe = np.where(flat == 0, 1.0, flat)
    cal /= flat_safe[None]
    cal[:, dead.astype(bool)] = 0.0
    return cal


data_root = Path(DATA_ROOT)
train_dir = data_root / "train"

# ── Discover planet directories ────────────────────────────────────────────
planet_dirs = sorted(
    [d for d in train_dir.iterdir() if d.is_dir()]
) if train_dir.exists() else []

print(f"Planet directories found: {len(planet_dirs)}")
for d in planet_dirs[:5]:
    print(f"  {d.name}")
if len(planet_dirs) > 5:
    print(f"  ... ({len(planet_dirs) - 5} more)")

airs = None
fgs1 = None
planet_id = None

if len(planet_dirs) > 0:
    planet_dir = planet_dirs[0]
    planet_id = planet_dir.name
    print(f"\nLoading planet '{planet_id}' ...")

    try:
        # ── Load and calibrate AIRS-CH0 ────────────────────────────────────
        raw_airs = (
            pd.read_parquet(planet_dir / "AIRS-CH0_signal.parquet")
            .values.astype(np.float64)
        )
        n_time_airs = raw_airs.shape[0]
        raw_airs = raw_airs.reshape(n_time_airs, AIRS_N_ROWS, AIRS_N_COLS)
        print(f"  AIRS raw shape     : ({n_time_airs}, {AIRS_N_ROWS}, {AIRS_N_COLS})")

        cal_dir = planet_dir / "AIRS-CH0_calibration"
        dark_airs = pd.read_parquet(cal_dir / "dark.parquet").values.astype(np.float64)
        flat_airs = pd.read_parquet(cal_dir / "flat.parquet").values.astype(np.float64)
        dead_airs = pd.read_parquet(cal_dir / "dead.parquet").values

        cal_airs = calibrate(raw_airs, dark_airs, flat_airs, dead_airs)
        airs = cal_airs.sum(axis=1)  # sum spatial rows → (n_time, 356)
        print(f"  AIRS calibrated    : {cal_airs.shape} → summed to {airs.shape}")

        # ── Load and calibrate FGS1 ───────────────────────────────────────
        raw_fgs = (
            pd.read_parquet(planet_dir / "FGS1_signal.parquet")
            .values.astype(np.float64)
        )
        n_time_fgs = raw_fgs.shape[0]
        raw_fgs = raw_fgs.reshape(n_time_fgs, FGS1_N_ROWS, FGS1_N_COLS)
        print(f"  FGS1 raw shape     : ({n_time_fgs}, {FGS1_N_ROWS}, {FGS1_N_COLS})")

        cal_dir_fgs = planet_dir / "FGS1_calibration"
        dark_fgs = pd.read_parquet(cal_dir_fgs / "dark.parquet").values.astype(np.float64)
        flat_fgs = pd.read_parquet(cal_dir_fgs / "flat.parquet").values.astype(np.float64)
        dead_fgs = pd.read_parquet(cal_dir_fgs / "dead.parquet").values

        cal_fgs = calibrate(raw_fgs, dark_fgs, flat_fgs, dead_fgs)
        fgs1_full = cal_fgs.sum(axis=(1, 2))  # sum all pixels → (n_time_fgs,)
        print(f"  FGS1 calibrated    : {cal_fgs.shape} → summed to {fgs1_full.shape}")

        # Downsample FGS1 to AIRS cadence (average every `ratio` frames)
        ratio = n_time_fgs // n_time_airs
        if ratio > 1:
            trimmed = fgs1_full[: n_time_airs * ratio]
            fgs1 = trimmed.reshape(n_time_airs, ratio).mean(axis=1)
        else:
            fgs1 = fgs1_full[:n_time_airs]
        print(f"  FGS1 downsampled   : {fgs1_full.shape} → {fgs1.shape}  (ratio={ratio}:1)")

    except Exception as e:
        print(f"Error loading planet data: {e}")
        airs = None

# ── Fallback: synthetic data so all cells run offline ─────────────────────
if airs is None:
    print("\nNo parquet data found. Generating SYNTHETIC data for offline demonstration.")
    rng = np.random.default_rng(42)
    N_TIME, N_CHAN = 300, 356
    t_norm = np.linspace(0, 1, N_TIME)
    # Flat baseline with a ~1% transit dip between 0.2 and 0.8
    transit_mask_synth = (t_norm >= 0.2) & (t_norm <= 0.8)
    # Channel-dependent depth: shallow at short wavelengths, deeper at long
    depth_per_chan = 0.005 + 0.005 * np.linspace(0, 1, N_CHAN)
    airs = 1e5 * (
        1.0
        - transit_mask_synth[:, None] * depth_per_chan[None, :]
        + rng.normal(0, 0.001, (N_TIME, N_CHAN))
    )
    fgs1 = 2e5 * (
        1.0
        - transit_mask_synth * 0.008
        + rng.normal(0, 0.001, N_TIME)
    )
    planet_id = "synthetic-planet-000"
    print(f"  Synthetic AIRS shape : {airs.shape}")
    print(f"  Synthetic FGS1 shape : {fgs1.shape}")

# Ensure AIRS is (time, wavelength)
if airs.ndim == 2 and airs.shape[0] < airs.shape[1]:
    print(f"Transposing AIRS from {airs.shape} to {airs.T.shape} (guessed orientation)")
    airs = airs.T

# Ensure FGS1 is 1-D
fgs1 = fgs1.ravel()

print(f"\nAIRS shape : {airs.shape}  (time x wavelength)")
print(f"FGS1 shape : {fgs1.shape}  (time,)")
print(f"[Done] Loaded planet '{planet_id}'.")

## Step 1: Out-of-Transit Mask

We define a boolean mask over the time axis that is `True` wherever the planet is **not** in front of the star.

- **Ingress fraction**: the normalised time at which the planet disc first overlaps the stellar disc.
- **Egress fraction**: the normalised time at which the planet disc last overlaps the stellar disc.
- Frames between ingress and egress are **in-transit** (mask = `False`); all others are **out-of-transit** (mask = `True`).

The default `ingress=0.2`, `egress=0.8` means the transit occupies the central 60% of the observation window.

In [None]:
INGRESS = 0.2
EGRESS  = 0.8

n_time = airs.shape[0]
mask_oot = out_of_transit_mask(n_time, ingress=INGRESS, egress=EGRESS)

n_oot = mask_oot.sum()
n_it  = (~mask_oot).sum()

print(f"Total time steps : {n_time}")
print(f"OOT frames       : {n_oot}  ({100 * n_oot / n_time:.1f}%)")
print(f"In-transit frames: {n_it}   ({100 * n_it / n_time:.1f}%)")
print(f"Ingress fraction : {INGRESS}  (frame {int(INGRESS * n_time)})")
print(f"Egress fraction  : {EGRESS}   (frame {int(EGRESS * n_time)})")

# ── Plot the mask as a shaded bar ──────────────────────────────────────────
fig, ax = plt.subplots(figsize=(12, 2.5))

t_idx = np.arange(n_time)

# Draw OOT (green) and in-transit (red) as filled regions
ax.fill_between(t_idx, 0, 1,
                where=mask_oot,
                step="mid",
                color="mediumseagreen", alpha=0.45, label="Out-of-transit (OOT)")
ax.fill_between(t_idx, 0, 1,
                where=~mask_oot,
                step="mid",
                color="tomato", alpha=0.45, label="In-transit")

# Ingress / egress vertical lines
ax.axvline(INGRESS * n_time, color="navy", lw=1.5, linestyle="--", label=f"Ingress (t={INGRESS})")
ax.axvline(EGRESS  * n_time, color="darkred", lw=1.5, linestyle="--", label=f"Egress  (t={EGRESS})")

ax.set_xlim(0, n_time - 1)
ax.set_ylim(0, 1)
ax.set_xlabel("Time index", fontsize=11)
ax.set_yticks([])
ax.set_title(
    f"Out-of-Transit Mask  |  {n_oot} OOT frames (green)  +  {n_it} in-transit frames (red)  |  "
    f"ingress={INGRESS}, egress={EGRESS}",
    fontsize=11
)
ax.legend(loc="upper right", fontsize=9, framealpha=0.9)

plt.tight_layout()
fig.savefig(FIG_DIR / "step1_oot_mask.png", bbox_inches="tight")
plt.show()

print(f"[Done] OOT mask: {n_oot} baseline frames, {n_it} in-transit frames.")

## Step 2: Baseline Normalisation

Each wavelength channel is divided by its **out-of-transit median**. After normalisation:

- OOT flux fluctuates around **1.0** for every channel independently.
- In-transit flux dips **below 1.0** by an amount equal to the transit depth `(Rp/Rs)²`.
- Absolute detector gain differences between channels are removed.

We use the **AIRS white-light curve** (mean across all 356 wavelength channels) to illustrate.

In [None]:
# White-light curve = mean over all wavelength channels
airs_white_raw = airs.mean(axis=1)           # (n_time,)

# Baseline normalise the full 2-D cube
airs_norm = baseline_normalize(airs, mask_oot)  # (n_time, n_chan)

# Normalised white-light
airs_white_norm = airs_norm.mean(axis=1)      # (n_time,)

# Diagnostics
oot_median_before = np.median(airs_white_raw[mask_oot])
oot_median_after  = np.median(airs_white_norm[mask_oot])
print(f"OOT median (raw)         : {oot_median_before:.4f}")
print(f"OOT median (normalised)  : {oot_median_after:.6f}  (should be ≈ 1.0)")

t_idx = np.arange(n_time)

fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
fig.suptitle("Baseline Normalisation — AIRS White-Light Curve", fontsize=13, fontweight="bold")

# ── Raw flux ───────────────────────────────────────────────────────────────
ax0 = axes[0]
ax0.plot(t_idx, airs_white_raw, lw=0.9, color="steelblue", label="Raw white-light flux")
ax0.axhline(oot_median_before, color="darkgreen", lw=1.2, linestyle="--",
            label=f"OOT median = {oot_median_before:.4f}")
ax0.fill_between(t_idx, airs_white_raw.min(), airs_white_raw.max(),
                 where=~mask_oot, color="tomato", alpha=0.12, label="In-transit window")
ax0.set_ylabel("Detector counts", fontsize=10)
ax0.set_title("Raw flux (ADU)", fontsize=10)
ax0.legend(fontsize=8)

# ── Normalised flux ────────────────────────────────────────────────────────
ax1 = axes[1]
ax1.plot(t_idx, airs_white_norm, lw=0.9, color="darkorchid", label="Normalised white-light flux")
ax1.axhline(1.0, color="darkgreen", lw=1.2, linestyle="--", label="OOT median = 1.0")
ax1.fill_between(t_idx, airs_white_norm.min() * 0.9998, 1.001,
                 where=~mask_oot, color="tomato", alpha=0.12, label="In-transit window")
ax1.set_xlabel("Time index", fontsize=10)
ax1.set_ylabel("Normalised flux", fontsize=10)
ax1.set_title("Normalised flux (dimensionless, OOT ≈ 1.0)", fontsize=10)
ax1.legend(fontsize=8)

plt.tight_layout()
fig.savefig(FIG_DIR / "step2_baseline_normalisation.png", bbox_inches="tight")
plt.show()

print(f"AIRS normalised cube shape: {airs_norm.shape}")
print(f"[Done] Baseline normalisation complete. OOT median before={oot_median_before:.4f}, after={oot_median_after:.6f}.")

## Step 3: Common-Mode Correction

Even after per-channel normalisation, correlated noise can affect every wavelength simultaneously:

- Telescope pointing jitter moves the target slightly on the detector.
- Thermal "breathing" changes the PSF size uniformly across the focal plane.
- Stellar variability (granulation, flares) brightens or dims all channels together.

The **common mode** is the wavelength-mean light curve at each time step. Dividing by it removes these shared systematics while leaving wavelength-dependent signals (atmospheric absorption lines) intact.

> **Key subtlety**: In-transit frames naturally dip below 1.0 (the planet blocks light). If we divided all frames by the raw common mode including this dip, we would accidentally cancel the transit signal we want to measure. Therefore, in-transit common-mode values are **clamped** to the OOT mean before dividing.

In [None]:
# Common mode = mean over all wavelength channels per time step
common_mode = airs_norm.mean(axis=1)          # (n_time,)
oot_level   = common_mode[mask_oot].mean()    # scalar ≈ 1.0

# Apply common-mode correction
airs_cmc = common_mode_correction(airs_norm, mask_oot)  # (n_time, n_chan)

t_idx = np.arange(n_time)

fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=True)
fig.suptitle("Common-Mode Correction", fontsize=13, fontweight="bold")

# ── Common-mode signal ─────────────────────────────────────────────────────
ax0 = axes[0]
ax0.plot(t_idx, common_mode, lw=0.9, color="darkorange", label="Common-mode signal (λ-mean)")
ax0.axhline(oot_level, color="darkgreen", lw=1.2, linestyle="--",
            label=f"OOT mean = {oot_level:.5f}")
ax0.fill_between(t_idx, common_mode.min() * 0.9998, oot_level * 1.0002,
                 where=~mask_oot, color="tomato", alpha=0.15)
ax0.set_ylabel("Wavelength-mean flux", fontsize=10)
ax0.set_title("Common-mode signal (mean over all 356 AIRS channels)", fontsize=10)
ax0.legend(fontsize=8)

# Annotate the clamping
ax0.annotate(
    "In-transit values clamped to OOT mean\nbefore dividing (prevents transit self-subtraction)",
    xy=(n_time * 0.5, common_mode[n_time // 2]),
    xytext=(n_time * 0.62, oot_level + (oot_level - common_mode.min()) * 0.6),
    arrowprops=dict(arrowstyle="->", color="gray", lw=1.2),
    fontsize=8, color="gray",
    bbox=dict(boxstyle="round,pad=0.3", facecolor="lightyellow", edgecolor="gray", alpha=0.8),
)

# ── CMC white-light curve ──────────────────────────────────────────────────
ax1 = axes[1]
airs_cmc_white = airs_cmc.mean(axis=1)
ax1.plot(t_idx, airs_cmc_white, lw=0.9, color="steelblue",
         label="CMC-corrected white-light flux")
ax1.axhline(1.0, color="darkgreen", lw=1.2, linestyle="--", label="Expected OOT baseline = 1.0")
ax1.fill_between(t_idx, airs_cmc_white.min() * 0.9998, 1.001,
                 where=~mask_oot, color="tomato", alpha=0.12, label="In-transit window")
ax1.set_xlabel("Time index", fontsize=10)
ax1.set_ylabel("Normalised flux", fontsize=10)
ax1.set_title("After common-mode correction — correlated systematics removed", fontsize=10)
ax1.legend(fontsize=8)

plt.tight_layout()
fig.savefig(FIG_DIR / "step3_common_mode_correction.png", bbox_inches="tight")
plt.show()

print(f"Common-mode OOT level : {oot_level:.6f}")
print(f"CMC cube shape        : {airs_cmc.shape}")
print(f"[Done] Common-mode correction applied. CMC white-light in-transit depth: "
      f"{1 - airs_cmc_white[~mask_oot].mean():.5f}.")

## Step 4: Time Binning

Co-adding `bin_size` consecutive frames reduces photon noise by a factor of **√bin_size** (for white noise).

With `bin_size=5`:
- Time axis shrinks from `n_time` to `n_time // 5` frames.
- Expected SNR improvement: **√5 ≈ 2.24×**.
- Trailing frames that don't fill a complete bin are dropped.

Below we compare 3 representative wavelength channels before and after binning.

In [None]:
BIN_SIZE = 5

airs_binned = bin_time(airs_cmc, bin_size=BIN_SIZE)  # (n_time // BIN_SIZE, n_chan)

n_time_binned = airs_binned.shape[0]
expected_snr_gain = np.sqrt(BIN_SIZE)

print(f"Before binning : {airs_cmc.shape[0]} time steps")
print(f"After binning  : {n_time_binned} time steps  (bin_size={BIN_SIZE})")
print(f"Expected SNR gain (white noise): √{BIN_SIZE} = {expected_snr_gain:.3f}×")

# Representative channels: first, middle, last
n_chan = airs_cmc.shape[1]
chan_indices = [0, n_chan // 2, n_chan - 1]
chan_labels  = ["Channel 0 (blue edge)", f"Channel {n_chan // 2} (middle)",
                f"Channel {n_chan - 1} (red edge)"]

t_before = np.arange(airs_cmc.shape[0])
t_after  = np.arange(n_time_binned) * BIN_SIZE + BIN_SIZE / 2  # centre of each bin

fig, axes = plt.subplots(3, 2, figsize=(14, 10), sharex=False)
fig.suptitle(
    f"Time Binning (bin_size={BIN_SIZE})  |  Expected SNR gain ≈ √{BIN_SIZE} = {expected_snr_gain:.2f}×",
    fontsize=13, fontweight="bold"
)

for row, (ci, cl) in enumerate(zip(chan_indices, chan_labels)):
    raw_lc  = airs_cmc[:, ci]
    bin_lc  = airs_binned[:, ci]

    # Empirical noise (OOT std) before and after binning
    # For binned mask, approximate: frame i is OOT if it was majority OOT
    mask_binned = bin_time(mask_oot.astype(np.float32), BIN_SIZE) > 0.5

    noise_before = raw_lc[mask_oot].std() if mask_oot.sum() > 1 else np.nan
    noise_after  = bin_lc[mask_binned].std() if mask_binned.sum() > 1 else np.nan
    actual_gain  = noise_before / noise_after if noise_after > 0 else np.nan

    ax_left  = axes[row, 0]
    ax_right = axes[row, 1]

    # Before
    ax_left.plot(t_before, raw_lc, lw=0.6, color="steelblue", alpha=0.85)
    ax_left.fill_between(t_before, raw_lc.min(), raw_lc.max(),
                         where=~mask_oot, color="tomato", alpha=0.1)
    ax_left.set_title(f"{cl} — Before (σ_OOT={noise_before:.5f})", fontsize=9)
    ax_left.set_xlabel("Time index", fontsize=8)
    ax_left.set_ylabel("Flux", fontsize=8)
    ax_left.tick_params(labelsize=7)

    # After
    ax_right.plot(t_after, bin_lc, lw=1.0, color="darkorchid", alpha=0.9,
                  marker="o", markersize=2)
    ax_right.fill_between(t_after, bin_lc.min(), bin_lc.max(),
                          where=~mask_binned, color="tomato", alpha=0.1)
    ax_right.set_title(
        f"{cl} — After binning (σ_OOT={noise_after:.5f}, gain={actual_gain:.2f}×)",
        fontsize=9
    )
    ax_right.set_xlabel("Time index (bin centres)", fontsize=8)
    ax_right.set_ylabel("Flux", fontsize=8)
    ax_right.tick_params(labelsize=7)

plt.tight_layout()
fig.savefig(FIG_DIR / "step4_time_binning.png", bbox_inches="tight")
plt.show()

print(f"AIRS binned shape : {airs_binned.shape}")
print(f"[Done] Temporal binning done. Expected SNR gain ≈ {expected_snr_gain:.3f}×.")

## Step 5: Transit Depth Extraction

The **transit depth** at wavelength λ is:

$$d[\lambda] = 1 - \langle F_{\text{norm}}[\text{in-transit}, \lambda] \rangle$$

This equals the fraction of stellar light blocked by the planet's cross-section at that wavelength.  
When the atmosphere absorbs at a specific wavelength, the planet appears larger → deeper transit → higher `d[λ]`.

Plotting `d[λ]` vs wavelength gives the planet's **transmission spectrum** — the target the model must learn to predict.

In [None]:
# Binned OOT mask (same as computed in Step 4)
mask_binned = bin_time(mask_oot.astype(np.float32), BIN_SIZE) > 0.5

# Extract transit depth for every wavelength channel
depth, depth_err = extract_transit_depth(airs_binned, mask_binned)

n_chan = airs_binned.shape[1]
chan_axis = np.arange(n_chan)

print(f"Wavelength channels : {n_chan}")
print(f"depth shape         : {depth.shape}")
print(f"depth_err shape     : {depth_err.shape}")
print(f"Median transit depth: {np.median(depth):.5f}")
print(f"Min / Max depth     : {depth.min():.5f} / {depth.max():.5f}")
print(f"Median 1σ error     : {np.median(depth_err):.5f}")

fig, axes = plt.subplots(2, 1, figsize=(13, 8))
fig.suptitle(
    f"Transit Depth Spectrum — planet '{planet_id}'",
    fontsize=13, fontweight="bold"
)

# ── Top: spectrum with error bars ─────────────────────────────────────────
ax0 = axes[0]
ax0.errorbar(
    chan_axis, depth, yerr=depth_err,
    fmt="o", markersize=2, lw=0.5, capsize=1.5,
    color="steelblue", ecolor="lightsteelblue", alpha=0.85,
    label="Transit depth ± 1σ"
)
ax0.axhline(np.median(depth), color="darkorange", lw=1.2, linestyle="--",
            label=f"Median depth = {np.median(depth):.5f}")
ax0.set_xlabel("Wavelength channel index", fontsize=10)
ax0.set_ylabel("Transit depth  d[λ] = 1 − ⟨F_in⟩", fontsize=10)
ax0.set_title("Per-channel transit depth (this is the target the model must predict)", fontsize=10)
ax0.legend(fontsize=9)

# ── Bottom: SNR per channel = depth / depth_err ───────────────────────────
ax1 = axes[1]
snr = np.where(depth_err > 0, depth / depth_err, 0.0)
ax1.plot(chan_axis, snr, lw=0.8, color="darkorchid", alpha=0.85)
ax1.axhline(np.median(snr), color="darkorange", lw=1.2, linestyle="--",
            label=f"Median SNR = {np.median(snr):.1f}")
ax1.set_xlabel("Wavelength channel index", fontsize=10)
ax1.set_ylabel("SNR = depth / depth_err", fontsize=10)
ax1.set_title("Signal-to-noise ratio per wavelength channel", fontsize=10)
ax1.legend(fontsize=9)

plt.tight_layout()
fig.savefig(FIG_DIR / "step5_transit_depth_spectrum.png", bbox_inches="tight")
plt.show()

print(f"[Done] Transit depth spectrum extracted: {n_chan} channels, median depth={np.median(depth):.5f}, "
      f"median SNR={np.median(snr):.1f}.")

## Full Pipeline: `preprocess_planet()`

All five steps above are wrapped into the single function `preprocess_planet()`. Running it should produce results identical to the step-by-step walkthrough above.

In [None]:
# Run the one-liner full pipeline
result = preprocess_planet(
    airs,
    fgs1,
    ingress=INGRESS,
    egress=EGRESS,
    bin_size=BIN_SIZE,
)

# ── Inspect output dict ────────────────────────────────────────────────────
print("preprocess_planet() output keys and shapes:")
print("-" * 50)
for key, val in result.items():
    arr = np.asarray(val)
    print(f"  {key:<22} shape={arr.shape}  dtype={arr.dtype}")

# ── Verify consistency with step-by-step ──────────────────────────────────
print("\nConsistency checks (pipeline result vs step-by-step):")

depth_pipeline = result["transit_depth"]
max_diff_depth = np.max(np.abs(depth_pipeline - depth))
print(f"  max |depth_pipeline - depth_stepwise| = {max_diff_depth:.2e}  "
      f"({'PASS' if max_diff_depth < 1e-10 else 'MISMATCH — investigate!'})")

airs_norm_pipeline = result["airs_norm"]
max_diff_airs = np.max(np.abs(airs_norm_pipeline - airs_binned))
print(f"  max |airs_norm_pipeline - airs_binned_stepwise| = {max_diff_airs:.2e}  "
      f"({'PASS' if max_diff_airs < 1e-10 else 'MISMATCH — investigate!'})")

mask_pipeline = result["mask_oot"]
mask_match = np.array_equal(mask_pipeline, mask_binned)
print(f"  mask_oot identical: {mask_match}  "
      f"({'PASS' if mask_match else 'MISMATCH — investigate!'})")

print(f"\n[Done] Full pipeline result: {len(result)} keys, "
      f"AIRS shape={result['airs_norm'].shape}, FGS1 shape={result['fgs1_norm'].shape}.")

## Pipeline Output Visualisation

Final combined plot: normalised + binned AIRS white-light curve and FGS1 curve, alongside the extracted transit depth spectrum.

In [None]:
airs_out    = result["airs_norm"]          # (n_time_bin, n_chan)
fgs1_out    = result["fgs1_norm"]          # (n_time_bin,)
depth_out   = result["transit_depth"]      # (n_chan,)
err_out     = result["transit_depth_err"]  # (n_chan,)
mask_out    = result["mask_oot"]           # (n_time_bin,)

n_time_bin  = airs_out.shape[0]
n_chan       = airs_out.shape[1]
t_bin        = np.arange(n_time_bin)
chan_ax      = np.arange(n_chan)

airs_white_out = airs_out.mean(axis=1)

fig, axes = plt.subplots(3, 1, figsize=(13, 11))
fig.suptitle(
    f"Full Preprocessing Pipeline Output — planet '{planet_id}'",
    fontsize=13, fontweight="bold"
)

# ── Binned AIRS white-light ────────────────────────────────────────────────
ax0 = axes[0]
ax0.plot(t_bin, airs_white_out, lw=1.0, color="steelblue", marker="o", markersize=2.5,
         label="AIRS white-light (binned, CMC-corrected)")
ax0.axhline(1.0, color="darkgreen", lw=1.0, linestyle="--", alpha=0.7)
ax0.fill_between(t_bin, airs_white_out.min() * 0.9998, 1.0002,
                 where=~mask_out, color="tomato", alpha=0.12, label="In-transit")
ax0.set_ylabel("Normalised flux", fontsize=10)
ax0.set_title(f"AIRS-CH0 white-light — {n_time_bin} binned frames (bin_size={BIN_SIZE})", fontsize=10)
ax0.legend(fontsize=8)

# ── Binned FGS1 ───────────────────────────────────────────────────────────
ax1 = axes[1]
ax1.plot(t_bin, fgs1_out, lw=1.0, color="darkorange", marker="o", markersize=2.5,
         label="FGS1 (binned, baseline-normalised)")
ax1.axhline(1.0, color="darkgreen", lw=1.0, linestyle="--", alpha=0.7)
ax1.fill_between(t_bin, fgs1_out.min() * 0.9998, 1.0002,
                 where=~mask_out, color="tomato", alpha=0.12, label="In-transit")
ax1.set_ylabel("Normalised flux", fontsize=10)
ax1.set_title(f"FGS1 broadband — {n_time_bin} binned frames (bin_size={BIN_SIZE})", fontsize=10)
ax1.legend(fontsize=8)

# ── Transit depth spectrum ─────────────────────────────────────────────────
ax2 = axes[2]
ax2.errorbar(
    chan_ax, depth_out, yerr=err_out,
    fmt="o", markersize=2, lw=0.5, capsize=1.5,
    color="steelblue", ecolor="lightsteelblue", alpha=0.85,
    label="Transit depth ± 1σ  (MODEL TARGET)"
)
ax2.axhline(np.median(depth_out), color="darkorange", lw=1.2, linestyle="--",
            label=f"Median = {np.median(depth_out):.5f}")
ax2.set_xlabel("Wavelength channel index  (0 = 1.95 µm, 355 = 3.90 µm)", fontsize=10)
ax2.set_ylabel("Transit depth  1 − ⟨F_in⟩", fontsize=10)
ax2.set_title("Extracted transit depth spectrum — this is what the model learns to predict",
              fontsize=10)
ax2.legend(fontsize=8)

plt.tight_layout()
fig.savefig(FIG_DIR / "step6_full_pipeline_output.png", bbox_inches="tight")
plt.show()

print(f"[Done] Full pipeline visualisation complete for planet '{planet_id}'.")

## Summary

### What each preprocessing step achieves and why it matters

- **Step 0 — Detector calibration (raw parquet → calibrated numpy)**
  - Loads raw uint16 detector frames from per-planet parquet files.
  - **Dark subtraction**: removes thermal electron current accumulated during readout.
  - **Flat-field division**: corrects for per-pixel sensitivity variations across the detector.
  - **Dead pixel zeroing**: masks hot/dead pixels to prevent them from corrupting the signal.
  - **Spatial summation**: collapses the 2-D detector image (32 spatial rows for AIRS, 32x32 pixels for FGS1) into 1-D light curves per spectral channel.
  - **FGS1 downsampling**: averages every 12 FGS1 frames to match the AIRS cadence (135,000 → 11,250 frames).
  - This "Step 0" converts ~180 GB of raw detector data into clean `(n_time, 356)` AIRS and `(n_time,)` FGS1 arrays ready for signal processing.

- **Step 1 — Out-of-transit (OOT) mask**
  - Separates frames where the planet is in front of the star from those where it is not.
  - The OOT frames serve as the *reference baseline* for all subsequent normalisation steps.
  - Without this mask, in-transit frames would contaminate the baseline and bias the measured depths.

- **Step 2 — Baseline normalisation**
  - Divides each wavelength channel by its OOT median, making every channel dimensionless and centred near 1.0.
  - Removes channel-to-channel gain differences (quantum efficiency variations, detector non-uniformity).
  - After this step, in-transit flux dipping below 1.0 directly represents the atmospheric absorption signal.

- **Step 3 — Common-mode correction**
  - Removes correlated systematics shared across all wavelengths simultaneously: telescope jitter, PSF breathing, stellar variability.
  - Divides each time-step by the wavelength-mean of that step (the "common mode").
  - **Critical detail**: in-transit common-mode values are clamped to the OOT mean before dividing. This prevents accidental cancellation of the transit signal (which dips in-transit) against itself.
  - After this step, only wavelength-specific signals — atmospheric absorption features — remain.

- **Step 4 — Temporal binning**
  - Co-adds every `bin_size` consecutive frames, reducing the time axis by a factor of `bin_size`.
  - Reduces photon noise by sqrt(`bin_size`) (approx 2.24x for `bin_size=5`) at the cost of temporal resolution.
  - Lowers the feature dimensionality for downstream ML models and speeds up training.

- **Step 5 — Transit depth extraction**
  - Computes `d[lam] = 1 - mean(F_in-transit[lam])` for each wavelength channel independently.
  - Produces the **transmission spectrum**: the planet's apparent radius as a function of wavelength.
  - This 356-element vector (after binning to 283 competition targets) is what the ML model must predict.
  - The 1-sigma uncertainty (standard error of the in-transit mean) quantifies per-channel measurement noise.

### Why this pipeline matters for the model

- The raw AIRS parquet contains `(~11,250 time steps x 32 spatial rows x 356 channels)` — over 128 million numbers per planet.
- After calibration, spatial summation, and preprocessing, the transmission spectrum is just **356 numbers** — a massive compression that retains the scientifically relevant signal.
- The competition metric (Gaussian log-likelihood on predicted mean and std per wavelength) requires well-calibrated uncertainty estimates, making the `depth_err` output directly useful.
- Baseline and common-mode corrections are standard photometric reduction steps used in real transit spectroscopy (e.g., Hubble WFC3, JWST NIRSpec pipelines).
- Auxiliary information is provided as 5 ADC features per planet (ADC gain/offset for AIRS and FGS1, plus star type) from `train_adc_info.csv`.

## Push Preprocessing Figures to GitHub

Push the 6 step-by-step preprocessing plots to the repo so they can be reviewed without re-running the notebook.

In [None]:
import shutil
import subprocess
from pathlib import Path

# ── GitHub token for pushing (paste your PAT here) ────────────────────────
# Generate at: https://github.com/settings/tokens → Fine-grained → repo:write
GH_TOKEN = ""  # <-- paste your GitHub PAT here

# ── Repo paths ────────────────────────────────────────────────────────────
repo_dir    = Path("/kaggle/working/ariel-exoplanet-ml")
project_dir = repo_dir / "Kaggle competition" / "ARIEL neurIPS"

# ── Ensure repo is up-to-date ─────────────────────────────────────────────
if not repo_dir.exists():
    subprocess.run(
        ["git", "clone", "https://github.com/Smooth-Cactus0/ariel-exoplanet-ml.git",
         str(repo_dir)],
        check=True,
    )
else:
    subprocess.run(["git", "-C", str(repo_dir), "pull", "--ff-only"], check=False)

# ── Configure git identity (required on Kaggle kernels) ───────────────────
subprocess.run(["git", "-C", str(repo_dir), "config", "user.email", "alexy.louis@kaggle-notebook.local"], check=True)
subprocess.run(["git", "-C", str(repo_dir), "config", "user.name", "Alexy Louis (Kaggle)"], check=True)

# ── Copy preprocessing figures to repo ─────────────────────────────────────
repo_fig_dir = project_dir / "figures"
repo_fig_dir.mkdir(parents=True, exist_ok=True)

figure_files = sorted(FIG_DIR.glob("*.png"))
print(f"Found {len(figure_files)} figures in {FIG_DIR}:")
for fig_path in figure_files:
    dest = repo_fig_dir / fig_path.name
    shutil.copy2(fig_path, dest)
    print(f"  {fig_path.name} -> figures/{fig_path.name}")

# ── Git add, commit, push ─────────────────────────────────────────────────
subprocess.run(
    ["git", "-C", str(repo_dir), "add",
     "Kaggle competition/ARIEL neurIPS/figures/"],
    check=True,
)

status = subprocess.run(
    ["git", "-C", str(repo_dir), "diff", "--cached", "--quiet"],
    capture_output=True,
)
if status.returncode != 0:
    subprocess.run(
        ["git", "-C", str(repo_dir), "commit", "-m",
         "data: update preprocessing figures from Kaggle notebook run"],
        check=True,
    )
    # Set authenticated remote URL for push
    if GH_TOKEN:
        subprocess.run(
            ["git", "-C", str(repo_dir), "remote", "set-url", "origin",
             f"https://{GH_TOKEN}@github.com/Smooth-Cactus0/ariel-exoplanet-ml.git"],
            check=True,
        )
    subprocess.run(
        ["git", "-C", str(repo_dir), "push", "origin", "master"],
        check=True,
    )
    print("\n[Done] Preprocessing figures pushed to GitHub.")
else:
    print("\n[Done] No changes to push (figures already up-to-date).")