In [None]:
import os
import random
import string
from pathlib import Path

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# --- Configuration ---
class CFG:
    TRAIN_PATH = "/kaggle/input/aslamerican-sign-language-aplhabet-dataset/ASL_Alphabet_Dataset/asl_alphabet_train"
    LABELS     = list(string.ascii_uppercase) + ["del", "nothing", "space"]
    NUM_CLASSES= len(LABELS)
    IMG_SIZE   = 224
    BATCH_SIZE = 96 # 64
    EPOCHS     = 60 # 25
    LR         = 1e-4
    WEIGHT_DECAY = 5e-4 # 1e-4
    SEED       = 42
    PATIENCE   = 7  # 5 early stopping

    @staticmethod
    def seed_everything():
        random.seed(CFG.SEED)
        os.environ["PYTHONHASHSEED"] = str(CFG.SEED)
        np.random.seed(CFG.SEED)
        torch.manual_seed(CFG.SEED)
        torch.cuda.manual_seed_all(CFG.SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# --- Dataset ---
class LibrasDataset(Dataset):
    def __init__(self, split='train', transform=None, val_ratio=0.2):
        super().__init__()
        self.transform = transform
        samples = []
        for idx, label in enumerate(CFG.LABELS):
            label_dir = os.path.join(CFG.TRAIN_PATH, label)
            if not os.path.isdir(label_dir):
                continue
            for fname in os.listdir(label_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    samples.append((os.path.join(label_dir, fname), idx))
        random.shuffle(samples)
        split_idx = int(len(samples) * (1 - val_ratio))
        self.data = samples[:split_idx] if split == 'train' else samples[split_idx:]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

# --- Model Definition ---
class ASLNetVGG(nn.Module):
    def __init__(self, feature_dim=512, freeze_vgg=True):
        super().__init__()
        vgg = models.vgg16(pretrained=True)
        for m in vgg.features.modules():
            if isinstance(m, nn.ReLU):
                m.inplace = False
        self.vgg_feats = vgg.features
        if freeze_vgg:
            for p in self.vgg_feats.parameters():
                p.requires_grad = False
        self.pool1 = nn.AdaptiveAvgPool2d((1,1))
        self.pool2 = vgg.avgpool
        orig_cls = list(vgg.classifier.children())[:-1]
        self.asl_feats = nn.Sequential(*orig_cls)
        self.proj = nn.Linear(512 + 4096, feature_dim)
        self.act  = nn.ReLU()
        self.classifier = nn.Linear(feature_dim, CFG.NUM_CLASSES)

    def forward(self, x):
        feats = self.vgg_feats(x)
        f1 = self.pool1(feats)
        f1 = torch.flatten(f1,1)
        f2 = self.pool2(feats)
        f2 = torch.flatten(f2,1)
        f2 = self.asl_feats(f2)
        f = torch.cat([f1, f2], dim=1)
        feat = self.act(self.proj(f))
        return self.classifier(feat)

    @property
    def conv5_3(self):
        return self.vgg_feats[28]

# --- Utilities ---
def plot_confusion_matrix(preds, labels, plot_dir):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(15,12))
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues',
                xticklabels=CFG.LABELS, yticklabels=CFG.LABELS)
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(plot_dir / "confusion_matrix.png")
    plt.close()

# --- Training/Evaluation Loops ---
def train_epoch(model, loader, optimizer, criterion, device, scaler=None):
    model.train()
    loss_accum, correct, total = 0.0, 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(scaler is not None)):
            logits = model(imgs)
            loss = criterion(logits, labels)
        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        loss_accum += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return loss_accum / len(loader), correct / total

@torch.no_grad()
def eval_epoch(model, loader, criterion, device):
    model.eval()
    loss_accum, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss_accum += loss.item()
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    return loss_accum / len(loader), correct / total, all_preds, all_labels

# --- Main Script ---
def main():
    CFG.seed_everything()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(CFG.IMG_SIZE, scale=(0.8,1.0)),
        transforms.RandomRotation(15),
        transforms.ColorJitter(0.2,0.2,0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
    ])
    transform_val = transforms.Compose([
        transforms.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
    ])

    train_ds = LibrasDataset('train', transform_train)
    val_ds   = LibrasDataset('val',   transform_val)

    # sampler for balanced classes
    counts = np.bincount([label for _, label in train_ds.data])
    weights = 1.0 / counts
    sample_weights = [weights[label] for _, label in train_ds.data]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

    train_dl = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE,
                          sampler=sampler, num_workers=4)
    val_dl   = DataLoader(val_ds,   batch_size=CFG.BATCH_SIZE,
                          shuffle=False, num_workers=4)

    model = ASLNetVGG(feature_dim=512, freeze_vgg=True).to(device)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=CFG.LR,
                            weight_decay=CFG.WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    scaler = torch.amp.GradScaler()

    plot_dir = Path("./plots")
    plot_dir.mkdir(exist_ok=True)

    best_val = float('inf')
    epochs_no_improve = 0

    for epoch in range(1, CFG.EPOCHS + 1):
        tr_loss, tr_acc = train_epoch(model, train_dl, optimizer,
                                      criterion, device, scaler)
        vl_loss, vl_acc, vl_preds, vl_labels = eval_epoch(
            model, val_dl, criterion, device)

        print(f"[Epoch {epoch:02d}/{CFG.EPOCHS}] "
              f"Train L: {tr_loss:.4f}, A: {tr_acc:.4%} | "
              f"Val L: {vl_loss:.4f}, A: {vl_acc:.4%}")

        # conf matrix
        plot_confusion_matrix(vl_preds, vl_labels, plot_dir)

        # scheduler step
        scheduler.step(vl_loss)

        # early stopping & best model
        if vl_loss < best_val:
            best_val = vl_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), "filter_static_best.pt")
            print("👉 Best static model saved")
            # progressive unfreeze
            if epoch == 5:
                for p in model.vgg_feats[24:].parameters():
                    p.requires_grad = True
                print("🚀 Unfroze last VGG blocks for fine-tuning")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= CFG.PATIENCE:
                print(f"✋ Early stopping at epoch {epoch}. No improvement in {CFG.PATIENCE} epochs.")
                break

    torch.save(model.state_dict(), "filter_static_final.pt")
    print("👉 Static final model saved")

if __name__ == '__main__':
    main()
