In [2]:
%pip install torch torchvision timm scikit-learn pillow matplotlib seaborn tqdm -q

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
# run_fastvit_t8_highres_v3.py
# Final high-resolution experiment for 24 IMC paper (with robust resume for scheduler)
# Author: <your-name>          Date: 2025-08-xx

import warnings, random, json, time
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
from tqdm import tqdm
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import timm
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, confusion_matrix, classification_report)
import matplotlib.pyplot as plt, seaborn as sns

# ──────────────────────────────────────────────────────────────
# 1. CONFIGURATION
# ──────────────────────────────────────────────────────────────
MODEL_NAME     = "fastvit_t8"
ORIG_DATA_DIR  = Path("../datasets")
CKPT_DIR       = Path("trained_models")/MODEL_NAME
LOG_CSV        = Path("logs")/f"{MODEL_NAME}.csv"
OUT_JSON       = Path("evaluation_results")/f"{MODEL_NAME}_metrics.json"

IMG_SIZE       = 224
BATCH          = 8
ACC_STEPS      = 4
FROZEN_EPOCHS  = 5
TOTAL_EPOCHS   = 20
PATIENCE       = 5
LR             = 1e-4
WD             = 1e-4
WORKERS        = 4
SEED           = 42
USE_COMPILE    = False
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")
AMP            = torch.cuda.is_available()

for p in ["trained_models", "logs", "evaluation_results"]:
    Path(p).mkdir(exist_ok=True)

torch.manual_seed(SEED); np.random.seed(SEED); random.seed(SEED)
torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

# ──────────────────────────────────────────────────────────────
# 2. DATA TRANSFORMS & LOADERS
# ──────────────────────────────────────────────────────────────
def build_transforms():
    norm = transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    train_tf = transforms.Compose([transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8,1.0)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), norm])
    val_tf   = transforms.Compose([transforms.Resize(int(IMG_SIZE*1.12)), transforms.CenterCrop(IMG_SIZE), transforms.ToTensor(), norm])
    return train_tf, val_tf

full_ds = datasets.ImageFolder(ORIG_DATA_DIR)
CLASS_NAMES = full_ds.classes
NUM_CLASSES = len(CLASS_NAMES)
train_idx, val_idx = train_test_split(np.arange(len(full_ds.targets)), test_size=0.2, stratify=full_ds.targets, random_state=SEED)

# ──────────────────────────────────────────────────────────────
# 3. HELPERS
# ──────────────────────────────────────────────────────────────
def freeze_backbone(model, train_full=False):
    for n,p in model.named_parameters():
        p.requires_grad = train_full or any(k in n for k in ("head","fc","classifier"))

def maybe_compile(m):
    if USE_COMPILE and hasattr(torch,"compile"):
        try: return torch.compile(m, dynamic=False)
        except Exception as e: print("torch.compile disabled →", e)
    return m

def evaluate(model, loader):
    model.eval(); preds=[]; labels=[]
    with torch.no_grad(), torch.cuda.amp.autocast(enabled=AMP):
        for xb,yb in loader:
            preds.extend(model(xb.to(DEVICE)).argmax(1).cpu().numpy())
            labels.extend(yb.numpy())
    return np.array(preds), np.array(labels)

