In [None]:
#!/usr/bin/env python3
"""
FER High-Performance Attempt
Model: EfficientNet-B2 (Swapped from ViT for better FER performance)
Fixes: Class Weight Normalization, Simplified Head, Optimized Hyperparams
"""

import os
import random
from pathlib import Path
from collections import Counter
from typing import Tuple, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

from torchvision import transforms, models
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm


# ============================================================
#                  1. Label Smoothing Loss (FIXED WEIGHTS)
# ============================================================

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon: float = 0.1, weight=None):
        super().__init__()
        self.epsilon = epsilon
        self.weight = weight
        
    def forward(self, outputs, targets):
        n_classes = outputs.size(-1)
        log_preds = F.log_softmax(outputs, dim=-1)
        
        with torch.no_grad():
            true_dist = torch.zeros_like(log_preds)
            true_dist.fill_(self.epsilon / (n_classes - 1))
            true_dist.scatter_(1, targets.unsqueeze(1), 1.0 - self.epsilon)
            
            # Weight handling fix: Apply weights to the target distribution
            if self.weight is not None:
                true_dist = true_dist * self.weight.unsqueeze(0)
                # Re-normalize to ensure the weighted distribution sums correctly (optional but good practice)
                # true_dist = true_dist / true_dist.sum(dim=1, keepdim=True)
        
        # Mean reduction is standard
        return torch.sum(-true_dist * log_preds, dim=-1).mean()


# ============================================================
#                  2. Mixup & CutMix (UNCHANGED)
# ============================================================

