0. Imports & basic setup

In [1]:
import os, json, random
from pathlib import Path

import numpy as np
import pandas as pd
import cv2

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

import timm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

1. Config (easy to tweak)

In [None]:
TRAIN_DIR = r"F:\Workspace\Working\college-minor-project\datasets\Disease Grading\Original Images\Training Set"
TEST_DIR  = r"F:\Workspace\Working\college-minor-project\datasets\Disease Grading\Original Images\Testing Set"
TRAIN_CSV = r"F:\Workspace\Working\college-minor-project\datasets\Disease Grading\Groundtruths\a. IDRiD_Disease Grading_Training Labels.csv"
TEST_CSV  = r"F:\Workspace\Working\college-minor-project\datasets\Disease Grading\Groundtruths\b. IDRiD_Disease Grading_Testing Labels.csv"


# ---- Paths (preprocessed) ----
PREP_ROOT       = Path("preprocessed_idrid")
PREP_TRAIN_DIR  = PREP_ROOT / "train"
PREP_TEST_DIR   = PREP_ROOT / "test"
PREP_TRAIN_CSV  = PREP_ROOT / "train_preprocessed.csv"
PREP_TEST_CSV   = PREP_ROOT / "test_preprocessed.csv"

# ---- Outputs ----
RUN_DIR   = Path("runs/idrid_simple")
CKPT_PATH = Path("checkpoints/best_idrid_simple.pt")

# Model / training
MODEL_NAME = "mobilenetv3_large_100"   
IMG_SIZE   = 320                      # smaller than 512 → faster
#NUM_CLASSES = 5

NUM_DR_CLASSES = 5          # Retinopathy grade: 0..4
NUM_DME_CLASSES = 3         # Risk of macular edema: 0..2

LOSS_WEIGHT_DME = 1.0       # you can tune this (e.g. 0.5) if one task dominates

BATCH_SIZE = 12
EPOCHS = 20
VAL_SPLIT = 0.2
LR_WARMUP = 4e-4               # new
LR_FINETUNE = 1e-4
WEIGHT_DECAY = 1e-4
PATIENCE = 7

IMBALANCE_STRATEGY = "sampler"         # "sampler", "class_weights", "none"
WARMUP_EPOCHS = 2 

SEED = 24


IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

In [4]:
TRAIN_NPZ = PREP_ROOT / "train_arrays.npz"
TEST_NPZ  = PREP_ROOT / "test_arrays.npz"

1. Basic helpers (minimal functions)

In [5]:
def set_seed(seed=42):
    import torch.backends.cudnn as cudnn
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.deterministic = True
    cudnn.benchmark = False

set_seed(SEED)

In [6]:
def find_column(df: pd.DataFrame, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    lowered = {c.lower().strip(): c for c in df.columns}
    for c in candidates:
        key = c.lower().strip()
        if key in lowered:
            return lowered[key]
    raise ValueError(f"CSV missing expected columns. Tried {candidates}. Found: {list(df.columns)}")

In [7]:
def fundus_bbox_square(img: np.ndarray, pad_ratio: float = 0.01):
    """
    Robust crop that tries not to cut off the fundus:
    1) mask non-dark pixels
    2) bounding rect of mask
    3) pad to square with black borders
    """
    try:
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        thr = max(5, int(np.percentile(gray, 5)))
        mask = (gray > thr).astype(np.uint8) * 255
        cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not cnts:
            return img
        x, y, w, h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
        H, W = img.shape[:2]
        padx = int(w * pad_ratio)
        pady = int(h * pad_ratio)
        x1 = max(0, x - padx)
        y1 = max(0, y - pady)
        x2 = min(W, x + w + padx)
        y2 = min(H, y + h + pady)
        crop = img[y1:y2, x1:x2]

        hh, ww = crop.shape[:2]
        if hh == ww:
            return crop
        side = max(hh, ww)
        top = (side - hh) // 2
        bottom = side - hh - top
        left = (side - ww) // 2
        right = side - ww - left
        crop_sq = cv2.copyMakeBorder(
            crop, top, bottom, left, right,
            borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0)
        )
        return crop_sq
    except Exception:
        return img

In [8]:
def retina_enhance(rgb: np.ndarray):
    """Simple Ben Graham style shade correction + mild unsharp mask."""
    blur = cv2.GaussianBlur(rgb, (0, 0), sigmaX=rgb.shape[1] * 0.05)
    out = cv2.addWeighted(rgb, 4.0, blur, -4.0, 128)
    return np.clip(out, 0, 255).astype(np.uint8)

2. Init seed & device

In [9]:
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. This script is written for GPU training only.")
device = torch.device("cuda:0")
print(f"[Info] Using device: {device}")

[Info] Using device: cuda:0


3. Read original CSVs

In [10]:
df_train_raw = pd.read_csv(TRAIN_CSV)
df_test_raw  = pd.read_csv(TEST_CSV)

In [11]:




# train_img_col  = find_column(df_train_raw, img_col_candidates)
# train_label_col = find_column(df_train_raw, label_col_candidates)
# test_img_col   = find_column(df_test_raw, img_col_candidates)
# test_label_col  = find_column(df_test_raw, label_col_candidates)



