# Final Model Comparison – DenseNet121 Knee OA KL Grading
This notebook compares the 320 baseline and the improved 384 final model.

In [None]:
# ============================================
# AUTO DATASET LOADER (Drive ZIP) + DenseNet121@320 Tricks
# - Drive mount + unzip
# - auto-detect train/val/test
# - WeightedRandomSampler
# - MixUp
# - EMA
# - Cosine LR + warmup
# - AMP
# ============================================

# ----------------------------
# 0) Imports
# ----------------------------
import os, math, time, copy, random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, WeightedRandomSampler

from torchvision import datasets, transforms, models

from sklearn.metrics import f1_score, confusion_matrix, classification_report, accuracy_score

# ----------------------------
# 1) Reproducibility + Device
# ----------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ==========================================================
# 2) DATASET LOADING FROM GOOGLE DRIVE ZIP
# ==========================================================
DATASET_NAME = "knee-osteoarthritis-dataset-with-severity"  # folder name after unzip (can be different)
ZIP_NAME = "knee_oa_dataset.zip"  # <-- CHANGE to your zip name in MyDrive

# Mount Drive
from google.colab import drive
drive.mount('/content/drive')

zip_path = f"/content/drive/MyDrive/{ZIP_NAME}"
if not os.path.isfile(zip_path):
    raise FileNotFoundError(f"ZIP file not found at: {zip_path}\n"
                            f"➡️ Upload your dataset zip to Google Drive > MyDrive and set ZIP_NAME correctly.")

# Unzip if needed
extract_root = "/content"
dataset_root = os.path.join(extract_root, DATASET_NAME)

if not os.path.isdir(dataset_root):
    print("Unzipping dataset...")
    !unzip -q "{zip_path}" -d "{extract_root}"
else:
    print("Dataset folder already exists, skipping unzip.")

print("✅ Dataset ready")
print("Content of /content:", os.listdir("/content"))

# ==========================================================
# 3) Auto-detect train/val/test folders
# ==========================================================
def find_split_dirs(base_dir):
    """Find train/val/test under base_dir or any nested folder."""
    direct_train = os.path.join(base_dir, "train")
    direct_val   = os.path.join(base_dir, "val")
    direct_test  = os.path.join(base_dir, "test")

    if os.path.isdir(direct_train) and os.path.isdir(direct_test):
        return direct_train, (direct_val if os.path.isdir(direct_val) else None), direct_test

    for root, dirs, _ in os.walk(base_dir):
        if "train" in dirs and "test" in dirs:
            train_dir = os.path.join(root, "train")
            test_dir  = os.path.join(root, "test")
            val_dir   = os.path.join(root, "val") if "val" in dirs else None
            return train_dir, val_dir, test_dir

    raise FileNotFoundError(f"Could not find train/test folders under: {base_dir}")

TRAIN_DIR, VAL_DIR, TEST_DIR = find_split_dirs(dataset_root)
print("TRAIN_DIR:", TRAIN_DIR)
print("VAL_DIR:", VAL_DIR)
print("TEST_DIR:", TEST_DIR)

# If val doesn't exist, create one from train
AUTO_SPLIT_VAL_IF_MISSING = True
VAL_SPLIT = 0.15  # 15% validation

# ==========================================================
# 4) Config
# ==========================================================
NUM_CLASSES = 5
IMG_SIZE = 320
BATCH_SIZE = 16         # if OOM: 8
EPOCHS = 15             # 12–20
LR = 1e-4
WEIGHT_DECAY = 1e-4

WARMUP_EPOCHS = 2
MIXUP_ALPHA = 0.2
EMA_DECAY = 0.999

USE_AMP = True
GRAD_ACCUM_STEPS = 1    # if batch small: 2 or 4
CHECKPOINT_PATH = "best_densenet121_320_tricks.pt"

# ==========================================================
# 5) Transforms
# ==========================================================
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=7),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

