In [1]:
# ============================================
# AUTO DATASET LOADER + DenseNet121 Replication
# ============================================

import os

DATASET_NAME = "knee-osteoarthritis-dataset-with-severity"
ZIP_NAME = "knee_oa_dataset.zip"   # make sure this matches your Drive

if not os.path.isdir(f"/content/{DATASET_NAME}"):
    print("Dataset not found. Mounting Drive...")
    from google.colab import drive
    drive.mount('/content/drive')

    zip_path = f"/content/drive/MyDrive/{ZIP_NAME}"
    if not os.path.isfile(zip_path):
        raise FileNotFoundError(f"ZIP file not found at {zip_path}")

    print("Unzipping dataset...")
    !unzip -q "{zip_path}" -d /content/

print("Dataset ready ✅")
print(os.listdir("/content"))


Dataset not found. Mounting Drive...
Mounted at /content/drive
Unzipping dataset...
Dataset ready ✅
['.config', 'drive', 'knee-osteoarthritis-dataset-with-severity', 'sample_data']


In [2]:
import random, time, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, WeightedRandomSampler


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [3]:
DATA_DIR = "/content/knee-osteoarthritis-dataset-with-severity"
TRAIN_DIR = f"{DATA_DIR}/train"
VAL_DIR   = f"{DATA_DIR}/val"
TEST_DIR  = f"{DATA_DIR}/test"

RUN_NAME = time.strftime("densenet121_replication_%Y%m%d_%H%M%S")
RUN_DIR = f"/content/experiments/{RUN_NAME}"
os.makedirs(f"{RUN_DIR}/checkpoints", exist_ok=True)
os.makedirs(f"{RUN_DIR}/logs", exist_ok=True)

print("RUN_DIR:", RUN_DIR)

RUN_DIR: /content/experiments/densenet121_replication_20260216_173015


In [4]:
IMAGE_SIZE = 224
BATCH_SIZE = 32

train_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_tfms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [5]:
train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms)
val_ds   = datasets.ImageFolder(VAL_DIR, transform=val_tfms)
test_ds  = datasets.ImageFolder(TEST_DIR, transform=val_tfms)

NUM_CLASSES = len(train_ds.classes)
print("Classes:", train_ds.classes)

# Weighted sampler (replicating imbalance handling from literature)
labels = np.array([y for _, y in train_ds.samples])
class_counts = np.bincount(labels, minlength=NUM_CLASSES)
class_weights = 1.0 / (class_counts + 1e-6)
sample_weights = torch.DoubleTensor(class_weights[labels])

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Train class counts:", class_counts)

Classes: ['0', '1', '2', '3', '4']
Train class counts: [2286 1046 1516  757  173]


In [6]:
def build_densenet121(num_classes):
    model = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)
    in_features = model.classifier.in_features
    model.classifier = nn.Linear(in_features, num_classes)
    return model

model = build_densenet121(NUM_CLASSES).to(device)

# Freeze backbone first
for param in model.parameters():
    param.requires_grad = False

for param in model.classifier.parameters():
    param.requires_grad = True

print("DenseNet121 ready ✅")

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


100%|██████████| 30.8M/30.8M [00:00<00:00, 158MB/s]


DenseNet121 ready ✅


In [7]:
# Loss (standard CE for replication baseline)
criterion = nn.CrossEntropyLoss()
print("Using CrossEntropyLoss ✅")

# Stage 1: train only classifier
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3, weight_decay=1e-4)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.5,
    patience=2
)

Using CrossEntropyLoss ✅


In [8]:
def run_epoch(model, loader, train=False):
    if train:
        model.train()
    else:
        model.eval()

    total_loss = 0.0
    all_preds, all_targets = [], []

    for imgs, targets in loader:
        imgs, targets = imgs.to(device), targets.to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(imgs)
            loss = criterion(outputs, targets)

            if train:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)

        all_preds.extend(preds.detach().cpu().numpy())
        all_targets.extend(targets.detach().cpu().numpy())

    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    macro_f1 = f1_score(all_targets, all_preds, average="macro")

    return avg_loss, acc, macro_f1

