# Demo — CNN + Majority Vote (19-slice OCT volume)

This notebook runs **inference only** for the baseline model:
- **Per-slice ResNet-50** → softmax probabilities
- **Volume-level prediction** = **majority vote** across the 19 slices  
  (ties resolved by mean probability)

It is designed to work **directly after cloning this GitHub repo**.

**What you need to change:** only the variables in the **CONFIG** cell.


In [None]:
# (Recommended) install dependencies (run once)
# If you already ran `pip install -r requirements.txt` in a terminal, you can skip this cell.

!pip -q install -r requirements.txt


## 1) CONFIG (edit only this cell)

In [None]:
\
from pathlib import Path
import torch

# Folder that contains your images.
# You can point to: train/ or val/ or test/ OR directly to a class folder (CHM/ USH2A/ Healthy/)
INPUT_DIR = Path(r"/path/to/your/data")  # <-- CHANGE THIS

CLASS_NAMES = ["CHM", "Healthy", "USH2A"] 

# Checkpoint filename (downloaded from GitHub Releases v1.0 into ./weights/)
WEIGHTS_PATH = Path("weights/cnn_resnet50_2025-10-20_best.pt")

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("INPUT_DIR :", INPUT_DIR)
print("WEIGHTS   :", WEIGHTS_PATH)
print("DEVICE    :", DEVICE)
print("CLASSES   :", CLASS_NAMES)


## 2) Download checkpoints (from GitHub Releases)

In [None]:
# This downloads all *.pt assets from Releases/v1.0 into ./weights/
# If you already downloaded them, the script will skip existing files.

!python tools/download_checkpoints.py


## 3) Inference code (baseline CNN + vote)

In [None]:
import re
from collections import Counter, defaultdict
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torchvision import models, transforms
from pathlib import Path

IMG_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

def parse_filename(fname: str) -> Optional[Tuple[str, str, int]]:
    """
    Expected filename format:
      <id1>_<id2>_<id3>_<eye>_<slice>.<ext>
    Example:
      371_4646_28891_R_01.PNG

    Returns: (patient_id, eye, slice_idx) with slice_idx in [1..19]
    """
    stem = Path(fname).stem  # removes extension
    parts = stem.split("_")
    if len(parts) < 5:
        return None
    patient_id = "_".join(parts[:3])
    eye = parts[3]
    m = re.match(r"^(\d+)$", parts[4])
    if m is None:
        return None
    slice_idx = int(m.group(1))
    return patient_id, eye, slice_idx

def infer_true_label_from_path(p: Path, class_names: List[str]) -> Optional[str]:
    parts = [x.lower() for x in p.parts]
    for c in class_names:
        if c.lower() in parts:
            return c
    return None

def collect_volumes(input_dir: Path, class_names: List[str]) -> Dict[Tuple[str, str, Optional[str]], Dict[int, Path]]:
    vols: Dict[Tuple[str, str, Optional[str]], Dict[int, Path]] = defaultdict(dict)
    for p in input_dir.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMG_EXTS:
            continue
        parsed = parse_filename(p.name)
        if parsed is None:
            continue
        patient_id, eye, slice_idx = parsed
        true_label = infer_true_label_from_path(p, class_names)
        key = (patient_id, eye, true_label)
        vols[key][slice_idx] = p
    return vols

class CNNResNet50(nn.Module):
    """
    Wrapper so that state_dict keys match 'resnet.*' (as in your checkpoint).
    """
    def __init__(self, num_classes: int):
        super().__init__()
        self.resnet = models.resnet50(weights=None)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

def load_model(weights_path: Path, num_classes: int, device: str) -> nn.Module:
    model = CNNResNet50(num_classes=num_classes).to(device)
    sd = torch.load(weights_path, map_location="cpu")
    model.load_state_dict(sd, strict=True)
    model.eval()
    return model

def build_transform(image_size: int = 224):
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet
            std=[0.229, 0.224, 0.225],
        ),
    ])

@torch.no_grad()
def predict_volume(model: nn.Module, slice_paths: List[Path], tfm, device: str):
    probs = []
    slice_preds = []
    for p in slice_paths:
        img = Image.open(p).convert("RGB")
        x = tfm(img).unsqueeze(0).to(device)
        logits = model(x)
        prob = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy()
        probs.append(prob)
        slice_preds.append(int(prob.argmax()))

    probs = np.stack(probs, axis=0)
    prob_mean = probs.mean(axis=0)

    vote = Counter(slice_preds).most_common()
    top_count = vote[0][1]
    top_classes = [cls for cls, cnt in vote if cnt == top_count]

    if len(top_classes) == 1:
        pred_idx = top_classes[0]
    else:
        pred_idx = int(np.argmax(prob_mean))

    return pred_idx, prob_mean, slice_preds

def run_cnn_vote(input_dir: Path, weights_path: Path, class_names: List[str], device: str, image_size: int = 224) -> pd.DataFrame:
    vols = collect_volumes(input_dir, class_names)
    if len(vols) == 0:
        raise RuntimeError("No images found. Check INPUT_DIR and filename format like 371_4646_28891_R_01.PNG.")

    model = load_model(weights_path, num_classes=len(class_names), device=device)
    tfm = build_transform(image_size=image_size)

    rows = []
    skipped = 0
    expected = list(range(1, 20))

    for (patient_id, eye, true_label), slice_map in sorted(vols.items()):
        if not all(k in slice_map for k in expected):
            skipped += 1
            continue

        slice_paths = [slice_map[i] for i in expected]
        pred_idx, prob_mean, slice_preds = predict_volume(model, slice_paths, tfm, device)

        row = {
            "patient_id": patient_id,
            "eye": eye,
            "true_label": true_label if true_label is not None else "",
            "pred_label": class_names[pred_idx],
            "pred_idx": int(pred_idx),
            "device": device,
        }
        for j, cname in enumerate(class_names):
            row[f"prob_{cname}"] = float(prob_mean[j])

        counts = Counter(slice_preds)
        for j, cname in enumerate(class_names):
            row[f"vote_{cname}"] = int(counts.get(j, 0))

        rows.append(row)

    df = pd.DataFrame(rows)
    if skipped > 0:
        print(f"⚠️ Skipped {skipped} volume(s) because some slices 01..19 were missing.")
    return df


## 4) Run inference

In [None]:
df = run_cnn_vote(
    input_dir=INPUT_DIR,
    weights_path=WEIGHTS_PATH,
    class_names=CLASS_NAMES,
    device=DEVICE,
)
df.head()


## 5) Save predictions

In [None]:
Path("outputs").mkdir(parents=True, exist_ok=True)
out_path = Path("outputs/predictions_cnn_vote.csv")
df.to_csv(out_path, index=False)
out_path