In [12]:
df_train = df_train_raw[["Image name", "Retinopathy grade", "Risk of macular edema"]].copy()
df_train.columns = ["image", "label_dr", "label_dme"]

df_test = df_test_raw[["Image name", "Retinopathy grade", "Risk of macular edema "]].copy()
df_test.columns = ["image", "label_dr", "label_dme"]

In [13]:
def ensure_ext(x):
    x = str(x)
    if any(x.lower().endswith(ext) for ext in [".jpg", ".jpeg", ".png", ".tif", ".tiff"]):
        return x
    return x + ".jpg"

df_train["image"]    = df_train["image"].apply(ensure_ext)
df_test["image"]     = df_test["image"].apply(ensure_ext)

df_train["label_dr"]  = df_train["label_dr"].astype(int)
df_train["label_dme"] = df_train["label_dme"].astype(int)

df_test["label_dr"]   = df_test["label_dr"].astype(int)
df_test["label_dme"]  = df_test["label_dme"].astype(int)

In [14]:
print(f"[Info] Original -> Train rows: {len(df_train)}, Test rows: {len(df_test)}")

[Info] Original -> Train rows: 413, Test rows: 103


4. Preprocess & save images (only once)
- crop, enhance, resize
- save to PREP_TRAIN_DIR / PREP_TEST_DIR
- write new CSVs with full paths

In [15]:
df_train.iterrows

<bound method DataFrame.iterrows of              image  label_dr  label_dme
0    IDRiD_001.jpg         3          2
1    IDRiD_002.jpg         3          2
2    IDRiD_003.jpg         2          2
3    IDRiD_004.jpg         3          2
4    IDRiD_005.jpg         4          0
..             ...       ...        ...
408  IDRiD_409.jpg         2          1
409  IDRiD_410.jpg         2          0
410  IDRiD_411.jpg         2          0
411  IDRiD_412.jpg         2          0
412  IDRiD_413.jpg         2          0

[413 rows x 3 columns]>

In [16]:
PREP_TRAIN_DIR.mkdir(parents=True, exist_ok=True)
PREP_TEST_DIR.mkdir(parents=True, exist_ok=True)
PREP_ROOT.mkdir(parents=True, exist_ok=True)

if PREP_TRAIN_CSV.exists() and PREP_TEST_CSV.exists():
    print(f"[Info] Preprocessed CSVs already exist: {PREP_TRAIN_CSV}, {PREP_TEST_CSV}")
    df_train_prep = pd.read_csv(PREP_TRAIN_CSV)
    df_test_prep  = pd.read_csv(PREP_TEST_CSV)
else:
    print("[Info] Preprocessing images (this is the slow part, done once)...")

    # ---- preprocess TRAIN ----
    paths_prep      = []
    labels_dr_prep  = []
    labels_dme_prep = []

    for i, row in df_train.iterrows():
        if (i + 1) % 50 == 0:
            print(f"  [Train] {i+1}/{len(df_train)}")

        raw_path = Path(TRAIN_DIR) / row["image"]
        if not raw_path.exists():
            print(f"  [WARN] Missing train image: {raw_path}")
            continue

        bgr = cv2.imread(str(raw_path), cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"  [WARN] Cannot read train image: {raw_path}")
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        # crop + enhance + resize
        rgb = fundus_bbox_square(rgb, pad_ratio=0.01)
        rgb = retina_enhance(rgb)
        rgb = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)

        # save
        out_name = f"train_{i:04d}.jpg"
        out_path = PREP_TRAIN_DIR / out_name
        cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))

        paths_prep.append(str(out_path.resolve()))
        labels_dr_prep.append(int(row["label_dr"]))
        labels_dme_prep.append(int(row["label_dme"]))

    df_train_prep = pd.DataFrame({
        "path": paths_prep,
        "label_dr": labels_dr_prep,
        "label_dme": labels_dme_prep,
    })
    df_train_prep.to_csv(PREP_TRAIN_CSV, index=False)
    print(f"[Info] Saved preprocessed train CSV: {PREP_TRAIN_CSV}")

[Info] Preprocessing images (this is the slow part, done once)...
  [Train] 50/413
  [Train] 100/413
  [Train] 150/413
  [Train] 200/413
  [Train] 250/413
  [Train] 300/413
  [Train] 350/413
  [Train] 400/413
[Info] Saved preprocessed train CSV: preprocessed_idrid\train_preprocessed.csv


In [17]:
# ---- preprocess TEST ----
paths_prep      = []
labels_dr_prep  = []
labels_dme_prep = []

for i, row in df_test.iterrows():
    if (i + 1) % 50 == 0:
        print(f"  [Test] {i+1}/{len(df_test)}")

    raw_path = Path(TEST_DIR) / row["image"]
    if not raw_path.exists():
        print(f"  [WARN] Missing test image: {raw_path}")
        continue

    bgr = cv2.imread(str(raw_path), cv2.IMREAD_COLOR)
    if bgr is None:
        print(f"  [WARN] Cannot read test image: {raw_path}")
        continue
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    rgb = fundus_bbox_square(rgb, pad_ratio=0.01)
    rgb = retina_enhance(rgb)
    rgb = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)

    out_name = f"test_{i:04d}.jpg"
    out_path = PREP_TEST_DIR / out_name
    cv2.imwrite(str(out_path), cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))

    paths_prep.append(str(out_path.resolve()))
    labels_dr_prep.append(int(row["label_dr"]))
    labels_dme_prep.append(int(row["label_dme"]))

