# 01: Food Allergy Baseline (Track A)

Baseline notebook for the DIABIMMUNE Track A dataset (HF microbiome embeddings + corrected `Month_*.csv` metadata).

Design goals (see `docs/specs/03_NOTEBOOK_STRUCTURE.md`):
- Self-contained (no project helper modules).
- Assertion-driven (fail fast if invariants break).
- Leakage-safe evaluation (subject-level aggregation + `StratifiedGroupKFold`).
- Cumulative horizon analyses: `≤3mo`, `≤6mo`, `≤12mo`, and `all`.
- LOCO analysis to probe country confounding.


## 0) Setup & Configuration


In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import sklearn
import sys
from IPython.display import display

# Configuration
RANDOM_SEED = 42
N_SPLITS_OUTER = 5
HORIZONS = [None, 3, 6, 12]  # None = all samples (association baseline)

# Paths
# Note: nbconvert executes notebooks with cwd set to the notebook's directory,
# so we auto-detect the repo root by searching for `data/processed/`.
def find_repo_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data" / "processed").exists():
            return p
    raise FileNotFoundError(
        f"Could not find repo root from {start.resolve()} (expected data/processed/)"
    )

REPO_ROOT = find_repo_root(Path.cwd())
METADATA_DIR = REPO_ROOT / "data" / "processed" / "longitudinal_wgs_subset"
EMBEDDINGS_PATH = REPO_ROOT / "data" / "processed" / "hf_legacy" / "microbiome_embeddings_100d.h5"
# Results go to `notebooks/results/` (alongside the notebook) for transplantability.
# Note: nbconvert executes notebooks with cwd set to the notebook's directory.
RESULTS_DIR = Path("results")
RESULTS_DIR.mkdir(exist_ok=True)

# Set seeds for reproducibility
np.random.seed(RANDOM_SEED)

# Print versions
print(f"Python: {sys.version}")
print(f"NumPy: {np.__version__}")
print(f"pandas: {pd.__version__}")
print(f"scikit-learn: {sklearn.__version__}")
print(f"Random seed: {RANDOM_SEED}")


## 1) Load Data


In [None]:
import h5py
import pandas as pd
from pathlib import Path

assert METADATA_DIR.exists(), f"Missing METADATA_DIR: {METADATA_DIR.resolve()}"
assert EMBEDDINGS_PATH.exists(), f"Missing EMBEDDINGS_PATH: {EMBEDDINGS_PATH.resolve()}"

# Load all Month CSVs
dfs = []
for csv_path in sorted(METADATA_DIR.glob("Month_*.csv")):
    month = int(csv_path.stem.split("_")[1])
    df = pd.read_csv(csv_path)
    df["month"] = month
    dfs.append(df)
metadata = pd.concat(dfs, ignore_index=True)

# Load embeddings
embeddings_dict = {}
with h5py.File(EMBEDDINGS_PATH, "r") as f:
    for key in f.keys():
        embeddings_dict[key] = f[key][:]

# Merge
metadata["embedding"] = metadata["sid"].map(embeddings_dict)
df = metadata.dropna(subset=["embedding"])  # Should be 0 drops if aligned

# Extract arrays
X = np.stack(df["embedding"].values)
y = df["label"].values
patient_ids = df["patient_id"].values
countries = df["country"].values


## 2) Integrity Checks


In [None]:
# Counts
print(f"Samples: {len(df)}")
print(f"Unique patients: {df['patient_id'].nunique()}")
print(f"Label distribution: {df['label'].value_counts().to_dict()}")
print(f"Country distribution: {df['country'].value_counts().to_dict()}")

# Assertions
assert X.shape == (785, 100), f"Expected (785, 100), got {X.shape}"
assert len(y) == 785
assert df["patient_id"].nunique() == 212
assert len(df) == len(metadata), f"Expected 0 dropped rows, dropped {len(metadata) - len(df)}"

# Check: each sample in exactly one month
assert df["sid"].duplicated().sum() == 0, "Duplicate SRS IDs found!"

# Check: labels consistent per patient
labels_per_patient = df.groupby("patient_id")["label"].nunique()
assert (labels_per_patient == 1).all(), "Inconsistent labels within patient!"

