<a href="https://colab.research.google.com/github/Rafi076/RTFER/blob/main/ResNet34_FTA_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Block 0 — Setup & Imports**

In [None]:
# Block 0: Setup & Imports
import os, zipfile, random, copy
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Reproducibility
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


**Block 1 — Access ZIP & Detect Folders**

In [None]:
# Block 1: Access ZIP & Detect Folders (robust)
ZIP_PATH = "/content/FER-2013.zip"
ROOT     = "/content/FER-2013"   # extraction root

# 1) Unzip once
if os.path.exists(ZIP_PATH) and not os.path.isdir(ROOT):
    import zipfile
    with zipfile.ZipFile(ZIP_PATH, "r") as z:
        z.extractall(ROOT)
    print("Unzipped FER-2013 to:", ROOT)

# 2) Find the DEEPEST folder that actually contains 'train' or 'test'
def find_base(start):
    candidates = []
    for root, dirs, files in os.walk(start):
        if "train" in dirs or "test" in dirs or "PublicTest" in dirs or "PrivateTest" in dirs:
            candidates.append(root)
    if not candidates:
        return start
    # pick the deepest path
    return sorted(candidates, key=lambda p: len(p.split("/")))[-1]

BASE = find_base(ROOT)
print("Base folder detected:", BASE)

# 3) Prefer official split names if present
TRAIN_DIR_OFF   = os.path.join(BASE, "train")
PUBLIC_DIR_OFF  = os.path.join(BASE, "PublicTest")
PRIVATE_DIR_OFF = os.path.join(BASE, "PrivateTest")

HAS_OFFICIAL = all(os.path.isdir(p) for p in [TRAIN_DIR_OFF, PUBLIC_DIR_OFF, PRIVATE_DIR_OFF])
print("Official splits found:", HAS_OFFICIAL)

# 4) Fallback (your ZIP has only train/test)
TRAIN_SPLIT = os.path.join(BASE, "train_split")
VAL_SPLIT   = os.path.join(BASE, "val_split")
TEST_DIR_FALLBACK = os.path.join(BASE, "test")

if HAS_OFFICIAL:
    TRAIN_DIR = TRAIN_DIR_OFF
    VAL_DIR   = PUBLIC_DIR_OFF
    TEST_DIR  = PRIVATE_DIR_OFF  # FINAL TEST
else:
    TRAIN_DIR = TRAIN_SPLIT if os.path.isdir(TRAIN_SPLIT) else TRAIN_DIR_OFF
    VAL_DIR   = VAL_SPLIT   if os.path.isdir(VAL_SPLIT)   else PUBLIC_DIR_OFF if os.path.isdir(PUBLIC_DIR_OFF) else TRAIN_DIR  # (val==train only if nothing else)
    TEST_DIR  = TEST_DIR_FALLBACK

print("TRAIN_DIR:", TRAIN_DIR)
print("VAL_DIR:  ", VAL_DIR)
print("TEST_DIR: ", TEST_DIR)


Base folder detected: /content/FER-2013/FER-2013
Official splits found: False
TRAIN_DIR: /content/FER-2013/FER-2013/train
VAL_DIR:   /content/FER-2013/FER-2013/train
TEST_DIR:  /content/FER-2013/FER-2013/test


**Block 2 — Quick Sanity Check (counts)**