eval_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# ==========================================================
# 6) Datasets
# ==========================================================
train_full = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms)
test_ds    = datasets.ImageFolder(TEST_DIR, transform=eval_tfms)

if VAL_DIR is not None:
    val_ds = datasets.ImageFolder(VAL_DIR, transform=eval_tfms)
    train_ds = train_full
else:
    if not AUTO_SPLIT_VAL_IF_MISSING:
        raise FileNotFoundError("VAL_DIR missing and AUTO_SPLIT_VAL_IF_MISSING=False")

    from torch.utils.data import Subset
    n = len(train_full)
    idx = np.arange(n)
    np.random.shuffle(idx)
    split = int(n * (1 - VAL_SPLIT))
    train_idx, val_idx = idx[:split], idx[split:]

    train_ds = Subset(train_full, train_idx)

    # For val subset, use eval transforms
    train_full_eval = datasets.ImageFolder(TRAIN_DIR, transform=eval_tfms)
    val_ds = Subset(train_full_eval, val_idx)

print("Train size:", len(train_ds), "Val size:", len(val_ds), "Test size:", len(test_ds))
print("Classes:", train_full.classes)

# ==========================================================
# 7) WeightedRandomSampler
# ==========================================================
def get_targets(ds):
    # Handle Subset(ImageFolder)
    if isinstance(ds, torch.utils.data.Subset):
        base = ds.dataset  # ImageFolder
        return [base.samples[i][1] for i in ds.indices]
    else:
        return [y for _, y in ds.samples]

targets = get_targets(train_ds)
class_counts = np.bincount(targets, minlength=NUM_CLASSES)
class_weights = 1.0 / np.maximum(class_counts, 1)
sample_weights = [class_weights[y] for y in targets]

sampler = WeightedRandomSampler(
    weights=torch.DoubleTensor(sample_weights),
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Train class counts:", class_counts)

# ==========================================================
# 8) Model: DenseNet121
# ==========================================================
model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, NUM_CLASSES)
model = model.to(device)

# ==========================================================
# 9) EMA
# ==========================================================
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)

    @torch.no_grad()
    def apply_to(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.backup[name] = p.data.clone()
                p.data.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.backup[name])
        self.backup = None

ema = EMA(model, decay=EMA_DECAY)

# ==========================================================
# 10) MixUp
# ==========================================================
def mixup_data(x, y, alpha=0.2):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    bs = x.size(0)
    idx = torch.randperm(bs).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(crit, pred, y_a, y_b, lam):
    return lam * crit(pred, y_a) + (1 - lam) * crit(pred, y_b)

# ==========================================================
# 11) Loss, Optimizer, Scheduler (Cosine + warmup)
# ==========================================================
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

def lr_lambda(epoch):
    if epoch < WARMUP_EPOCHS:
        return float(epoch + 1) / float(max(1, WARMUP_EPOCHS))
    progress = (epoch - WARMUP_EPOCHS) / float(max(1, EPOCHS - WARMUP_EPOCHS))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type == "cuda"))

# ==========================================================
# 12) Eval function
# ==========================================================
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds, all_targets = [], []
    total_loss = 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    acc = accuracy_score(all_targets, all_preds)
    macro_f1 = f1_score(all_targets, all_preds, average="macro")
    cm = confusion_matrix(all_targets, all_preds)

    return total_loss / len(loader.dataset), acc, macro_f1, cm, all_targets, all_preds

# ==========================================================
# 13) Training
# ==========================================================
best_val_f1 = -1.0