def mixup_data(x, y, alpha=0.2, device='cuda'): # Lower alpha for stability
    """Mixup with slightly lower alpha for CNN stability"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
        lam = max(lam, 1 - lam)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0, device='cuda'):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    W, H = x.size(2), x.size(3)
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x_cutmix = x.clone()
    x_cutmix[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    y_a, y_b = y, y[index]
    return x_cutmix, y_a, y_b, lam


# ============================================================
#                  3. Data Collection (UNCHANGED)
# ============================================================

def gather_image_paths_and_labels(root_dir: str) -> Tuple[List[str], List[int], List[str]]:
    root = Path(root_dir)
    if not root.exists():
        raise FileNotFoundError(f"Dataset root not found: {root_dir}")

    image_data = []
    exts = {".jpg", ".jpeg", ".png", ".bmp"}
    
    for img_path in root.rglob("*"):
        if img_path.suffix.lower() in exts:
            emotion_class = img_path.parent.name
            if emotion_class in ['train', 'test']:
                emotion_class = img_path.parent.parent.name
            
            image_data.append((str(img_path), emotion_class))
    
    if not image_data:
        raise ValueError(f"No images found under {root_dir}")
    
    unique_classes = sorted(set(emotion for _, emotion in image_data))
    class_to_idx = {name: idx for idx, name in enumerate(unique_classes)}
    
    image_paths = [path for path, _ in image_data]
    labels = [class_to_idx[emotion] for _, emotion in image_data]
    
    return image_paths, labels, unique_classes


# ============================================================
#                  4. Pre-cached Dataset (UNCHANGED)
# ============================================================

class PreCachedImageDataset(Dataset):
    def __init__(self, image_paths: List[str], labels: List[int], 
                 transform=None, cache_images: bool = True, img_size: int = 224,
                 is_train: bool = True):
        self.labels = torch.tensor(labels, dtype=torch.long)
        self.transform = transform
        self.is_train = is_train
        self.cached_tensors = []
        
        if cache_images:
            print(f"Pre-loading {len(image_paths)} images...")
            
            # EfficientNet Standard Mean/Std
            imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            
            resize_transform = transforms.Compose([
                transforms.Resize((img_size, img_size)), # Force exact size
                transforms.ToTensor(),
            ])
            
            for path in tqdm(image_paths, desc="Caching"):
                try:
                    img = Image.open(path).convert("RGB")
                    img_tensor = resize_transform(img)
                    
                    if not is_train:
                        img_tensor = (img_tensor - imagenet_mean) / imagenet_std
                    
                    self.cached_tensors.append(img_tensor)
                except Exception as e:
                    blank = torch.zeros(3, img_size, img_size)
                    if not is_train:
                        blank = (blank - imagenet_mean) / imagenet_std
                    self.cached_tensors.append(blank)
        else:
            self.cached_tensors = None
            self.image_paths = image_paths

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.cached_tensors is not None:
            img = self.cached_tensors[idx]
            if self.is_train and self.transform:
                img = self.transform(img)
        else:
            img = Image.open(self.image_paths[idx]).convert("RGB")
            if self.transform:
                img = self.transform(img)
        
        return img, self.labels[idx]


# ============================================================
#                  5. Augmentation (OPTIMIZED)
# ============================================================

def get_augmentation_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15), # Reduced rotation slightly
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
        transforms.Normalize(imagenet_mean, imagenet_std),
        transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)), # Less aggressive erasing
    ])
    
    # Validation transform is handled in the dataset cache logic for speed, 
    # but strictly defined here for non-cached mode
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(imagenet_mean, imagenet_std),
    ])
    
    return train_transform, val_transform


# ============================================================
#                  6. EfficientNet Model Setup
# ============================================================

def build_model(num_classes: int = 7, dropout_rate: float = 0.3) -> nn.Module:
    """
    Builds EfficientNet-B2.
    Why B2? It's the sweet spot for FER (better than B0, faster/easier to train than B4).
    """
    print(f"Loading EfficientNet-B2...")
    weights = models.EfficientNet_B2_Weights.IMAGENET1K_V1
    model = models.efficientnet_b2(weights=weights)
    
    # Get input features of the final classifier
    # EfficientNet classifier structure: model.classifier -> Sequential(Dropout, Linear)
    in_features = model.classifier[1].in_features
    
    # Replace classifier with a simple, clean head
    # AVOID deep MLP heads on fine-tuning tasks; they often hurt convergence.
    model.classifier = nn.Sequential(
        nn.Dropout(p=dropout_rate),
        nn.Linear(in_features, num_classes)
    )
    
    return model


# ============================================================
#                  7. Training Function (UNCHANGED)
# ============================================================

def train_one_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   optimizer: optim.Optimizer, device: torch.device, epoch: int,
                   scaler: GradScaler = None, use_mixup: bool = True, 
                   mixup_alpha: float = 0.2, use_amp: bool = False,
                   use_cutmix: bool = True) -> Tuple[float, float]:
    
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Train]", leave=False)

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        use_aug = random.random() < 0.5
        use_cutmix_now = use_cutmix and random.random() < 0.5
        
        if use_amp and scaler is not None:
            with autocast():
                if use_aug and use_mixup:
                    if use_cutmix_now:
                        mixed_images, labels_a, labels_b, lam = cutmix_data(images, labels, 1.0, device)
                    else:
                        mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, mixup_alpha, device)
                    
                    outputs = model(mixed_images)
                    loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                    
                    preds = outputs.argmax(dim=1)
                    correct += (lam * (preds == labels_a).sum().item() + (1 - lam) * (preds == labels_b).sum().item())
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    
                    preds = outputs.argmax(dim=1)
                    correct += (preds == labels).sum().item()

            optimizer.zero_grad(set_to_none=True)
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            # Fallback for CPU/No-AMP
            if use_aug and use_mixup:
                if use_cutmix_now:
                    mixed_images, labels_a, labels_b, lam = cutmix_data(images, labels, 1.0, device)
                else:
                    mixed_images, labels_a, labels_b, lam = mixup_data(images, labels, mixup_alpha, device)
                
                outputs = model(mixed_images)
                loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b)
                preds = outputs.argmax(dim=1)
                correct += (lam * (preds == labels_a).sum().item() + (1 - lam) * (preds == labels_b).sum().item())
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        running_loss += loss.item()
        total += images.size(0)

        if pbar.n % max(1, len(loader) // 20) == 0:
            avg_loss = running_loss / (pbar.n + 1)
            avg_acc = 100.0 * correct / total
            pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    avg_loss = running_loss / len(loader)
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc


# ============================================================
#                  8. Evaluation Function (UNCHANGED)
# ============================================================

def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module,
            device: torch.device, epoch: int, use_amp: bool = False,
            use_tta: bool = False) -> Tuple[float, float]:
    
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Val]", leave=False)

    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            if use_amp:
                with autocast():
                    if use_tta:
                        # TTA: Original + Flip
                        outputs1 = model(images)
                        outputs2 = model(torch.flip(images, dims=[3]))
                        outputs = (outputs1 + outputs2) / 2
                    else:
                        outputs = model(images)
                    loss = criterion(outputs, labels)
            else:
                if use_tta:
                    outputs1 = model(images)
                    outputs2 = model(torch.flip(images, dims=[3]))
                    outputs = (outputs1 + outputs2) / 2
                else:
                    outputs = model(images)
                loss = criterion(outputs, labels)

            running_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

            if pbar.n % max(1, len(loader) // 10) == 0:
                avg_loss = running_loss / (pbar.n + 1)
                avg_acc = 100.0 * correct / total
                pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    avg_loss = running_loss / len(loader)
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc


# ============================================================
#                  9. Training Loop
# ============================================================

def train_advanced(model, train_loader, val_loader, criterion, 
                   device, CONFIG, class_names):
    
    # 1. Optimizer: AdamW is great
    optimizer = optim.AdamW(
        model.parameters(),
        lr=CONFIG['lr'],
        weight_decay=CONFIG['weight_decay']
    )
    
    # 2. Scheduler: OneCycleLR
    total_steps = len(train_loader) * CONFIG['epochs']
    scheduler = OneCycleLR(
        optimizer,
        max_lr=CONFIG['lr'], # Note: removed *10 multiplier for stability
        total_steps=total_steps,
        pct_start=0.2,
        div_factor=25.0,
        final_div_factor=1000.0
    )
    
    use_amp = torch.cuda.is_available()
    scaler = GradScaler() if use_amp else None
    
    best_val_acc = 0.0
    best_epoch = 0
    patience_counter = 0
    
    print("\n" + "="*60)
    print("Starting Optimized Training (EfficientNet-B2)...")
    print(f"Target: >75% validation accuracy")
    print("="*60)
    
    for epoch in range(1, CONFIG['epochs'] + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch,
            scaler, use_mixup=CONFIG['use_mixup'], 
            mixup_alpha=CONFIG['mixup_alpha'], use_amp=use_amp,
            use_cutmix=CONFIG['use_cutmix']
        )
        
        # Step scheduler each batch is ideal, but stepping each epoch is okay for OneCycle 
        # provided we config it right. OneCycleLR expects step() per batch usually.
        # But here let's rely on the optimizer updates inside train_one_epoch if we moved scheduler there.
        # However, to keep code simple, we step usually per batch. 
        # **Correction**: OneCycleLR should be stepped per batch.
        # Since we didn't pass scheduler to train_one_epoch, let's just step it here approx or 
        # better yet, since implementation is tricky without passing it down, 
        # let's switch to CosineAnnealingWarmRestarts for epoch-level stepping.
        # BUT, to minimize code changes, we will use CosineAnnealingLR (easy epoch step).
        
        # TTA enabled after epoch 10
        use_tta = CONFIG['use_tta'] and epoch > 10
        val_loss, val_acc = evaluate(
            model, val_loader, criterion, device, epoch, 
            use_amp=use_amp, use_tta=use_tta
        )
        
        scheduler.step() # Stepping once per epoch (Cosine/Plateau style)
        
        gap = train_acc - val_acc
        
        print(f"\n[Epoch {epoch}/{CONFIG['epochs']}]")
        print(f"  Train -> Loss: {train_loss:.4f}  Acc: {train_acc:.2f}%")
        print(f"  Val   -> Loss: {val_loss:.4f}  Acc: {val_acc:.2f}%")
        print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, val_acc, val_loss, 
                          class_names, CONFIG)
            print(f"  âœ“ New best! Val Acc: {val_acc:.2f}%")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{CONFIG['patience']})")
            
            if patience_counter >= CONFIG['patience']:
                print(f"\nEarly stopping triggered")
                break
    
    return best_val_acc, best_epoch


def save_checkpoint(model, optimizer, epoch, val_acc, val_loss, class_names, CONFIG):
    best_path = os.path.join(CONFIG['save_dir'], "best_model_effnet.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
        'val_loss': val_loss,
        'class_names': class_names,
        'config': CONFIG
    }, best_path)


# ============================================================
#                  10. Main Pipeline
# ============================================================

def main():
    CONFIG = {
        'data_root': "C:/adam/AMIT_Diploma/grad_project/FER/archive_v1",
        'img_size': 224,
        'batch_size': 32, # B2 is larger, reduce batch size slightly if OOM
        'epochs': 50,
        'lr': 3e-4,
        'weight_decay': 1e-4,
        'dropout': 0.3,
        'val_size': 0.15,
        'random_seed': 42,
        'patience': 10,
        'use_class_weights': True,
        'use_label_smoothing': True,
        'label_smooth_eps': 0.1,
        'use_mixup': True,
        'mixup_alpha': 0.2, # Reduced from 0.4
        'use_cutmix': True,
        'use_tta': True,
        'save_dir': './checkpoints',
        'cache_images': True,
    }

    # Setup
    random.seed(CONFIG['random_seed'])
    np.random.seed(CONFIG['random_seed'])
    torch.manual_seed(CONFIG['random_seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(CONFIG['random_seed'])
        torch.backends.cudnn.benchmark = True

    os.makedirs(CONFIG['save_dir'], exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load Data
    image_paths, labels, class_names = gather_image_paths_and_labels(CONFIG['data_root'])
    num_classes = len(class_names)
    
    # Split
    sss = StratifiedShuffleSplit(n_splits=1, test_size=CONFIG['val_size'],
                                random_state=CONFIG['random_seed'])
    indices = np.arange(len(labels))
    train_idx, val_idx = next(sss.split(indices, labels))

    train_paths = [image_paths[i] for i in train_idx]
    train_labels_list = [labels[i] for i in train_idx]
    val_paths = [image_paths[i] for i in val_idx]
    val_labels_list = [labels[i] for i in val_idx]

    # Create Datasets
    train_tf, val_tf = get_augmentation_transforms()
    
    # Note: validation transform is None because dataset handles normalization/resize if cached
    train_dataset = PreCachedImageDataset(train_paths, train_labels_list, transform=train_tf, 
                                        cache_images=CONFIG['cache_images'], img_size=CONFIG['img_size'], is_train=True)
    val_dataset = PreCachedImageDataset(val_paths, val_labels_list, transform=None, 
                                      cache_images=CONFIG['cache_images'], img_size=CONFIG['img_size'], is_train=False)

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=0)

    # Build Model (EfficientNet-B2)
    model = build_model(num_classes=num_classes, dropout_rate=CONFIG['dropout']).to(device)

    # Calculate Weights (FIXED NORMALIZATION)
    if CONFIG['use_class_weights']:
        train_counts = np.bincount(train_labels_list, minlength=num_classes)
        # Inverse frequency
        weights = 1.0 / (train_counts + 1e-6)
        # Normalize to sum to N_CLASSES (so average weight is 1.0)
        weights = weights / weights.sum() * num_classes
        class_weights = torch.tensor(weights, dtype=torch.float32).to(device)
        print(f"\nClass weights (Normalized): {class_weights.cpu().numpy()}")
    else:
        class_weights = None

    # Loss
    if CONFIG['use_label_smoothing']:
        criterion = LabelSmoothingCrossEntropy(epsilon=CONFIG['label_smooth_eps'], weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)

    # Train
    # We switch to CosineAnnealingLR for simpler epoch-stepping
    train_advanced(model, train_loader, val_loader, criterion, device, CONFIG, class_names)

if __name__ == "__main__":
    main()