In [1]:
# --- Imports & basic config ---
import math
from pathlib import Path
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet18, ResNet18_Weights

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def device_auto() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

device = device_auto()
device

device(type='cuda')

Dataloaders

In [None]:
# --- Dataloaders for aligned 224x224 crops (ImageFolder layout) ---

def make_dataloaders(
    train_dir: str,
    val_dir: str,
    batch_size: int = 64,
    num_workers: int = 4,
    aug: bool = True,
) -> Tuple[DataLoader, DataLoader, List[str]]:
    # Only light aug; no Resize/CenterCrop because files are already 224x224
    if aug:
        train_tfms = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])
    else:
        train_tfms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
        ])

    val_tfms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

    train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
    val_ds   = datasets.ImageFolder(val_dir,   transform=val_tfms)

    classes = train_ds.classes
    assert classes == val_ds.classes, "Train/val classes differ!"

    pin_memory = torch.cuda.is_available()
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=pin_memory)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                              num_workers=num_workers, pin_memory=pin_memory)
    return train_loader, val_loader, classes

Build model & utilities (freeze FC, save backbone-only)

In [3]:
# --- Build ResNet18 with a replaceable head + helpers ---

def build_resnet18(num_classes: int) -> nn.Module:
    """
    Loads ImageNet-pretrained ResNet-18 and replaces the final head:
    - Binary (2 classes) -> 1 logit (for BCEWithLogitsLoss)
    - Multi-class (K>=3) -> K logits (for CrossEntropyLoss)
    """
    model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    in_feats = model.fc.in_features
    if num_classes == 2:
        model.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, 1))
    else:
        model.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
    return model

def freeze_fc_only(model: nn.Module):
    """Freeze only the classifier head so backbone trains."""
    for p in model.fc.parameters():
        p.requires_grad = False

def save_backbone_only(model: nn.Module, path: str, img_size: int = 224):
    """
    Save a checkpoint containing only the backbone (drops fc.* keys).
    This avoids head-shape mismatch later.
    """
    state = model.state_dict()
    backbone_only = {k: v for k, v in state.items() if not k.startswith("fc.")}
    torch.save({
        "arch": "resnet18",
        "state_dict": backbone_only,
        "img_size": img_size,
        "mean": IMAGENET_MEAN,
        "std":  IMAGENET_STD,
    }, path)


Eval helper (loss + accuracy)

In [4]:
# --- Evaluation (loss + accuracy) ---

@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device, num_classes: int) -> Tuple[float, float]:
    model.eval()
    total, correct, total_loss = 0, 0, 0.0

    criterion = nn.BCEWithLogitsLoss() if num_classes == 2 else nn.CrossEntropyLoss()

    for imgs, targets in loader:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        logits = model(imgs)

        if num_classes == 2:
            targets_f = targets.float().unsqueeze(1)   # (N,1)
            loss = criterion(logits, targets_f)
            preds = (torch.sigmoid(logits) >= 0.5).long().squeeze(1)
        else:
            loss = criterion(logits, targets)
            preds = logits.argmax(dim=1)

        total_loss += float(loss.item()) * imgs.size(0)
        correct    += int((preds == targets).sum().item())
        total      += int(imgs.size(0))

    return total_loss / max(total, 1), correct / max(total, 1)


Generic training loop (cosine schedule, early stop)

In [5]:
# --- Training loop with AdamW + CosineAnnealingLR + EarlyStopping ---

def train_one_phase(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    num_classes: int,
    epochs: int = 10,
    lr: float = 1e-3,
    weight_decay: float = 1e-2,
    patience: int = 5,
    pos_weight: Optional[float] = None,
    label_smoothing: float = 0.05,
    cosine_eta_min_scale: float = 0.01,
    desc: str = "Train",
) -> Tuple[nn.Module, float]:
    # ----- Loss
    if num_classes == 2:
        criterion = nn.BCEWithLogitsLoss(
            pos_weight=(torch.tensor([pos_weight], device=device) if pos_weight is not None else None)
        )
    else:
        criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    # ----- Optimizer over current trainable params only
    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable, lr=lr, weight_decay=weight_decay)

    # ----- Cosine LR schedule
    eta_min = lr * cosine_eta_min_scale
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=eta_min)

    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

    best_acc, best_state, no_improve = -1.0, None, 0

    for epoch in range(1, epochs + 1):
        model.train()
        total, correct, total_loss = 0, 0, 0.0

        for imgs, targets in train_loader:
            imgs = imgs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                logits = model(imgs)
                if num_classes == 2:
                    targets_f = targets.float().unsqueeze(1)
                    loss = criterion(logits, targets_f)
                    preds = (torch.sigmoid(logits) >= 0.5).long().squeeze(1)
                else:
                    loss = criterion(logits, targets)
                    preds = logits.argmax(dim=1)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += float(loss.item()) * imgs.size(0)
            correct    += int((preds == targets).sum().item())
            total      += int(imgs.size(0))

        scheduler.step()
        train_loss = total_loss / max(total, 1)
        train_acc  = correct / max(total, 1)
        val_loss, val_acc = evaluate(model, val_loader, device, num_classes)

        print(f"[{desc}] Epoch {epoch:03d}/{epochs} | "
              f"train {train_loss:.4f}/{train_acc:.4f} | "
              f"val {val_loss:.4f}/{val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[{desc}] Early stopping (patience={patience}).")
                break

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)
    return model, best_acc