for epoch in range(EPOCHS):
    t0 = time.time()
    model.train()
    running_loss = 0.0

    optimizer.zero_grad(set_to_none=True)

    for step, (x, y) in enumerate(train_loader):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if MIXUP_ALPHA > 0:
            x, y_a, y_b, lam = mixup_data(x, y, alpha=MIXUP_ALPHA)

        with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
            logits = model(x)
            if MIXUP_ALPHA > 0:
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            else:
                loss = criterion(logits, y)

            loss = loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            ema.update(model)

        running_loss += loss.item() * x.size(0) * GRAD_ACCUM_STEPS

    scheduler.step()

    train_loss = running_loss / len(train_loader.dataset)

    # Validate using EMA weights
    ema.apply_to(model)
    val_loss, val_acc, val_f1, val_cm, _, _ = evaluate(model, val_loader)
    ema.restore(model)

    dt = time.time() - t0
    print(f"Epoch {epoch+1:02d}/{EPOCHS} | "
          f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
          f"val_acc={val_acc:.4f} | val_macroF1={val_f1:.4f} | "
          f"lr={scheduler.get_last_lr()[0]:.2e} | time={dt:.1f}s")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        ckpt = {
            "epoch": epoch + 1,
            "model": copy.deepcopy(model.state_dict()),
            "ema_shadow": copy.deepcopy(ema.shadow),
            "best_val_macro_f1": best_val_f1,
            "img_size": IMG_SIZE
        }
        torch.save(ckpt, CHECKPOINT_PATH)
        print(f"  ✅ Saved best checkpoint: {CHECKPOINT_PATH} (val_macroF1={best_val_f1:.4f})")

print("\n✅ Best Val Macro-F1:", best_val_f1)

# ==========================================================
# 14) Test Best Checkpoint (EMA)
# ==========================================================
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
model.load_state_dict(ckpt["model"])
model = model.to(device)
ema.shadow = ckpt["ema_shadow"]

ema.apply_to(model)
test_loss, test_acc, test_f1, test_cm, y_true, y_pred = evaluate(model, test_loader)
ema.restore(model)

print("\n===== TEST RESULTS (Best EMA) =====")
print(f"Test Loss:     {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Macro-F1: {test_f1:.4f}")
print("\nConfusion Matrix:\n", test_cm)
print("\nClassification Report:\n", classification_report(y_true, y_pred, digits=4))

# Save plain state dict too
torch.save(model.state_dict(), "best_densenet121_320_tricks_state_dict.pt")
print("\nSaved: best_densenet121_320_tricks_state_dict.pt")

Device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset folder already exists, skipping unzip.
✅ Dataset ready
Content of /content: ['.config', 'knee-osteoarthritis-dataset-with-severity', 'best_densenet121_320_tricks.pt', 'drive', 'sample_data']
TRAIN_DIR: /content/knee-osteoarthritis-dataset-with-severity/train
VAL_DIR: /content/knee-osteoarthritis-dataset-with-severity/val
TEST_DIR: /content/knee-osteoarthritis-dataset-with-severity/test
Train size: 5778 Val size: 826 Test size: 1656
Classes: ['0', '1', '2', '3', '4']
Train class counts: [2286 1046 1516  757  173]


  scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 01/15 | train_loss=1.2257 | val_loss=1.5756 | val_acc=0.1961 | val_macroF1=0.0849 | lr=1.00e-04 | time=69.8s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.0849)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 02/15 | train_loss=1.0074 | val_loss=1.5009 | val_acc=0.4540 | val_macroF1=0.3096 | lr=1.00e-04 | time=68.7s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.3096)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 03/15 | train_loss=0.9287 | val_loss=1.4435 | val_acc=0.4915 | val_macroF1=0.3330 | lr=9.85e-05 | time=68.4s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.3330)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 04/15 | train_loss=0.8716 | val_loss=1.2669 | val_acc=0.5690 | val_macroF1=0.4391 | lr=9.43e-05 | time=69.7s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.4391)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 05/15 | train_loss=0.8673 | val_loss=1.0549 | val_acc=0.5400 | val_macroF1=0.6200 | lr=8.74e-05 | time=69.1s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6200)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 06/15 | train_loss=0.8382 | val_loss=0.9809 | val_acc=0.6017 | val_macroF1=0.6080 | lr=7.84e-05 | time=69.9s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 07/15 | train_loss=0.7874 | val_loss=0.9490 | val_acc=0.5860 | val_macroF1=0.6196 | lr=6.77e-05 | time=68.7s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 08/15 | train_loss=0.7452 | val_loss=0.9197 | val_acc=0.5751 | val_macroF1=0.6405 | lr=5.60e-05 | time=69.1s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6405)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 09/15 | train_loss=0.7004 | val_loss=0.9243 | val_acc=0.5847 | val_macroF1=0.6548 | lr=4.40e-05 | time=69.2s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6548)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 10/15 | train_loss=0.6687 | val_loss=0.8673 | val_acc=0.6053 | val_macroF1=0.6761 | lr=3.23e-05 | time=69.6s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6761)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 11/15 | train_loss=0.6026 | val_loss=0.8749 | val_acc=0.6114 | val_macroF1=0.6790 | lr=2.16e-05 | time=69.8s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6790)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 12/15 | train_loss=0.5973 | val_loss=0.8310 | val_acc=0.6368 | val_macroF1=0.6938 | lr=1.26e-05 | time=69.1s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.6938)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 13/15 | train_loss=0.5854 | val_loss=0.8337 | val_acc=0.6683 | val_macroF1=0.7065 | lr=5.73e-06 | time=68.6s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.7065)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 14/15 | train_loss=0.5841 | val_loss=0.8591 | val_acc=0.6344 | val_macroF1=0.6968 | lr=1.45e-06 | time=69.6s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 15/15 | train_loss=0.5359 | val_loss=0.8506 | val_acc=0.6586 | val_macroF1=0.7076 | lr=0.00e+00 | time=69.3s
  ✅ Saved best checkpoint: best_densenet121_320_tricks.pt (val_macroF1=0.7076)