In [9]:
history = []
best_val_f1 = 0
best_path = f"{RUN_DIR}/checkpoints/best_densenet121.pth"

EPOCHS_STAGE1 = 3
EPOCHS_STAGE2 = 17
TOTAL_EPOCHS = EPOCHS_STAGE1 + EPOCHS_STAGE2

print("Training plan:", TOTAL_EPOCHS, "epochs")

for epoch in range(1, EPOCHS_STAGE1 + 1):

    train_loss, train_acc, train_f1 = run_epoch(model, train_loader, train=True)
    val_loss, val_acc, val_f1 = run_epoch(model, val_loader, train=False)

    scheduler.step(val_f1)

    print(f"Epoch {epoch:02d} | "
          f"train loss {train_loss:.4f} acc {train_acc:.3f} f1 {train_f1:.3f} || "
          f"val loss {val_loss:.4f} acc {val_acc:.3f} f1 {val_f1:.3f}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), best_path)


Training plan: 20 epochs
Epoch 01 | train loss 1.5281 acc 0.305 f1 0.297 || val loss 1.4398 acc 0.317 f1 0.263
Epoch 02 | train loss 1.3959 acc 0.375 f1 0.362 || val loss 1.3792 acc 0.361 f1 0.328
Epoch 03 | train loss 1.3379 acc 0.407 f1 0.397 || val loss 1.4116 acc 0.306 f1 0.292


In [10]:
# -------- Stage 2: fine-tune last dense block + classifier --------

# Unfreeze only: features.denseblock4 + features.norm5 + classifier
for name, param in model.named_parameters():
    param.requires_grad = (
        name.startswith("features.denseblock4") or
        name.startswith("features.norm5") or
        name.startswith("classifier")
    )

# New optimizer for fine-tuning (smaller LR)
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=2
)

start_epoch = EPOCHS_STAGE1 + 1

for epoch in range(start_epoch, TOTAL_EPOCHS + 1):
    train_loss, train_acc, train_f1 = run_epoch(model, train_loader, train=True)
    val_loss, val_acc, val_f1 = run_epoch(model, val_loader, train=False)

    scheduler.step(val_f1)

    print(f"Epoch {epoch:02d} | "
          f"train loss {train_loss:.4f} acc {train_acc:.3f} f1 {train_f1:.3f} || "
          f"val loss {val_loss:.4f} acc {val_acc:.3f} f1 {val_f1:.3f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), best_path)

print("Stage2 done ✅ Best val macro-F1:", best_val_f1)
print("Best checkpoint:", best_path)


Epoch 04 | train loss 1.0836 acc 0.508 f1 0.506 || val loss 1.0788 acc 0.538 f1 0.506 | lr 3.00e-04
Epoch 05 | train loss 0.9063 acc 0.594 f1 0.585 || val loss 1.0832 acc 0.506 f1 0.534 | lr 3.00e-04
Epoch 06 | train loss 0.8492 acc 0.618 f1 0.614 || val loss 1.0260 acc 0.513 f1 0.558 | lr 3.00e-04
Epoch 07 | train loss 0.8124 acc 0.635 f1 0.631 || val loss 1.0976 acc 0.455 f1 0.541 | lr 3.00e-04
Epoch 08 | train loss 0.7770 acc 0.651 f1 0.644 || val loss 1.0928 acc 0.552 f1 0.560 | lr 3.00e-04
Epoch 09 | train loss 0.7527 acc 0.661 f1 0.661 || val loss 1.1758 acc 0.433 f1 0.506 | lr 3.00e-04
Epoch 10 | train loss 0.7225 acc 0.687 f1 0.683 || val loss 0.9697 acc 0.584 f1 0.577 | lr 3.00e-04
Epoch 11 | train loss 0.6961 acc 0.697 f1 0.694 || val loss 1.0069 acc 0.548 f1 0.560 | lr 3.00e-04
Epoch 12 | train loss 0.7189 acc 0.681 f1 0.682 || val loss 1.0582 acc 0.525 f1 0.577 | lr 3.00e-04
Epoch 13 | train loss 0.6695 acc 0.710 f1 0.708 || val loss 0.9858 acc 0.574 f1 0.576 | lr 1.50e-04


In [11]:
# -------- TEST evaluation (best checkpoint) --------
best_model = build_densenet121(NUM_CLASSES).to(device)
best_model.load_state_dict(torch.load(best_path, map_location=device))
best_model.eval()

all_preds, all_targets = [], []

with torch.no_grad():
    for imgs, targets in test_loader:
        imgs, targets = imgs.to(device), targets.to(device)
        outputs = best_model(imgs)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy().tolist())
        all_targets.extend(targets.cpu().numpy().tolist())