Phase A: Train backbone only and save backbone-only checkpoint

In [6]:
# === PHASE 6 (alt): dataset + loaders with optional 80/20 split ===
import copy
from pathlib import Path
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets, transforms

# ---- Config you may tweak ----
DATA_DIR     = "data"   # parent folder; expects train/val OR single folder of class subdirs
BATCH_SIZE   = 64
WORKERS      = 4
USE_AUG      = True
SEED         = 42               # for deterministic split

# ---- Transforms (aligned 224x224 crops; no resize/crop needed) ----
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
]) if USE_AUG else transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

val_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# ---- Dataset discovery ----
data_dir  = Path(DATA_DIR)
train_dir = data_dir / "train"
val_dir   = data_dir / "val"

if train_dir.is_dir() and val_dir.is_dir():
    # Case A: explicit train/val subdirs
    train_ds = datasets.ImageFolder(train_dir, transform=train_tfms)
    val_ds   = datasets.ImageFolder(val_dir,   transform=val_tfms)
else:
    # Case B: single folder with class subdirs → split 80/20
    full_ds  = datasets.ImageFolder(data_dir, transform=train_tfms)
    n_val    = max(1, int(0.2 * len(full_ds)))
    n_train  = len(full_ds) - n_val
    train_ds, val_idx = random_split(
        full_ds, [n_train, n_val],
        generator=torch.Generator().manual_seed(SEED)
    )
    # Rebuild a val dataset that shares the same files but uses val transforms
    val_ds = Subset(copy.deepcopy(full_ds), val_idx.indices)
    val_ds.dataset.transform = val_tfms

# ---- Class info ----
classes = train_ds.dataset.classes if isinstance(train_ds, Subset) else train_ds.classes
num_classes = len(classes)
print(f"Classes ({num_classes}): {classes}")

# ---- DataLoaders ----
pin_mem = torch.cuda.is_available()
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=WORKERS, pin_memory=pin_mem)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=WORKERS, pin_memory=pin_mem)

# ---- (Optional) compute pos_weight for binary imbalance (neg/pos) ----
POS_WEIGHT = None
if num_classes == 2:
    # Count by scanning underlying dataset.samples
    # Works for both Case A and Case B
    def _count_samples(ds):
        if isinstance(ds, Subset):
            # ds.dataset is ImageFolder; ds.indices selects subset
            labels = [ds.dataset.samples[i][1] for i in ds.indices]
        else:
            labels = [y for _, y in ds.samples]
        neg = sum(1 for y in labels if y == 0)
        pos = sum(1 for y in labels if y == 1)
        return neg, pos

    neg, pos = _count_samples(train_ds)
    if pos > 0:
        POS_WEIGHT = neg / float(pos)
        print(f"[Binary] Train counts: neg={neg}, pos={pos} → pos_weight={POS_WEIGHT:.4f}")
    else:
        print("[Binary] WARNING: no positive samples in train set; pos_weight not set.")


Classes (2): ['Negative', 'Positive']
[Binary] Train counts: neg=1619, pos=253 → pos_weight=6.3992


In [9]:
OUT_DIR   = "outputs"

BATCH_SIZE = 64
WORKERS    = 4
EPOCHS_BACKBONE = 10
LR_BACKBONE     = 1e-3
WEIGHT_DECAY    = 1e-2
PATIENCE        = 5
LABEL_SMOOTHING = 0.05
COSINE_ETA_MIN_SCALE = 0.01
USE_AUG = True

model = build_resnet18(num_classes)
freeze_fc_only(model)
model = model.to(device)