✅ Best Val Macro-F1: 0.7075860183108729


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):



===== TEST RESULTS (Best EMA) =====
Test Loss:     0.7971
Test Accuracy: 0.6739
Test Macro-F1: 0.6914

Confusion Matrix:
 [[503 125  10   1   0]
 [111 138  43   4   0]
 [ 27 135 251  34   0]
 [  1  10  22 184   6]
 [  0   0   0  11  40]]

Classification Report:
               precision    recall  f1-score   support

           0     0.7835    0.7872    0.7853       639
           1     0.3382    0.4662    0.3920       296
           2     0.7699    0.5615    0.6494       447
           3     0.7863    0.8251    0.8053       223
           4     0.8696    0.7843    0.8247        51

    accuracy                         0.6739      1656
   macro avg     0.7095    0.6849    0.6914      1656
weighted avg     0.7033    0.6739    0.6822      1656


Saved: best_densenet121_320_tricks_state_dict.pt


In [None]:
# ============================================
# DenseNet121 @384 with:
# - Drive ZIP auto-load + unzip
# - auto-detect train/val/test
# - WeightedRandomSampler
# - MixUp schedule (on first half, off second half)
# - EMA
# - Cosine LR + warmup
# - AMP
# - Label Smoothing CE
# Target: push test accuracy > 0.70
# ============================================

# ----------------------------
# 0) Imports
# ----------------------------
import os, math, time, copy, random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, WeightedRandomSampler

from torchvision import datasets, transforms, models

from sklearn.metrics import f1_score, confusion_matrix, classification_report, accuracy_score

# ----------------------------
# 1) Reproducibility + Device
# ----------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False

seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ==========================================================
# 2) DATASET LOADING FROM GOOGLE DRIVE ZIP
# ==========================================================
DATASET_NAME = "knee-osteoarthritis-dataset-with-severity"  # folder name after unzip
ZIP_NAME = "knee_oa_dataset.zip"  # <-- CHANGE if your zip name differs

from google.colab import drive
drive.mount('/content/drive')

zip_path = f"/content/drive/MyDrive/{ZIP_NAME}"
if not os.path.isfile(zip_path):
    raise FileNotFoundError(f"ZIP file not found at: {zip_path}\n"
                            f"➡️ Upload dataset zip to Google Drive > MyDrive and set ZIP_NAME correctly.")