# Horizon sanity checks (Track A; month derived from file name)
def horizon_counts(m: int):
    d = df[df["month"] <= m]
    return len(d), d["patient_id"].nunique(), d["label"].value_counts().to_dict(), d["country"].value_counts().to_dict()

print("Horizon counts:")
for m in [3, 6, 12]:
    n_samp, n_pat, labels, countries_m = horizon_counts(m)
    print(f"  month<={m}: samples={n_samp}, patients={n_pat}, labels={labels}, countries={countries_m}")

assert horizon_counts(3)[0] == 45
assert horizon_counts(6)[0] == 110
assert horizon_counts(12)[0] == 307


## 3) ANALYSES 1–4: Association Baseline + Cumulative Horizons (Subject-Level)


In [None]:
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

outer_cv = StratifiedGroupKFold(n_splits=N_SPLITS_OUTER, shuffle=True, random_state=RANDOM_SEED)

def build_subject_table(samples_df: pd.DataFrame) -> pd.DataFrame:
    subj = samples_df.groupby("patient_id").agg(
        label=("label", "first"),
        country=("country", "first"),
        n_samples=("sid", "count"),
    )
    subj["embedding"] = samples_df.groupby("patient_id")["embedding"].apply(
        lambda x: np.mean(np.stack(x.to_list()), axis=0)
    )
    return subj.reset_index()

def run_cv(subj_df: pd.DataFrame, horizon_label: str) -> list[dict]:
    X_subj = np.stack(subj_df["embedding"].to_list())
    y_subj = subj_df["label"].to_numpy()
    groups = subj_df["patient_id"].to_numpy()

    fold_rows = []
    for fold_idx, (train_idx, test_idx) in enumerate(outer_cv.split(X_subj, y_subj, groups=groups)):
        X_train, X_test = X_subj[train_idx], X_subj[test_idx]
        y_train, y_test = y_subj[train_idx], y_subj[test_idx]

        model = Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LogisticRegression(
                C=1.0,
                class_weight="balanced",
                solver="lbfgs",
                max_iter=2000,
                random_state=RANDOM_SEED,
            )),
        ])

        model.fit(X_train, y_train)
        y_pred_proba = model.predict_proba(X_test)[:, 1]
        y_pred = (y_pred_proba >= 0.5).astype(int)

        fold_rows.append({
            "horizon": horizon_label,
            "fold": fold_idx,
            "n_train_subjects": len(y_train),
            "n_test_subjects": len(y_test),
            "auroc": roc_auc_score(y_test, y_pred_proba),
            "auprc": average_precision_score(y_test, y_pred_proba),
            "f1": f1_score(y_test, y_pred),
        })
    return fold_rows

cv_results = []

for m in HORIZONS:
    if m is None:
        horizon_label = "all"
        df_h = df.copy()
    else:
        horizon_label = f"≤{m}mo"
        df_h = df[df["month"] <= m].copy()

    subj = build_subject_table(df_h)
    cv_results.extend(run_cv(subj, horizon_label=horizon_label))

cv_df = pd.DataFrame(cv_results)
print(cv_df)

# Summary by horizon (mean ± std across folds)
summary_rows = []
for horizon_label, g in cv_df.groupby("horizon"):
    summary_rows.append({
        "horizon": horizon_label,
        "auroc_mean": g["auroc"].mean(),
        "auroc_std": g["auroc"].std(),
        "auprc_mean": g["auprc"].mean(),
        "auprc_std": g["auprc"].std(),
        "f1_mean": g["f1"].mean(),
        "f1_std": g["f1"].std(),
    })
summary_df = pd.DataFrame(summary_rows)

print("\n" + "="*60)
print("SUMMARY (mean ± std across folds)")
print("="*60)
display(summary_df)


## 4) LOCO: Leave-One-Country-Out (Per Horizon, Where Meaningful)


In [None]:
loco_results = []

