In [33]:
## All needed imports
import os, math, random, json, time, copy
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler

import timm
from torchvision import datasets, transforms
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.model_selection import StratifiedKFold
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rich import print

SEED = 10
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

<torch._C.Generator at 0x7ca9ff8b34f0>

In [51]:
from contextlib import nullcontext

use_cuda = (cfg.device == 'cuda')

try:
    # New API (PyTorch 2.x)
    amp_autocast = (lambda: torch.amp.autocast('cuda')) if use_cuda else nullcontext
    scaler = torch.amp.GradScaler('cuda') if use_cuda else None
except Exception:
    # Fallback to old API if on older PyTorch
    amp_autocast = (lambda: torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_cuda)) if use_cuda else nullcontext
    scaler = torch.cuda.amp.GradScaler(enabled=use_cuda)

In [34]:
map_to_class = {
    "joy": "Positive",
    "anger": "NegativeActive",
    "fear": "NegativeActive",
    "surprise": "NegativeActive",
    "sadness": "NegativePassive",
    "Natural": "Positive",
}

@dataclass
class CFG:
    data_root: str = "/kaggle/input/autistic-children-emotions-dr-fatma-m-talaat/Autistic Children Emotions - Dr. Fatma M. Talaat/Train"                
    subjects_csv: str | None = None 
    img_size: int = 224
    batch_size: int = 32
    num_workers: int = 4
    val_split: float = 0.2    
    epochs: int = 40
    patience: int = 7          
    lr_head: float = 3e-4
    lr_adapters: float = 3e-5
    weight_decay: float = 1e-4
    label_smoothing: float = 0.05
    use_focal: bool = False
    focal_gamma: float = 1.5
    use_weighted_sampler: bool = True
    mixup_alpha: float = 0.0 
    cutmix_alpha: float = 0.0
    freeze_backbone: bool = True
    use_lora_adapters: bool = False 
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    save_dir: str = "artifacts_vit_trackA"
    
cfg = CFG()
Path(cfg.save_dir).mkdir(parents=True, exist_ok=True)
print(cfg)