df_test_prep = pd.DataFrame({
    "path": paths_prep,
    "label_dr": labels_dr_prep,
    "label_dme": labels_dme_prep,
})
df_test_prep.to_csv(PREP_TEST_CSV, index=False)
print(f"[Info] Saved preprocessed test CSV: {PREP_TEST_CSV}")

  [Test] 50/103
  [Test] 100/103
[Info] Saved preprocessed test CSV: preprocessed_idrid\test_preprocessed.csv


In [50]:
# ============================================================
# 4b. Targeted offline augmentation for DR grade-1 ONLY
#     (run ONCE; then comment out this block or set flag False)
# ============================================================
DO_DR1_AUG = True          # set False after you’ve run once
N_AUG_PER_DR1 = 4          # 3–5 is reasonable

if DO_DR1_AUG:
    import albumentations as A

    print(f"[Info] Running targeted augmentation for DR grade-1 "
          f"({N_AUG_PER_DR1} extra samples per original).")

    # mild, realistic augmentations
    dr1_aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(0.1, 0.1, p=0.7),
        A.ShiftScaleRotate(shift_limit=0.02,
                           scale_limit=0.05,
                           rotate_limit=10,
                           border_mode=cv2.BORDER_REFLECT101,
                           p=0.7),
        A.GaussianBlur(blur_limit=(3, 5), p=0.2)
    ])

    # filter DR grade-1 rows
    df_dr1 = df_train_prep[df_train_prep["label_dr"] == 1].reset_index(drop=True)
    print(f"[Info] Found {len(df_dr1)} DR grade-1 images for augmentation.")

    extra_rows = []

    for i, row in df_dr1.iterrows():
        img_path = Path(row["path"])
        bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"[WARN] Cannot read DR1 image: {img_path}")
            continue

        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        for k in range(N_AUG_PER_DR1):
            augmented = dr1_aug(image=rgb)["image"]

            out_name = f"{img_path.stem}_dr1aug{k}.jpg"
            out_path = PREP_TRAIN_DIR / out_name
            cv2.imwrite(str(out_path), cv2.cvtColor(augmented, cv2.COLOR_RGB2BGR))

            extra_rows.append({
                "path": str(out_path.resolve()),
                "label_dr": int(row["label_dr"]),      # always 1
                "label_dme": int(row["label_dme"])     # keep same DME label
            })

        if (i + 1) % 10 == 0:
            print(f"  [DR1-Aug] {i+1}/{len(df_dr1)} originals done")

    if extra_rows:
        df_extra = pd.DataFrame(extra_rows)
        df_train_prep = pd.concat([df_train_prep, df_extra], ignore_index=True)
        df_train_prep.to_csv(PREP_TRAIN_CSV, index=False)
        print(f"[Info] Added {len(df_extra)} augmented DR=1 samples. "
              f"New train size: {len(df_train_prep)}")
    else:
        print("[Info] No extra DR1 samples created (check warnings above).")


[Info] Running targeted augmentation for DR grade-1 (4 extra samples per original).
[Info] Found 40 DR grade-1 images for augmentation.


  original_init(self, **validated_kwargs)


  [DR1-Aug] 10/40 originals done
  [DR1-Aug] 20/40 originals done
  [DR1-Aug] 30/40 originals done
  [DR1-Aug] 40/40 originals done
[Info] Added 160 augmented DR=1 samples. New train size: 986


5. OPTIONAL: Albumentations offline augmentation

In [16]:
USE_ALBUMENTATIONS_OFFLINE = True  # <-- set True if you want to run this block ONCE

In [18]:
import albumentations as A

print("[Info] Running offline Albumentations to create extra augmented images...")
aug = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(0.1, 0.1, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.05, rotate_limit=10, border_mode=cv2.BORDER_REFLECT101, p=0.5),
    A.GaussianBlur(blur_limit=(3,5), p=0.2)
])

extra_paths      = []
extra_labels_dr  = []
extra_labels_dme = []

for i, row in df_train_prep.iterrows():
    if (i + 1) % 50 == 0:
        print(f"  [Aug] {i+1}/{len(df_train_prep)}")
    img_path = Path(row["path"])
    bgr = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
    if bgr is None:
        continue
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    augmented = aug(image=rgb)["image"]
    out_name = img_path.stem + "_aug.jpg"
    out_path = PREP_TRAIN_DIR / out_name
    cv2.imwrite(str(out_path), cv2.cvtColor(augmented, cv2.COLOR_RGB2BGR))
    extra_paths.append(str(out_path.resolve()))
    extra_labels_dr.append(int(row["label_dr"]))
    extra_labels_dme.append(int(row["label_dme"]))

