# Demo — CNN + Transformer fusion (Hybrid) (19-slice OCT volume)

This notebook runs **inference** on a folder of OCT PNG slices using the **CNN + Transformer fusion** model from this repository.

**What you get:** a CSV file saved to `outputs/predictions_hybrid.csv`.


In [None]:
# (Optional) Install dependencies (run once)
# If you already ran `pip install -r requirements.txt`, you can skip this cell.
!pip install -r requirements.txt


In [4]:
# Sanity check: you must run this notebook from the repository folder (TestAI)
import os
from pathlib import Path

print("CWD:", os.getcwd())
assert Path("tools/download_checkpoints.py").exists(), (
    "tools/download_checkpoints.py not found. "
    "Open Jupyter from the TestAI repo folder (the one with README.md, tools/, models/, ...)."
)


CWD: C:\Users\kevin\Downloads\TestAI-main (1)\TestAI-main


## 1) CONFIG (edit only this cell)

In [36]:
from pathlib import Path
import torch

# Folder that contains your PNG images (recursively).
# Examples:
#   INPUT_DIR = Path(r"C:/path/to/dataset_png")         # contains CHM/ Healthy/ USH2A/
#   INPUT_DIR = Path(r"C:/path/to/dataset_png/CHM")     # a single label folder
#   INPUT_DIR = Path(r"C:/path/to/any_folder_with_pngs")
INPUT_DIR = Path(r"C:\Users\kevin\Documents\Thèse\validation_externe\dataset")  # <-- CHANGE THIS

# Label order used during training: CHM=0, Healthy=1, USH2A=2
CLASS_NAMES = ["CHM", "Healthy", "USH2A"]

MODEL_NAME = "hybrid"  # do not change

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
print("INPUT_DIR:", INPUT_DIR)


DEVICE: cuda
INPUT_DIR: C:\Users\kevin\Documents\Thèse\validation_externe\dataset


## 2) Download checkpoints (from GitHub Releases)

In [6]:
# This downloads all *.pt assets from Releases/v1.0 into ./weights/
# If you already downloaded them, the script will skip existing files.
!python tools/download_checkpoints.py


Found 3 checkpoint(s) in release v1.0:
 - cnn_resnet50_2025-10-20_best.pt
 - cnn_resnet50_MOE_2025-11-09_best.pt
 - hybrid_resnet50_xformer_2025-11-04_best.pt
[OK] cnn_resnet50_2025-10-20_best.pt already downloaded.
[OK] cnn_resnet50_MOE_2025-11-09_best.pt already downloaded.
[OK] hybrid_resnet50_xformer_2025-11-04_best.pt already downloaded.

Done. weights/ contains:
 - cnn_resnet50_2025-10-20_best.pt
 - cnn_resnet50_MOE_2025-11-09_best.pt
 - hybrid_resnet50_xformer_2025-11-04_best.pt


## 3) Run inference (Hybrid)


In [38]:
from inference import run_inference

df = run_inference(
    input_dir=INPUT_DIR,
    model_name=MODEL_NAME,
    device=DEVICE,
    class_names=CLASS_NAMES,
)
df.head()


  ckpt = torch.load(weights_path, map_location="cpu")


[hybrid] inferred num_encoder_layers = 1




Unnamed: 0,volume_id,laterality,n_slices,pred_idx,pred_label,true_label,prob_CHM,prob_Healthy,prob_USH2A
0,10978_20597_150810,R,19,2,USH2A,USH2A,0.000612,0.001738,0.99765
1,10978_20597_150811,L,19,2,USH2A,USH2A,0.001176,0.001392,0.997433
2,11391_21325_156743,R,19,2,USH2A,USH2A,0.000492,0.000675,0.998833
3,11391_21325_156745,L,19,2,USH2A,USH2A,0.000657,0.001149,0.998194
4,11681_21913_161058,R,19,0,CHM,CHM,0.983699,0.011244,0.005057


## 4) Save predictions


In [40]:
from pathlib import Path

Path("outputs").mkdir(parents=True, exist_ok=True)
out_path = Path("outputs/predictions_hybrid.csv")
df.to_csv(out_path, index=False)
out_path


WindowsPath('outputs/predictions_hybrid.csv')

## (Optional) Convert `.E2E` files to PNG before inference

In [42]:
# If you start from `.E2E` files, place them in:
#   E2E_ROOT/CHM/*.E2E, E2E_ROOT/Healthy/*.E2E, E2E_ROOT/USH2A/*.E2E
#
# Then run:
# !pip install -r requirements-e2e.txt
# !python tools/export_e2e_to_png.py --e2e-root "C:/path/to/E2E_ROOT" --out-root "C:/path/to/dataset_png"


# 5) Métriques

In [44]:
import numpy as np
import pandas as pd
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, average_precision_score
)
from sklearn.preprocessing import label_binarize

# ====== External validation stats from df ======
# Expected df columns:
# true_label (str), pred_label (str) or pred_idx (int),
# prob_<CLASS> columns (float) matching class names.