In [35]:
train_tfms = transforms.Compose([
    transforms.Resize((cfg.img_size, cfg.img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=5),
    transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.1, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
])

val_tfms = transforms.Compose([
    transforms.Resize((cfg.img_size, cfg.img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
])

# Load folders
orig_ds = datasets.ImageFolder(cfg.data_root, transform=train_tfms)
orig_classes = orig_ds.classes
print("Original classes:", orig_classes)

In [36]:
missing = set(orig_classes) - set(map_to_class.keys())
if missing:
    raise ValueError(f"Mapping missing classes: {missing}")

super_classes = sorted(set(map_to_class.values()))
super_to_idx = {c:i for i,c in enumerate(super_classes)}
print("Superclasses:", super_classes)

remapped_samples = []
for path, orig_idx in orig_ds.samples:
    orig_name = orig_ds.classes[orig_idx]
    super_name = map_to_class[orig_name]
    super_idx = super_to_idx[super_name]
    remapped_samples.append((path, super_idx))

In [37]:
class RemappedImageFolder(datasets.ImageFolder):
    def __init__(self, base_ds, remapped_samples, transform):
        self.root = base_ds.root
        self.loader = base_ds.loader
        self.extensions = base_ds.extensions
        self.transform = transform
        self.target_transform = None
        self.samples = remapped_samples
        self.targets = [t for _,t in remapped_samples]
        self.classes = super_classes
        self.class_to_idx = super_to_idx

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

In [38]:
full_ds = RemappedImageFolder(orig_ds, remapped_samples, transform=train_tfms)
class_names = full_ds.classes
num_classes = len(class_names)
print(f"n={len(full_ds.samples)} | classes={class_names}")

In [39]:
all_labels = np.array(full_ds.targets)
idxs = np.arange(len(all_labels))

skf = StratifiedKFold(n_splits=int(1/cfg.val_split), shuffle=True, random_state=SEED)
tr, va = next(iter(skf.split(idxs, all_labels)))
train_idx, val_idx = tr, va

print(f"Train n={len(train_idx)} | Val n={len(val_idx)}")

In [40]:
train_ds = copy.copy(full_ds)
val_ds   = copy.copy(full_ds)
train_ds.samples = [full_ds.samples[i] for i in train_idx]
val_ds.samples   = [full_ds.samples[i] for i in val_idx]
train_ds.targets = [t for _,t in train_ds.samples]
val_ds.targets   = [t for _,t in val_ds.samples]
train_ds.transform = train_tfms
val_ds.transform   = val_tfms

In [41]:
train_counts = Counter(train_ds.targets)
print("Train class counts (superclasses):", {class_names[k]: int(v) for k,v in train_counts.items()})

In [42]:
def effective_num_weights(counts, beta=0.9999):
    counts = np.array(counts, dtype=np.float64)
    eff_num = 1.0 - np.power(beta, counts)
    weights = (1.0 - beta) / np.clip(eff_num, 1e-8, None)
    weights = weights / weights.sum() * len(counts)
    return torch.tensor(weights, dtype=torch.float)

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=1.5, reduction='mean', label_smoothing=0.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.ls = label_smoothing
    def forward(self, logits, target):
        ce = F.cross_entropy(logits, target, weight=self.alpha, label_smoothing=self.ls, reduction='none')
        pt = torch.exp(-ce)
        loss = (1-pt) ** self.gamma * ce
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss

class_counts = np.bincount(train_ds.targets, minlength=num_classes)
class_weights = effective_num_weights(class_counts)
print("Class weights:", {class_names[i]: float(w) for i,w in enumerate(class_weights)})

if cfg.use_weighted_sampler:
    per_sample_w = class_weights[train_ds.targets]
    train_sampler = WeightedRandomSampler(weights=per_sample_w.double(), num_samples=len(per_sample_w), replacement=True)
else:
    train_sampler = None

In [43]:
train_loader = DataLoader(
    train_ds, batch_size=cfg.batch_size, shuffle=(train_sampler is None),
    sampler=train_sampler, num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)
val_loader = DataLoader(
    val_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True
)

In [44]:
backbone = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
emb_dim = backbone.num_features
print("Embedding dim:", emb_dim)

In [45]:
for p in backbone.parameters():
    p.requires_grad = not cfg.freeze_backbone

In [46]:
def add_last_block_adapters(model, rank=8, scale=1.0):
    last_blk = model.blocks[-1]
    attn = last_blk.attn
    d = attn.proj.out_features
    adapter = nn.Sequential(nn.Linear(d, rank, bias=False), nn.Linear(rank, d, bias=False))
    for p in adapter.parameters(): p.requires_grad = True
    nn.init.zeros_(adapter[1].weight)
    orig_proj = attn.proj
    class ProjWithAdapter(nn.Module):
        def __init__(self, proj, adapter, scale): super().__init__(); self.proj=proj; self.adapter=adapter; self.scale=scale
        def forward(self, x): y = self.proj(x); return y + self.scale * self.adapter(y)
    attn.proj = ProjWithAdapter(orig_proj, adapter, scale)
    return model

if cfg.use_lora_adapters:
    backbone = add_last_block_adapters(backbone, rank=8, scale=1.0)

head = nn.Sequential(
    nn.Linear(emb_dim, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.2),
    nn.Linear(512, num_classes)
)

model = nn.Sequential(backbone, head).to(cfg.device)
print(model)

In [47]:
def param_groups(m):
    head_params = list(head.parameters())
    adapter_params = []
    if cfg.use_lora_adapters:
        for n,p in backbone.named_parameters():
            if p.requires_grad:
                adapter_params.append(p)
    return [
        {"params": head_params, "lr": cfg.lr_head},
        {"params": adapter_params, "lr": cfg.lr_adapters}
    ]

optimizer = torch.optim.AdamW(param_groups(model), lr=cfg.lr_head, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)

criterion = FocalLoss(alpha=class_weights.to(cfg.device), gamma=cfg.focal_gamma, label_smoothing=cfg.label_smoothing) \
            if cfg.use_focal else \
            nn.CrossEntropyLoss(weight=class_weights.to(cfg.device), label_smoothing=cfg.label_smoothing)

In [48]:
def mixup_data(x, y, alpha):
    if alpha <= 0: return x, y, y, 1.0
    lam = float(np.random.beta(alpha, alpha))
    idx = torch.randperm(x.size(0), device=x.device)
    return lam * x + (1-lam) * x[idx], y, y[idx], lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1-lam) * criterion(pred, y_b)

In [None]:
def run_epoch(loader, train=True):
    model.train(train)
    losses = []
    logits_all, y_all = [], []

    for xb, yb in loader:
        xb, yb = xb.to(cfg.device, non_blocking=True), yb.to(cfg.device, non_blocking=True)

        if train:
            optimizer.zero_grad(set_to_none=True)

        use_mixup = train and (cfg.mixup_alpha > 0)
        if use_mixup:
            xb, y_a, y_b, lam = mixup_data(xb, yb, cfg.mixup_alpha)

        with amp_autocast():
            logits = model(xb)
            if use_mixup:
                loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
            else:
                loss = criterion(logits, yb)

        if train:
            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

        losses.append(loss.item())
        logits_all.append(logits.detach().float().cpu())
        y_all.append((y_a if use_mixup else yb).detach().cpu() if False else yb.detach().cpu())

    logits_all = torch.cat(logits_all)
    y_all = torch.cat(y_all)
    preds = logits_all.argmax(1).numpy()
    macro_f1 = f1_score(y_all.numpy(), preds, average='macro')
    return np.mean(losses), macro_f1, logits_all, y_all

best_f1, best_state, patience = -1, None, cfg.patience
history = {"train_f1":[], "val_f1":[], "train_loss":[], "val_loss":[]}

for epoch in range(1, cfg.epochs+1):
    tr_loss, tr_f1, _, _ = run_epoch(train_loader, train=True)
    val_loss, val_f1, val_logits, val_targets = run_epoch(val_loader, train=False)
    scheduler.step(val_f1)

    history["train_loss"].append(tr_loss); history["val_loss"].append(val_loss)
    history["train_f1"].append(tr_f1);     history["val_f1"].append(val_f1)
    print(f"[{epoch:03d}] train_loss={tr_loss:.4f} train_f1={tr_f1:.4f} | val_loss={val_loss:.4f} val_f1={val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        best_state = {
            "model": copy.deepcopy(model.state_dict()),
            "class_names": class_names,
            "cfg": vars(cfg),
            "history": history,
            "super_map": map_to_class
        }
        patience = cfg.patience
    else:
        patience -= 1
        if patience == 0:
            print("[red]Early stopping[/red]")
            break

if best_state is not None:
    model.load_state_dict(best_state["model"])

In [53]:
model.eval()
logits_all, y_all = [], []
with torch.no_grad():
    for xb, yb in val_loader:
        xb = xb.to(cfg.device)
        logits_all.append(model(xb).cpu())
        y_all.append(yb)
logits_all = torch.cat(logits_all)
y_all = torch.cat(y_all)
preds = logits_all.argmax(1).numpy()
y_true = y_all.numpy()

print("\nClassification report (val, superclasses):")
print(classification_report(y_true, preds, target_names=class_names, digits=4))

cm = confusion_matrix(y_true, preds, labels=list(range(num_classes)))
cm_df = pd.DataFrame(cm, index=[f"true_{c}" for c in class_names],
                        columns=[f"pred_{c}" for c in class_names])
display(cm_df.style.background_gradient(cmap="Blues"))

Unnamed: 0,pred_NegativeActive,pred_NegativePassive,pred_Positive
true_NegativeActive,10,20,2
true_NegativePassive,16,23,1
true_Positive,6,15,59


In [None]:
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1)*1.0)
    def forward(self, x): return self.model(x) / self.temperature.clamp(min=0.05, max=10)
    def set_temperature(self, logits, labels, max_iter=50, lr=0.01):
        self.to(cfg.device)
        nll = nn.CrossEntropyLoss()
        opt = torch.optim.LBFGS([self.temperature], lr=lr, max_iter=max_iter)
        logits = logits.to(cfg.device); labels = labels.to(cfg.device)
        def _eval():
            opt.zero_grad()
            loss = nll(logits / self.temperature, labels)
            loss.backward()
            return loss
        opt.step(_eval); return self