if extra_paths:
    df_extra = pd.DataFrame({
        "path": extra_paths,
        "label_dr": extra_labels_dr,
        "label_dme": extra_labels_dme,
    })
    df_train_prep = pd.concat([df_train_prep, df_extra], ignore_index=True)
    df_train_prep.to_csv(PREP_TRAIN_CSV, index=False)
    print(f"[Info] Albumentations added {len(extra_paths)} samples. New train size: {len(df_train_prep)}")

[Info] Running offline Albumentations to create extra augmented images...


  original_init(self, **validated_kwargs)


  [Aug] 50/413
  [Aug] 100/413
  [Aug] 150/413
  [Aug] 200/413
  [Aug] 250/413
  [Aug] 300/413
  [Aug] 350/413
  [Aug] 400/413
[Info] Albumentations added 413 samples. New train size: 826


6. Load preprocessed images into memory & normalize

In [51]:
if TRAIN_NPZ.exists() and TEST_NPZ.exists():
    print(f"[Info] NPZ cache found: {TRAIN_NPZ}, {TEST_NPZ}")
    train_data = np.load(TRAIN_NPZ)
    test_data  = np.load(TEST_NPZ)

    X_train     = train_data["X"]
    y_train_dr  = train_data["y_dr"]
    y_train_dme = train_data["y_dme"]

    X_test      = test_data["X"]
    y_test_dr   = test_data["y_dr"]
    y_test_dme  = test_data["y_dme"]