test_acc = accuracy_score(all_targets, all_preds)
test_macro_f1 = f1_score(all_targets, all_preds, average="macro")

print("TEST accuracy:", test_acc)
print("TEST macro-F1:", test_macro_f1)

print("\nClassification report:")
print(classification_report(all_targets, all_preds, target_names=train_ds.classes, digits=4))

cm = confusion_matrix(all_targets, all_preds)
print("\nConfusion matrix:\n", cm)

# save metrics for your report
metrics = {
    "best_val_macro_f1": float(best_val_f1),
    "test_accuracy": float(test_acc),
    "test_macro_f1": float(test_macro_f1),
    "epochs_stage1": EPOCHS_STAGE1,
    "epochs_stage2": EPOCHS_STAGE2,
    "image_size": IMAGE_SIZE,
    "batch_size": BATCH_SIZE
}

with open(f"{RUN_DIR}/logs/test_metrics.json", "w") as f:
    json.dump(metrics, f, indent=2)

print("\nSaved:", f"{RUN_DIR}/logs/test_metrics.json")


TEST accuracy: 0.538647342995169
TEST macro-F1: 0.5855880938639025

Classification report:
              precision    recall  f1-score   support

           0     0.7282    0.5618    0.6343       639
           1     0.2650    0.5068    0.3480       296
           2     0.6044    0.4340    0.5052       447
           3     0.6637    0.6726    0.6682       223
           4     0.7800    0.7647    0.7723        51

    accuracy                         0.5386      1656
   macro avg     0.6083    0.5880    0.5856      1656
weighted avg     0.6049    0.5386    0.5571      1656


Confusion matrix:
 [[359 230  46   4   0]
 [ 87 150  49  10   0]
 [ 43 159 194  51   0]
 [  4  27  31 150  11]
 [  0   0   1  11  39]]

Saved: /content/experiments/densenet121_replication_20260216_173015/logs/test_metrics.json


In [12]:
# ---------- Improvement Experiment: deeper fine-tuning ----------
# Load best checkpoint first (start from best weights)
model.load_state_dict(torch.load(best_path, map_location=device))

# Unfreeze denseblock3 + denseblock4 + norm layers + classifier
for name, param in model.named_parameters():
    param.requires_grad = (
        name.startswith("features.denseblock3") or
        name.startswith("features.denseblock4") or
        name.startswith("features.norm5") or
        name.startswith("classifier")
    )

# Lower LR for stability
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-4,           # LOWER than before
    weight_decay=1e-4
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="max", factor=0.5, patience=2
)

# Run a short improvement fine-tune (5 epochs)
IMPROVE_EPOCHS = 5
best_val_f1_improve = best_val_f1
best_path_improve = f"{RUN_DIR}/checkpoints/best_densenet121_improved.pth"