extract_root = "/content"
dataset_root = os.path.join(extract_root, DATASET_NAME)

if not os.path.isdir(dataset_root):
    print("Unzipping dataset...")
    !unzip -q "{zip_path}" -d "{extract_root}"
else:
    print("Dataset folder already exists, skipping unzip.")

print("✅ Dataset ready")
print("Content of /content:", os.listdir("/content"))

# ==========================================================
# 3) Auto-detect train/val/test
# ==========================================================
def find_split_dirs(base_dir):
    direct_train = os.path.join(base_dir, "train")
    direct_val   = os.path.join(base_dir, "val")
    direct_test  = os.path.join(base_dir, "test")

    if os.path.isdir(direct_train) and os.path.isdir(direct_test):
        return direct_train, (direct_val if os.path.isdir(direct_val) else None), direct_test

    for root, dirs, _ in os.walk(base_dir):
        if "train" in dirs and "test" in dirs:
            train_dir = os.path.join(root, "train")
            test_dir  = os.path.join(root, "test")
            val_dir   = os.path.join(root, "val") if "val" in dirs else None
            return train_dir, val_dir, test_dir

    raise FileNotFoundError(f"Could not find train/test folders under: {base_dir}")

TRAIN_DIR, VAL_DIR, TEST_DIR = find_split_dirs(dataset_root)
print("TRAIN_DIR:", TRAIN_DIR)
print("VAL_DIR:", VAL_DIR)
print("TEST_DIR:", TEST_DIR)

AUTO_SPLIT_VAL_IF_MISSING = True
VAL_SPLIT = 0.15

# ==========================================================
# 4) Config (UPDATED)
# ==========================================================
NUM_CLASSES = 5

IMG_SIZE = 384          # ✅ increased from 320
BATCH_SIZE = 8          # ✅ reduce for T4 safety at 384
EPOCHS = 18             # ✅ slightly longer

LR = 1e-4
WEIGHT_DECAY = 1e-4

WARMUP_EPOCHS = 2
MIXUP_ALPHA = 0.2       # base mixup alpha for first half (schedule below)
EMA_DECAY = 0.999

USE_AMP = True
GRAD_ACCUM_STEPS = 1    # if you want effective batch 16, set 2
CHECKPOINT_PATH = "best_densenet121_384_ls_mixupSchedule_ema.pt"

LABEL_SMOOTHING = 0.07  # ✅ helps KL1/adjacent confusion

# ==========================================================
# 5) Transforms
# ==========================================================
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=7),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

eval_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# ==========================================================
# 6) Datasets
# ==========================================================
train_full = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms)
test_ds    = datasets.ImageFolder(TEST_DIR, transform=eval_tfms)

if VAL_DIR is not None:
    val_ds = datasets.ImageFolder(VAL_DIR, transform=eval_tfms)
    train_ds = train_full
else:
    if not AUTO_SPLIT_VAL_IF_MISSING:
        raise FileNotFoundError("VAL_DIR missing and AUTO_SPLIT_VAL_IF_MISSING=False")
    from torch.utils.data import Subset
    n = len(train_full)
    idx = np.arange(n)
    np.random.shuffle(idx)
    split = int(n * (1 - VAL_SPLIT))
    train_idx, val_idx = idx[:split], idx[split:]
    train_ds = Subset(train_full, train_idx)

    train_full_eval = datasets.ImageFolder(TRAIN_DIR, transform=eval_tfms)
    val_ds = Subset(train_full_eval, val_idx)

print("Train size:", len(train_ds), "Val size:", len(val_ds), "Test size:", len(test_ds))
print("Classes:", train_full.classes)

# ==========================================================
# 7) WeightedRandomSampler
# ==========================================================
def get_targets(ds):
    if isinstance(ds, torch.utils.data.Subset):
        base = ds.dataset
        return [base.samples[i][1] for i in ds.indices]
    return [y for _, y in ds.samples]