else:
    print("[Info] NPZ cache not found. Loading JPGs and building arrays (one-time)...")

    # ---- build train arrays ----
    X_list      = []
    y_dr_list   = []
    y_dme_list  = []
    for i, row in df_train_prep.iterrows():
        if (i + 1) % 100 == 0:
            print(f"  [Load-Train] {i+1}/{len(df_train_prep)}")
        path = row["path"]
        bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"[WARN] Cannot read preprocessed train image: {path}")
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        X_list.append(rgb)
        y_dr_list.append(int(row["label_dr"]))
        y_dme_list.append(int(row["label_dme"]))

    X_train     = np.stack(X_list, axis=0)
    y_train_dr  = np.array(y_dr_list, dtype=np.int64)
    y_train_dme = np.array(y_dme_list, dtype=np.int64)

    # ---- build test arrays ----
    X_list      = []
    y_dr_list   = []
    y_dme_list  = []
    for i, row in df_test_prep.iterrows():
        if (i + 1) % 100 == 0:
            print(f"  [Load-Test] {i+1}/{len(df_test_prep)}")
        path = row["path"]
        bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"[WARN] Cannot read preprocessed test image: {path}")
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = cv2.resize(rgb, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        X_list.append(rgb)
        y_dr_list.append(int(row["label_dr"]))
        y_dme_list.append(int(row["label_dme"]))

    X_test      = np.stack(X_list, axis=0)
    y_test_dr   = np.array(y_dr_list, dtype=np.int64)
    y_test_dme  = np.array(y_dme_list, dtype=np.int64)

    # save NPZ so future runs are INSTANT
    np.savez(TRAIN_NPZ, X=X_train, y_dr=y_train_dr, y_dme=y_train_dme)
    np.savez(TEST_NPZ,  X=X_test,  y_dr=y_test_dr,  y_dme=y_test_dme)
    print(f"[Info] Saved NPZ cache: {TRAIN_NPZ}, {TEST_NPZ}")

print(f"[Info] X_train shape: {X_train.shape}, y_train_dr shape: {y_train_dr.shape}, y_train_dme shape: {y_train_dme.shape}")
print(f"[Info] X_test shape:  {X_test.shape},  y_test_dr shape:  {y_test_dr.shape},  y_test_dme shape:  {y_test_dme.shape}")


[Info] NPZ cache not found. Loading JPGs and building arrays (one-time)...
  [Load-Train] 100/986
  [Load-Train] 200/986
  [Load-Train] 300/986
  [Load-Train] 400/986
  [Load-Train] 500/986
  [Load-Train] 600/986
  [Load-Train] 700/986
  [Load-Train] 800/986
  [Load-Train] 900/986
  [Load-Test] 100/103
[Info] Saved NPZ cache: preprocessed_idrid\train_arrays.npz, preprocessed_idrid\test_arrays.npz
[Info] X_train shape: (986, 320, 320, 3), y_train_dr shape: (986,), y_train_dme shape: (986,)
[Info] X_test shape:  (103, 320, 320, 3),  y_test_dr shape:  (103,),  y_test_dme shape:  (103,)


In [None]:
def load_images_from_df(df, img_size):
    X_list      = []
    y_dr_list   = []
    y_dme_list  = []
    for i, row in df.iterrows():
        if (i + 1) % 100 == 0:
            print(f"  [Load] {i+1}/{len(df)}")
        path = row["path"]
        bgr = cv2.imread(str(path), cv2.IMREAD_COLOR)
        if bgr is None:
            print(f"[WARN] Cannot read preprocessed image: {path}")
            continue
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = cv2.resize(rgb, (img_size, img_size), interpolation=cv2.INTER_AREA)
        X_list.append(rgb)
        y_dr_list.append(int(row["label_dr"]))
        y_dme_list.append(int(row["label_dme"]))

    X    = np.stack(X_list, axis=0)
    y_dr = np.array(y_dr_list, dtype=np.int64)
    y_dme = np.array(y_dme_list, dtype=np.int64)
    return X, y_dr, y_dme

print("[Info] Loading preprocessed train images into RAM...")
X_train, y_train_dr, y_train_dme = load_images_from_df(df_train_prep, IMG_SIZE)

print("[Info] Loading preprocessed test images into RAM...")
X_test, y_test_dr, y_test_dme   = load_images_from_df(df_test_prep, IMG_SIZE)

[Info] Loading preprocessed train images into RAM...
  [Load] 100/826
  [Load] 200/826
  [Load] 300/826
  [Load] 400/826
  [Load] 500/826
  [Load] 600/826
  [Load] 700/826
  [Load] 800/826
[Info] Loading preprocessed test images into RAM...
  [Load] 100/103


In [63]:
print(f"[Info] X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"[Info] X_test shape:  {X_test.shape},  y_test shape:  {y_test.shape}")

[Info] X_train shape: (826, 320, 320, 3), y_train shape: (826,)
[Info] X_test shape:  (103, 320, 320, 3),  y_test shape:  (103,)


In [53]:
X_train = X_train.astype(np.float32) / 255.0
X_test  = X_test.astype(np.float32)  / 255.0

In [54]:
mean = IMAGENET_MEAN.reshape(1, 1, 1, 3)
std  = IMAGENET_STD.reshape(1, 1, 1, 3)
X_train = (X_train - mean) / std
X_test  = (X_test  - mean) / std

In [55]:
# N,H,W,C -> N,C,H,W
X_train = np.transpose(X_train, (0, 3, 1, 2))
X_test  = np.transpose(X_test,  (0, 3, 1, 2))

In [56]:
X_train_t    = torch.from_numpy(X_train)
y_train_dr_t = torch.from_numpy(y_train_dr)
y_train_dme_t = torch.from_numpy(y_train_dme)

X_test_t     = torch.from_numpy(X_test)
y_test_dr_t  = torch.from_numpy(y_test_dr)
y_test_dme_t = torch.from_numpy(y_test_dme)

In [57]:
print(f"[Info] Tensor shapes -> X_train: {X_train_t.shape}, X_test: {X_test_t.shape}")

[Info] Tensor shapes -> X_train: torch.Size([986, 3, 320, 320]), X_test: torch.Size([103, 3, 320, 320])


7. Train/val split

In [58]:
splitter = StratifiedShuffleSplit(
    n_splits=1,
    test_size=VAL_SPLIT,
    random_state=SEED
)
train_idx, val_idx = next(splitter.split(np.arange(len(y_train_dr)), y_train_dr))


In [59]:
X_tr_t        = X_train_t[train_idx]
y_tr_dr_t     = y_train_dr_t[train_idx]
y_tr_dme_t    = y_train_dme_t[train_idx]

X_va_t        = X_train_t[val_idx]
y_va_dr_t     = y_train_dr_t[val_idx]
y_va_dme_t    = y_train_dme_t[val_idx]

print(f"[Info] Split -> Train: {len(y_tr_dr_t)}, Val: {len(y_va_dr_t)}, Test: {len(y_test_dr_t)}")

[Info] Split -> Train: 788, Val: 198, Test: 103


8. Dataloaders (TensorDataset, no custom Dataset class)

In [60]:
train_dataset = TensorDataset(X_tr_t, y_tr_dr_t, y_tr_dme_t)
val_dataset   = TensorDataset(X_va_t, y_va_dr_t, y_va_dme_t)
test_dataset  = TensorDataset(X_test_t, y_test_dr_t, y_test_dme_t)


In [61]:
sampler = None
class_weights_dr = None

In [62]:
IMBALANCE_STRATEGY = "sampler"

In [63]:
if IMBALANCE_STRATEGY == "sampler":
    counts = np.bincount(y_tr_dr_t.numpy(), minlength=NUM_DR_CLASSES)
    w_per_sample = 1.0 / (counts[y_tr_dr_t.numpy()] + 1e-6)
    sampler = WeightedRandomSampler(
        weights=w_per_sample,
        num_samples=len(w_per_sample),
        replacement=True
    )
    print("[Info] Using WeightedRandomSampler (DR labels).")
elif IMBALANCE_STRATEGY == "class_weights":
    counts = np.bincount(y_tr_dr_t.numpy(), minlength=NUM_DR_CLASSES).astype(float)
    weights = counts.sum() / (counts + 1e-6)
    weights = weights / weights.mean()
    class_weights_dr = torch.tensor(weights, dtype=torch.float32, device=device)
    print(f"[Info] Using DR class weights: {class_weights_dr.cpu().numpy()}")
else:
    print("[Info] No imbalance strategy.")


[Info] Using WeightedRandomSampler (DR labels).


In [64]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=(sampler is None),
    sampler=sampler,
    num_workers=0,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

9. Model, loss, optimizer

In [65]:
# model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES, drop_rate=0.4)
# model = model.to(device)
# print(f"[Info] Model '{MODEL_NAME}' created. First param device: {next(model.parameters()).device}")


# # --- Warmup setup: freeze all but classifier head ---
# WARMUP_EPOCHS = WARMUP_EPOCHS  # uses the value you set above

# # 1) Freeze everything
# for p in model.parameters():
#     p.requires_grad = False

# # 2) Unfreeze only classifier / head parameters
# for name, p in model.named_parameters():
#     if any(k in name.lower() for k in ["classifier", "head", "fc"]):
#         p.requires_grad = True

# print("[Info] Warmup: training only the classifier head for first "
#       f"{WARMUP_EPOCHS} epochs.")


In [66]:
class MultiTaskEffNet(nn.Module):
    def __init__(self, backbone_name, num_dr, num_dme, drop_rate=0.4):
        super().__init__()
        # backbone without classifier, global_pool='avg' for pooled features
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=True,
            num_classes=0,
            global_pool='avg',
            drop_rate=drop_rate
        )
        feat_dim = self.backbone.num_features
        self.head_dr  = nn.Linear(feat_dim, num_dr)
        self.head_dme = nn.Linear(feat_dim, num_dme)

    def forward(self, x):
        feats = self.backbone(x)
        logits_dr  = self.head_dr(feats)
        logits_dme = self.head_dme(feats)
        return logits_dr, logits_dme

model = MultiTaskEffNet(MODEL_NAME, NUM_DR_CLASSES, NUM_DME_CLASSES, drop_rate=0.4).to(device)
print(f"[Info] Multi-task model created. First param device: {next(model.parameters()).device}")


[Info] Multi-task model created. First param device: cuda:0


In [67]:
criterion_dr  = nn.CrossEntropyLoss(weight=class_weights_dr)
criterion_dme = nn.CrossEntropyLoss()   # could add weights if DME is imbalanced

# Warmup optimizer: head-only
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR_WARMUP,
    weight_decay=WEIGHT_DECAY
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=EPOCHS - WARMUP_EPOCHS,
    eta_min=LR_FINETUNE * 0.1,
)



