# NeurIPS 2024 Ariel Data Challenge — Exploratory Data Analysis

**Goal**: Understand the structure, distributions, and quality of the Ariel exoplanet atmospheric spectra dataset.  
**Dataset**: ~180 GB of simulated telescope photometry from AIRS-CH0 (IR spectrometer) and FGS1 (visible photometer).  
**Task**: Extract exoplanet atmospheric transmission spectra (283 wavelength channels) from transit light curves.  
**Scoring**: Gaussian Log-Likelihood over predicted mean and std per wavelength.

**Data format**: Nested directories with parquet files per planet, plus CSV/parquet metadata.  
Detector geometry: AIRS = 32 spatial rows x 356 spectral channels, FGS1 = 32x32, FGS1 runs at 12x AIRS cadence.

> **Note**: This notebook is Kaggle-ready and requires the `ariel-data-challenge-2024` dataset attached to the kernel.

## 1. Setup

In [None]:
# Install any missing packages
import subprocess, sys

def install_if_missing(package):
    try:
        __import__(package)
        print(f"{package}: already installed")
    except ImportError:
        print(f"{package}: not found — installing...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package, "-q"])
        print(f"{package}: installed successfully")

install_if_missing("pyarrow")
install_if_missing("seaborn")

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import pyarrow.parquet as pq

# ── Data root ──────────────────────────────────────────────────────────────
DATA_ROOT = Path("/kaggle/input/ariel-data-challenge-2024")

# ── Figure export directory ────────────────────────────────────────────────
# Figures are saved here so they can be displayed in the GitHub README.
# On Kaggle, /kaggle/working/ is the persistent output directory.
FIG_DIR = Path("/kaggle/working/figures")
FIG_DIR.mkdir(parents=True, exist_ok=True)
print(f"Figures will be saved to: {FIG_DIR}")

# ── Plot style ─────────────────────────────────────────────────────────────
plt.rcParams.update({
    "figure.dpi": 120,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "savefig.dpi": 150,
    "savefig.bbox": "tight",
    "savefig.facecolor": "white",
})
sns.set_theme(style="whitegrid", palette="muted")

# ── Version report ─────────────────────────────────────────────────────────
print(f"Python      : {sys.version}")
print(f"NumPy       : {np.__version__}")
print(f"Pandas      : {pd.__version__}")
print(f"Matplotlib  : {matplotlib.__version__}")
print(f"Seaborn     : {sns.__version__}")
print(f"PyArrow     : {pq.lib.version()}")
print(f"\nDATA_ROOT   : {DATA_ROOT}")
print(f"Exists      : {DATA_ROOT.exists()}")

## 2. File Tree

In [None]:
# Walk the data root and report file sizes
total_bytes = 0
file_records = []

for dirpath, dirnames, filenames in os.walk(DATA_ROOT):
    # Skip hidden directories
    dirnames[:] = [d for d in dirnames if not d.startswith(".")]
    for fname in sorted(filenames):
        fpath = Path(dirpath) / fname
        try:
            size_bytes = fpath.stat().st_size
        except OSError:
            size_bytes = 0
        total_bytes += size_bytes
        rel_path = fpath.relative_to(DATA_ROOT)
        file_records.append({
            "path": str(rel_path),
            "size_MB": round(size_bytes / 1_048_576, 2),
        })

df_files = pd.DataFrame(file_records).sort_values("size_MB", ascending=False).reset_index(drop=True)
print(f"Total files : {len(df_files)}")
print(f"Total size  : {total_bytes / 1_073_741_824:.2f} GB\n")
print(df_files.head(30).to_string(index=False))

print(f"\n[Summary] Found {len(df_files)} files totalling {total_bytes / 1_073_741_824:.2f} GB on disk.")

## 3. CSV / Metadata Inspection

In [None]:
# ── train_adc_info.csv ─────────────────────────────────────────────────────
# Columns: planet_id | FGS1_adc_offset | FGS1_adc_gain | AIRS-CH0_adc_offset | AIRS-CH0_adc_gain | star
adc_path = DATA_ROOT / "train_adc_info.csv"
df_adc = pd.read_csv(adc_path)

print("=" * 60)
print("train_adc_info.csv")
print("=" * 60)
print(f"Shape: {df_adc.shape}  ({df_adc.shape[0]} planets, {df_adc.shape[1]} columns)")
print(f"Columns: {df_adc.columns.tolist()}")
print("\n--- Head ---")
display(df_adc.head())

print("\n--- Data types ---")
print(df_adc.dtypes.to_string())

print("\n--- Missing value counts ---")
missing = df_adc.isnull().sum()
print(missing[missing > 0].to_string() if missing.any() else "No missing values detected.")

print("\n--- Descriptive statistics ---")
display(df_adc.describe())

print(f"\n[Summary] train_adc_info.csv has {df_adc.shape[0]} planets and {df_adc.shape[1]} columns "
      f"(5 ADC features + planet_id) with {df_adc.isnull().sum().sum()} total missing values.")

In [None]:
# ── train_labels.csv ───────────────────────────────────────────────────────
# Columns: planet_id | wl_1 | wl_2 | ... | wl_283 (means only, no quartile/sigma columns)
labels_path = DATA_ROOT / "train_labels.csv"
df_labels = pd.read_csv(labels_path)

print("=" * 60)
print("train_labels.csv")
print("=" * 60)
print(f"Shape: {df_labels.shape}")
print(f"\n--- First 6 columns ---")
display(df_labels.iloc[:, :6].head(10))
print(f"\n--- Column name pattern ---")
print(df_labels.columns[:10].tolist(), "...")

# Labelled vs unlabelled planets
labelled_ids = set(df_labels.iloc[:, 0].unique()) if df_labels.shape[0] > 0 else set()
all_ids = set(df_adc.iloc[:, 0].unique())
n_labelled = len(labelled_ids)
n_total = len(all_ids)
n_unlabelled = n_total - n_labelled

print("\n--- Labelled vs Unlabelled ---")
print(f"Total planets (train_adc_info)  : {n_total:>6}")
print(f"Labelled planets (train_labels) : {n_labelled:>6}  ({100 * n_labelled / n_total:.1f}%)")
print(f"Unlabelled planets              : {n_unlabelled:>6}  ({100 * n_unlabelled / n_total:.1f}%)")

print(f"\n[Summary] {n_labelled} of {n_total} planets ({100 * n_labelled / n_total:.1f}%) have "
      f"labels (wl_1...wl_283 means); {n_unlabelled} are unlabelled (test set).")

In [None]:
# ── wavelengths.csv ────────────────────────────────────────────────────────
# (1, 283) row: wl_1=0.705 um (FGS1), wl_2-wl_283 = 282 AIRS spectral bins
wl_path = DATA_ROOT / "wavelengths.csv"
df_wl = pd.read_csv(wl_path)

print("=" * 60)
print("wavelengths.csv")
print("=" * 60)
print(f"Shape: {df_wl.shape}")
display(df_wl)

wl_values = df_wl.values.flatten()
print(f"\nWavelength range: {wl_values.min():.4f} um to {wl_values.max():.4f} um")
print(f"FGS1 channel (wl_1): {wl_values[0]:.4f} um")
print(f"AIRS channels (wl_2 - wl_283): {wl_values[1]:.4f} um to {wl_values[-1]:.4f} um")

print(f"\n[Summary] 283 wavelength channels: 1 FGS1 ({wl_values[0]:.3f} um) + "
      f"282 AIRS ({wl_values[1]:.3f} - {wl_values[-1]:.3f} um).")

In [None]:
# ── axis_info.parquet ──────────────────────────────────────────────────────
# (135000, 4) time/wavelength metadata
axis_path = DATA_ROOT / "axis_info.parquet"
if axis_path.exists():
    df_axis = pd.read_parquet(axis_path)
    print("=" * 60)
    print("axis_info.parquet")
    print("=" * 60)
    print(f"Shape: {df_axis.shape}")
    print(f"Columns: {df_axis.columns.tolist()}")
    print(f"Dtypes:\n{df_axis.dtypes.to_string()}")
    print("\n--- Head ---")
    display(df_axis.head(10))
    print("\n--- Descriptive statistics ---")
    display(df_axis.describe())
    print(f"\n[Summary] axis_info.parquet has {df_axis.shape[0]} rows and {df_axis.shape[1]} columns.")
else:
    print("axis_info.parquet not found at expected location.")
    df_axis = None

## 4. Data Directory Structure

Each planet has its own directory containing parquet files for signal data and calibration frames:
```
{data_root}/{split}/{planet_id}/
    AIRS-CH0_signal.parquet      (11250, 32*356) uint16
    FGS1_signal.parquet          (135000, 32*32) uint16
    AIRS-CH0_calibration/
        dark.parquet  flat.parquet  dead.parquet  read.parquet  linear_corr.parquet
    FGS1_calibration/
        dark.parquet  flat.parquet  dead.parquet  read.parquet  linear_corr.parquet
```

In [None]:
# ── List planet directories ────────────────────────────────────────────────
# Planets live under {DATA_ROOT}/train/ and {DATA_ROOT}/test/

for split in ["train", "test"]:
    split_dir = DATA_ROOT / split
    if not split_dir.exists():
        print(f"{split}/ directory not found.")
        continue
    planet_dirs = sorted([d.name for d in split_dir.iterdir() if d.is_dir()])
    print(f"{split}/ : {len(planet_dirs)} planet directories")
    print(f"  First 5: {planet_dirs[:5]}")
    print(f"  Last  5: {planet_dirs[-5:]}")
    print()

# ── Show contents of one example planet directory ──────────────────────────
train_dir = DATA_ROOT / "train"
if train_dir.exists():
    planet_dirs_all = sorted([d for d in train_dir.iterdir() if d.is_dir()])
    example_planet_dir = planet_dirs_all[0]
    example_planet_id = example_planet_dir.name

    print(f"Contents of example planet directory: {example_planet_id}/")
    print("-" * 60)
    for item in sorted(example_planet_dir.rglob("*")):
        rel = item.relative_to(example_planet_dir)
        if item.is_file():
            size_kb = item.stat().st_size / 1024
            print(f"  {rel}  ({size_kb:.1f} KB)")
        elif item.is_dir():
            print(f"  {rel}/")

    print(f"\n[Summary] Example planet '{example_planet_id}' directory listed above. "
          f"Total planets in train/: {len(planet_dirs_all)}.")
else:
    example_planet_id = None
    example_planet_dir = None
    print("WARNING: train/ directory not found. Adjust DATA_ROOT.")

## 5. Single-Planet Light-Curve Plot

Load one planet's signal and calibration parquet files, apply calibration (dark subtract, flat field, dead pixel mask), sum over spatial rows, and plot the resulting light curves.

In [None]:
# ── Calibration helper ─────────────────────────────────────────────────────

def calibrate(raw, dark, flat, dead):
    """Dark subtract, flat field, zero dead pixels."""
    cal = raw - dark[None]
    flat_safe = np.where(flat == 0, 1.0, flat)
    cal /= flat_safe[None]
    cal[:, dead.astype(bool)] = 0.0
    return cal

print("calibrate() helper defined.")

In [None]:
# ── Load and calibrate one planet ──────────────────────────────────────────
# AIRS-CH0: raw shape (n_time, 32*356) uint16 → reshape to (n_time, 32, 356)
#   calibration frames: dark (32, 356), flat (32, 356), dead (32, 356)
#   after calibration, sum spatial rows → (n_time, 356)
# FGS1: raw shape (n_time_fgs, 32*32) uint16 → reshape to (n_time_fgs, 32, 32)
#   same calibration, sum spatial → (n_time_fgs,), then downsample 12:1

airs_data_calibrated = None  # will be (n_time, 356) after calibration
fgs1_data_calibrated = None  # will be (n_time_fgs,) after calibration
fgs1_downsampled = None      # will be (n_time,) after 12:1 downsample

if example_planet_dir is not None:
    # ── AIRS-CH0 ───────────────────────────────────────────────────────────
    airs_signal_path = example_planet_dir / "AIRS-CH0_signal.parquet"
    airs_cal_dir = example_planet_dir / "AIRS-CH0_calibration"

    airs_raw = pd.read_parquet(airs_signal_path).values.astype(np.float64)
    n_time_airs = airs_raw.shape[0]
    airs_raw = airs_raw.reshape(n_time_airs, 32, 356)  # (n_time, 32_spatial, 356_spectral)

    airs_dark = pd.read_parquet(airs_cal_dir / "dark.parquet").values.astype(np.float64).reshape(32, 356)
    airs_flat = pd.read_parquet(airs_cal_dir / "flat.parquet").values.astype(np.float64).reshape(32, 356)
    airs_dead = pd.read_parquet(airs_cal_dir / "dead.parquet").values.astype(np.float64).reshape(32, 356)

    # Reshape raw for calibration: (n_time, 32*356) to pass to calibrate
    airs_raw_flat = airs_raw.reshape(n_time_airs, 32 * 356)
    airs_cal_flat = calibrate(
        airs_raw_flat,
        airs_dark.ravel(),
        airs_flat.ravel(),
        airs_dead.ravel()
    )
    # Reshape back and sum over spatial rows
    airs_cal_3d = airs_cal_flat.reshape(n_time_airs, 32, 356)
    airs_data_calibrated = airs_cal_3d.sum(axis=1)  # (n_time, 356)

    print(f"AIRS-CH0 signal raw shape   : ({n_time_airs}, 32, 356)")
    print(f"AIRS-CH0 calibrated + summed: {airs_data_calibrated.shape}")
    print(f"  value range: [{airs_data_calibrated.min():.2f}, {airs_data_calibrated.max():.2f}]")

    # ── FGS1 ───────────────────────────────────────────────────────────────
    fgs1_signal_path = example_planet_dir / "FGS1_signal.parquet"
    fgs1_cal_dir = example_planet_dir / "FGS1_calibration"

    fgs1_raw = pd.read_parquet(fgs1_signal_path).values.astype(np.float64)
    n_time_fgs = fgs1_raw.shape[0]
    fgs1_raw = fgs1_raw.reshape(n_time_fgs, 32, 32)  # (n_time_fgs, 32, 32)

    fgs1_dark = pd.read_parquet(fgs1_cal_dir / "dark.parquet").values.astype(np.float64).reshape(32, 32)
    fgs1_flat = pd.read_parquet(fgs1_cal_dir / "flat.parquet").values.astype(np.float64).reshape(32, 32)
    fgs1_dead = pd.read_parquet(fgs1_cal_dir / "dead.parquet").values.astype(np.float64).reshape(32, 32)

    fgs1_raw_flat = fgs1_raw.reshape(n_time_fgs, 32 * 32)
    fgs1_cal_flat = calibrate(
        fgs1_raw_flat,
        fgs1_dark.ravel(),
        fgs1_flat.ravel(),
        fgs1_dead.ravel()
    )
    # Sum over all spatial pixels → (n_time_fgs,)
    fgs1_data_calibrated = fgs1_cal_flat.reshape(n_time_fgs, 32, 32).sum(axis=(1, 2))

    # Downsample FGS1 by factor of 12 to match AIRS cadence
    n_downsample = n_time_fgs // 12
    fgs1_downsampled = fgs1_data_calibrated[:n_downsample * 12].reshape(n_downsample, 12).mean(axis=1)

    print(f"\nFGS1 signal raw shape       : ({n_time_fgs}, 32, 32)")
    print(f"FGS1 calibrated + summed    : {fgs1_data_calibrated.shape}")
    print(f"FGS1 downsampled (12:1)     : {fgs1_downsampled.shape}")
    print(f"  value range: [{fgs1_downsampled.min():.2f}, {fgs1_downsampled.max():.2f}]")

    print(f"\n[Summary] Loaded and calibrated planet '{example_planet_id}': "
          f"AIRS ({n_time_airs}, 356), FGS1 ({n_time_fgs},) → downsampled to ({n_downsample},).")
else:
    print("Skipping planet loading — no example planet directory found.")

In [None]:
# ── Plot light curves ──────────────────────────────────────────────────────
# AIRS-CH0: white-light curve = mean across 356 spectral channels vs time
# FGS1: downsampled broadband light curve

planet_label = example_planet_id if example_planet_id is not None else "Unknown"

fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=False)
fig.suptitle(f"Planet {planet_label} — calibrated light curves", fontsize=14, fontweight="bold")

