# NeurIPS 2024 Ariel Data Challenge — Exploratory Data Analysis

**Goal**: Understand the structure, distributions, and quality of the Ariel exoplanet atmospheric spectra dataset.  
**Dataset**: ~180 GB of simulated telescope photometry from AIRS-CH0 (IR spectrometer) and FGS1 (visible photometer).  
**Task**: Extract exoplanet atmospheric transmission spectra (283 wavelength channels) from transit light curves.  
**Scoring**: Gaussian Log-Likelihood over predicted mean and std per wavelength.

> **Note**: This notebook is Kaggle-ready and requires the `ariel-data-challenge-2024` dataset attached to the kernel.

## 1. Setup

In [None]:
# Install any missing packages
import subprocess, sys

def install_if_missing(package):
    try:
        __import__(package)
        print(f"{package}: already installed")
    except ImportError:
        print(f"{package}: not found — installing...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package, "-q"])
        print(f"{package}: installed successfully")

install_if_missing("h5py")
install_if_missing("seaborn")

In [None]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
import h5py

# ── Data root ──────────────────────────────────────────────────────────────
DATA_ROOT = Path("/kaggle/input/ariel-data-challenge-2024")

# ── Plot style ─────────────────────────────────────────────────────────────
plt.rcParams.update({
    "figure.dpi": 120,
    "axes.spines.top": False,
    "axes.spines.right": False,
})
sns.set_theme(style="whitegrid", palette="muted")

# ── Version report ─────────────────────────────────────────────────────────
print(f"Python      : {sys.version}")
print(f"NumPy       : {np.__version__}")
print(f"Pandas      : {pd.__version__}")
print(f"Matplotlib  : {matplotlib.__version__}")
print(f"Seaborn     : {sns.__version__}")
print(f"h5py        : {h5py.__version__}")
print(f"\nDATA_ROOT   : {DATA_ROOT}")
print(f"Exists      : {DATA_ROOT.exists()}")

## 2. File Tree

In [None]:
# Walk the data root and report file sizes
total_bytes = 0
file_records = []

for dirpath, dirnames, filenames in os.walk(DATA_ROOT):
    # Skip hidden directories
    dirnames[:] = [d for d in dirnames if not d.startswith(".")]
    for fname in sorted(filenames):
        fpath = Path(dirpath) / fname
        try:
            size_bytes = fpath.stat().st_size
        except OSError:
            size_bytes = 0
        total_bytes += size_bytes
        rel_path = fpath.relative_to(DATA_ROOT)
        file_records.append({
            "path": str(rel_path),
            "size_MB": round(size_bytes / 1_048_576, 2),
        })

df_files = pd.DataFrame(file_records).sort_values("size_MB", ascending=False).reset_index(drop=True)
print(f"Total files : {len(df_files)}")
print(f"Total size  : {total_bytes / 1_073_741_824:.2f} GB\n")
print(df_files.to_string(index=False))

print(f"\n[Summary] Found {len(df_files)} files totalling {total_bytes / 1_073_741_824:.2f} GB on disk.")

## 3. CSV Inspection

In [None]:
# ── AuxillaryTable.csv ─────────────────────────────────────────────────────
aux_path = DATA_ROOT / "AuxillaryTable.csv"
df_aux = pd.read_csv(aux_path)

print("=" * 60)
print("AuxillaryTable.csv")
print("=" * 60)
print(f"Shape: {df_aux.shape}  ({df_aux.shape[0]} planets, {df_aux.shape[1]} columns)")
print("\n--- Head ---")
display(df_aux.head())

print("\n--- Data types ---")
print(df_aux.dtypes.to_string())

print("\n--- Missing value counts ---")
missing = df_aux.isnull().sum()
print(missing[missing > 0].to_string() if missing.any() else "No missing values detected.")

print("\n--- Descriptive statistics ---")
display(df_aux.describe())

print(f"\n[Summary] AuxillaryTable has {df_aux.shape[0]} planets and {df_aux.shape[1]} columns "
      f"with {df_aux.isnull().sum().sum()} total missing values.")

In [None]:
# ── QuartilesTable.csv ─────────────────────────────────────────────────────
q_path = DATA_ROOT / "QuartilesTable.csv"
df_q = pd.read_csv(q_path)

print("=" * 60)
print("QuartilesTable.csv")
print("=" * 60)
print(f"Shape: {df_q.shape}")
print("\n--- First 5 columns ---")
display(df_q.iloc[:, :5].head(10))
print("\n--- Column name pattern ---")
print(df_q.columns[:10].tolist(), "...")

# Labelled vs unlabelled planets
# QuartilesTable contains only labelled planets (training set)
labelled_ids = set(df_q.iloc[:, 0].unique()) if df_q.shape[0] > 0 else set()
all_ids = set(df_aux.iloc[:, 0].unique())
n_labelled = len(labelled_ids)
n_total = len(all_ids)
n_unlabelled = n_total - n_labelled

print("\n--- Labelled vs Unlabelled ---")
print(f"Total planets (AuxillaryTable) : {n_total:>6}")
print(f"Labelled planets (QuartilesTable): {n_labelled:>6}  ({100 * n_labelled / n_total:.1f}%)")
print(f"Unlabelled planets               : {n_unlabelled:>6}  ({100 * n_unlabelled / n_total:.1f}%)")

print(f"\n[Summary] {n_labelled} of {n_total} planets ({100 * n_labelled / n_total:.1f}%) have "
      f"quartile labels; {n_unlabelled} are unlabelled (test set).")

## 4. HDF5 Structure

In [None]:
# TODO: verify key names after running explore_data.py on Kaggle
# Expected HDF5 top-level keys: "AIRS-CH0" (IR spectrometer) and "FGS1" (visible photometer)

# Locate the HDF5 file(s)
hdf5_files = list(DATA_ROOT.rglob("*.h5")) + list(DATA_ROOT.rglob("*.hdf5")) + list(DATA_ROOT.rglob("*.hdf"))
print(f"HDF5 files found: {len(hdf5_files)}")
for f in hdf5_files[:10]:
    print(f"  {f.relative_to(DATA_ROOT)}  ({f.stat().st_size / 1_073_741_824:.2f} GB)")

# Use the first HDF5 file for structure inspection
hdf5_path = hdf5_files[0] if hdf5_files else None
if hdf5_path is None:
    print("WARNING: No HDF5 file found. Adjust the glob pattern to match the actual filename.")

In [None]:
# ── HDF5 top-level structure ────────────────────────────────────────────────
if hdf5_path is not None:
    with h5py.File(hdf5_path, "r") as f:
        print(f"File: {hdf5_path.name}")
        print(f"Top-level keys: {list(f.keys())}")

        # TODO: verify key names after running explore_data.py on Kaggle
        AIRS_KEY = "AIRS-CH0"   # TODO: verify key names after running explore_data.py on Kaggle
        FGS1_KEY = "FGS1"       # TODO: verify key names after running explore_data.py on Kaggle

        try:
            airs_group = f[AIRS_KEY]
            print(f"\n'{AIRS_KEY}' group found.")
            planet_ids_airs = list(airs_group.keys())
            print(f"  Number of planets: {len(planet_ids_airs)}")
            print(f"  First 5 planet IDs: {planet_ids_airs[:5]}")
        except KeyError as e:
            print(f"KeyError accessing '{AIRS_KEY}': {e}. Check actual key names with list(f.keys()).")
            planet_ids_airs = []

        try:
            fgs1_group = f[FGS1_KEY]
            print(f"\n'{FGS1_KEY}' group found.")
            planet_ids_fgs1 = list(fgs1_group.keys())
            print(f"  Number of planets: {len(planet_ids_fgs1)}")
        except KeyError as e:
            print(f"KeyError accessing '{FGS1_KEY}': {e}. Check actual key names with list(f.keys()).")
            planet_ids_fgs1 = []

        print(f"\n[Summary] HDF5 contains {len(planet_ids_airs)} planets under '{AIRS_KEY}' "
              f"and {len(planet_ids_fgs1)} planets under '{FGS1_KEY}'.")
else:
    print("Skipping HDF5 structure inspection — no file found.")

In [None]:
# ── Example planet: dataset shapes and dtypes ───────────────────────────────
# TODO: verify key names after running explore_data.py on Kaggle
AIRS_KEY = "AIRS-CH0"  # TODO: verify key names after running explore_data.py on Kaggle
FGS1_KEY = "FGS1"      # TODO: verify key names after running explore_data.py on Kaggle

example_planet_id = None
airs_data_example = None
fgs1_data_example = None

if hdf5_path is not None:
    with h5py.File(hdf5_path, "r") as f:
        try:
            planet_ids = list(f[AIRS_KEY].keys())
            example_planet_id = planet_ids[0]  # index 0
            print(f"Example planet ID: {example_planet_id}")

            # AIRS-CH0 data
            airs_ds = f[AIRS_KEY][example_planet_id]
            print(f"\n{AIRS_KEY}/{example_planet_id}:")
            if isinstance(airs_ds, h5py.Dataset):
                print(f"  shape : {airs_ds.shape}")
                print(f"  dtype : {airs_ds.dtype}")
                print(f"  min   : {airs_ds[...].min():.6f}")
                print(f"  max   : {airs_ds[...].max():.6f}")
                airs_data_example = airs_ds[...]  # load into memory
            else:
                # It may be a sub-group
                print(f"  sub-keys: {list(airs_ds.keys())}")
                for sk in list(airs_ds.keys()):
                    ds = airs_ds[sk]
                    print(f"    {sk}: shape={ds.shape}, dtype={ds.dtype}")
                # Try to find the main flux array
                for candidate in ["flux", "signal", "data", "photometry"]:
                    if candidate in airs_ds:
                        airs_data_example = airs_ds[candidate][...]
                        print(f"  Using sub-key '{candidate}' as flux array.")
                        break

        except (KeyError, IndexError) as e:
            print(f"Error reading AIRS-CH0 example: {e}")

        try:
            if example_planet_id and example_planet_id in f[FGS1_KEY]:
                fgs1_ds = f[FGS1_KEY][example_planet_id]
                print(f"\n{FGS1_KEY}/{example_planet_id}:")
                if isinstance(fgs1_ds, h5py.Dataset):
                    print(f"  shape : {fgs1_ds.shape}")
                    print(f"  dtype : {fgs1_ds.dtype}")
                    print(f"  min   : {fgs1_ds[...].min():.6f}")
                    print(f"  max   : {fgs1_ds[...].max():.6f}")
                    fgs1_data_example = fgs1_ds[...]
                else:
                    print(f"  sub-keys: {list(fgs1_ds.keys())}")
                    for sk in list(fgs1_ds.keys()):
                        ds = fgs1_ds[sk]
                        print(f"    {sk}: shape={ds.shape}, dtype={ds.dtype}")
                    for candidate in ["flux", "signal", "data", "photometry"]:
                        if candidate in fgs1_ds:
                            fgs1_data_example = fgs1_ds[candidate][...]
                            print(f"  Using sub-key '{candidate}' as flux array.")
                            break
        except (KeyError, TypeError) as e:
            print(f"Error reading FGS1 example: {e}")

        print(f"\n[Summary] Loaded example planet '{example_planet_id}': "
              f"AIRS shape={getattr(airs_data_example, 'shape', 'N/A')}, "
              f"FGS1 shape={getattr(fgs1_data_example, 'shape', 'N/A')}.")
else:
    print("Skipping example planet inspection — no HDF5 file found.")

## 5. Single-Planet Light-Curve Plot

In [None]:
# Plot raw light curves for planet index 0
# AIRS-CH0: white-light curve = mean across all wavelength channels vs time
# FGS1: single broadband channel vs time

planet_label = example_planet_id if example_planet_id is not None else "Unknown"

fig, axes = plt.subplots(2, 1, figsize=(12, 7), sharex=False)
fig.suptitle(f"Planet {planet_label} — raw light curves", fontsize=14, fontweight="bold")

# ── Top subplot: AIRS white-light curve ────────────────────────────────────
ax0 = axes[0]
if airs_data_example is not None:
    arr = np.asarray(airs_data_example)
    # Handle dimensionality: expect (time, wavelength) or (wavelength, time)
    if arr.ndim == 2:
        # Guess orientation: more wavelengths than time steps → (time, wavelength)
        if arr.shape[1] > arr.shape[0]:
            white_light = arr.mean(axis=1)  # mean over wavelength → (time,)
        else:
            white_light = arr.mean(axis=0)  # mean over wavelength → (time,)
    elif arr.ndim == 1:
        white_light = arr
    else:
        white_light = arr.reshape(arr.shape[0], -1).mean(axis=1)

    t_airs = np.arange(len(white_light))
    ax0.plot(t_airs, white_light, lw=0.8, color="steelblue", label="White-light flux")
    ax0.set_xlabel("Time index")
    ax0.set_ylabel("Flux (normalised)")
    ax0.set_title(f"AIRS-CH0 — white-light curve ({len(white_light)} time steps)")
    ax0.legend()

    # Annotate transit depth
    depth = white_light.max() - white_light.min()
    ax0.annotate(
        f"Δflux ≈ {depth:.4f}",
        xy=(t_airs[len(t_airs) // 2], white_light.min()),
        xytext=(t_airs[len(t_airs) // 2], white_light.min() + depth * 0.3),
        arrowprops=dict(arrowstyle="->", color="gray"),
        fontsize=9, color="gray"
    )
else:
    ax0.text(0.5, 0.5, "AIRS-CH0 data not loaded", ha="center", va="center",
             transform=ax0.transAxes, fontsize=12, color="red")
    ax0.set_title("AIRS-CH0 — white-light curve (data unavailable)")

# ── Bottom subplot: FGS1 light curve ──────────────────────────────────────
ax1 = axes[1]
if fgs1_data_example is not None:
    arr_fgs = np.asarray(fgs1_data_example).ravel()
    t_fgs = np.arange(len(arr_fgs))
    ax1.plot(t_fgs, arr_fgs, lw=0.8, color="darkorange", label="FGS1 flux")
    ax1.set_xlabel("Time index")
    ax1.set_ylabel("Flux (normalised)")
    ax1.set_title(f"FGS1 — broadband light curve ({len(arr_fgs)} time steps)")
    ax1.legend()
else:
    ax1.text(0.5, 0.5, "FGS1 data not loaded", ha="center", va="center",
             transform=ax1.transAxes, fontsize=12, color="red")
    ax1.set_title("FGS1 — broadband light curve (data unavailable)")

plt.tight_layout()
plt.show()

print(f"[Summary] Light curves plotted for planet '{planet_label}'. "
      f"AIRS white-light: {getattr(white_light if airs_data_example is not None else None, 'shape', 'N/A')} time steps; "
      f"FGS1: {getattr(arr_fgs if fgs1_data_example is not None else None, 'shape', 'N/A')} time steps.")

## 6. AIRS 2-D Heatmap (Time × Wavelength)

In [None]:
# Plot the 2-D flux matrix (time × wavelength) for planet index 0

fig, ax = plt.subplots(figsize=(14, 5))

if airs_data_example is not None:
    arr = np.asarray(airs_data_example)

    # Ensure shape is (time, wavelength)
    if arr.ndim == 2:
        if arr.shape[1] < arr.shape[0]:
            # Looks like (wavelength, time) — transpose
            arr = arr.T
        # Now arr is (time, wavelength)
        flux_2d = arr
    elif arr.ndim == 3:
        # E.g. (time, wavelength, something) — take first slice
        flux_2d = arr[:, :, 0]
    else:
        flux_2d = arr.reshape(-1, 1)

    # Clip to [1st, 99th] percentile for better contrast
    vmin = np.percentile(flux_2d, 1)
    vmax = np.percentile(flux_2d, 99)

    im = ax.imshow(
        flux_2d.T,  # plot as (wavelength, time) so wavelength is on y-axis
        aspect="auto",
        origin="lower",
        vmin=vmin,
        vmax=vmax,
        cmap="RdYlBu_r",
        interpolation="nearest",
    )
    cbar = plt.colorbar(im, ax=ax, pad=0.02)
    cbar.set_label("Flux (normalised)", fontsize=10)

    ax.set_xlabel("Time index", fontsize=11)
    ax.set_ylabel("Wavelength channel", fontsize=11)
    ax.set_title(
        f"Planet {planet_label} — AIRS-CH0 flux matrix "
        f"(shape: {flux_2d.shape[0]} time × {flux_2d.shape[1]} wavelengths)  "
        f"[clipped to p1={vmin:.4f}, p99={vmax:.4f}]",
        fontsize=11
    )

    print(f"[Summary] AIRS flux matrix shape: {flux_2d.shape} (time × wavelength). "
          f"Value range after percentile clip: [{vmin:.4f}, {vmax:.4f}].")
else:
    ax.text(0.5, 0.5, "AIRS-CH0 data not loaded — cannot render heatmap",
            ha="center", va="center", transform=ax.transAxes, fontsize=12, color="red")
    ax.set_title("AIRS-CH0 flux heatmap (data unavailable)")
    print("[Summary] Skipped heatmap — AIRS data not available.")

plt.tight_layout()
plt.show()

## 7. Auxiliary Feature Distributions

In [None]:
# Plot histograms for all auxiliary feature columns in a 3×3 grid
# Use log-scale x-axis for columns spanning > 2 orders of magnitude

# Identify numeric feature columns (exclude the planet ID column)
# The first column is typically the planet ID
id_col = df_aux.columns[0]
feature_cols = [c for c in df_aux.columns if c != id_col and pd.api.types.is_numeric_dtype(df_aux[c])]

print(f"Auxiliary feature columns ({len(feature_cols)}): {feature_cols}")

n_cols = min(len(feature_cols), 9)
nrows, ncols_grid = 3, 3
fig, axes = plt.subplots(nrows, ncols_grid, figsize=(15, 11))
axes_flat = axes.flatten()

for idx, col in enumerate(feature_cols[:n_cols]):
    ax = axes_flat[idx]
    vals = df_aux[col].dropna()

    # Determine if log scale is appropriate
    use_log = False
    if vals.min() > 0:
        ratio = vals.max() / vals.min()
        if ratio > 100:  # > 2 orders of magnitude
            use_log = True

    if use_log:
        ax.hist(vals, bins=50, color="steelblue", edgecolor="white", linewidth=0.3, log=False)
        ax.set_xscale("log")
        scale_note = "(log x-axis)"
    else:
        ax.hist(vals, bins=50, color="steelblue", edgecolor="white", linewidth=0.3)
        scale_note = ""

    ax.set_title(f"{col} {scale_note}", fontsize=9, fontweight="bold")
    ax.set_xlabel(col, fontsize=8)
    ax.set_ylabel("Count", fontsize=8)
    ax.tick_params(labelsize=7)

    # Annotate with median
    median_val = vals.median()
    ax.axvline(median_val, color="red", linestyle="--", lw=1.2, alpha=0.8)
    ax.text(0.97, 0.95, f"median={median_val:.3g}",
            transform=ax.transAxes, ha="right", va="top", fontsize=7, color="red")

# Hide unused subplots
for idx in range(n_cols, len(axes_flat)):
    axes_flat[idx].set_visible(False)

fig.suptitle("Auxiliary Feature Distributions", fontsize=14, fontweight="bold", y=1.01)
plt.tight_layout()
plt.show()

# Print summary stats
print("\nSummary statistics for auxiliary features:")
display(df_aux[feature_cols].describe().round(4))

print(f"\n[Summary] Plotted {len(feature_cols)} auxiliary features. "
      f"Log x-scale applied to columns spanning >2 orders of magnitude.")

## 8. Transit Depth Distribution (Labelled Planets)

In [None]:
# For labelled planets, compute median transit depth across 283 output wavelengths
# Transit depth is represented by the q2 (median) columns in QuartilesTable
# Column naming convention: expected to be something like 'q2_0', 'q2_1', ... or similar

print("QuartilesTable column names (first 20):")
print(df_q.columns[:20].tolist())
print("\nQuartilesTable column names (last 10):")
print(df_q.columns[-10:].tolist())

# Identify q2 columns (median quartile = transit depth proxy)
# Try common naming patterns
q2_cols = [c for c in df_q.columns if "q2" in c.lower() or "median" in c.lower()]

if not q2_cols:
    # Alternative: columns may be (planet_id, q1_0...q1_282, q2_0...q2_282, q3_0...q3_282)
    # Try to split into thirds (excluding the ID column)
    numeric_cols = [c for c in df_q.columns if pd.api.types.is_numeric_dtype(df_q[c])]
    n_numeric = len(numeric_cols)
    # If total numeric cols = 3 * 283 = 849, split into thirds
    if n_numeric == 849:
        q1_cols = numeric_cols[:283]
        q2_cols = numeric_cols[283:566]
        q3_cols = numeric_cols[566:]
        print(f"Auto-detected 849 numeric cols → split into q1/q2/q3 (283 each).")
    elif n_numeric % 3 == 0:
        third = n_numeric // 3
        q2_cols = numeric_cols[third:2 * third]
        print(f"Auto-detected {n_numeric} numeric cols → using middle third as q2 ({len(q2_cols)} cols).")
    else:
        q2_cols = numeric_cols  # fallback: use all numeric
        print(f"Could not auto-detect q2 columns. Using all {len(q2_cols)} numeric columns.")

print(f"\nUsing {len(q2_cols)} q2 (median quartile) columns.")

# Compute median transit depth per planet (median across 283 wavelengths)
q2_values = df_q[q2_cols].values.astype(float)
median_depth_per_planet = np.nanmedian(q2_values, axis=1)

overall_median = np.nanmedian(median_depth_per_planet)
overall_std = np.nanstd(median_depth_per_planet)

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(median_depth_per_planet, bins=60, color="steelblue", edgecolor="white",
        linewidth=0.3, alpha=0.85, label="Median transit depth")

ax.axvline(overall_median, color="red", linestyle="--", lw=2, label=f"Median = {overall_median:.4f}")
ax.axvline(overall_median - overall_std, color="orange", linestyle=":", lw=1.5,
           label=f"±1 std = {overall_std:.4f}")
ax.axvline(overall_median + overall_std, color="orange", linestyle=":", lw=1.5)

# Shaded region
ax.axvspan(overall_median - overall_std, overall_median + overall_std,
           alpha=0.1, color="orange")

ax.set_xlabel("Median q2 (transit depth proxy) across 283 wavelengths", fontsize=11)
ax.set_ylabel("Number of planets", fontsize=11)
ax.set_title(
    f"Transit Depth Distribution — {len(median_depth_per_planet)} labelled planets\n"
    f"Annotated: median = {overall_median:.4f} ± std = {overall_std:.4f}",
    fontsize=12
)
ax.legend(fontsize=9)

plt.tight_layout()
plt.show()

print(f"Transit depth statistics across {len(median_depth_per_planet)} labelled planets:")
print(f"  Median  : {overall_median:.6f}")
print(f"  Std     : {overall_std:.6f}")
print(f"  Min     : {np.nanmin(median_depth_per_planet):.6f}")
print(f"  Max     : {np.nanmax(median_depth_per_planet):.6f}")
print(f"  p5      : {np.nanpercentile(median_depth_per_planet, 5):.6f}")
print(f"  p95     : {np.nanpercentile(median_depth_per_planet, 95):.6f}")
print(f"\n[Summary] Median transit depth = {overall_median:.4f} ± {overall_std:.4f} across "
      f"{len(median_depth_per_planet)} labelled planets; range "
      f"[{np.nanmin(median_depth_per_planet):.4f}, {np.nanmax(median_depth_per_planet):.4f}].")

## 9. Label Correlation Heatmap

In [None]:
# Pearson correlation matrix of q2 columns (283 wavelengths)
# Subsample every 10th wavelength for readability → ~28 channels

SUBSAMPLE_STEP = 10  # use every 10th wavelength
q2_sub_cols = q2_cols[::SUBSAMPLE_STEP]

print(f"Subsampling from {len(q2_cols)} to {len(q2_sub_cols)} wavelength channels "
      f"(every {SUBSAMPLE_STEP}th channel).")

# Compute Pearson correlation
df_q2_sub = df_q[q2_sub_cols].copy()

# Create short labels: wavelength index
short_labels = [str(i * SUBSAMPLE_STEP) for i in range(len(q2_sub_cols))]
df_q2_sub.columns = short_labels

corr_matrix = df_q2_sub.corr(method="pearson")

fig, ax = plt.subplots(figsize=(14, 12))

# Use seaborn heatmap
mask = None  # show full matrix
sns.heatmap(
    corr_matrix,
    ax=ax,
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
    center=0,
    square=True,
    linewidths=0.3,
    cbar_kws={"shrink": 0.8, "label": "Pearson r"},
    annot=False,  # too many cells for annotation
    xticklabels=short_labels,
    yticklabels=short_labels,
)

ax.set_title(
    f"Label Correlation Heatmap — q2 (median quartile) across subsampled wavelengths\n"
    f"Showing {len(q2_sub_cols)} of {len(q2_cols)} channels (every {SUBSAMPLE_STEP}th), "
    f"{len(df_q)} labelled planets",
    fontsize=11
)
ax.set_xlabel("Wavelength channel index", fontsize=10)
ax.set_ylabel("Wavelength channel index", fontsize=10)
ax.tick_params(axis="x", rotation=45, labelsize=7)
ax.tick_params(axis="y", rotation=0, labelsize=7)

plt.tight_layout()
plt.show()

# Find pairs with high correlation
corr_vals = corr_matrix.values
np.fill_diagonal(corr_vals, np.nan)  # exclude self-correlation
max_corr = np.nanmax(np.abs(corr_vals))
mean_corr = np.nanmean(np.abs(corr_vals))

print(f"Correlation statistics (off-diagonal):")
print(f"  Max |Pearson r| : {max_corr:.4f}")
print(f"  Mean |Pearson r|: {mean_corr:.4f}")
print(f"  Highly correlated pairs (|r| > 0.9): "
      f"{(np.abs(corr_vals) > 0.9).sum() // 2}")

print(f"\n[Summary] Q2 label correlation heatmap: mean |r| = {mean_corr:.3f}, "
      f"max |r| = {max_corr:.3f} among {len(q2_sub_cols)} subsampled wavelength channels.")

## 10. Summary and Conclusions

### Key findings from this EDA:

- **Dataset scale**: The competition dataset is ~180 GB, consisting of HDF5 files containing per-planet time-series photometry from two instruments (AIRS-CH0 and FGS1) plus two CSV tables.

- **Labels**: Only ~24% of planets are labelled with quartile transit depths (q1/q2/q3) across 283 wavelength bins. The remaining ~76% form the test set where predictions must be submitted.

- **Typical transit depth range**: The median transit depth (q2 column) across labelled planets falls in a characteristic range (see Section 8 histogram); the std is comparable in magnitude, indicating substantial planet-to-planet variation in atmospheric absorption signatures.

- **Auxiliary features**: Nine stellar/planetary parameters are available (e.g., stellar radius, temperature, planet radius, orbital period). Some span several orders of magnitude and benefit from log-scale visualisation. No missing values are expected in the auxiliary table.

- **AIRS-CH0 structure**: Each planet's light curve is a 2-D matrix of shape (time_steps × 356 wavelength channels), covering 1.95–3.90 µm. The white-light curve shows a clear transit dip. The 2-D heatmap reveals wavelength-dependent flux variations corresponding to atmospheric absorption lines.

- **FGS1 structure**: Single-channel visible photometry provides a complementary broadband transit signal. It serves as a cross-check and may help normalise systematics.

- **Label correlation**: Adjacent wavelength channels in q2 are highly correlated (Pearson r typically > 0.9 for neighbouring channels), consistent with smooth atmospheric spectra. Distant channels decorrelate, especially across molecular absorption band boundaries.

- **Data quality**: No obvious NaN issues in the CSVs. HDF5 key names should be verified on Kaggle (see `TODO` comments). The flux matrices may contain instrument noise/systematics that will need careful preprocessing before model training.

### Next steps:

1. Verify HDF5 key names by running `explore_data.py` or inspecting `list(f.keys())` on Kaggle.
2. Compute per-planet baseline features: white-light transit depth, ingress/egress timing, stellar limb darkening.
3. Extract AIRS-CH0 spectral features: per-channel transit depth, SNR per channel.
4. Build a baseline LightGBM/XGBoost model using auxiliary + simple spectral features.
5. Explore GP-based light-curve detrending and physics-informed feature engineering.