cal_model = ModelWithTemperature(model)
cal_model.set_temperature(torch.tensor(logits_all.numpy()), torch.tensor(y_true))
print("Calibrated temperature:", float(cal_model.temperature.item()))

In [54]:
torch.save(best_state, Path(cfg.save_dir)/"best_state.pth")
torch.save(model.state_dict(), Path(cfg.save_dir)/"best_model_weights.pth")
with open(Path(cfg.save_dir)/"label_map.json","w") as f:
    json.dump({i:c for i,c in enumerate(class_names)}, f, indent=2)
with open(Path(cfg.save_dir)/"superclass_mapping.json","w") as f:
    json.dump(map_to_class, f, indent=2)
print(f"Saved to {cfg.save_dir}")

In [None]:
from PIL import Image
infer_tfms = val_tfms

@torch.no_grad()
def predict_image(path: str, topk=3, return_probs=True):
    img = Image.open(path).convert('RGB')
    x = infer_tfms(img).unsqueeze(0).to(cfg.device)
    logits = cal_model(x)
    probs = F.softmax(logits, dim=1).squeeze(0).cpu().numpy()
    topk_idx = probs.argsort()[::-1][:topk]
    result = [(class_names[i], float(probs[i])) for i in topk_idx]
    return (result, probs) if return_probs else result