In [68]:
scaler = torch.amp.GradScaler('cuda')

10. Training loop (inline, no extra functions)

In [69]:
best_f1 = -1.0
waited = 0

RUN_DIR.mkdir(parents=True, exist_ok=True)
CKPT_PATH.parent.mkdir(parents=True, exist_ok=True)

In [70]:
scheduler = None   # start with no scheduler

for epoch in range(1, EPOCHS + 1):

    # -------------------------------
    # WARMUP → UNFREEZE
    # -------------------------------
    if epoch == WARMUP_EPOCHS + 1:
        print(f"[Info] Warmup finished. Unfreezing backbone for finetuning (LR={LR_FINETUNE}).")

        for p in model.parameters():
            p.requires_grad = True

        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LR_FINETUNE,
            weight_decay=WEIGHT_DECAY
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=EPOCHS - WARMUP_EPOCHS,
            eta_min=LR_FINETUNE * 0.1
        )

    # -------------------------------
    # TRAIN
    # -------------------------------
    model.train()
    total_loss = 0.0
    total_samples = 0

    for step, (imgs, labels_dr, labels_dme) in enumerate(train_loader):
        imgs       = imgs.to(device, non_blocking=True)
        labels_dr  = labels_dr.to(device, non_blocking=True)
        labels_dme = labels_dme.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda', enabled=True):
            logits_dr, logits_dme = model(imgs)
            loss_dr  = criterion_dr(logits_dr, labels_dr)
            loss_dme = criterion_dme(logits_dme, labels_dme)
            loss = loss_dr + LOSS_WEIGHT_DME * loss_dme

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = imgs.size(0)
        total_loss += loss.item() * bs
        total_samples += bs

        if (step + 1) % 10 == 0:
            print(f"  [Epoch {epoch:02d}] step {step+1}/{len(train_loader)} loss={loss.item():.4f}")

    train_loss = total_loss / max(total_samples, 1)

    # -------------------------------
    # VALIDATION
    # -------------------------------
    model.eval()
    all_dr_labels  = []
    all_dr_preds   = []
    all_dme_labels = []
    all_dme_preds  = []

    with torch.no_grad():
        for imgs, labels_dr, labels_dme in val_loader:
            imgs       = imgs.to(device, non_blocking=True)
            labels_dr  = labels_dr.to(device, non_blocking=True)
            labels_dme = labels_dme.to(device, non_blocking=True)

            logits_dr, logits_dme = model(imgs)
            preds_dr  = logits_dr.argmax(dim=1)
            preds_dme = logits_dme.argmax(dim=1)

            all_dr_labels.append(labels_dr.cpu())
            all_dr_preds.append(preds_dr.cpu())
            all_dme_labels.append(labels_dme.cpu())
            all_dme_preds.append(preds_dme.cpu())

    all_dr_labels  = torch.cat(all_dr_labels).numpy()
    all_dr_preds   = torch.cat(all_dr_preds).numpy()
    all_dme_labels = torch.cat(all_dme_labels).numpy()
    all_dme_preds  = torch.cat(all_dme_preds).numpy()

    val_acc_dr = accuracy_score(all_dr_labels, all_dr_preds)
    val_f1_dr  = f1_score(all_dr_labels, all_dr_preds, average="macro")

    val_acc_dme = accuracy_score(all_dme_labels, all_dme_preds)
    val_f1_dme  = f1_score(all_dme_labels, all_dme_preds, average="macro")

    print(
        f"[Epoch {epoch:02d}] train_loss={train_loss:.4f}  "
        f"val_f1_DR={val_f1_dr:.4f}  val_acc_DR={val_acc_dr:.4f}  "
        f"val_f1_DME={val_f1_dme:.4f}  val_acc_DME={val_acc_dme:.4f}"
    )

    # -------------------------------
    # SCHEDULER STEP
    # -------------------------------
    if scheduler is not None:
        scheduler.step()

    # -------------------------------
    # EARLY STOP / CHECKPOINT
    # -------------------------------
    avg_f1 = 0.5 * (val_f1_dr + val_f1_dme)
    if avg_f1 > best_f1 + 1e-5:
        best_f1 = avg_f1
        waited = 0
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": epoch,
                "val_f1_dr":  float(val_f1_dr),
                "val_f1_dme": float(val_f1_dme),
            },
            CKPT_PATH,
        )
        print("  ↳ Saved checkpoint")
    else:
        waited += 1
        if waited >= PATIENCE:
            print(f"[EarlyStop] No improvement for {PATIENCE} epochs.")
            break


  [Epoch 01] step 10/66 loss=1.7087
  [Epoch 01] step 20/66 loss=1.5197
  [Epoch 01] step 30/66 loss=1.1206
  [Epoch 01] step 40/66 loss=0.9460
  [Epoch 01] step 50/66 loss=0.8860
  [Epoch 01] step 60/66 loss=1.8964
