# Demo â€” CNN + Mixture-of-Experts (MoE) (19-slice OCT volume)

This notebook runs **inference** on a folder of OCT PNG slices using the **CNN + MoE** model from this repository.

**What you get:** a CSV file saved to `outputs/predictions_moe.csv`.

> Tip (Windows): if you downloaded the repo as a ZIP, make sure you opened Jupyter **from inside the `TestAI` folder** (the one containing `README.md`, `tools/`, `models/`, ...).


In [None]:
# (Optional) Install dependencies (run once)
# If you already ran `pip install -r requirements.txt`, you can skip this cell.
!pip install -r requirements.txt


In [None]:
# Sanity check: you must run this notebook from the repository folder (TestAI)
import os
from pathlib import Path

print("CWD:", os.getcwd())
assert Path("tools/download_checkpoints.py").exists(), (
    "tools/download_checkpoints.py not found. "
    "Open Jupyter from the TestAI repo folder (the one with README.md, tools/, models/, ...)."
)


## 1) CONFIG (edit only this cell)

In [None]:
from pathlib import Path
import torch

# Folder that contains your PNG images (recursively).
# Examples:
#   INPUT_DIR = Path(r"C:/path/to/dataset_png")         # contains CHM/ Healthy/ USH2A/
#   INPUT_DIR = Path(r"C:/path/to/dataset_png/CHM")     # a single label folder
#   INPUT_DIR = Path(r"C:/path/to/any_folder_with_pngs")
INPUT_DIR = Path(r"C:/path/to/dataset_png")  # <-- CHANGE THIS

# Label order used during training: CHM=0, Healthy=1, USH2A=2
CLASS_NAMES = ["CHM", "Healthy", "USH2A"]

MODEL_NAME = "moe"  # do not change

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)
print("INPUT_DIR:", INPUT_DIR)


## 2) Download checkpoints (from GitHub Releases)

In [None]:
# This downloads all *.pt assets from Releases/v1.0 into ./weights/
# If you already downloaded them, the script will skip existing files.
!python tools/download_checkpoints.py


## 3) Run inference (MoE)

In [None]:
from inference import run_inference

df = run_inference(
    input_dir=INPUT_DIR,
    model_name=MODEL_NAME,
    device=DEVICE,
    class_names=CLASS_NAMES,
)
df.head()


## 4) Save predictions

In [None]:
from pathlib import Path

Path("outputs").mkdir(parents=True, exist_ok=True)
out_path = Path("outputs/predictions_moe.csv")
df.to_csv(out_path, index=False)
out_path


## (Optional) Convert `.E2E` files to PNG before inference

In [None]:
# If you start from `.E2E` files, place them in:
#   E2E_ROOT/CHM/*.E2E, E2E_ROOT/Healthy/*.E2E, E2E_ROOT/USH2A/*.E2E
#
# Then run:
# !pip install -r requirements-e2e.txt
# !python tools/export_e2e_to_png.py --e2e-root "C:/path/to/E2E_ROOT" --out-root "C:/path/to/dataset_png"