# --- Train backbone only ---
model, best_acc = train_one_phase(
    model, train_loader, val_loader, device, num_classes,
    epochs=EPOCHS_BACKBONE,
    lr=LR_BACKBONE,
    weight_decay=WEIGHT_DECAY,
    patience=PATIENCE,
    pos_weight=POS_WEIGHT,
    label_smoothing=LABEL_SMOOTHING,
    cosine_eta_min_scale=COSINE_ETA_MIN_SCALE,
    desc="Backbone-only"
)

# --- Save backbone-only checkpoint (drops fc.*) ---
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)
BACKBONE_CKPT = str(Path(OUT_DIR) / "resnet18_backbone_only.pt")
save_backbone_only(model, BACKBONE_CKPT, img_size=224)
print("Saved:", BACKBONE_CKPT, "| Best val acc:", f"{best_acc:.4f}")

  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


[Backbone-only] Epoch 001/10 | train 0.5326/0.8686 | val 9.6868/0.3041
[Backbone-only] Epoch 002/10 | train 0.2293/0.9647 | val 0.0892/0.9679
[Backbone-only] Epoch 003/10 | train 0.0990/0.9840 | val 0.0690/0.9743
[Backbone-only] Epoch 004/10 | train 0.0917/0.9840 | val 0.6385/0.8223
[Backbone-only] Epoch 005/10 | train 0.0698/0.9904 | val 0.0828/0.9829
[Backbone-only] Epoch 006/10 | train 0.0516/0.9915 | val 0.0483/0.9829
[Backbone-only] Epoch 007/10 | train 0.0370/0.9963 | val 0.1074/0.9700
[Backbone-only] Epoch 008/10 | train 0.0325/0.9963 | val 0.0230/0.9914
[Backbone-only] Epoch 009/10 | train 0.0177/0.9984 | val 0.0234/0.9936
[Backbone-only] Epoch 010/10 | train 0.0117/0.9989 | val 0.0245/0.9914
Saved: outputs\resnet18_backbone_only.pt | Best val acc: 0.9936


Phase B (step 1): Attach new head, warm-start by training head only

In [10]:
# === PHASE B-1: Load backbone-only ckpt, attach new head, train head-only ===
# You can point TRAIN_DIR/VAL_DIR to a new dataset if desired.
BACKBONE_CKPT = str(Path(OUT_DIR) / "resnet18_backbone_only.pt")

# Rebuild base & head for current K
model_ft = resnet18(weights=None)
in_feats = model_ft.fc.in_features
if num_classes == 2:
    model_ft.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, 1))
else:
    model_ft.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))

# Load the backbone weights (fc.* are intentionally missing)
ckpt = torch.load(BACKBONE_CKPT, map_location="cpu")
missing, unexpected = model_ft.load_state_dict(ckpt["state_dict"], strict=False)
print("load_state_dict missing:", missing, "unexpected:", unexpected)

model_ft = model_ft.to(device)

# --- Warm-start: head only ---
for p in model_ft.parameters():
    p.requires_grad = False
for p in model_ft.fc.parameters():
    p.requires_grad = True

EPOCHS_HEAD = 5
LR_HEAD     = 1e-3

model_ft, best_acc_head = train_one_phase(
    model_ft, train_loader, val_loader, device, num_classes,
    epochs=EPOCHS_HEAD,
    lr=LR_HEAD,
    weight_decay=0.0,                 # usually no WD for small head
    patience=max(2, PATIENCE // 2),
    pos_weight=POS_WEIGHT,
    label_smoothing=LABEL_SMOOTHING,
    cosine_eta_min_scale=COSINE_ETA_MIN_SCALE,
    desc="Head-only"
)

print("Best val acc (head-only):", f"{best_acc_head:.4f}")

load_state_dict missing: ['fc.1.weight', 'fc.1.bias'] unexpected: []


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


[Head-only] Epoch 001/5 | train 0.1777/0.9583 | val 0.0256/0.9957
[Head-only] Epoch 002/5 | train 0.0223/0.9989 | val 0.0203/0.9957
[Head-only] Epoch 003/5 | train 0.0123/1.0000 | val 0.0189/0.9957
[Head-only] Early stopping (patience=2).
Best val acc (head-only): 0.9957


Phase B (step 2): Unfreeze backbone and fine-tune end-to-end (differential LRs)

In [11]:
# === PHASE B-2: Unfreeze backbone and fine-tune end-to-end ===

for p in model_ft.parameters():
    p.requires_grad = True

LR_FINETUNE_HEAD     = 5e-4
LR_FINETUNE_BACKBONE = 1e-4
EPOCHS_FINETUNE      = 10

# Build optimizer with differential LRs
optim_params = [
    {"params": (p for n,p in model_ft.named_parameters() if n.startswith("fc.")),
     "lr": LR_FINETUNE_HEAD, "weight_decay": 0.0},
    {"params": (p for n,p in model_ft.named_parameters() if not n.startswith("fc.")),
     "lr": LR_FINETUNE_BACKBONE, "weight_decay": WEIGHT_DECAY},
]
optimizer = torch.optim.AdamW(optim_params)

# Loss for the fine-tune phase
criterion_ft = (nn.BCEWithLogitsLoss(pos_weight=(torch.tensor([POS_WEIGHT], device=device) if (num_classes==2 and POS_WEIGHT is not None) else None))
                if num_classes == 2
                else nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING))

