In [3]:
#!/usr/bin/env python3
"""
FER with Vision Transformer Backbone
Modified from 5th attempt - ONLY model architecture changed
Everything else (augmentation, training, loss, etc.) remains identical
"""

import os
import random
from pathlib import Path
from collections import Counter
from typing import Tuple, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

from torchvision import transforms, models
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm


# ============================================================
#                  1. Label Smoothing Loss (UNCHANGED)
# ============================================================

class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing prevents overconfident predictions."""
    def __init__(self, epsilon: float = 0.1, weight=None):
        super().__init__()
        self.epsilon = epsilon
        self.weight = weight
        
    def forward(self, outputs, targets):
        n_classes = outputs.size(-1)
        log_preds = F.log_softmax(outputs, dim=-1)
        
        with torch.no_grad():
            true_dist = torch.zeros_like(log_preds)
            true_dist.fill_(self.epsilon / (n_classes - 1))
            true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - self.epsilon)
            
            if self.weight is not None:
                true_dist = true_dist * self.weight.unsqueeze(0)
        
        loss = torch.sum(-true_dist * log_preds, dim=-1)
        return loss.mean()


# ============================================================
#                  2. Mixup & CutMix (UNCHANGED)
# ============================================================

def mixup_data(x, y, alpha=0.4, device='cuda'):
    """Enhanced mixup for better generalization"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
        lam = max(lam, 1 - lam)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0, device='cuda'):
    """CutMix augmentation for facial features."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    W, H = x.size(2), x.size(3)
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x_cutmix = x.clone()
    x_cutmix[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    
    y_a, y_b = y, y[index]
    return x_cutmix, y_a, y_b, lam


# ============================================================
#                  3. Data Collection (UNCHANGED)
# ============================================================

def gather_image_paths_and_labels(root_dir: str) -> Tuple[List[str], List[int], List[str]]:
    """Collect all images from emotion class subdirectories."""
    root = Path(root_dir)
    if not root.exists():
        raise FileNotFoundError(f"Dataset root not found: {root_dir}")

    image_data = []
    exts = {".jpg", ".jpeg", ".png", ".bmp"}
    
    for img_path in root.rglob("*"):
        if img_path.suffix.lower() in exts:
            emotion_class = img_path.parent.name
            if emotion_class in ['train', 'test']:
                emotion_class = img_path.parent.parent.name
            
            image_data.append((str(img_path), emotion_class))
    
    if not image_data:
        raise ValueError(f"No images found under {root_dir}")
    
    unique_classes = sorted(set(emotion for _, emotion in image_data))
    class_to_idx = {name: idx for idx, name in enumerate(unique_classes)}
    
    image_paths = [path for path, _ in image_data]
    labels = [class_to_idx[emotion] for _, emotion in image_data]
    
    return image_paths, labels, unique_classes


# ============================================================
#                  4. Pre-cached Dataset (UNCHANGED)
# ============================================================

class PreCachedImageDataset(Dataset):
    """Ultra-fast dataset with pre-cached tensors."""
    def __init__(self, image_paths: List[str], labels: List[int], 
                 transform=None, cache_images: bool = True, img_size: int = 224,
                 is_train: bool = True):
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform
        self.is_train = is_train
        self.cached_tensors = []
        
        if cache_images:
            print(f"Pre-loading {len(image_paths)} images...")
            
            imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            
            resize_transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ])
            
            for path in tqdm(image_paths, desc="Caching"):
                try:
                    img = Image.open(path).convert("RGB")
                    img_tensor = resize_transform(img)
                    
                    if not is_train:
                        img_tensor = (img_tensor - imagenet_mean) / imagenet_std
                    
                    self.cached_tensors.append(img_tensor)
                except Exception as e:
                    blank = torch.zeros(3, img_size, img_size)
                    if not is_train:
                        blank = (blank - imagenet_mean) / imagenet_std
                    self.cached_tensors.append(blank)
        else:
            self.cached_tensors = None
            self.image_paths = image_paths

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.cached_tensors is not None:
            img = self.cached_tensors[idx]
            if self.is_train and self.transform:
                img = self.transform(img)
        else:
            img = Image.open(self.image_paths[idx]).convert("RGB")
            if self.transform:
                img = self.transform(img)
        
        return img, self.labels[idx]


# ============================================================
#                  5. Augmentation Transforms (UNCHANGED)
# ============================================================

def get_augmentation_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
    """Strong augmentation strategy for better generalization."""
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15)),
        transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
        transforms.Normalize(imagenet_mean, imagenet_std),
        transforms.RandomErasing(p=0.3, scale=(0.02, 0.2)),
    ])
    
    val_transform = transforms.Compose([
        transforms.Normalize(imagenet_mean, imagenet_std),
    ])
    
    return train_transform, val_transform


# ============================================================
#                  6. MODIFIED: ViT Model with Same Classifier Head
# ============================================================

def build_model(model_name: str = 'vit_b_16', num_classes: int = 7, 
                dropout_rate: float = 0.5) -> nn.Module:
    """
    Build Vision Transformer model with custom classifier head.
    
    CHANGES FROM ORIGINAL:
    - Replaced ResNet-34 backbone with Vision Transformer (ViT-B/16)
    - Kept EXACT SAME 3-layer classifier head architecture
    - All other parameters (dropout, batch norm, layer sizes) identical
    
    Args:
        model_name: ViT variant ('vit_b_16', 'vit_b_32', 'vit_l_16')
        num_classes: Number of emotion classes (7)
        dropout_rate: Dropout probability (0.5, same as ResNet version)
    
    Returns:
        ViT model with custom 3-layer MLP classifier
    """
    # Load pre-trained Vision Transformer
    if model_name == 'vit_b_16':
        weights = models.ViT_B_16_Weights.IMAGENET1K_V1
        model = models.vit_b_16(weights=weights)
        hidden_dim = 768  # ViT-B/16 output dimension
    elif model_name == 'vit_b_32':
        weights = models.ViT_B_32_Weights.IMAGENET1K_V1
        model = models.vit_b_32(weights=weights)
        hidden_dim = 768  # ViT-B/32 output dimension
    elif model_name == 'vit_l_16':
        weights = models.ViT_L_16_Weights.IMAGENET1K_V1
        model = models.vit_l_16(weights=weights)
        hidden_dim = 1024  # ViT-L/16 output dimension
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # Replace classifier head with IDENTICAL architecture to ResNet version
    # Original ResNet: 512 → 1024 → 512 → 7
    # ViT version:     768 → 1024 → 512 → 7 (only input dim changes)
    model.heads = nn.Sequential(
        nn.BatchNorm1d(hidden_dim),              # BatchNorm on ViT features
        nn.Dropout(dropout_rate * 0.5),          # 0.25 dropout (same as original)
        nn.Linear(hidden_dim, 1024),             # 768→1024 (ResNet had 512→1024)
        nn.BatchNorm1d(1024),                    # Same as original
        nn.ReLU(inplace=True),                   # Same activation
        nn.Dropout(dropout_rate),                # 0.5 dropout (same as original)
        nn.Linear(1024, 512),                    # Same as original
        nn.BatchNorm1d(512),                     # Same as original
        nn.ReLU(inplace=True),                   # Same activation
        nn.Dropout(dropout_rate),                # 0.5 dropout (same as original)
        nn.Linear(512, num_classes)              # Same final layer
    )

    return model


# ============================================================
#                  7. Training Function (UNCHANGED)
# ============================================================

def train_one_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   optimizer: optim.Optimizer, device: torch.device, epoch: int,
                   scaler: GradScaler = None, use_mixup: bool = True, 
                   mixup_alpha: float = 0.4, use_amp: bool = False,
                   use_cutmix: bool = True) -> Tuple[float, float]:
    """Training with mixup/cutmix alternation."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Train]", leave=False)

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        use_aug = random.random() < 0.5
        use_cutmix_now = use_cutmix and random.random() < 0.5
        
        if use_amp and scaler is not None:
            with autocast():
                if use_aug and use_mixup:
                    if use_cutmix_now:
                        mixed_images, labels_a, labels_b, lam = cutmix_data(images, labels, 1.0, device)
                    else:
                        mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, mixup_alpha, device)
                    
                    outputs = model(mixed_images)
                    loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                    
                    preds = outputs.argmax(dim=1)
                    correct += (lam * (preds == labels_a).sum().item() + (1 - lam) * (preds == labels_b).sum().item())
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    preds = outputs.argmax(dim=1)
                    correct += (preds == labels).sum().item()

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            if use_aug and use_mixup:
                if use_cutmix_now:
                    mixed_images, labels_a, labels_b, lam = cutmix_data(images, labels, 1.0, device)
                else:
                    mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, mixup_alpha, device)
                
                outputs = model(mixed_images)
                loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                
                preds = outputs.argmax(dim=1)
                correct += (lam * (preds == labels_a).sum().item() + (1 - lam) * (preds == labels_b).sum().item())
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        running_loss += loss.item()
        total += images.size(0)

        if pbar.n % max(1, len(loader) // 20) == 0:
            avg_loss = running_loss / (pbar.n + 1)
            avg_acc = 100.0 * correct / total
            pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    avg_loss = running_loss / len(loader)
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc


# ============================================================
#                  8. Evaluation Function (UNCHANGED)
# ============================================================

def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module,
            device: torch.device, epoch: int, use_amp: bool = False,
            use_tta: bool = False) -> Tuple[float, float]:
    """Evaluation with optional TTA."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Val]", leave=False)

    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            if use_amp:
                with autocast():
                    if use_tta:
                        outputs1 = model(images)
                        outputs2 = model(torch.flip(images, dims=[3]))
                        outputs = (outputs1 + outputs2) / 2
                    else:
                        outputs = model(images)
                    loss = criterion(outputs, labels)
            else:
                if use_tta:
                    outputs1 = model(images)
                    outputs2 = model(torch.flip(images, dims=[3]))
                    outputs = (outputs1 + outputs2) / 2
                else:
                    outputs = model(images)
                loss = criterion(outputs, labels)

            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

            if pbar.n % max(1, len(loader) // 10) == 0:
                avg_loss = running_loss / (pbar.n + 1)
                avg_acc = 100.0 * correct / total
                pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    avg_loss = running_loss / len(loader)
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc


# ============================================================
#                  9. Training Loop with OneCycleLR (UNCHANGED)
# ============================================================

def train_advanced(model, train_loader, val_loader, criterion, 
                   device, CONFIG, class_names):
    """Advanced training with OneCycleLR."""
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['lr'],
        weight_decay=CONFIG['weight_decay']
    )
    
    total_steps = len(train_loader) * CONFIG['epochs']
    scheduler = OneCycleLR(
        optimizer,
        max_lr=CONFIG['lr'] * 10,
        total_steps=total_steps,
        pct_start=0.3,
        anneal_strategy='cos',
        div_factor=25.0,
        final_div_factor=10000.0
    )
    
    use_amp = torch.cuda.is_available()
    scaler = GradScaler() if use_amp else None
    
    best_val_acc = 0.0
    best_epoch = 0
    patience_counter = 0
    
    print("\n" + "="*60)
    print("Starting high-accuracy training with ViT backbone...")
    if use_amp:
        print("Mixed Precision: ENABLED")
    else:
        print("Mixed Precision: DISABLED (CPU mode)")
    print(f"Target: 75-80% validation accuracy")
    print("="*60)
    
    for epoch in range(1, CONFIG['epochs'] + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch,
            scaler, use_mixup=CONFIG['use_mixup'], 
            mixup_alpha=CONFIG['mixup_alpha'], use_amp=use_amp,
            use_cutmix=CONFIG['use_cutmix']
        )
        
        use_tta = CONFIG['use_tta'] and epoch > 20
        val_loss, val_acc = evaluate(
            model, val_loader, criterion, device, epoch, 
            use_amp=use_amp, use_tta=use_tta
        )
        
        gap = train_acc - val_acc
        
        print(f"\n[Epoch {epoch}/{CONFIG['epochs']}]")
        print(f"  Train -> Loss: {train_loss:.4f}  Acc: {train_acc:.2f}%")
        print(f"  Val   -> Loss: {val_loss:.4f}  Acc: {val_acc:.2f}%")
        print(f"  Gap: {gap:.2f}% | LR: {optimizer.param_groups[0]['lr']:.2e}")
        if use_tta:
            print(f"  TTA: Enabled")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, val_acc, val_loss, 
                          class_names, CONFIG)
            print(f"  ✓ New best! Val Acc: {val_acc:.2f}% (Gap: {gap:.2f}%)")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{CONFIG['patience']})")
            
            if patience_counter >= CONFIG['patience']:
                print(f"\n{'='*60}")
                print(f"Early stopping triggered")
                print("="*60)
                break
    
    return best_val_acc, best_epoch


def save_checkpoint(model, optimizer, epoch, val_acc, val_loss, class_names, CONFIG):
    """Save model checkpoint"""
    best_path = os.path.join(CONFIG['save_dir'], "best_model_vit.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
        'val_loss': val_loss,
        'class_names': class_names,
        'config': CONFIG
    }, best_path)


# ============================================================
#                  10. Main Pipeline
# ============================================================

def main():
    # ============= Configuration (ONLY model_name changed) =============
    CONFIG = {
        'data_root': "C:/adam/AMIT_Diploma/grad_project/FER/archive_v1",
        'model_name': 'vit_b_16',  # CHANGED: ViT-B/16 instead of resnet34
        'img_size': 224,
        'batch_size': 64,
        'epochs': 60,
        'lr': 3e-4,
        'weight_decay': 2e-4,
        'dropout': 0.5,
        'val_size': 0.15,
        'random_seed': 42,
        'num_workers': 0,
        'patience': 12,
        'use_class_weights': True,
        'use_label_smoothing': True,
        'label_smooth_eps': 0.1,
        'use_mixup': True,
        'mixup_alpha': 0.4,
        'use_cutmix': True,
        'use_tta': True,
        'save_dir': './checkpoints',
        'cache_images': True,
    }

    # Setup
    random.seed(CONFIG['random_seed'])
    np.random.seed(CONFIG['random_seed'])
    torch.manual_seed(CONFIG['random_seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(CONFIG['random_seed'])
        torch.backends.cudnn.benchmark = True

    os.makedirs(CONFIG['save_dir'], exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")

    # Load Data
    print("\n" + "="*60)
    print("Loading dataset...")
    print("="*60)
    
    image_paths, labels, class_names = gather_image_paths_and_labels(CONFIG['data_root'])
    num_classes = len(class_names)
    
    print(f"\nDataset Statistics:")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Classes ({num_classes}): {class_names}")
    
    label_counts = Counter(labels)
    print(f"\n  Class distribution:")
    for cls_idx, cls_name in enumerate(class_names):
        print(f"    {cls_name}: {label_counts[cls_idx]}")

    # Stratified Split
    print(f"\nCreating stratified split (val_size={CONFIG['val_size']})...")
    
    sss = StratifiedShuffleSplit(n_splits=1, test_size=CONFIG['val_size'],
                                random_state=CONFIG['random_seed'])
    indices = np.arange(len(labels))
    train_idx, val_idx = next(sss.split(indices, labels))

    train_paths = [image_paths[i] for i in train_idx]
    train_labels_list = [labels[i] for i in train_idx]
    val_paths = [image_paths[i] for i in val_idx]
    val_labels_list = [labels[i] for i in val_idx]

    print(f"  Train samples: {len(train_paths)}")
    print(f"  Val samples: {len(val_paths)}")

    # Create Datasets
    train_tf, val_tf = get_augmentation_transforms()
    
    train_dataset = PreCachedImageDataset(
        train_paths, train_labels_list, 
        transform=train_tf, 
        cache_images=CONFIG['cache_images'],
        img_size=CONFIG['img_size'],
        is_train=True
    )
    val_dataset = PreCachedImageDataset(
        val_paths, val_labels_list,
        transform=None,
        cache_images=CONFIG['cache_images'],
        img_size=CONFIG['img_size'],
        is_train=False
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False,
    )

    # ============= Build Enhanced Model =============
    print(f"\n{'='*60}")
    print("Building enhanced model...")
    print("="*60)
    
    model = build_model(
        model_name=CONFIG['model_name'],
        num_classes=num_classes,
        dropout_rate=CONFIG['dropout']
    ).to(device)
    
    print(f"  Model: {CONFIG['model_name']}")
    print(f"  Architecture: {CONFIG['model_name'].upper()} + 3-layer classifier")
    print(f"  Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # ============= Loss with Label Smoothing =============
    if CONFIG['use_class_weights']:
        train_counts = np.bincount(train_labels_list, minlength=num_classes)
        class_weights = 1.0 / (train_counts + 1e-6)
        class_weights = class_weights * (len(train_labels_list) / class_weights.sum())
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        print(f"\n  Class weights: {class_weights.cpu().numpy()}")
    else:
        class_weights = None
    
    if CONFIG['use_label_smoothing']:
        criterion = LabelSmoothingCrossEntropy(
            epsilon=CONFIG['label_smooth_eps'], 
            weight=class_weights
        )
        print(f"  Loss: Label Smoothing CE (eps={CONFIG['label_smooth_eps']})")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f"  Loss: Standard Cross Entropy")

    # ============= Training =============
    best_val_acc, best_epoch = train_advanced(
        model, train_loader, val_loader, criterion, device, CONFIG, class_names
    )

    # ============= Final Results =============
    print(f"\n{'='*60}")
    print("Training Complete!")
    print("="*60)
    print(f"  Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
    print(f"  Target Range: 75-80%")
    if best_val_acc >= 75:
        print(f"  ✓ TARGET ACHIEVED!")
    else:
        print(f"  Gap to target: {75 - best_val_acc:.2f}%")
    print(f"  Model saved to: {CONFIG['save_dir']}/best_model.pth")

    # ============= Final Results =============
    print(f"\n{'='*60}")
    print("Training Complete!")
    print("="*60)
    print(f"  Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
    print(f"  Target Range: 75-80%")
    if best_val_acc >= 75:
        print(f"  ✓ TARGET ACHIEVED!")
    else:
        print(f"  Gap to target: {75 - best_val_acc:.2f}%")
    print(f"  Model saved to: {CONFIG['save_dir']}/best_model.pth")

In [4]:
if __name__ == "__main__":
    main()

Using device: cuda
GPU: NVIDIA GeForce RTX 4060 Laptop GPU

Loading dataset...

Dataset Statistics:
  Total images: 35887
  Classes (7): ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']

  Class distribution:
    angry: 4953
    disgusted: 547
    fearful: 5121
    happy: 8989
    neutral: 6198
    sad: 6077
    surprised: 4002

Creating stratified split (val_size=0.15)...
  Train samples: 30503
  Val samples: 5384
Pre-loading 30503 images...


Caching: 100%|██████████| 30503/30503 [06:52<00:00, 74.01it/s] 


Pre-loading 5384 images...


Caching: 100%|██████████| 5384/5384 [01:26<00:00, 61.97it/s] 



Building enhanced model...


  scaler = GradScaler() if use_amp else None


  Model: vit_b_16
  Architecture: VIT_B_16 + 3-layer classifier
  Trainable params: 87,119,111

  Class weights: [ 2114.7783 19146.703   2045.306   1165.3425  1690.0564  1723.7594
  2617.0537]
  Loss: Label Smoothing CE (eps=0.1)

Starting high-accuracy training with ViT backbone...
Mixed Precision: ENABLED
Target: 75-80% validation accuracy


  with autocast():
  with autocast():
                                                                                          


[Epoch 1/60]
  Train -> Loss: 4605.6379  Acc: 8.67%
  Val   -> Loss: 4496.5427  Acc: 10.72%
  Gap: -2.04% | LR: 1.20e-04
  ✓ New best! Val Acc: 10.72% (Gap: -2.04%)


                                                                                             


[Epoch 2/60]
  Train -> Loss: 4503.4270  Acc: 6.37%
  Val   -> Loss: 4387.7775  Acc: 2.84%
  Gap: 3.53% | LR: 1.20e-04
  No improvement (1/12)


                                                                                             


[Epoch 3/60]
  Train -> Loss: 4457.7111  Acc: 4.27%
  Val   -> Loss: 4379.5167  Acc: 2.62%
  Gap: 1.65% | LR: 1.20e-04
  No improvement (2/12)


                                                                                             


[Epoch 4/60]
  Train -> Loss: 4431.1481  Acc: 3.55%
  Val   -> Loss: 4365.7392  Acc: 2.56%
  Gap: 0.99% | LR: 1.20e-04
  No improvement (3/12)


                                                                                             


[Epoch 5/60]
  Train -> Loss: 4417.5211  Acc: 3.52%
  Val   -> Loss: 4357.0949  Acc: 2.17%
  Gap: 1.35% | LR: 1.20e-04
  No improvement (4/12)


                                                                                             


[Epoch 6/60]
  Train -> Loss: 4401.2205  Acc: 3.30%
  Val   -> Loss: 4348.1410  Acc: 2.17%
  Gap: 1.12% | LR: 1.20e-04
  No improvement (5/12)


                                                                                             


[Epoch 7/60]
  Train -> Loss: 4396.5623  Acc: 3.95%
  Val   -> Loss: 4337.5629  Acc: 5.83%
  Gap: -1.89% | LR: 1.20e-04
  No improvement (6/12)


                                                                                             


[Epoch 8/60]
  Train -> Loss: 4382.5975  Acc: 4.70%
  Val   -> Loss: 4314.1113  Acc: 5.57%
  Gap: -0.88% | LR: 1.20e-04
  No improvement (7/12)


                                                                                             


[Epoch 9/60]
  Train -> Loss: 4371.5855  Acc: 5.56%
  Val   -> Loss: 4288.7819  Acc: 5.03%
  Gap: 0.53% | LR: 1.20e-04
  No improvement (8/12)


                                                                                              


[Epoch 10/60]
  Train -> Loss: 4360.6124  Acc: 5.42%
  Val   -> Loss: 4293.1100  Acc: 6.67%
  Gap: -1.25% | LR: 1.20e-04
  No improvement (9/12)


                                                                                              


[Epoch 11/60]
  Train -> Loss: 4347.8442  Acc: 6.41%
  Val   -> Loss: 4279.2679  Acc: 6.32%
  Gap: 0.10% | LR: 1.20e-04
  No improvement (10/12)


                                                                                              


[Epoch 12/60]
  Train -> Loss: 4334.2594  Acc: 6.72%
  Val   -> Loss: 4231.1365  Acc: 8.92%
  Gap: -2.19% | LR: 1.20e-04
  No improvement (11/12)


                                                                                              


[Epoch 13/60]
  Train -> Loss: 4313.6705  Acc: 8.57%
  Val   -> Loss: 4224.6777  Acc: 16.03%
  Gap: -7.46% | LR: 1.20e-04
  ✓ New best! Val Acc: 16.03% (Gap: -7.46%)


                                                                                              


[Epoch 14/60]
  Train -> Loss: 4311.7160  Acc: 8.90%
  Val   -> Loss: 4186.8806  Acc: 12.63%
  Gap: -3.73% | LR: 1.20e-04
  No improvement (1/12)


                                                                                               


[Epoch 15/60]
  Train -> Loss: 4299.1905  Acc: 10.24%
  Val   -> Loss: 4177.6210  Acc: 16.53%
  Gap: -6.29% | LR: 1.20e-04
  ✓ New best! Val Acc: 16.53% (Gap: -6.29%)


                                                                                               


[Epoch 16/60]
  Train -> Loss: 4275.7074  Acc: 11.36%
  Val   -> Loss: 4200.9513  Acc: 16.29%
  Gap: -4.93% | LR: 1.20e-04
  No improvement (1/12)


                                                                                               


[Epoch 17/60]
  Train -> Loss: 4267.6626  Acc: 12.81%
  Val   -> Loss: 4145.5176  Acc: 15.17%
  Gap: -2.36% | LR: 1.20e-04
  No improvement (2/12)


                                                                                               


[Epoch 18/60]
  Train -> Loss: 4253.3246  Acc: 13.99%
  Val   -> Loss: 4113.1862  Acc: 18.52%
  Gap: -4.52% | LR: 1.20e-04
  ✓ New best! Val Acc: 18.52% (Gap: -4.52%)


                                                                                               


[Epoch 19/60]
  Train -> Loss: 4224.6999  Acc: 15.92%
  Val   -> Loss: 4130.6442  Acc: 29.42%
  Gap: -13.50% | LR: 1.20e-04
  ✓ New best! Val Acc: 29.42% (Gap: -13.50%)


                                                                                               


[Epoch 20/60]
  Train -> Loss: 4203.1476  Acc: 18.09%
  Val   -> Loss: 4054.3557  Acc: 24.76%
  Gap: -6.67% | LR: 1.20e-04
  No improvement (1/12)


                                                                                               


[Epoch 21/60]
  Train -> Loss: 4192.2859  Acc: 18.31%
  Val   -> Loss: 4055.3577  Acc: 26.76%
  Gap: -8.46% | LR: 1.20e-04
  TTA: Enabled
  No improvement (2/12)


                                                                                               


[Epoch 22/60]
  Train -> Loss: 4165.4546  Acc: 20.52%
  Val   -> Loss: 3980.4496  Acc: 30.81%
  Gap: -10.29% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 30.81% (Gap: -10.29%)


                                                                                               


[Epoch 23/60]
  Train -> Loss: 4161.5373  Acc: 21.00%
  Val   -> Loss: 3938.8116  Acc: 27.90%
  Gap: -6.90% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                               


[Epoch 24/60]
  Train -> Loss: 4125.8332  Acc: 22.81%
  Val   -> Loss: 3850.7500  Acc: 31.91%
  Gap: -9.10% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 31.91% (Gap: -9.10%)


                                                                                               


[Epoch 25/60]
  Train -> Loss: 4126.4679  Acc: 23.96%
  Val   -> Loss: 3870.8091  Acc: 25.48%
  Gap: -1.52% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                               


[Epoch 26/60]
  Train -> Loss: 4117.1874  Acc: 23.85%
  Val   -> Loss: 3884.8939  Acc: 33.66%
  Gap: -9.80% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 33.66% (Gap: -9.80%)


                                                                                               


[Epoch 27/60]
  Train -> Loss: 4103.3701  Acc: 24.69%
  Val   -> Loss: 3856.3130  Acc: 27.88%
  Gap: -3.19% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                               


[Epoch 28/60]
  Train -> Loss: 4082.9649  Acc: 25.93%
  Val   -> Loss: 3829.5598  Acc: 31.80%
  Gap: -5.86% | LR: 1.20e-04
  TTA: Enabled
  No improvement (2/12)


                                                                                               


[Epoch 29/60]
  Train -> Loss: 4053.1219  Acc: 27.95%
  Val   -> Loss: 3794.0193  Acc: 38.63%
  Gap: -10.68% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 38.63% (Gap: -10.68%)


                                                                                               


[Epoch 30/60]
  Train -> Loss: 4045.8313  Acc: 27.74%
  Val   -> Loss: 3783.8509  Acc: 36.00%
  Gap: -8.26% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                                       


[Epoch 31/60]
  Train -> Loss: 4021.7997  Acc: 29.57%
  Val   -> Loss: 3733.9746  Acc: 38.61%
  Gap: -9.04% | LR: 1.20e-04
  TTA: Enabled
  No improvement (2/12)


                                                                                               


[Epoch 32/60]
  Train -> Loss: 4018.5041  Acc: 29.52%
  Val   -> Loss: 3719.6648  Acc: 33.62%
  Gap: -4.10% | LR: 1.20e-04
  TTA: Enabled
  No improvement (3/12)


                                                                                               


[Epoch 33/60]
  Train -> Loss: 4014.0926  Acc: 30.03%
  Val   -> Loss: 3693.6085  Acc: 40.55%
  Gap: -10.52% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 40.55% (Gap: -10.52%)


                                                                                               


[Epoch 34/60]
  Train -> Loss: 4012.9070  Acc: 30.01%
  Val   -> Loss: 3788.8727  Acc: 36.33%
  Gap: -6.32% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                               


[Epoch 35/60]
  Train -> Loss: 3995.5719  Acc: 30.09%
  Val   -> Loss: 3737.4357  Acc: 39.60%
  Gap: -9.50% | LR: 1.20e-04
  TTA: Enabled
  No improvement (2/12)


                                                                                               


[Epoch 36/60]
  Train -> Loss: 4018.7474  Acc: 30.04%
  Val   -> Loss: 3657.9405  Acc: 41.81%
  Gap: -11.77% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 41.81% (Gap: -11.77%)


                                                                                               


[Epoch 37/60]
  Train -> Loss: 3978.4998  Acc: 31.00%
  Val   -> Loss: 3671.2495  Acc: 41.16%
  Gap: -10.16% | LR: 1.20e-04
  TTA: Enabled
  No improvement (1/12)


                                                                                               


[Epoch 38/60]
  Train -> Loss: 3946.1577  Acc: 32.80%
  Val   -> Loss: 3618.1123  Acc: 46.94%
  Gap: -14.14% | LR: 1.20e-04
  TTA: Enabled
  ✓ New best! Val Acc: 46.94% (Gap: -14.14%)


                                                                                             

KeyboardInterrupt: 