# 1) Define class order from prob_ columns (recommended: consistent order)
prob_cols = [c for c in df.columns if c.startswith("prob_")]
if len(prob_cols) == 0:
    raise ValueError("No probability columns found (expected columns like prob_CHM, prob_Healthy, prob_USH2A).")

classes = [c.replace("prob_", "") for c in prob_cols]  # e.g. ["CHM","Healthy","USH2A"]

# 2) Build y_true / y_pred as indices in this class order
if "true_label" not in df.columns:
    raise ValueError("df must contain 'true_label' column (string labels).")
if "pred_label" not in df.columns and "pred_idx" not in df.columns:
    raise ValueError("df must contain 'pred_label' or 'pred_idx'.")

label_to_idx = {lbl: i for i, lbl in enumerate(classes)}

# y_true
y_true = df["true_label"].map(label_to_idx).to_numpy()
if np.any(pd.isna(y_true)):
    missing = sorted(set(df.loc[pd.isna(df["true_label"].map(label_to_idx)), "true_label"]))
    raise ValueError(f"Some true_label values are not in prob_ columns class list: {missing}\n"
                     f"Classes inferred from prob_ cols: {classes}")

# y_pred
if "pred_label" in df.columns:
    y_pred = df["pred_label"].map(label_to_idx).to_numpy()
    if np.any(pd.isna(y_pred)):
        missing = sorted(set(df.loc[pd.isna(df["pred_label"].map(label_to_idx)), "pred_label"]))
        raise ValueError(f"Some pred_label values are not in prob_ columns class list: {missing}\n"
                         f"Classes inferred from prob_ cols: {classes}")
else:
    # assumes pred_idx already matches the same ordering as prob_cols
    y_pred = df["pred_idx"].to_numpy().astype(int)

# probabilities matrix aligned with `classes`
y_prob = df[prob_cols].to_numpy(dtype=float)

# 3) Basic counts
N = len(df)
counts = df["true_label"].value_counts().reindex(classes, fill_value=0)

# 4) Metrics (article-level)
acc = accuracy_score(y_true, y_pred)
prec_w, rec_w, f1_w, _ = precision_recall_fscore_support(
    y_true, y_pred, average="weighted", zero_division=0
)

# Average Precision (multi-class OvR)
# Choose ONE definition and stick to it (macro is common; weighted is also ok)
y_true_bin = label_binarize(y_true, classes=np.arange(len(classes)))
ap_macro = average_precision_score(y_true_bin, y_prob, average="macro")
ap_weighted = average_precision_score(y_true_bin, y_prob, average="weighted")

# Confusion matrix
cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(classes)))

# 5) Pretty printing
correct = int((y_true == y_pred).sum())
print("=== External validation summary ===")
print(f"N volumes: {N}")
print("Class distribution (true):")
for c in classes:
    print(f"  {c:10s}: {int(counts[c])}")

print(f"\nAccuracy: {acc:.4f} ({correct}/{N})")
print(f"Weighted precision: {prec_w:.4f}")
print(f"Weighted recall   : {rec_w:.4f}")
print(f"Weighted F1-score : {f1_w:.4f}")

print("\nAverage Precision (OvR):")
print(f"  AP macro    : {ap_macro:.4f}")
print(f"  AP weighted : {ap_weighted:.4f}  (use this if you want prevalence-weighted AP)")

print("\nConfusion matrix (rows=true, cols=pred) | order:", classes)
cm_df = pd.DataFrame(cm, index=[f"true_{c}" for c in classes], columns=[f"pred_{c}" for c in classes])
display(cm_df)

# Optional: per-class precision/recall/F1 (still simple, but can be useful in supplement)
prec_c, rec_c, f1_c, sup_c = precision_recall_fscore_support(
    y_true, y_pred, labels=np.arange(len(classes)), average=None, zero_division=0
)
per_class_df = pd.DataFrame({
    "class": classes,
    "support": sup_c,
    "precision": prec_c,
    "recall": rec_c,
    "f1": f1_c
})
display(per_class_df)


=== External validation summary ===
N volumes: 22
Class distribution (true):
  CHM       : 8
  Healthy   : 8
  USH2A     : 6

Accuracy: 0.7273 (16/22)
Weighted precision: 0.7727
Weighted recall   : 0.7273
Weighted F1-score : 0.7403

Average Precision (OvR):
  AP macro    : 0.8746
  AP weighted : 0.8791  (use this if you want prevalence-weighted AP)

Confusion matrix (rows=true, cols=pred) | order: ['CHM', 'Healthy', 'USH2A']


Unnamed: 0,pred_CHM,pred_Healthy,pred_USH2A
true_CHM,6,0,2
true_Healthy,0,6,2
true_USH2A,0,2,4


Unnamed: 0,class,support,precision,recall,f1
0,CHM,8,1.0,0.75,0.857143
1,Healthy,8,0.75,0.75,0.75
2,USH2A,6,0.5,0.666667,0.571429