targets = get_targets(train_ds)
class_counts = np.bincount(targets, minlength=NUM_CLASSES)
class_weights = 1.0 / np.maximum(class_counts, 1)
sample_weights = [class_weights[y] for y in targets]

sampler = WeightedRandomSampler(
    weights=torch.DoubleTensor(sample_weights),
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Train class counts:", class_counts)

# ==========================================================
# 8) Model: DenseNet121
# ==========================================================
model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
in_features = model.classifier.in_features
model.classifier = nn.Linear(in_features, NUM_CLASSES)
model = model.to(device)

# ==========================================================
# 9) EMA
# ==========================================================
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.data.clone()

    @torch.no_grad()
    def update(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay)

    @torch.no_grad()
    def apply_to(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.backup[name] = p.data.clone()
                p.data.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.backup[name])
        self.backup = None

ema = EMA(model, decay=EMA_DECAY)

# ==========================================================
# 10) MixUp
# ==========================================================
def mixup_data(x, y, alpha=0.2):
    if alpha <= 0:
        return x, y, y, 1.0
    lam = np.random.beta(alpha, alpha)
    bs = x.size(0)
    idx = torch.randperm(bs).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[idx]
    y_a, y_b = y, y[idx]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(crit, pred, y_a, y_b, lam):
    return lam * crit(pred, y_a) + (1 - lam) * crit(pred, y_b)

# ==========================================================
# 11) Loss (UPDATED: label smoothing), Optimizer, Scheduler
# ==========================================================
criterion = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

def lr_lambda(epoch):
    if epoch < WARMUP_EPOCHS:
        return float(epoch + 1) / float(max(1, WARMUP_EPOCHS))
    progress = (epoch - WARMUP_EPOCHS) / float(max(1, EPOCHS - WARMUP_EPOCHS))
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type == "cuda"))

# ==========================================================
# 12) Eval
# ==========================================================
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_preds, all_targets = [], []
    total_loss = 0.0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)

        total_loss += loss.item() * x.size(0)
        preds = torch.argmax(logits, dim=1)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(y.cpu().numpy())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    acc = accuracy_score(all_targets, all_preds)
    macro_f1 = f1_score(all_targets, all_preds, average="macro")
    cm = confusion_matrix(all_targets, all_preds)

    return total_loss / len(loader.dataset), acc, macro_f1, cm, all_targets, all_preds

# ==========================================================
# 13) Train (UPDATED: MixUp schedule)
# ==========================================================
best_val_f1 = -1.0