[Epoch 01] train_loss=1.3697  val_f1_DR=0.6381  val_acc_DR=0.6566  val_f1_DME=0.6632  val_acc_DME=0.8889
  ↳ Saved checkpoint
  [Epoch 02] step 10/66 loss=0.4349
  [Epoch 02] step 20/66 loss=0.9045
  [Epoch 02] step 30/66 loss=0.8419
  [Epoch 02] step 40/66 loss=0.4711
  [Epoch 02] step 50/66 loss=0.3548
  [Epoch 02] step 60/66 loss=0.8235
[Epoch 02] train_loss=0.6615  val_f1_DR=0.8268  val_acc_DR=0.8485  val_f1_DME=0.8254  val_acc_DME=0.9545
  ↳ Saved checkpoint
[Info] Warmup finished. Unfreezing backbone for finetuning (LR=0.0001).
  [Epoch 03] step 10/66 loss=0.3088
  [Epoch 03] step 20/66 loss=0.4332
  [Epoch 03] step 30/66 loss=0.1690
  [Epoch 03] step 40/66 loss=0.2113
  [Epoch 03] step 50/66 loss=0.3117
  [Epoch 03] step 60/66 loss=0.3699
[Epoch 03] train_loss=0.2982

11. Load best and full evaluation (val + test)