for e in range(1, IMPROVE_EPOCHS + 1):
    train_loss, train_acc, train_f1 = run_epoch(model, train_loader, train=True)
    val_loss, val_acc, val_f1 = run_epoch(model, val_loader, train=False)

    scheduler.step(val_f1)

    print(f"Improve {e:02d}/{IMPROVE_EPOCHS} | "
          f"train loss {train_loss:.4f} acc {train_acc:.3f} f1 {train_f1:.3f} || "
          f"val loss {val_loss:.4f} acc {val_acc:.3f} f1 {val_f1:.3f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if val_f1 > best_val_f1_improve:
        best_val_f1_improve = val_f1
        torch.save(model.state_dict(), best_path_improve)

print("Improvement done ✅ Best val macro-F1:", best_val_f1_improve)
print("Improved checkpoint:", best_path_improve)


Improve 01/5 | train loss 0.5961 acc 0.750 f1 0.747 || val loss 1.0143 acc 0.575 f1 0.602 | lr 1.00e-04
Improve 02/5 | train loss 0.5915 acc 0.749 f1 0.749 || val loss 0.9565 acc 0.582 f1 0.597 | lr 1.00e-04
Improve 03/5 | train loss 0.5341 acc 0.781 f1 0.777 || val loss 0.9481 acc 0.608 f1 0.601 | lr 1.00e-04
Improve 04/5 | train loss 0.5283 acc 0.782 f1 0.782 || val loss 1.0291 acc 0.579 f1 0.598 | lr 5.00e-05
Improve 05/5 | train loss 0.4692 acc 0.810 f1 0.808 || val loss 0.9777 acc 0.599 f1 0.623 | lr 5.00e-05
Improvement done ✅ Best val macro-F1: 0.6228256608717828
Improved checkpoint: /content/experiments/densenet121_replication_20260216_173015/checkpoints/best_densenet121_improved.pth


In [13]:
# -------- TEST evaluation (improved checkpoint) --------
improved_model = build_densenet121(NUM_CLASSES).to(device)
improved_model.load_state_dict(torch.load(best_path_improve, map_location=device))
improved_model.eval()

all_preds, all_targets = [], []

with torch.no_grad():
    for imgs, targets in test_loader:
        imgs, targets = imgs.to(device), targets.to(device)
        outputs = improved_model(imgs)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy().tolist())
        all_targets.extend(targets.cpu().numpy().tolist())

test_acc = accuracy_score(all_targets, all_preds)
test_macro_f1 = f1_score(all_targets, all_preds, average="macro")

print("IMPROVED TEST accuracy:", test_acc)
print("IMPROVED TEST macro-F1:", test_macro_f1)

print("\nClassification report (improved):")
print(classification_report(all_targets, all_preds, target_names=train_ds.classes, digits=4))

cm = confusion_matrix(all_targets, all_preds)
print("\nConfusion matrix (improved):\n", cm)

# save metrics
improved_metrics = {
    "val_macro_f1_best_improved": float(best_val_f1_improve),
    "test_accuracy_improved": float(test_acc),
    "test_macro_f1_improved": float(test_macro_f1),
    "improvement_note": "Unfroze denseblock3+4 and reduced LR to 1e-4 for 5 epochs"
}

with open(f"{RUN_DIR}/logs/test_metrics_improved.json", "w") as f:
    json.dump(improved_metrics, f, indent=2)

print("\nSaved:", f"{RUN_DIR}/logs/test_metrics_improved.json")


IMPROVED TEST accuracy: 0.5839371980676329
IMPROVED TEST macro-F1: 0.597435930610521

Classification report (improved):
              precision    recall  f1-score   support

           0     0.6790    0.7449    0.7104       639
           1     0.2872    0.3716    0.3240       296
           2     0.6207    0.4430    0.5170       447
           3     0.7087    0.6547    0.6807       223
           4     0.7872    0.7255    0.7551        51

    accuracy                         0.5839      1656
   macro avg     0.6166    0.5879    0.5974      1656
weighted avg     0.6006    0.5839    0.5865      1656


Confusion matrix (improved):
 [[476 134  28   1   0]
 [136 110  46   4   0]
 [ 87 120 198  42   0]
 [  2  19  46 146  10]
 [  0   0   1  13  37]]

Saved: /content/experiments/densenet121_replication_20260216_173015/logs/test_metrics_improved.json


In [15]:
# ============================================================
# DenseNet121 Improvement Experiment: Higher Resolution (320)

# ============================================================

import os, json, time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix

# ---------- Config ----------
IMAGE_SIZE = 320
BATCH_SIZE = 16
EPOCHS_STAGE1 = 2     # classifier only
EPOCHS_STAGE2 = 10    # fine-tune denseblock3+4
LR_STAGE1 = 1e-3
LR_STAGE2 = 1e-4
WEIGHT_DECAY = 1e-4