eta_min = max(LR_FINETUNE_HEAD, LR_FINETUNE_BACKBONE) * COSINE_ETA_MIN_SCALE
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_FINETUNE, eta_min=eta_min)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

best_acc_finetune, best_state, no_improve = -1.0, None, 0

print("Fine-tuning...")
for epoch in range(1, EPOCHS_FINETUNE + 1):
    model_ft.train()
    total, correct, total_loss = 0, 0, 0.0

    for imgs, targets in train_loader:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            logits = model_ft(imgs)
            if num_classes == 2:
                targets_f = targets.float().unsqueeze(1)
                loss = criterion_ft(logits, targets_f)
                preds = (torch.sigmoid(logits) >= 0.5).long().squeeze(1)
            else:
                loss = criterion_ft(logits, targets)
                preds = logits.argmax(dim=1)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item()) * imgs.size(0)
        correct    += int((preds == targets).sum().item())
        total      += int(imgs.size(0))

    scheduler.step()
    train_loss = total_loss / max(total, 1)
    train_acc  = correct / max(total, 1)

    val_loss, val_acc = evaluate(model_ft, val_loader, device, num_classes)
    print(f"[Finetune] Epoch {epoch:03d}/{EPOCHS_FINETUNE} | "
          f"train {train_loss:.4f}/{train_acc:.4f} | "
          f"val {val_loss:.4f}/{val_acc:.4f}")

    if val_acc > best_acc_finetune:
        best_acc_finetune = val_acc
        best_state = {k: v.cpu() for k, v in model_ft.state_dict().items()}
        no_improve = 0
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"[Finetune] Early stopping (patience={PATIENCE}).")
            break

if best_state is not None:
    model_ft.load_state_dict(best_state, strict=True)

FINAL_CKPT = str(Path(OUT_DIR) / "resnet18_finetuned_with_head.pt")
torch.save(model_ft.state_dict(), FINAL_CKPT)
print("Saved:", FINAL_CKPT, "| Best val acc (finetune):", f"{best_acc_finetune:.4f}")


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))


Fine-tuning...


  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


[Finetune] Epoch 001/10 | train 0.0171/0.9989 | val 0.0161/0.9936
[Finetune] Epoch 002/10 | train 0.0040/0.9989 | val 0.0173/0.9957
[Finetune] Epoch 003/10 | train 0.0023/1.0000 | val 0.0185/0.9936
[Finetune] Epoch 004/10 | train 0.0044/1.0000 | val 0.0247/0.9872
[Finetune] Epoch 005/10 | train 0.0037/0.9989 | val 0.0216/0.9914
[Finetune] Epoch 006/10 | train 0.0013/1.0000 | val 0.0191/0.9936
[Finetune] Epoch 007/10 | train 0.0013/1.0000 | val 0.0109/0.9936
[Finetune] Early stopping (patience=5).
Saved: outputs\resnet18_finetuned_with_head.pt | Best val acc (finetune): 0.9957


Inference helper

In [7]:
# --- Inference helper (binary & multiclass) ---
from PIL import Image

infer_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

def predict_image(model: nn.Module, img_path: str, classes: List[str]) -> dict:
    model.eval()
    img = Image.open(img_path).convert("RGB")
    x = infer_tfms(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(x)
        if len(classes) == 2:
            p = torch.sigmoid(logits).item()
            pred_idx = int(p >= 0.5)
            return {"pred_idx": pred_idx, "pred_class": classes[pred_idx], "prob_positive": float(p)}
        else:
            probs = torch.softmax(logits, dim=1).squeeze(0)
            pred_idx = int(probs.argmax().item())
            return {"pred_idx": pred_idx, "pred_class": classes[pred_idx], "probs": probs.cpu().tolist()}


In [16]:
predict_image(model_ft, "t/1.jpg", classes)

{'pred_idx': 1, 'pred_class': 'Positive', 'prob_positive': 0.7976053953170776}