# ── Top subplot: AIRS white-light curve ────────────────────────────────────
ax0 = axes[0]
if airs_data_calibrated is not None:
    white_light = airs_data_calibrated.mean(axis=1)  # mean over 356 spectral channels → (n_time,)
    t_airs = np.arange(len(white_light))
    ax0.plot(t_airs, white_light, lw=0.8, color="steelblue", label="White-light flux")
    ax0.set_xlabel("Time index")
    ax0.set_ylabel("Flux (calibrated)")
    ax0.set_title(f"AIRS-CH0 — white-light curve ({len(white_light)} time steps)")
    ax0.legend()

    # Annotate transit depth
    depth = white_light.max() - white_light.min()
    ax0.annotate(
        f"delta flux = {depth:.2f}",
        xy=(t_airs[len(t_airs) // 2], white_light.min()),
        xytext=(t_airs[len(t_airs) // 2], white_light.min() + depth * 0.3),
        arrowprops=dict(arrowstyle="->", color="gray"),
        fontsize=9, color="gray"
    )
else:
    ax0.text(0.5, 0.5, "AIRS-CH0 data not loaded", ha="center", va="center",
             transform=ax0.transAxes, fontsize=12, color="red")
    ax0.set_title("AIRS-CH0 — white-light curve (data unavailable)")

# ── Bottom subplot: FGS1 light curve ──────────────────────────────────────
ax1 = axes[1]
if fgs1_downsampled is not None:
    t_fgs = np.arange(len(fgs1_downsampled))
    ax1.plot(t_fgs, fgs1_downsampled, lw=0.8, color="darkorange", label="FGS1 flux (12:1 downsampled)")
    ax1.set_xlabel("Time index (AIRS cadence)")
    ax1.set_ylabel("Flux (calibrated)")
    ax1.set_title(f"FGS1 — broadband light curve ({len(fgs1_downsampled)} time steps after 12:1 downsample)")
    ax1.legend()
else:
    ax1.text(0.5, 0.5, "FGS1 data not loaded", ha="center", va="center",
             transform=ax1.transAxes, fontsize=12, color="red")
    ax1.set_title("FGS1 — broadband light curve (data unavailable)")

plt.tight_layout()
fig.savefig(FIG_DIR / "light_curves.png")
print(f"Saved: {FIG_DIR / 'light_curves.png'}")
plt.show()

print(f"[Summary] Light curves plotted for planet '{planet_label}'. "
      f"AIRS white-light: {len(white_light) if airs_data_calibrated is not None else 'N/A'} time steps; "
      f"FGS1 downsampled: {len(fgs1_downsampled) if fgs1_downsampled is not None else 'N/A'} time steps.")

## 6. AIRS 2-D Heatmap (Time x Wavelength)

In [None]:
# Plot the calibrated 2-D flux matrix (time x 356 spectral channels) for the example planet

fig, ax = plt.subplots(figsize=(14, 5))

if airs_data_calibrated is not None:
    flux_2d = airs_data_calibrated  # (n_time, 356)

    # Clip to [1st, 99th] percentile for better contrast
    vmin = np.percentile(flux_2d, 1)
    vmax = np.percentile(flux_2d, 99)

    im = ax.imshow(
        flux_2d.T,  # plot as (wavelength, time) so wavelength is on y-axis
        aspect="auto",
        origin="lower",
        vmin=vmin,
        vmax=vmax,
        cmap="RdYlBu_r",
        interpolation="nearest",
    )
    cbar = plt.colorbar(im, ax=ax, pad=0.02)
    cbar.set_label("Flux (calibrated)", fontsize=10)

    ax.set_xlabel("Time index", fontsize=11)
    ax.set_ylabel("Spectral channel (0-355)", fontsize=11)
    ax.set_title(
        f"Planet {planet_label} — AIRS-CH0 calibrated flux "
        f"(shape: {flux_2d.shape[0]} time x {flux_2d.shape[1]} channels)  "
        f"[clipped to p1={vmin:.1f}, p99={vmax:.1f}]",
        fontsize=11
    )

    print(f"[Summary] AIRS calibrated flux matrix shape: {flux_2d.shape} (time x spectral). "
          f"Value range after percentile clip: [{vmin:.1f}, {vmax:.1f}].")
else:
    ax.text(0.5, 0.5, "AIRS-CH0 data not loaded — cannot render heatmap",
            ha="center", va="center", transform=ax.transAxes, fontsize=12, color="red")
    ax.set_title("AIRS-CH0 flux heatmap (data unavailable)")
    print("[Summary] Skipped heatmap — AIRS data not available.")

plt.tight_layout()
fig.savefig(FIG_DIR / "airs_heatmap.png")
print(f"Saved: {FIG_DIR / 'airs_heatmap.png'}")
plt.show()

## 7. ADC Feature Distributions

In [None]:
# Plot histograms for all ADC feature columns from train_adc_info.csv
# Columns: FGS1_adc_offset, FGS1_adc_gain, AIRS-CH0_adc_offset, AIRS-CH0_adc_gain, star

id_col = df_adc.columns[0]
feature_cols = [c for c in df_adc.columns if c != id_col and pd.api.types.is_numeric_dtype(df_adc[c])]

print(f"ADC feature columns ({len(feature_cols)}): {feature_cols}")

n_features = len(feature_cols)
ncols_grid = min(n_features, 3)
nrows = (n_features + ncols_grid - 1) // ncols_grid
fig, axes = plt.subplots(nrows, ncols_grid, figsize=(5 * ncols_grid, 4 * nrows))
if n_features == 1:
    axes_flat = [axes]
else:
    axes_flat = axes.flatten() if hasattr(axes, 'flatten') else [axes]

for idx, col in enumerate(feature_cols):
    ax = axes_flat[idx]
    vals = df_adc[col].dropna()

    # Determine if log scale is appropriate
    use_log = False
    if vals.min() > 0:
        ratio = vals.max() / vals.min()
        if ratio > 100:
            use_log = True

    if use_log:
        ax.hist(vals, bins=50, color="steelblue", edgecolor="white", linewidth=0.3)
        ax.set_xscale("log")
        scale_note = "(log x-axis)"
    else:
        ax.hist(vals, bins=50, color="steelblue", edgecolor="white", linewidth=0.3)
        scale_note = ""

    ax.set_title(f"{col} {scale_note}", fontsize=10, fontweight="bold")
    ax.set_xlabel(col, fontsize=9)
    ax.set_ylabel("Count", fontsize=9)
    ax.tick_params(labelsize=8)

    # Annotate with median
    median_val = vals.median()
    ax.axvline(median_val, color="red", linestyle="--", lw=1.2, alpha=0.8)
    ax.text(0.97, 0.95, f"median={median_val:.3g}",
            transform=ax.transAxes, ha="right", va="top", fontsize=8, color="red")

# Hide unused subplots
for idx in range(n_features, len(axes_flat)):
    axes_flat[idx].set_visible(False)

fig.suptitle("ADC Feature Distributions (train_adc_info.csv)", fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout()
fig.savefig(FIG_DIR / "adc_distributions.png")
print(f"Saved: {FIG_DIR / 'adc_distributions.png'}")
plt.show()

# Print summary stats
print("\nSummary statistics for ADC features:")
display(df_adc[feature_cols].describe().round(4))

print(f"\n[Summary] Plotted {len(feature_cols)} ADC features. "
      f"Log x-scale applied to columns spanning >2 orders of magnitude.")

## 8. Transit Depth Distribution (Labelled Planets)

In [None]:
# For labelled planets, compute median transit depth across 283 wavelength channels
# train_labels.csv has columns: planet_id, wl_1, wl_2, ..., wl_283 (means only, no quartiles)

print("train_labels.csv column names (first 10):")
print(df_labels.columns[:10].tolist())
print(f"\ntrain_labels.csv column names (last 5):")
print(df_labels.columns[-5:].tolist())

# Identify wavelength columns (wl_1 ... wl_283)
wl_cols = [c for c in df_labels.columns if c.startswith("wl_")]
print(f"\nFound {len(wl_cols)} wavelength columns.")

# Compute median transit depth per planet (median across 283 wavelengths)
wl_values_labels = df_labels[wl_cols].values.astype(float)
median_depth_per_planet = np.nanmedian(wl_values_labels, axis=1)

overall_median = np.nanmedian(median_depth_per_planet)
overall_std = np.nanstd(median_depth_per_planet)

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(median_depth_per_planet, bins=60, color="steelblue", edgecolor="white",
        linewidth=0.3, alpha=0.85, label="Median transit depth")

ax.axvline(overall_median, color="red", linestyle="--", lw=2, label=f"Median = {overall_median:.4f}")
ax.axvline(overall_median - overall_std, color="orange", linestyle=":", lw=1.5,
           label=f"+-1 std = {overall_std:.4f}")
ax.axvline(overall_median + overall_std, color="orange", linestyle=":", lw=1.5)

# Shaded region
ax.axvspan(overall_median - overall_std, overall_median + overall_std,
           alpha=0.1, color="orange")

ax.set_xlabel("Median wl value (transit depth proxy) across 283 wavelengths", fontsize=11)
ax.set_ylabel("Number of planets", fontsize=11)
ax.set_title(
    f"Transit Depth Distribution — {len(median_depth_per_planet)} labelled planets\n"
    f"Annotated: median = {overall_median:.4f} +- std = {overall_std:.4f}",
    fontsize=12
)
ax.legend(fontsize=9)

plt.tight_layout()
fig.savefig(FIG_DIR / "transit_depth_distribution.png")
print(f"Saved: {FIG_DIR / 'transit_depth_distribution.png'}")
plt.show()

print(f"Transit depth statistics across {len(median_depth_per_planet)} labelled planets:")
print(f"  Median  : {overall_median:.6f}")
print(f"  Std     : {overall_std:.6f}")
print(f"  Min     : {np.nanmin(median_depth_per_planet):.6f}")
print(f"  Max     : {np.nanmax(median_depth_per_planet):.6f}")
print(f"  p5      : {np.nanpercentile(median_depth_per_planet, 5):.6f}")
print(f"  p95     : {np.nanpercentile(median_depth_per_planet, 95):.6f}")
print(f"\n[Summary] Median transit depth = {overall_median:.4f} +- {overall_std:.4f} across "
      f"{len(median_depth_per_planet)} labelled planets; range "
      f"[{np.nanmin(median_depth_per_planet):.4f}, {np.nanmax(median_depth_per_planet):.4f}].")

## 9. Label Correlation Heatmap

In [None]:
# Pearson correlation matrix of wl_* columns (283 wavelengths)
# Subsample every 10th wavelength for readability → ~28 channels

SUBSAMPLE_STEP = 10  # use every 10th wavelength
wl_sub_cols = wl_cols[::SUBSAMPLE_STEP]

print(f"Subsampling from {len(wl_cols)} to {len(wl_sub_cols)} wavelength channels "
      f"(every {SUBSAMPLE_STEP}th channel).")

# Compute Pearson correlation
df_wl_sub = df_labels[wl_sub_cols].copy()

# Create short labels: wavelength index
short_labels = [c.replace("wl_", "") for c in wl_sub_cols]
df_wl_sub.columns = short_labels

corr_matrix = df_wl_sub.corr(method="pearson")

fig, ax = plt.subplots(figsize=(14, 12))

sns.heatmap(
    corr_matrix,
    ax=ax,
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
    center=0,
    square=True,
    linewidths=0.3,
    cbar_kws={"shrink": 0.8, "label": "Pearson r"},
    annot=False,
    xticklabels=short_labels,
    yticklabels=short_labels,
)

ax.set_title(
    f"Label Correlation Heatmap — wl_* (mean transit depth) across subsampled wavelengths\n"
    f"Showing {len(wl_sub_cols)} of {len(wl_cols)} channels (every {SUBSAMPLE_STEP}th), "
    f"{len(df_labels)} labelled planets",
    fontsize=11
)
ax.set_xlabel("Wavelength channel index", fontsize=10)
ax.set_ylabel("Wavelength channel index", fontsize=10)
ax.tick_params(axis="x", rotation=45, labelsize=7)
ax.tick_params(axis="y", rotation=0, labelsize=7)

plt.tight_layout()
fig.savefig(FIG_DIR / "label_correlation_heatmap.png")
print(f"Saved: {FIG_DIR / 'label_correlation_heatmap.png'}")
plt.show()

# Find pairs with high correlation
corr_vals = corr_matrix.values.copy()
np.fill_diagonal(corr_vals, np.nan)
max_corr = np.nanmax(np.abs(corr_vals))
mean_corr = np.nanmean(np.abs(corr_vals))

print(f"Correlation statistics (off-diagonal):")
print(f"  Max |Pearson r| : {max_corr:.4f}")
print(f"  Mean |Pearson r|: {mean_corr:.4f}")
print(f"  Highly correlated pairs (|r| > 0.9): "
      f"{(np.abs(corr_vals) > 0.9).sum() // 2}")

print(f"\n[Summary] Label correlation heatmap: mean |r| = {mean_corr:.3f}, "
      f"max |r| = {max_corr:.3f} among {len(wl_sub_cols)} subsampled wavelength channels.")

## 10. Summary and Conclusions

### Key findings from this EDA:

- **Dataset scale**: The competition dataset is ~180 GB, consisting of nested directories with parquet files per planet, containing time-series photometry from two instruments (AIRS-CH0 and FGS1) plus CSV/parquet metadata tables.

- **Data format**: Each planet directory contains `AIRS-CH0_signal.parquet` (11250 time steps, 32x356 detector pixels as uint16), `FGS1_signal.parquet` (135000 time steps at 12x cadence, 32x32 pixels as uint16), and calibration subdirectories with `dark.parquet`, `flat.parquet`, `dead.parquet`, `read.parquet`, and `linear_corr.parquet`.

- **Calibration pipeline**: Raw uint16 detector frames require dark subtraction, flat fielding, and dead pixel masking before scientific use. Spatial rows are then summed to produce 1-D spectra per time step.

- **Labels**: Only ~24% of planets are labelled with mean transit depths (wl_1...wl_283) across 283 wavelength bins in `train_labels.csv`. No quartile or sigma columns are provided — only means. The remaining ~76% form the test set.

- **Wavelength grid**: 283 channels total: 1 FGS1 visible channel (0.705 um) + 282 AIRS infrared channels. Grid defined in `wavelengths.csv`.

- **Typical transit depth range**: The median transit depth across labelled planets shows substantial planet-to-planet variation (see Section 8 histogram).

- **ADC features**: Five features available in `train_adc_info.csv`: FGS1_adc_offset, FGS1_adc_gain, AIRS-CH0_adc_offset, AIRS-CH0_adc_gain, and star identifier. These describe detector gain/offset calibration parameters per planet.

- **AIRS-CH0 structure**: After calibration and spatial summing, each planet's light curve is a 2-D matrix of shape (n_time x 356 spectral channels). The white-light curve shows a clear transit dip. The 2-D heatmap reveals wavelength-dependent flux variations.

- **FGS1 structure**: 32x32 visible photometer at 12x higher cadence than AIRS. After calibration, spatial summing, and 12:1 downsampling, it provides a complementary broadband transit signal.

- **Label correlation**: Adjacent wavelength channels are highly correlated (Pearson r typically > 0.9 for neighbouring channels), consistent with smooth atmospheric spectra. Distant channels decorrelate across molecular absorption band boundaries.

### Next steps:

1. Implement full calibration pipeline including `read.parquet` and `linear_corr.parquet` corrections.
2. Compute per-planet baseline features: white-light transit depth, ingress/egress timing.
3. Extract AIRS-CH0 spectral features: per-channel transit depth, SNR per channel.
4. Build a baseline model using ADC features + spectral features.
5. Explore common-mode correction using FGS1 as a systematics reference.