In [71]:
if CKPT_PATH.exists():
    ckpt = torch.load(CKPT_PATH, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    print(f"[Info] Loaded best checkpoint from epoch {ckpt['epoch']} | "
          f"val_f1_dr={ckpt.get('val_f1_dr', float('nan')):.4f}  "
          f"val_f1_dme={ckpt.get('val_f1_dme', float('nan')):.4f}")



[Info] Loaded best checkpoint from epoch 18 | val_f1_dr=0.9299  val_f1_dme=0.9964


In [72]:
# Validation report (DR + DME)
model.eval()
all_dr_labels  = []
all_dr_preds   = []
all_dme_labels = []
all_dme_preds  = []

with torch.no_grad():
    for imgs, labels_dr, labels_dme in val_loader:
        imgs       = imgs.to(device, non_blocking=True)
        labels_dr  = labels_dr.to(device, non_blocking=True)
        labels_dme = labels_dme.to(device, non_blocking=True)

        logits_dr, logits_dme = model(imgs)
        preds_dr  = logits_dr.argmax(dim=1)
        preds_dme = logits_dme.argmax(dim=1)

        all_dr_labels.append(labels_dr.cpu())
        all_dr_preds.append(preds_dr.cpu())
        all_dme_labels.append(labels_dme.cpu())
        all_dme_preds.append(preds_dme.cpu())

all_dr_labels  = torch.cat(all_dr_labels).numpy()
all_dr_preds   = torch.cat(all_dr_preds).numpy()
all_dme_labels = torch.cat(all_dme_labels).numpy()
all_dme_preds  = torch.cat(all_dme_preds).numpy()

val_report_dr  = classification_report(all_dr_labels,  all_dr_preds,  digits=4)
val_cm_dr      = confusion_matrix(all_dr_labels,       all_dr_preds)
val_acc_dr     = accuracy_score(all_dr_labels,         all_dr_preds)
val_f1_dr      = f1_score(all_dr_labels,               all_dr_preds,  average="macro")

val_report_dme = classification_report(all_dme_labels, all_dme_preds, digits=4)
val_cm_dme     = confusion_matrix(all_dme_labels,      all_dme_preds)
val_acc_dme    = accuracy_score(all_dme_labels,        all_dme_preds)
val_f1_dme     = f1_score(all_dme_labels,              all_dme_preds, average="macro")

print("\n===== Validation results (DR) =====")
print(val_report_dr)

print("\n===== Validation results (DME) =====")
print(val_report_dme)



===== Validation results (DR) =====
              precision    recall  f1-score   support

           0     1.0000    0.9815    0.9907        54
           1     1.0000    1.0000    1.0000        40
           2     0.9592    0.8704    0.9126        54
           3     0.7500    1.0000    0.8571        30
           4     1.0000    0.8000    0.8889        20

    accuracy                         0.9394       198
   macro avg     0.9418    0.9304    0.9299       198
weighted avg     0.9510    0.9394    0.9408       198


===== Validation results (DME) =====
              precision    recall  f1-score   support

           0     1.0000    0.9898    0.9949        98
           1     1.0000    1.0000    1.0000        10
           2     0.9890    1.0000    0.9945        90

    accuracy                         0.9949       198
   macro avg     0.9963    0.9966    0.9964       198
weighted avg     0.9950    0.9949    0.9950       198



In [73]:
# Test report (DR + DME)
all_dr_labels  = []
all_dr_preds   = []
all_dme_labels = []
all_dme_preds  = []

with torch.no_grad():
    for imgs, labels_dr, labels_dme in test_loader:
        imgs       = imgs.to(device, non_blocking=True)
        labels_dr  = labels_dr.to(device, non_blocking=True)
        labels_dme = labels_dme.to(device, non_blocking=True)

        logits_dr, logits_dme = model(imgs)
        preds_dr  = logits_dr.argmax(dim=1)
        preds_dme = logits_dme.argmax(dim=1)

        all_dr_labels.append(labels_dr.cpu())
        all_dr_preds.append(preds_dr.cpu())
        all_dme_labels.append(labels_dme.cpu())
        all_dme_preds.append(preds_dme.cpu())

all_dr_labels  = torch.cat(all_dr_labels).numpy()
all_dr_preds   = torch.cat(all_dr_preds).numpy()
all_dme_labels = torch.cat(all_dme_labels).numpy()
all_dme_preds  = torch.cat(all_dme_preds).numpy()

test_report_dr  = classification_report(all_dr_labels,  all_dr_preds,  digits=4)
test_cm_dr      = confusion_matrix(all_dr_labels,       all_dr_preds)
test_acc_dr     = accuracy_score(all_dr_labels,         all_dr_preds)
test_f1_dr      = f1_score(all_dr_labels,               all_dr_preds,  average="macro")

test_report_dme = classification_report(all_dme_labels, all_dme_preds, digits=4)
test_cm_dme     = confusion_matrix(all_dme_labels,      all_dme_preds)
test_acc_dme    = accuracy_score(all_dme_labels,        all_dme_preds)
test_f1_dme     = f1_score(all_dme_labels,              all_dme_preds, average="macro")


print("===== Test results (DR) =====")
print(test_report_dr)

print("===== Test results (DME) =====")
print(test_report_dme)


===== Test results (DR) =====
              precision    recall  f1-score   support

           0     0.6304    0.8529    0.7250        34
           1     0.0000    0.0000    0.0000         5
           2     0.5405    0.6250    0.5797        32
           3     0.5000    0.3158    0.3871        19
           4     0.4286    0.2308    0.3000        13

    accuracy                         0.5631       103
   macro avg     0.4199    0.4049    0.3984       103
weighted avg     0.5224    0.5631    0.5287       103

===== Test results (DME) =====
              precision    recall  f1-score   support

           0     0.8039    0.9111    0.8542        45
           1     0.3750    0.3000    0.3333        10
           2     0.8864    0.8125    0.8478        48

    accuracy                         0.8058       103
   macro avg     0.6884    0.6745    0.6784       103
weighted avg     0.8007    0.8058    0.8006       103



In [74]:
metrics_out = {
    "val": {
        "dr": {
            "acc": float(val_acc_dr),
            "f1_macro": float(val_f1_dr),
            "confusion_matrix": val_cm_dr.tolist(),
        },
        "dme": {
            "acc": float(val_acc_dme),
            "f1_macro": float(val_f1_dme),
            "confusion_matrix": val_cm_dme.tolist(),
        }
    },
    "test": {
        "dr": {
            "acc": float(test_acc_dr),
            "f1_macro": float(test_f1_dr),
            "confusion_matrix": test_cm_dr.tolist(),
        },
        "dme": {
            "acc": float(test_acc_dme),
            "f1_macro": float(test_f1_dme),
            "confusion_matrix": test_cm_dme.tolist(),
        }
    }
}

RUN_DIR.mkdir(parents=True, exist_ok=True)
with open(RUN_DIR / "metrics_simple.json", "w") as f:
    json.dump(metrics_out, f, indent=2)

print(f"[Info] Metrics saved to {RUN_DIR / 'metrics_simple.json'}")
print("[Done] Training + evaluation complete.")


[Info] Metrics saved to runs\idrid_simple\metrics_simple.json
[Done] Training + evaluation complete.