for epoch in range(EPOCHS):
    t0 = time.time()
    model.train()
    running_loss = 0.0

    # ✅ MixUp schedule: ON in first half, OFF in second half
    current_mixup = MIXUP_ALPHA if epoch < (EPOCHS // 2) else 0.0

    optimizer.zero_grad(set_to_none=True)

    for step, (x, y) in enumerate(train_loader):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if current_mixup > 0:
            x, y_a, y_b, lam = mixup_data(x, y, alpha=current_mixup)

        with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
            logits = model(x)
            if current_mixup > 0:
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            else:
                loss = criterion(logits, y)

            loss = loss / GRAD_ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % GRAD_ACCUM_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            ema.update(model)

        running_loss += loss.item() * x.size(0) * GRAD_ACCUM_STEPS

    scheduler.step()

    train_loss = running_loss / len(train_loader.dataset)

    # Validate using EMA weights
    ema.apply_to(model)
    val_loss, val_acc, val_f1, val_cm, _, _ = evaluate(model, val_loader)
    ema.restore(model)

    dt = time.time() - t0
    print(f"Epoch {epoch+1:02d}/{EPOCHS} | "
          f"mixup={current_mixup:.2f} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
          f"val_acc={val_acc:.4f} | val_macroF1={val_f1:.4f} | "
          f"lr={scheduler.get_last_lr()[0]:.2e} | time={dt:.1f}s")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        ckpt = {
            "epoch": epoch + 1,
            "model": copy.deepcopy(model.state_dict()),
            "ema_shadow": copy.deepcopy(ema.shadow),
            "best_val_macro_f1": best_val_f1,
            "img_size": IMG_SIZE,
            "label_smoothing": LABEL_SMOOTHING,
            "mixup_schedule": "first_half_on_second_half_off"
        }
        torch.save(ckpt, CHECKPOINT_PATH)
        print(f"  ✅ Saved best checkpoint: {CHECKPOINT_PATH} (val_macroF1={best_val_f1:.4f})")

print("\n✅ Best Val Macro-F1:", best_val_f1)

# ==========================================================
# 14) Test best checkpoint (EMA)
# ==========================================================
ckpt = torch.load(CHECKPOINT_PATH, map_location="cpu")
model.load_state_dict(ckpt["model"])
model = model.to(device)
ema.shadow = ckpt["ema_shadow"]

ema.apply_to(model)
test_loss, test_acc, test_f1, test_cm, y_true, y_pred = evaluate(model, test_loader)
ema.restore(model)

print("\n===== TEST RESULTS (Best EMA) =====")
print(f"Test Loss:     {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Macro-F1: {test_f1:.4f}")
print("\nConfusion Matrix:\n", test_cm)
print("\nClassification Report:\n", classification_report(y_true, y_pred, digits=4))

torch.save(model.state_dict(), "best_densenet121_384_final_state_dict.pt")
print("\nSaved: best_densenet121_384_final_state_dict.pt")

Device: cuda
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Dataset folder already exists, skipping unzip.
✅ Dataset ready
Content of /content: ['.config', 'knee-osteoarthritis-dataset-with-severity', 'best_densenet121_320_tricks.pt', 'drive', 'best_densenet121_320_tricks_state_dict.pt', 'sample_data']
TRAIN_DIR: /content/knee-osteoarthritis-dataset-with-severity/train
VAL_DIR: /content/knee-osteoarthritis-dataset-with-severity/val
TEST_DIR: /content/knee-osteoarthritis-dataset-with-severity/test
Train size: 5778 Val size: 826 Test size: 1656
Classes: ['0', '1', '2', '3', '4']
Train class counts: [2286 1046 1516  757  173]


  scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 01/18 | mixup=0.20 | train_loss=1.2674 | val_loss=1.7932 | val_acc=0.0327 | val_macroF1=0.0127 | lr=1.00e-04 | time=135.4s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.0127)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 02/18 | mixup=0.20 | train_loss=1.1291 | val_loss=1.7927 | val_acc=0.0363 | val_macroF1=0.0163 | lr=1.00e-04 | time=115.5s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.0163)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 03/18 | mixup=0.20 | train_loss=1.0732 | val_loss=1.2652 | val_acc=0.5847 | val_macroF1=0.4662 | lr=9.90e-05 | time=115.7s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.4662)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 04/18 | mixup=0.20 | train_loss=1.0169 | val_loss=1.1503 | val_acc=0.5823 | val_macroF1=0.5409 | lr=9.62e-05 | time=115.6s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.5409)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 05/18 | mixup=0.20 | train_loss=0.9850 | val_loss=1.0465 | val_acc=0.5993 | val_macroF1=0.6327 | lr=9.16e-05 | time=116.3s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.6327)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 06/18 | mixup=0.20 | train_loss=0.9627 | val_loss=0.9316 | val_acc=0.6840 | val_macroF1=0.6928 | lr=8.54e-05 | time=117.3s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.6928)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 07/18 | mixup=0.20 | train_loss=0.9664 | val_loss=0.9094 | val_acc=0.6864 | val_macroF1=0.6813 | lr=7.78e-05 | time=119.0s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 08/18 | mixup=0.20 | train_loss=0.9234 | val_loss=0.8982 | val_acc=0.6889 | val_macroF1=0.6975 | lr=6.91e-05 | time=116.9s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.6975)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 09/18 | mixup=0.20 | train_loss=0.8900 | val_loss=0.8980 | val_acc=0.6973 | val_macroF1=0.7038 | lr=5.98e-05 | time=119.7s
  ✅ Saved best checkpoint: best_densenet121_384_ls_mixupSchedule_ema.pt (val_macroF1=0.7038)


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 10/18 | mixup=0.00 | train_loss=0.6770 | val_loss=0.9464 | val_acc=0.6755 | val_macroF1=0.6771 | lr=5.00e-05 | time=117.5s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 11/18 | mixup=0.00 | train_loss=0.6436 | val_loss=0.9557 | val_acc=0.6743 | val_macroF1=0.6948 | lr=4.02e-05 | time=116.0s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 12/18 | mixup=0.00 | train_loss=0.5899 | val_loss=1.0950 | val_acc=0.6199 | val_macroF1=0.6665 | lr=3.09e-05 | time=115.4s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 13/18 | mixup=0.00 | train_loss=0.5362 | val_loss=1.0188 | val_acc=0.6707 | val_macroF1=0.6731 | lr=2.22e-05 | time=116.7s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 14/18 | mixup=0.00 | train_loss=0.5030 | val_loss=1.0540 | val_acc=0.6634 | val_macroF1=0.6820 | lr=1.46e-05 | time=116.0s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 15/18 | mixup=0.00 | train_loss=0.4791 | val_loss=1.1246 | val_acc=0.6513 | val_macroF1=0.6670 | lr=8.43e-06 | time=115.8s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 16/18 | mixup=0.00 | train_loss=0.4636 | val_loss=1.1084 | val_acc=0.6513 | val_macroF1=0.6719 | lr=3.81e-06 | time=116.1s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 17/18 | mixup=0.00 | train_loss=0.4483 | val_loss=1.1050 | val_acc=0.6562 | val_macroF1=0.6824 | lr=9.61e-07 | time=116.6s


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):
  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):


