In [None]:
# T-A: head-only fine-tuning (best-by macro-F1)
from pathlib import Path
import json, numpy as np, torch
from torch import nn, optim
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score, confusion_matrix

# project paths (adjust if needed)
ROOT = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()
DATA_ROOT = ROOT / "asl_alphabet_train"
ART = ROOT / "artifacts"
CKPT = ROOT / "checkpoints"; CKPT.mkdir(exist_ok=True)
RESULTS = ROOT / "results"; RESULTS.mkdir(exist_ok=True)

SEED = 429
BATCH_SIZE = 64
IMG_SIZE = 224
NUM_WORKERS = 2

torch.manual_seed(SEED)

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    torch.mps.manual_seed(SEED)
else:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if DEVICE.type == "cuda":
        torch.cuda.manual_seed_all(SEED)

print("Device:", DEVICE)

In [None]:
# load class mapping and split indices
with open(ART / "class_to_idx.json") as f:
    class_to_idx = json.load(f)
classes = [c for c, _ in sorted(class_to_idx.items(), key=lambda kv: kv[1])]
num_classes = len(classes)

train_idx = np.load(ART / "train_idx.npy")
val_idx   = np.load(ART / "val_idx.npy")
print("Classes:", num_classes, classes[:10], "...")

# transforms (ImageNet stats)
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# rebuild full dataset from folder structure
full_ds = datasets.ImageFolder(root=DATA_ROOT, transform=tfm)
# Minimal sanity: expected classes present
assert set(class_to_idx.keys()).issubset(set(full_ds.classes)), "Missing expected classes."

# fixed split subsets
train_ds = Subset(full_ds, train_idx.tolist())
val_ds   = Subset(full_ds, val_idx.tolist())

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

print("Batches -> train:", len(train_loader), "val:", len(val_loader))

In [None]:
# ImageNet-pretrained ResNet-18
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# replace classifier head to 27 outputs
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

# freeze all backbone params
for p in model.parameters():
    p.requires_grad = False
# Unfreeze ONLY the new head
for p in model.fc.parameters():
    p.requires_grad = True

# keep frozen BatchNorms in eval to avoid running-stat drift
def _bn_eval(m):
    if isinstance(m, nn.BatchNorm2d): m.eval()
model.apply(_bn_eval)

model = model.to(DEVICE)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

In [None]:
# define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-3)

# helper functions
def _batch_correct(logits, y_true):
    preds = logits.argmax(1)
    correct = (preds == y_true).sum().item()
    return preds, correct

# evaluate with F1 score
@torch.no_grad()
def evaluate_with_f1(model, loader):
    model.eval()
    total, correct, running_loss = 0, 0, 0.0
    all_pred, all_true = [], []
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        loss = criterion(logits, yb)
        running_loss += loss.item() * xb.size(0)
        preds, corr = _batch_correct(logits, yb)
        all_pred.extend(preds.cpu().tolist())
        all_true.extend(yb.cpu().tolist())
        correct += corr
        total += xb.size(0)
    acc = correct / total
    macro_f1 = f1_score(all_true, all_pred, average="macro")
    cm = confusion_matrix(all_true, all_pred, labels=list(range(num_classes)))
    return running_loss / total, acc, macro_f1, cm

In [None]:
from copy import deepcopy

# extend if val F1 keeps improving
EPOCHS = 10  
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_f1": []}

best_val_f1 = -1.0
best_state = None

for epoch in range(1, EPOCHS + 1):
    # train
    model.train()
    tr_total, tr_correct, tr_running = 0, 0, 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        tr_running += loss.item() * xb.size(0)
        preds, corr = _batch_correct(logits, yb)
        tr_correct += corr
        tr_total += xb.size(0)

    tr_loss = tr_running / tr_total
    tr_acc  = tr_correct / tr_total

    # validate (loss, acc, macro-F1)
    va_loss, va_acc, va_f1, _ = evaluate_with_f1(model, val_loader)

    # log
    history["train_loss"].append(tr_loss)
    history["train_acc"].append(tr_acc)
    history["val_loss"].append(va_loss)
    history["val_acc"].append(va_acc)
    history["val_f1"].append(va_f1)

    # model selection by macro-F1
    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        best_state = deepcopy(model.state_dict())
        tag = "<= BEST"
    else:
        tag = ""
    print(f"Epoch {epoch:02d} | "
          f"train: loss {tr_loss:.4f}, acc {tr_acc:.3f} | "
          f"val: loss {va_loss:.4f}, acc {va_acc:.3f}, F1 {va_f1:.4f} {tag}")

# save best-by-F1 checkpoint
best_path = CKPT / "best_TA.pt"
torch.save(best_state, best_path)
print("Saved best T-A to:", best_path, "| Best Val Macro-F1:", f"{best_val_f1:.4f}")

In [None]:
# save training curves for plotting later
import pandas as pd
pd.DataFrame(history).to_csv(RESULTS / "TA_history.csv", index=False)

# Confusion matrix for best checkpoint (on val)
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
_, _, va_f1_best, cm_best = evaluate_with_f1(model, val_loader)
np.save(RESULTS / "TA_val_confusion_matrix.npy", cm_best)
print("Val macro-F1 (best):", f"{va_f1_best:.4f}")