In [None]:
# Block 2: Sanity Check: class names and image counts
def count_images(root):
    total = 0
    per_class = {}
    if not os.path.isdir(root):
        return 0, {}
    for cls in sorted(os.listdir(root)):
        p = os.path.join(root, cls)
        if os.path.isdir(p):
            n = len([f for f in os.listdir(p) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
            per_class[cls] = n
            total += n
    return total, per_class

for name, path in [("Train", TRAIN_DIR), ("PublicTest/Val", VAL_DIR), ("PrivateTest/Final", TEST_DIR)]:
    tot, pc = count_images(path)
    print(f"{name}: total={tot}, per_class={pc}")


Train: total=28709, per_class={'angry': 3995, 'disgust': 436, 'fear': 4097, 'happy': 7215, 'neutral': 4965, 'sad': 4830, 'surprise': 3171}
PublicTest/Val: total=28709, per_class={'angry': 3995, 'disgust': 436, 'fear': 4097, 'happy': 7215, 'neutral': 4965, 'sad': 4830, 'surprise': 3171}
PrivateTest/Final: total=7178, per_class={'angry': 958, 'disgust': 111, 'fear': 1024, 'happy': 1774, 'neutral': 1233, 'sad': 1247, 'surprise': 831}


**Block 3 — Transforms (CLAHE + Augs + TenCrop)**

In [None]:
# Block 3: Transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# CLAHE -> returns a PIL RGB image
class CLAHE_PIL(object):
    def __init__(self, clip=2.0, grid=(8,8)):
        self.clip = clip; self.grid = grid
    def __call__(self, img: Image.Image):
        g = np.array(img.convert("L"))
        clahe = cv2.createCLAHE(clipLimit=self.clip, tileGridSize=self.grid)
        g = clahe.apply(g)
        rgb = cv2.cvtColor(g, cv2.COLOR_GRAY2RGB)
        return Image.fromarray(rgb)

# Train pipeline: CLAHE + strong augs
train_transform = transforms.Compose([
    CLAHE_PIL(clip=2.0, grid=(8,8)),
    transforms.Resize(56),
    transforms.RandomResizedCrop(48, scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# Eval pipeline: CLAHE + TenCrop(48); stack 10 crops
eval_transform = transforms.Compose([
    CLAHE_PIL(clip=2.0, grid=(8,8)),
    transforms.Resize(56),
    transforms.TenCrop(48),
    transforms.Lambda(lambda crops: torch.stack([
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)(
            transforms.ToTensor()(c)
        ) for c in crops
    ])),
])


**Block 4 — Datasets & Dataloaders**

In [None]:
# Block 4: Datasets & Dataloaders (fixed class mapping)
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

assert os.path.isdir(TRAIN_DIR), f"Missing TRAIN_DIR: {TRAIN_DIR}"
assert os.path.isdir(TEST_DIR),  f"Missing TEST_DIR: {TEST_DIR}"
assert os.path.isdir(VAL_DIR),   f"Missing VAL_DIR: {VAL_DIR}"

# Build datasets
train_ds = datasets.ImageFolder(TRAIN_DIR, transform=train_transform)

# Build eval datasets first (with their own classes)
val_ds   = datasets.ImageFolder(VAL_DIR,   transform=eval_transform)
test_ds  = datasets.ImageFolder(TEST_DIR,  transform=eval_transform)

# ---- FIX: force a consistent class mapping using the TRAIN classes ----
fixed_classes = train_ds.classes
fixed_map = {cls: i for i, cls in enumerate(fixed_classes)}

def remap_dataset_targets(ds, fixed_map):
    # Translate local idx -> global idx using class names
    local_classes = ds.classes
    translate = {i: fixed_map[c] for i, c in enumerate(local_classes) if c in fixed_map}
    # remap targets
    if hasattr(ds, "targets"):
        ds.targets = [translate[t] for t in ds.targets]
    # For older torchvision, samples = [(path, target), ...]
    if hasattr(ds, "samples"):
        ds.samples = [(p, translate[t]) for (p, t) in ds.samples]
    # override classes and class_to_idx to fixed ones (for reports)
    ds.classes = list(fixed_map.keys())
    ds.class_to_idx = dict(fixed_map)
    return ds

val_ds  = remap_dataset_targets(val_ds,  fixed_map)
test_ds = remap_dataset_targets(test_ds, fixed_map)

print("Fixed classes:", fixed_classes)
print("Train n:", len(train_ds), "| Val n:", len(val_ds), "| Test n:", len(test_ds))

# Dataloaders
BATCH_TRAIN = 64
BATCH_EVAL  = 32

train_loader = DataLoader(train_ds, batch_size=BATCH_TRAIN, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_EVAL,  shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_EVAL,  shuffle=False, num_workers=2, pin_memory=True)


Fixed classes: ['angry', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise']
Train n: 28709 | Val n: 28709 | Test n: 7178


**Block 5 — Model (ResNet34) + Loss/Opt/Scheduler**

In [None]:
# Block 5: Model + Loss/Opt/Scheduler
class EmotionResNet34(nn.Module):
    def __init__(self, num_classes=7, dropout=0.5):
        super().__init__()
        self.backbone = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(in_features),
            nn.Dropout(dropout),
            nn.Linear(in_features, 512),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        self._init_weights()
    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    def forward(self, x):
        feat = self.backbone(x)
        out  = self.classifier(feat)
        return out

model = EmotionResNet34().to(device)

# Label smoothing loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing
    def forward(self, pred, target):
        log_probs = F.log_softmax(pred, dim=-1)
        n = pred.size(1)
        true_dist = torch.zeros_like(log_probs)
        true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        true_dist += self.smoothing / n
        return torch.mean(torch.sum(-true_dist * log_probs, dim=-1))

criterion = LabelSmoothingLoss(0.1)
optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
scaler    = torch.cuda.amp.GradScaler()


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 98.4MB/s]
  scaler    = torch.cuda.amp.GradScaler()


**Block 6 — Evaluation Utility (TenCrop averaging)**

In [None]:
# Block 6: Evaluation with TenCrop averaging
@torch.no_grad()
def evaluate_tencrop(model, loader):
    model.eval()
    total, correct = 0, 0
    for images, labels in loader:
        # images: [B, 10, C, H, W]
        bs, ncrops, c, h, w = images.size()
        images = images.view(-1, c, h, w).to(device)
        labels = labels.to(device)

        with torch.autocast(device_type='cuda', enabled=(device.type=='cuda')):
            logits = model(images)              # [B*10, 7]
        logits = logits.view(bs, ncrops, -1).mean(1)  # avg over 10 crops
        preds = logits.argmax(1)

        correct += (preds == labels).sum().item()
        total   += labels.size(0)
    return correct / total


**Block 7 — Stage-1 Training (freeze backbone)**

In [None]:
# Block 7: Stage-1 Training (Freeze backbone)
for p in model.backbone.parameters():
    p.requires_grad = False

best_val_acc = 0.0
best_wts = copy.deepcopy(model.state_dict())
patience = 7
wait = 0
EPOCHS_S1 = 40

for epoch in range(EPOCHS_S1):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type='cuda', enabled=(device.type=='cuda')):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * x.size(0)
        correct      += (logits.argmax(1) == y).sum().item()
        total        += y.size(0)

    scheduler.step()
    train_loss = running_loss / total
    train_acc  = correct / total
    val_acc    = evaluate_tencrop(model, val_loader)

    print(f"[S1][{epoch+1:03d}/{EPOCHS_S1}] loss={train_loss:.4f} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_wts = copy.deepcopy(model.state_dict())
        torch.save(best_wts, "/content/best_fer_model_stage1.pth")
        print("  ↳ Saved best Stage-1 weights.")
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("  ↳ Early stop Stage-1.")
            break

# Load best from Stage-1
model.load_state_dict(best_wts)


KeyboardInterrupt: 

**Block 8 — Stage-2 Fine-Tune (unfreeze backbone)**

In [None]:
# Block 8: Stage-2 Fine-Tuning (Unfreeze)
for p in model.backbone.parameters():
    p.requires_grad = True

# Lower LR for fine-tuning
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

best_val_acc_ft = best_val_acc
best_wts_ft = copy.deepcopy(model.state_dict())
patience_ft = 40
wait = 0
EPOCHS_S2 = 100

for epoch in range(EPOCHS_S2):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.autocast(device_type='cuda', enabled=(device.type=='cuda')):
            logits = model(x)
            loss = criterion(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * x.size(0)
        correct      += (logits.argmax(1) == y).sum().item()
        total        += y.size(0)

    scheduler.step()
    train_loss = running_loss / total
    train_acc  = correct / total
    val_acc    = evaluate_tencrop(model, val_loader)

    print(f"[S2][{epoch+1:03d}/{EPOCHS_S2}] loss={train_loss:.4f} | train_acc={train_acc:.4f} | val_acc={val_acc:.4f}")

    if val_acc > best_val_acc_ft:
        best_val_acc_ft = val_acc
        best_wts_ft = copy.deepcopy(model.state_dict())
        torch.save(best_wts_ft, "/content/best_fer_model_finetuned.pth")
        print("  ↳ Saved best Stage-2 weights.")
        wait = 0
    else:
        wait += 1
        if wait >= patience_ft:
            print("  ↳ Early stop Stage-2.")
            break

# Load best fine-tuned model
model.load_state_dict(best_wts_ft)


**Block 9 — FINAL Test on PrivateTest (with TenCrop TTA)**

In [None]:
# Block 9: FINAL Test (PrivateTest if available)
final_ckpt = "/content/best_fer_model_finetuned.pth"
if not os.path.exists(final_ckpt):
    final_ckpt = "/content/best_fer_model_stage1.pth"
model.load_state_dict(torch.load(final_ckpt, map_location=device))
print("Loaded best checkpoint:", final_ckpt)

final_test_acc = evaluate_tencrop(model, test_loader)
print(f"\n FINAL TEST ACCURACY (TenCrop TTA) on {'PrivateTest' if 'PrivateTest' in TEST_DIR else 'test/'}: {final_test_acc*100:.2f}%")


**Block 10 — Confusion Matrix on Final Test**

In [None]:
# Block 10: Confusion Matrix (optional)
from sklearn.metrics import confusion_matrix, classification_report

@torch.no_grad()
def predict_all(model, loader):
    model.eval(); ys, yh = [], []
    for images, labels in loader:
        bs, ncrops, c, h, w = images.size()
        images = images.view(-1, c, h, w).to(device)
        labels = labels.to(device)
        with torch.autocast(device_type='cuda', enabled=(device.type=='cuda')):
            logits = model(images)
        logits = logits.view(bs, ncrops, -1).mean(1)
        preds = logits.argmax(1)
        ys.append(labels.cpu().numpy())
        yh.append(preds.cpu().numpy())
    return np.concatenate(ys), np.concatenate(yh)

y_true, y_pred = predict_all(model, test_loader)
print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
print("\nClassification Report:\n", classification_report(y_true, y_pred, target_names=train_ds.classes))