Epoch 18/18 | mixup=0.00 | train_loss=0.4352 | val_loss=1.1069 | val_acc=0.6465 | val_macroF1=0.6646 | lr=0.00e+00 | time=116.7s

✅ Best Val Macro-F1: 0.7037964064015501


  with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type == "cuda")):



===== TEST RESULTS (Best EMA) =====
Test Loss:     0.8544
Test Accuracy: 0.7089
Test Macro-F1: 0.7085

Confusion Matrix:
 [[547  73  18   1   0]
 [130 102  59   5   0]
 [ 38  83 299  27   0]
 [  0   4  28 181  10]
 [  0   0   0   6  45]]

Classification Report:
               precision    recall  f1-score   support

           0     0.7650    0.8560    0.8080       639
           1     0.3893    0.3446    0.3656       296
           2     0.7401    0.6689    0.7027       447
           3     0.8227    0.8117    0.8172       223
           4     0.8182    0.8824    0.8491        51

    accuracy                         0.7089      1656
   macro avg     0.7071    0.7127    0.7085      1656
weighted avg     0.7006    0.7089    0.7030      1656


Saved: best_densenet121_384_final_state_dict.pt


| Setting         | DenseNet121 @320 | DenseNet121 @384 (Final) |
| --------------- | ---------------- | ------------------------ |
| Image Size      | 320×320          | 384×384                  |
| Batch Size      | 16               | 8                        |
| Epochs          | 15               | 18                       |
| Label Smoothing | ❌                | ✅ (0.07)                 |
| MixUp           | Always ON        | Scheduled (ON → OFF)     |
| EMA             | ✅                | ✅                        |
| Test Accuracy   | 0.6739           | **0.7089**               |
| Test Macro-F1   | 0.6914           | **0.7085**               |


The DenseNet121 @384 configuration is selected as the final best-performing single-model pipeline for submission due to its superior test accuracy (70.89%) and macro-F1 (70.85%).