# ──────────────────────────────────────────────────────────────
# 4. TRAINING LOOP
# ──────────────────────────────────────────────────────────────
def main():
    train_tf, val_tf = build_transforms()
    train_ds = Subset(full_ds, train_idx); train_ds.dataset.transform = train_tf
    val_ds   = Subset(full_ds, val_idx);   val_ds.dataset.transform   = val_tf
    train_ld = DataLoader(train_ds, BATCH, True , num_workers=WORKERS, pin_memory=True)
    val_ld   = DataLoader(val_ds,   BATCH, False, num_workers=WORKERS, pin_memory=True)

    model = timm.create_model(MODEL_NAME, pretrained=True, num_classes=NUM_CLASSES).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler(enabled=AMP)

    start_ep, best_acc, patience_left = 0, 0., PATIENCE
    optimizer_state, scaler_state, scheduler_state = None, None, None # --- MODIFIED ---

    if (CKPT_DIR/"last.pth").is_file():
        chk = torch.load(CKPT_DIR/"last.pth", map_location=DEVICE)
        model.load_state_dict(chk["model"])
        start_ep       = chk["epoch"] + 1
        best_acc       = chk["best_acc"]
        patience_left  = chk["patience"]
        optimizer_state = chk["optim"]
        scaler_state    = chk["scaler"]
        scheduler_state = chk.get("scheduler") # --- MODIFIED --- .get() is safer
        print(f"↪ Resuming from epoch {start_ep}")

    if start_ep < FROZEN_EPOCHS:
        freeze_backbone(model, False)
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WD)
    else:
        freeze_backbone(model, True)
        optimizer = optim.AdamW(model.parameters(), lr=LR*0.5, weight_decay=WD)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS) # --- MODIFIED --- No last_epoch needed

    if optimizer_state: optimizer.load_state_dict(optimizer_state)
    if scaler_state: scaler.load_state_dict(scaler_state)
    if scheduler_state: scheduler.load_state_dict(scheduler_state) # --- MODIFIED ---

    model = maybe_compile(model)

    if start_ep==0 and LOG_CSV.exists(): LOG_CSV.unlink()

    for ep in range(start_ep, TOTAL_EPOCHS):
        if ep == FROZEN_EPOCHS:
            freeze_backbone(model, True)
            optimizer = optim.AdamW(model.parameters(), lr=LR*0.5, weight_decay=WD)
            # Recreate scheduler for the new optimizer
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_EPOCHS, last_epoch=ep-1)

        model.train(); seen=0; loss_sum=0; correct=0
        optimizer.zero_grad()
        for i,(xb,yb) in enumerate(tqdm(train_ld, desc=f"Epoch {ep+1}/{TOTAL_EPOCHS}")):
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            with torch.cuda.amp.autocast(enabled=AMP):
                out = model(xb); loss = criterion(out,yb)/ACC_STEPS
            scaler.scale(loss).backward()
            if (i+1)%ACC_STEPS==0:
                scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
            loss_sum += loss.item()*ACC_STEPS*xb.size(0)
            correct  += (out.argmax(1)==yb).sum().item()
            seen     += xb.size(0)
        train_loss, train_acc = loss_sum/seen, correct/seen

        preds, labels = evaluate(model, val_ld)
        val_acc = accuracy_score(labels, preds)
        scheduler.step()

        with open(LOG_CSV,"a") as f:
            if ep==0: f.write("epoch,train_loss,train_acc,val_acc,lr\n")
            f.write(f"{ep},{train_loss:.5f},{train_acc:.5f},{val_acc:.5f},{scheduler.get_last_lr()[0]:.6f}\n")
        print(f"val_acc={val_acc:.4f}  best={best_acc:.4f}  patience={patience_left}")

        # --- MODIFIED ---: Save scheduler state in checkpoint
        state = {"epoch":ep,"model":model.state_dict(),"optim":optimizer.state_dict(),
                 "scaler":scaler.state_dict(),"scheduler":scheduler.state_dict(),
                 "best_acc":best_acc,"patience":patience_left}
        CKPT_DIR.mkdir(parents=True, exist_ok=True)
        torch.save(state, CKPT_DIR/"last.pth")
        torch.save(state, CKPT_DIR/f"epoch{ep:03d}.pth")
        
        if val_acc > best_acc:
            best_acc = val_acc
            patience_left = PATIENCE
            torch.save(state, CKPT_DIR/"best.pth")
        else:
            patience_left -= 1
        
        if patience_left == 0 and ep >= FROZEN_EPOCHS:
            print("Early-stopping triggered.")
            break

    # Final Evaluation
    best_state = torch.load(CKPT_DIR/"best.pth", map_location=DEVICE)
    model.load_state_dict(best_state["model"])
    preds, labels = evaluate(model, val_ld)
    t0=time.time()
    with torch.no_grad():
        for xb,_ in val_ld: model(xb.to(DEVICE))
    latency = (time.time()-t0)/len(val_ds)
    metrics = dict(model=MODEL_NAME, image_size=IMG_SIZE, val_accuracy=accuracy_score(labels, preds), precision=precision_score(labels, preds, average="weighted", zero_division=0), recall=recall_score(labels, preds, average="weighted",  zero_division=0), f1_score=f1_score(labels, preds,  average="weighted", zero_division=0), inf_sec_per_img=latency, conf_matrix=confusion_matrix(labels, preds).tolist(), class_report=classification_report(labels, preds, target_names=CLASS_NAMES, zero_division=0, output_dict=True))
    OUT_JSON.parent.mkdir(exist_ok=True)
    json.dump(metrics, open(OUT_JSON,"w"), indent=2)
    print("✓ Metrics saved →", OUT_JSON)
    sns.heatmap(metrics["conf_matrix"], cmap="Blues", cbar=False, annot=False, xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.title(f"{MODEL_NAME} Confusion Matrix"); plt.tight_layout()
    plt.savefig(OUT_JSON.with_suffix(".png"), dpi=300); plt.close()
    print("✓ Confusion matrix plot saved.")
    print("Finished! Checkpoints in", CKPT_DIR)

if __name__ == "__main__":
    main()


↪ Resuming from epoch 19


Epoch 20/20: 100%|██████████| 1601/1601 [12:39<00:00,  2.11it/s]


val_acc=0.9966  best=0.9966  patience=2
✓ Metrics saved → evaluation_results\fastvit_t8_metrics.json
✓ Confusion matrix plot saved.
Finished! Checkpoints in trained_models\fastvit_t8