for m in HORIZONS:
    if m is None:
        horizon_label = "all"
        df_h = df.copy()
    else:
        horizon_label = f"≤{m}mo"
        df_h = df[df["month"] <= m].copy()

    subj = build_subject_table(df_h)
    X_subj = np.stack(subj["embedding"].to_list())
    y_subj = subj["label"].to_numpy()
    countries_subj = subj["country"].to_numpy()

    for held_out in ["FIN", "EST", "RUS"]:
        # At <=3 months, RUS has too few subjects (and no positives) for meaningful LOCO.
        if m == 3 and held_out == "RUS":
            continue

        train_mask = countries_subj != held_out
        test_mask = countries_subj == held_out

        y_test = y_subj[test_mask]
        # AUROC is undefined if held-out set has only one class
        if len(set(y_test)) < 2:
            continue

        X_train, y_train = X_subj[train_mask], y_subj[train_mask]
        X_test = X_subj[test_mask]

        model = Pipeline([
            ("scaler", StandardScaler()),
            ("clf", LogisticRegression(
                C=1.0,
                class_weight="balanced",
                solver="lbfgs",
                max_iter=2000,
                random_state=RANDOM_SEED,
            )),
        ])

        model.fit(X_train, y_train)
        y_pred_proba = model.predict_proba(X_test)[:, 1]

        loco_results.append({
            "horizon": horizon_label,
            "held_out": held_out,
            "n_train_subjects": len(y_train),
            "n_test_subjects": len(y_test),
            "auroc": roc_auc_score(y_test, y_pred_proba),
            "auprc": average_precision_score(y_test, y_pred_proba),
        })

loco_df = pd.DataFrame(loco_results)
print("\n" + "="*60)
print("LOCO RESULTS (Leave-One-Country-Out)")
print("="*60)
display(loco_df)


## 5) Results Summary & Export


In [None]:
# Save results
cv_df.to_csv(RESULTS_DIR / "cv_metrics.csv", index=False)
loco_df.to_csv(RESULTS_DIR / "loco_metrics.csv", index=False)
summary_df.to_csv(RESULTS_DIR / "cv_summary.csv", index=False)

print("Results saved to results/")


## Known Limitations

**Onset timing is unknown.** The label is an *endpoint outcome* (eventual food allergy status), not a diagnosis at the time of sample collection. Any horizon may include post-onset samples for some infants.

**Milk allergy can manifest very early.** Infants exposed to cow's milk protein via formula or breast milk can develop milk allergy in the first weeks of life. This weakens "pure prediction" claims even at ≤3 months.

**Claim strength is a gradient, not a boundary:**
- **≤3 months**: Strongest predictive framing (but still limited)
- **≤6 months**: Moderate predictive framing
- **≤12 months**: Mixed predictive/associative
- **All samples**: Association only (includes post-onset samples)

**Country confounding.** Allergy prevalence differs by country (FIN 49%, EST 38%, RUS 15%). LOCO analysis helps assess whether the model learns transferable microbiome signal vs country-specific batch effects.

**Russia at ≤3 months.** Only 3 RUS patients at this horizon; LOCO results for RUS are not meaningful.


## Interpretation Guide

**AUROC interpretation:**
- 0.50 = random chance (no signal)
- 0.55–0.65 = weak signal
- 0.65–0.75 = moderate signal
- 0.75+ = strong signal

**LOCO interpretation:**
- If LOCO AUROC ≈ CV AUROC: Model learns transferable microbiome patterns
- If LOCO AUROC << CV AUROC: Model may be exploiting country-specific effects
- If LOCO AUROC < 0.50: Model fails to generalize to held-out country

**What to look for:**
1. Does AUROC decrease as horizon decreases? (Expected: less data = noisier estimates)
2. Is ≤3mo AUROC still above chance? (Key question for "early prediction" claim)
3. Do LOCO results hold up? (Country confounding check)


## Reproducibility Footer

In [None]:
from datetime import datetime

print("=" * 50)
print("REPRODUCIBILITY INFO")
print("=" * 50)
print(f"Random seed: {RANDOM_SEED}")
print(f"Outer CV splits: {N_SPLITS_OUTER}")
print(f"Python: {sys.version}")
print(f"NumPy: {np.__version__}")
print(f"pandas: {pd.__version__}")
print(f"scikit-learn: {sklearn.__version__}")
print(f"Run completed: {datetime.now().isoformat()}")