# Make a separate run folder inside your existing RUN_DIR
RUN320_DIR = os.path.join(RUN_DIR, f"res320_{time.strftime('%H%M%S')}")
os.makedirs(os.path.join(RUN320_DIR, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(RUN320_DIR, "logs"), exist_ok=True)
print("RUN320_DIR:", RUN320_DIR)

# ---------- Transforms ----------
train_tfms_320 = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.15, contrast=0.15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_tfms_320 = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ---------- Datasets ----------
train_ds_320 = datasets.ImageFolder(TRAIN_DIR, transform=train_tfms_320)
val_ds_320   = datasets.ImageFolder(VAL_DIR,   transform=val_tfms_320)
test_ds_320  = datasets.ImageFolder(TEST_DIR,  transform=val_tfms_320)

NUM_CLASSES_320 = len(train_ds_320.classes)
print("Classes:", train_ds_320.classes)

# ---------- Weighted sampler ----------
labels = np.array([y for _, y in train_ds_320.samples])
class_counts = np.bincount(labels, minlength=NUM_CLASSES_320)
class_weights = 1.0 / (class_counts + 1e-6)
sample_weights = torch.DoubleTensor(class_weights[labels])

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader_320 = DataLoader(train_ds_320, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2, pin_memory=True)
val_loader_320   = DataLoader(val_ds_320,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader_320  = DataLoader(test_ds_320,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print("Train class counts:", class_counts)
print("Image size:", IMAGE_SIZE, "Batch size:", BATCH_SIZE)

# ---------- Model ----------
model_320 = build_densenet121(NUM_CLASSES_320).to(device)

# Freeze backbone, train classifier first
for p in model_320.parameters():
    p.requires_grad = False
for p in model_320.classifier.parameters():
    p.requires_grad = True

criterion_320 = nn.CrossEntropyLoss()

def run_epoch_local(model, loader, optimizer=None):
    train = optimizer is not None
    model.train() if train else model.eval()

    total_loss = 0.0
    all_preds, all_targets = [], []

    for imgs, targets in loader:
        imgs, targets = imgs.to(device), targets.to(device)

        if train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(train):
            outputs = model(imgs)
            loss = criterion_320(outputs, targets)
            if train:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)

        all_preds.extend(preds.detach().cpu().numpy().tolist())
        all_targets.extend(targets.detach().cpu().numpy().tolist())

    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_targets, all_preds)
    macro_f1 = f1_score(all_targets, all_preds, average="macro")
    return avg_loss, acc, macro_f1

# ---------- Stage 1 ----------
optimizer = optim.Adam(model_320.classifier.parameters(), lr=LR_STAGE1, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

best_val_f1 = -1.0
best_path = os.path.join(RUN320_DIR, "checkpoints", "best_densenet121_res320.pth")

print(f"\nTraining plan: {EPOCHS_STAGE1 + EPOCHS_STAGE2} epochs ({EPOCHS_STAGE1} head + {EPOCHS_STAGE2} finetune)")

for epoch in range(1, EPOCHS_STAGE1 + 1):
    tr_loss, tr_acc, tr_f1 = run_epoch_local(model_320, train_loader_320, optimizer=optimizer)
    va_loss, va_acc, va_f1 = run_epoch_local(model_320, val_loader_320, optimizer=None)
    scheduler.step(va_f1)

    print(f"Epoch {epoch:02d} [head] | train loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} || "
          f"val loss {va_loss:.4f} acc {va_acc:.3f} f1 {va_f1:.3f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        torch.save(model_320.state_dict(), best_path)

# ---------- Stage 2: fine-tune denseblock3+4 ----------
for name, param in model_320.named_parameters():
    param.requires_grad = (
        name.startswith("features.denseblock3") or
        name.startswith("features.denseblock4") or
        name.startswith("features.norm5") or
        name.startswith("classifier")
    )

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model_320.parameters()),
                       lr=LR_STAGE2, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

start_epoch = EPOCHS_STAGE1 + 1
end_epoch = EPOCHS_STAGE1 + EPOCHS_STAGE2

for epoch in range(start_epoch, end_epoch + 1):
    tr_loss, tr_acc, tr_f1 = run_epoch_local(model_320, train_loader_320, optimizer=optimizer)
    va_loss, va_acc, va_f1 = run_epoch_local(model_320, val_loader_320, optimizer=None)
    scheduler.step(va_f1)

    print(f"Epoch {epoch:02d} [ft ] | train loss {tr_loss:.4f} acc {tr_acc:.3f} f1 {tr_f1:.3f} || "
          f"val loss {va_loss:.4f} acc {va_acc:.3f} f1 {va_f1:.3f} | lr {optimizer.param_groups[0]['lr']:.2e}")

    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        torch.save(model_320.state_dict(), best_path)

print("\nBest val macro-F1 (res320):", best_val_f1)
print("Best checkpoint:", best_path)

# ---------- TEST evaluation ----------
best_model = build_densenet121(NUM_CLASSES_320).to(device)
best_model.load_state_dict(torch.load(best_path, map_location=device))
best_model.eval()

all_preds, all_targets = [], []
with torch.no_grad():
    for imgs, targets in test_loader_320:
        imgs, targets = imgs.to(device), targets.to(device)
        outputs = best_model(imgs)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy().tolist())
        all_targets.extend(targets.cpu().numpy().tolist())

test_acc = accuracy_score(all_targets, all_preds)
test_macro_f1 = f1_score(all_targets, all_preds, average="macro")

print("\nRES320 TEST accuracy:", test_acc)
print("RES320 TEST macro-F1:", test_macro_f1)

print("\nClassification report (res320):")
print(classification_report(all_targets, all_preds, target_names=train_ds_320.classes, digits=4))

cm = confusion_matrix(all_targets, all_preds)
print("\nConfusion matrix (res320):\n", cm)

metrics = {
    "image_size": IMAGE_SIZE,
    "batch_size": BATCH_SIZE,
    "epochs_stage1": EPOCHS_STAGE1,
    "epochs_stage2": EPOCHS_STAGE2,
    "lr_stage1": LR_STAGE1,
    "lr_stage2": LR_STAGE2,
    "best_val_macro_f1": float(best_val_f1),
    "test_accuracy": float(test_acc),
    "test_macro_f1": float(test_macro_f1),
}
with open(os.path.join(RUN320_DIR, "logs", "metrics_res320.json"), "w") as f:
    json.dump(metrics, f, indent=2)

print("\nSaved metrics:", os.path.join(RUN320_DIR, "logs", "metrics_res320.json"))


RUN320_DIR: /content/experiments/densenet121_replication_20260216_173015/res320_175510
Classes: ['0', '1', '2', '3', '4']
Train class counts: [2286 1046 1516  757  173]
Image size: 320 Batch size: 16

Training plan: 12 epochs (2 head + 10 finetune)
Epoch 01 [head] | train loss 1.5589 acc 0.293 f1 0.288 || val loss 1.4107 acc 0.406 f1 0.270 | lr 1.00e-03
Epoch 02 [head] | train loss 1.4527 acc 0.350 f1 0.345 || val loss 1.4532 acc 0.407 f1 0.274 | lr 1.00e-03
Epoch 03 [ft ] | train loss 0.9798 acc 0.573 f1 0.568 || val loss 1.1018 acc 0.435 f1 0.526 | lr 1.00e-04
Epoch 04 [ft ] | train loss 0.7847 acc 0.656 f1 0.658 || val loss 0.9263 acc 0.592 f1 0.610 | lr 1.00e-04
Epoch 05 [ft ] | train loss 0.7124 acc 0.693 f1 0.690 || val loss 0.9128 acc 0.585 f1 0.601 | lr 1.00e-04
Epoch 06 [ft ] | train loss 0.6727 acc 0.702 f1 0.696 || val loss 0.8529 acc 0.653 f1 0.638 | lr 1.00e-04
Epoch 07 [ft ] | train loss 0.6540 acc 0.719 f1 0.717 || val loss 0.9504 acc 0.592 f1 0.630 | lr 1.00e-04
Epoch 0