## 1. üì¶ Imports and GPU Verification

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import os
import time
from collections import defaultdict
import random
from PIL import Image
import matplotlib.pyplot as plt

# GPU Verification
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Import albumentations for advanced augmentation
try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    HAS_ALBUMENTATIONS = True
    print("‚úÖ Albumentations available for advanced augmentation")
except ImportError:
    HAS_ALBUMENTATIONS = False
    print("‚ö†Ô∏è Install albumentations: pip install albumentations")

## 2. ‚öôÔ∏è Hyperparameter Configuration

All training parameters are centralized here to facilitate experimentation.

In [None]:
class Config:
    """MAXIMUM Configuration - Best possible result (epoch 1 slow, rest fast)"""

    # === DATA ===
    DATASET_ROOT = './kaggle/input/balanced-affectnet'

    # === MODEL ===
    NUM_CLASSES = 8
    IN_CHANNELS = 3
    INPUT_SIZE = 75
    USE_SE_BLOCKS = True

    # === TRAINING ===
    BATCH_SIZE = 1536     # ‚ö° Good GPU compromise (uses ~10-11GB)
    ACCUMULATION_STEPS = 1
    LEARNING_RATE = 0.0015  # Adjusted for batch 1536
    WEIGHT_DECAY = 1e-4
    EPOCHS = 100
    PATIENCE = 20

    # === ADVANCED TECHNIQUES ===
    USE_MIXUP = True
    MIXUP_ALPHA = 0.2
    USE_CUTMIX = False
    CUTMIX_ALPHA = 1.0
    CUTMIX_PROB = 0.0

    USE_LABEL_SMOOTHING = True
    LABEL_SMOOTHING = 0.1

    USE_FOCAL_LOSS = False
    FOCAL_GAMMA = 2.0

    # === AUGMENTATION ===
    USE_ADVANCED_AUG = True
    USE_CLAHE = False
    USE_GRID_DISTORTION = False

    # === CLASS BALANCING ===
    USE_OVERSAMPLING = False
    MAX_CLASS_WEIGHT = 3.0

    # === MAXIMUM GPU OPTIMIZATION ===
    USE_AMP = True                    # ‚úÖ Mixed Precision
    USE_COMPILE = True                # ‚úÖ torch.compile
    COMPILE_MODE = 'max-autotune'     # ‚ö° MAXIMUM: epoch 1 slow (~2-3min) but rest very fast
    NUM_WORKERS = 2                   # Optimal
    PREFETCH_FACTOR = 4
    PERSISTENT_WORKERS = True

    # === SWA ===
    USE_SWA = False
    SWA_START_EPOCH = 75
    SWA_LR = 0.0001

    # === DEVICE ===
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # === SAVING ===
    SAVE_PATH = 'emotion_model_best.pth'

config = Config()

# ‚ö° MAXIMUM CUDA Optimizations
if torch.cuda.is_available():
    # Performance
    torch.backends.cudnn.benchmark = True          # Auto-tune kernels
    torch.backends.cuda.matmul.allow_tf32 = True   # TensorFloat-32 (2x faster)
    torch.backends.cudnn.allow_tf32 = True         # TF32 for cuDNN
    torch.backends.cudnn.deterministic = False     # Non-deterministic = faster
    torch.set_float32_matmul_precision('high')     # Optimized Tensor Cores

    # ‚ö° NEW: Memory optimizations for larger batches
    torch.cuda.set_per_process_memory_fraction(0.95)  # Uses 95% of VRAM

    gpu_mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total VRAM: {gpu_mem_total:.1f} GB")
    print(f"‚ö° Mode: max-autotune (epoch 1 slow, rest very fast)")

print(f"\n{'='*60}")
print("üìã MAXIMUM CONFIGURATION (Quality + Speed)")
print(f"{'='*60}")
print(f"Device: {config.DEVICE}")
print(f"‚ö° Batch size: {config.BATCH_SIZE}")
print(f"‚ö° Learning rate: {config.LEARNING_RATE}")
print(f"‚ö° torch.compile: {config.COMPILE_MODE}")
print(f"‚ö° Mixed Precision: {config.USE_AMP}")
print(f"‚ö° TF32: Enabled")
print(f"‚ö†Ô∏è Epoch 1: ~2-3 min (compilation)")
print(f"‚úÖ Epochs 2+: ~20-25s (very fast)")
print(f"{'='*60}")

## 3. üìâ Loss Functions

### Focal Loss
Useful for imbalanced datasets - reduces the importance of easy examples.

### Label Smoothing Cross Entropy
Prevents the model from being too confident about predictions.

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss to handle class imbalance."""
    def __init__(self, gamma=2.0, alpha=None, reduction='mean', label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        if self.label_smoothing > 0:
            n_classes = inputs.size(-1)
            targets_smooth = torch.zeros_like(inputs)
            targets_smooth.fill_(self.label_smoothing / (n_classes - 1))
            targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.label_smoothing)

            log_probs = F.log_softmax(inputs, dim=-1)
            ce_loss = -(targets_smooth * log_probs).sum(dim=-1)
        else:
            ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        probs = torch.softmax(inputs, dim=-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal_weight = (1 - pt) ** self.gamma

        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            focal_weight = focal_weight * alpha_t

        loss = focal_weight * ce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class LabelSmoothingCrossEntropy(nn.Module):
    """Cross Entropy with label smoothing."""
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, inputs, targets):
        n_classes = inputs.size(-1)
        log_probs = F.log_softmax(inputs, dim=-1)

        targets_smooth = torch.zeros_like(log_probs)
        targets_smooth.fill_(self.smoothing / (n_classes - 1))
        targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)

        loss = -(targets_smooth * log_probs).sum(dim=-1)
        return loss.mean()

print("‚úÖ Loss functions defined")

## 4. üîÄ Mixup & CutMix

Augmentation techniques that mix images to improve generalization.

In [None]:
def mixup_data(x, y, alpha=0.2):
    """Mixup: mixes two samples."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0):
    """CutMix: cuts and pastes patches between samples."""
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    _, _, H, W = x.shape
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))

    return x, y, y[index], lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Calculates the mixed loss for mixup/cutmix."""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

print("‚úÖ Mixup and CutMix functions defined")

## 5. üñºÔ∏è Transformations and Data Augmentation

Uses Albumentations for advanced augmentations (rotation, noise, blur, etc.)

In [None]:
def get_train_transforms():
    """Transformations for training - BALANCED VERSION (neither too much nor too little)."""
    if HAS_ALBUMENTATIONS and config.USE_ADVANCED_AUG:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Affine(
                translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
                scale=(0.95, 1.05),   # Moderate
                rotate=(-10, 10),     # Moderate (not 15 which is too much)
                p=0.4
            ),
            # NO CLAHE or GridDistortion (too aggressive on 75x75)
            A.OneOf([
                A.GaussNoise(std_range=(0.02, 0.08), p=1),
                A.GaussianBlur(blur_limit=(3, 5), p=1),
            ], p=0.2),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=1),
                A.RandomGamma(gamma_limit=(85, 115), p=1),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=15, p=1),
            ], p=0.4),
            A.CoarseDropout(
                num_holes_range=(1, 2),
                hole_height_range=(4, 8),
                hole_width_range=(4, 8),
                fill=0,
                p=0.2
            ),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        # Fallback to torchvision
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


def get_val_transforms():
    """Transformations for validation (just normalization)."""
    if HAS_ALBUMENTATIONS:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

print("‚úÖ Transformations defined (balanced version)")

## 6. üìÅ AffectNet Dataset

In [None]:
from torch.utils.data import Dataset, WeightedRandomSampler

class BalancedAffectNetDataset(Dataset):
    """
    Dataset for Balanced AffectNet.

    Expected structure:
    data/
        train/Anger/, Contempt/, Disgust/, Fear/, Happy/, Neutral/, Sad/, Surprise/
        val/...
        test/...
    """

    NUM_CLASSES = 8

    EMOTION_CLASSES = {
        'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3,
        'Sad': 4, 'Surprise': 5, 'Neutral': 6, 'Contempt': 7,
    }

    IDX_TO_EMOTION = {v: k for k, v in EMOTION_CLASSES.items()}

    def __init__(self, root_dir='./kaggle/input/balanced-affectnet', split='train', transform=None, use_albumentations=False):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.use_albumentations = use_albumentations

        self.images = []
        self.labels = []

        split_dir = os.path.join(root_dir, split)

        if not os.path.exists(split_dir):
            raise FileNotFoundError(
                f"Dataset not found: {split_dir}\n"
                f"Download from: https://www.kaggle.com/datasets/dollyprajapati182/balanced-affectnet"
            )

        # Load all images
        for emotion_name, emotion_idx in self.EMOTION_CLASSES.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                print(f"‚ö†Ô∏è {emotion_dir} not found, ignored...")
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    self.images.append(os.path.join(emotion_dir, img_name))
                    self.labels.append(emotion_idx)

        print(f"üìÇ Loaded {len(self.images)} images from AffectNet {split}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        image = Image.open(img_path).convert('RGB')
        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


def get_class_weights(dataset, max_weight=5.0):
    """Calculates weights to balance classes."""
    counts = dataset.get_class_distribution()
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    weights = weights / weights.sum() * len(weights)
    weights = np.clip(weights, 0.3, max_weight)
    weights = weights / weights.sum() * len(weights)

    print("\nüìä Class weights:")
    for i, (count, weight) in enumerate(zip(counts, weights)):
        emotion = BalancedAffectNetDataset.IDX_TO_EMOTION.get(i, f"Class_{i}")
        print(f"    {emotion:10s}: {count:5d} samples, weight: {weight:.3f}")

    return torch.FloatTensor(weights)


def get_balanced_sampler(dataset):
    """Creates a balanced sampler for training."""
    labels = dataset.get_labels()
    counts = np.bincount(labels, minlength=BalancedAffectNetDataset.NUM_CLASSES)
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    sample_weights = weights[labels]

    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

print("‚úÖ Dataset Classes defined")

## 6bis. üìÅ FER2013 / FER2013+ Dataset

FER2013 is a classic emotion recognition dataset with ~35k images in 48x48 grayscale.
FER2013+ is a version with labels corrected and improved by Microsoft.

In [None]:
class FER2013Dataset(Dataset):
    """
    FER2013 Dataset with grayscale -> RGB conversion.

    Kaggle Structure (msambare/fer2013):
        train/angry/, disgust/, fear/, happy/, neutral/, sad/, surprise/
        test/...

    Images are 48x48 grayscale, automatically:
    - Resized to target_size (75x75 by default)
    - Converted to RGB (for model compatibility)
    """

    NUM_CLASSES = 7  # No Contempt in FER2013

    EMOTION_CLASSES = {
        'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3,
        'sad': 4, 'surprise': 5, 'neutral': 6
    }

    IDX_TO_EMOTION = {v: k.capitalize() for k, v in EMOTION_CLASSES.items()}

    def __init__(self, root_dir, split='train', transform=None,
                 use_albumentations=False, target_size=75):
        self.root_dir = root_dir
        self.split = 'train' if split == 'train' else 'test'  # FER2013 only has train/test
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size

        self.images = []
        self.labels = []

        split_dir = os.path.join(root_dir, self.split)

        if not os.path.exists(split_dir):
            raise FileNotFoundError(f"FER2013 not found: {split_dir}")

        for emotion_name, emotion_idx in self.EMOTION_CLASSES.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.images.append(os.path.join(emotion_dir, img_name))
                    self.labels.append(emotion_idx)

        print(f"üìÇ FER2013 {self.split}: {len(self.images)} images loaded")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load image (can be grayscale)
        image = Image.open(img_path)

        # Convert to RGB if grayscale
        if image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Resize to target_size (48x48 -> 75x75)
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


# ===============================================================================
# üì¶ FER+ (FER2013 with labels corrected by Microsoft)
# ===============================================================================

def download_ferplus_labels(dest_folder):
    """
    Downloads fer2013new.csv from Microsoft FERPlus repo.

    Returns:
        str: Path to the downloaded file
    """
    import urllib.request

    url = "https://raw.githubusercontent.com/microsoft/FERPlus/master/fer2013new.csv"
    dest_path = os.path.join(dest_folder, "fer2013new.csv")

    if os.path.exists(dest_path):
        print(f"  ‚úì fer2013new.csv already present")
        return dest_path

    print(f"  üì• Downloading fer2013new.csv from GitHub...")
    try:
        urllib.request.urlretrieve(url, dest_path)
        print(f"  ‚úì Downloaded: {dest_path}")
        return dest_path
    except Exception as e:
        print(f"  ‚úó Error: {e}")
        return None


def generate_ferplus_images(fer2013_csv_path, ferplus_csv_path, output_folder):
    """
    Generates PNG images from fer2013.csv with FER+ labels.

    Output structure:
        output_folder/
            FER2013Train/
                fer0000000.png
                ...
            FER2013Valid/
            FER2013Test/

    Returns:
        bool: True if success
    """
    import csv

    # Create folders
    for split in ['FER2013Train', 'FER2013Valid', 'FER2013Test']:
        os.makedirs(os.path.join(output_folder, split), exist_ok=True)

    # Mapping usage -> folder
    usage_to_folder = {
        'Training': 'FER2013Train',
        'PublicTest': 'FER2013Valid',
        'PrivateTest': 'FER2013Test'
    }

    # Read fer2013.csv and generate images
    print(f"  üñºÔ∏è Generating images from fer2013.csv...")

    with open(fer2013_csv_path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)  # emotion,pixels,Usage

        for idx, row in enumerate(reader):
            if len(row) < 3:
                continue

            emotion = row[0]
            pixels = row[1]
            usage = row[2]

            # Convert pixels to image
            pixel_values = [int(p) for p in pixels.split()]
            img_array = np.array(pixel_values, dtype=np.uint8).reshape(48, 48)
            img = Image.fromarray(img_array, mode='L')

            # Save
            folder = usage_to_folder.get(usage, 'FER2013Train')
            img_name = f"fer{idx:08d}.png"
            img_path = os.path.join(output_folder, folder, img_name)
            img.save(img_path)

            if idx % 5000 == 0:
                print(f"    Progress: {idx} images...")

    print(f"  ‚úì {idx + 1} images generated")
    return True


def setup_ferplus_dataset(fer2013_kaggle_path, output_folder=None):
    """
    Configures the complete FER+ dataset:
    1. Downloads fer2013new.csv from GitHub
    2. Finds/downloads fer2013.csv from Kaggle
    3. Generates PNG images

    Args:
        fer2013_kaggle_path: Path to FER2013 Kaggle dataset (msambare/fer2013)
        output_folder: Output folder (optional)

    Returns:
        str: Path to the ready-to-use FER+ dataset
    """
    if output_folder is None:
        output_folder = os.path.join(os.path.dirname(fer2013_kaggle_path), 'ferplus_generated')

    os.makedirs(output_folder, exist_ok=True)

    # Check if already generated
    train_folder = os.path.join(output_folder, 'FER2013Train')
    if os.path.exists(train_folder) and len(os.listdir(train_folder)) > 1000:
        print(f"  ‚úì FER+ already generated in {output_folder}")
        # Download labels anyway if not present
        download_ferplus_labels(output_folder)
        return output_folder

    print("\nüîß Configuring FER+ (first use)...")

    # 1. Download FER+ labels from GitHub
    ferplus_csv = download_ferplus_labels(output_folder)
    if ferplus_csv is None:
        return None

    # 2. Find fer2013.csv
    # The Kaggle dataset msambare/fer2013 is in folder format, not CSV
    # We must use the original dataset: deadskull7/fer2013
    fer2013_csv = os.path.join(fer2013_kaggle_path, 'fer2013.csv')

    if not os.path.exists(fer2013_csv):
        # Search in other possible locations
        for alt_path in [
            os.path.join(fer2013_kaggle_path, 'fer2013', 'fer2013.csv'),
            os.path.join(fer2013_kaggle_path, 'data', 'fer2013.csv'),
        ]:
            if os.path.exists(alt_path):
                fer2013_csv = alt_path
                break

    if not os.path.exists(fer2013_csv):
        print(f"  ‚ö†Ô∏è fer2013.csv not found. FER+ requires the original CSV dataset.")
        print(f"     The Kaggle dataset 'msambare/fer2013' is in image format.")
        print(f"     For FER+, use 'deadskull7/fer2013' which contains the CSV.")
        return None

    # 3. Generate images
    success = generate_ferplus_images(fer2013_csv, ferplus_csv, output_folder)

    if success:
        return output_folder
    return None


class FERPlusDataset(Dataset):
    """
    FER2013+ (FER+) Dataset with labels corrected by Microsoft.

    FER+ improves FER2013 with:
    - Labels voted by 10 annotators (more reliable)
    - 8 classes (addition of Contempt)
    - Possibility to use vote probabilities

    The dataset is automatically configured from:
    - fer2013new.csv (labels) from Microsoft GitHub
    - fer2013.csv (images) from Kaggle
    """

    NUM_CLASSES = 8

    # FER+ CSV columns: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF
    FERPLUS_EMOTIONS = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt']

    # Mapping FER+ order -> Unified order (AffectNet)
    # FER+: neutral(0), happiness(1), surprise(2), sadness(3), anger(4), disgust(5), fear(6), contempt(7)
    # Unified: Anger(0), Disgust(1), Fear(2), Happy(3), Sad(4), Surprise(5), Neutral(6), Contempt(7)
    FERPLUS_TO_UNIFIED = {
        0: 6,  # neutral -> Neutral
        1: 3,  # happiness -> Happy
        2: 5,  # surprise -> Surprise
        3: 4,  # sadness -> Sad
        4: 0,  # anger -> Anger
        5: 1,  # disgust -> Disgust
        6: 2,  # fear -> Fear
        7: 7,  # contempt -> Contempt
    }

    IDX_TO_EMOTION = {
        0: 'Anger', 1: 'Disgust', 2: 'Fear', 3: 'Happy',
        4: 'Sad', 5: 'Surprise', 6: 'Neutral', 7: 'Contempt'
    }

    def __init__(self, root_dir, split='train', transform=None,
                 use_albumentations=False, target_size=75,
                 label_mode='majority', min_votes=1):
        """
        Args:
            root_dir: Path to FER+ dataset (with FER2013Train/, etc.)
            split: 'train', 'val' or 'test'
            label_mode: 'majority' (most voted label) or 'probability' (distribution)
            min_votes: Minimum number of votes to include an image
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size
        self.label_mode = label_mode
        self.min_votes = min_votes

        self.images = []
        self.labels = []
        self.vote_distributions = []  # For probability mode

        # Mapping split -> folder
        split_to_folder = {
            'train': 'FER2013Train',
            'val': 'FER2013Valid',
            'test': 'FER2013Test'
        }

        folder_name = split_to_folder.get(split, 'FER2013Train')
        split_dir = os.path.join(root_dir, folder_name)

        if not os.path.exists(split_dir):
            print(f"‚ö†Ô∏è FER+ {split} not found: {split_dir}")
            return

        # Load labels from fer2013new.csv
        ferplus_csv = os.path.join(root_dir, 'fer2013new.csv')
        if not os.path.exists(ferplus_csv):
            print(f"‚ö†Ô∏è fer2013new.csv not found in {root_dir}")
            return

        self._load_data(split_dir, ferplus_csv, split)

        print(f"üìÇ FER+ {split}: {len(self.images)} images loaded (mode: {label_mode})")

    def _load_data(self, split_dir, ferplus_csv, split):
        """Loads images and labels."""
        import csv

        # Mapping usage in CSV
        usage_mapping = {
            'train': 'Training',
            'val': 'PublicTest',
            'test': 'PrivateTest'
        }
        target_usage = usage_mapping.get(split, 'Training')

        with open(ferplus_csv, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)  # Skip header

            for idx, row in enumerate(reader):
                if len(row) < 10:
                    continue

                usage = row[0]

                # Filter by split
                if usage != target_usage:
                    continue

                # Votes for each emotion (columns 1-8)
                # Format: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF
                try:
                    votes = [int(v) if v.strip().isdigit() else 0 for v in row[2:10]]
                except:
                    continue

                total_votes = sum(votes)

                # Ignore if not enough valid votes or if it is "unknown" / "NF"
                if total_votes < self.min_votes:
                    continue

                # Image path
                img_name = f"fer{idx:08d}.png"
                img_path = os.path.join(split_dir, img_name)

                if not os.path.exists(img_path):
                    continue

                # Calculate label
                ferplus_label = np.argmax(votes)
                unified_label = self.FERPLUS_TO_UNIFIED[ferplus_label]

                self.images.append(img_path)
                self.labels.append(unified_label)

                # Store distribution for probability mode
                if self.label_mode == 'probability':
                    vote_dist = np.array(votes, dtype=np.float32)
                    vote_dist = vote_dist / vote_dist.sum()  # Normalize
                    # Reorder according to unified order
                    unified_dist = np.zeros(8, dtype=np.float32)
                    for ferplus_idx, unified_idx in self.FERPLUS_TO_UNIFIED.items():
                        unified_dist[unified_idx] = vote_dist[ferplus_idx]
                    self.vote_distributions.append(unified_dist)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(img_path)

        # Convert to RGB
        if image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Resize
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        # Return according to mode
        if self.label_mode == 'probability' and len(self.vote_distributions) > idx:
            return image, label, torch.tensor(self.vote_distributions[idx])

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


print("‚úÖ FER2013 and FER+ Datasets defined")

## 6ter. üîÄ Combined Multi-Source Dataset

This dataset combines AffectNet, FER2013 and/or FER+ by unifying classes into 8 emotions.

In [None]:
class CombinedEmotionDataset(Dataset):
    """
    Dataset combining multiple sources with unified class mapping.

    Combines AffectNet, FER2013 and FER+ with:
    - Automatic resizing to target_size
    - Automatic grayscale -> RGB conversion
    - Unified mapping to 8 classes (AffectNet order)

    Unified classes:
        0: Anger, 1: Disgust, 2: Fear, 3: Happy,
        4: Sad, 5: Surprise, 6: Neutral, 7: Contempt
    """

    # 8 unified classes (AffectNet order)
    UNIFIED_CLASSES = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
    NUM_CLASSES = 8
    IDX_TO_EMOTION = {i: c for i, c in enumerate(UNIFIED_CLASSES)}

    def __init__(self, datasets_config, split='train', transform=None,
                 use_albumentations=False, target_size=75):
        """
        Args:
            datasets_config: dict {dataset_name: root_path}
                Example: {'affectnet': '/path/to/affectnet', 'fer2013': '/path/to/fer2013'}
            split: 'train', 'val', or 'test'
            target_size: uniform output size (75 by default)
        """
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size

        self.images = []  # List of dicts: {'path': str, 'is_grayscale': bool}
        self.labels = []
        self.sources = []  # For tracking/debug

        total_by_source = {}

        for dataset_name, root_dir in datasets_config.items():
            if root_dir is None or not os.path.exists(root_dir):
                print(f"‚ö†Ô∏è {dataset_name} ignored (not found): {root_dir}")
                continue

            count_before = len(self.images)

            if dataset_name == 'affectnet':
                self._load_affectnet(root_dir, split)
            elif dataset_name == 'fer2013':
                self._load_fer2013(root_dir, split)
            elif dataset_name == 'ferplus':
                self._load_ferplus(root_dir, split)
            else:
                print(f"‚ö†Ô∏è Unknown dataset: {dataset_name}")
                continue

            total_by_source[dataset_name] = len(self.images) - count_before

        print(f"\n{'='*50}")
        print(f"üìä COMBINED DATASET ({split})")
        print(f"{'='*50}")
        for src, count in total_by_source.items():
            print(f"  {src:15s}: {count:6d} images")
        print(f"  {'TOTAL':15s}: {len(self.images):6d} images")
        print(f"{'='*50}")
        self._print_class_distribution()

    def _load_affectnet(self, root_dir, split):
        """Loads AffectNet images."""
        affectnet_mapping = {
            'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3,
            'Sad': 4, 'Surprise': 5, 'Neutral': 6, 'Contempt': 7
        }

        split_dir = os.path.join(root_dir, split)
        if not os.path.exists(split_dir):
            return

        for emotion_name, unified_idx in affectnet_mapping.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    self.images.append({
                        'path': os.path.join(emotion_dir, img_name),
                        'is_grayscale': False
                    })
                    self.labels.append(unified_idx)
                    self.sources.append('affectnet')

    def _load_fer2013(self, root_dir, split):
        """Loads FER2013 images."""
        # FER2013 only has train/test, no val
        fer_split = 'train' if split == 'train' else 'test'
        split_dir = os.path.join(root_dir, fer_split)

        if not os.path.exists(split_dir):
            return

        # Mapping FER2013 (7 classes) -> unified (8 classes)
        fer_to_unified = {
            'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3,
            'sad': 4, 'surprise': 5, 'neutral': 6
        }

        for emotion_name, unified_idx in fer_to_unified.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.images.append({
                        'path': os.path.join(emotion_dir, img_name),
                        'is_grayscale': True  # FER2013 is grayscale
                    })
                    self.labels.append(unified_idx)
                    self.sources.append('fer2013')

    def _load_ferplus(self, root_dir, split):
        """Loads FER+ images (with labels corrected by Microsoft)."""
        import csv

        # Mapping FER+ order -> Unified order
        ferplus_to_unified = {
            0: 6,  # neutral -> Neutral
            1: 3,  # happiness -> Happy
            2: 5,  # surprise -> Surprise
            3: 4,  # sadness -> Sad
            4: 0,  # anger -> Anger
            5: 1,  # disgust -> Disgust
            6: 2,  # fear -> Fear
            7: 7,  # contempt -> Contempt
        }

        # Mapping split -> folder and usage
        split_mapping = {
            'train': ('FER2013Train', 'Training'),
            'val': ('FER2013Valid', 'PublicTest'),
            'test': ('FER2013Test', 'PrivateTest')
        }

        folder_name, target_usage = split_mapping.get(split, ('FER2013Train', 'Training'))
        split_dir = os.path.join(root_dir, folder_name)

        if not os.path.exists(split_dir):
            print(f"    ‚ö†Ô∏è FER+ {split} not found: {split_dir}")
            return

        # Find fer2013new.csv
        ferplus_csv = os.path.join(root_dir, 'fer2013new.csv')
        if not os.path.exists(ferplus_csv):
            print(f"    ‚ö†Ô∏è fer2013new.csv not found in {root_dir}")
            return

        # Load data
        with open(ferplus_csv, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)  # Skip header

            for idx, row in enumerate(reader):
                if len(row) < 10:
                    continue

                usage = row[0]

                # Filter by split
                if usage != target_usage:
                    continue

                # Votes for each emotion (columns 1-8)
                try:
                    votes = [int(v.strip()) if v.strip().isdigit() else 0 for v in row[2:10]]
                except:
                    continue

                if sum(votes) == 0:
                    continue

                # Label = emotion with most votes
                ferplus_label = np.argmax(votes)
                unified_label = ferplus_to_unified[ferplus_label]

                # Image path
                img_name = f"fer{idx:08d}.png"
                img_path = os.path.join(split_dir, img_name)

                if os.path.exists(img_path):
                    self.images.append({
                        'path': img_path,
                        'is_grayscale': True
                    })
                    self.labels.append(unified_label)
                    self.sources.append('ferplus')

    def _print_class_distribution(self):
        """Displays class distribution."""
        if len(self.labels) == 0:
            return
        counts = self.get_class_distribution()
        print("\n  Distribution by class:")
        max_count = max(counts) if len(counts) > 0 else 1
        for i, (cls, count) in enumerate(zip(self.UNIFIED_CLASSES, counts)):
            bar_len = int(30 * count / max_count) if max_count > 0 else 0
            bar = '‚ñà' * bar_len
            print(f"    {cls:10s}: {count:6d} {bar}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_info = self.images[idx]
        label = self.labels[idx]

        # Load image
        image = Image.open(img_info['path'])

        # Convert to RGB if necessary
        if img_info['is_grayscale'] or image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Resize to target_size
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        # Apply transformations
        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        if len(self.labels) == 0:
            return np.zeros(self.NUM_CLASSES, dtype=int)
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)

    def get_source_distribution(self):
        """Returns the number of images per source."""
        from collections import Counter
        return Counter(self.sources)


print("‚úÖ CombinedEmotionDataset defined")

## 7. üß† CNN Model Architecture (with SE Blocks)

In [None]:
# ‚úÖ NEW: Squeeze-and-Excitation Block to improve feature attention
class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block - improves feature quality."""
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ConvBlock(nn.Module):
    """Convolutional block with BatchNorm, ReLU and optional SE Block."""
    def __init__(self, in_channels, out_channels, use_se=True, reduction=16):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.se = SEBlock(out_channels, reduction) if use_se else nn.Identity()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.se(x)  # ‚úÖ Attention via SE Block
        x = self.pool(x)
        return x


class FaceEmotionCNN(nn.Module):
    """Improved CNN with SE Blocks for emotion recognition."""
    def __init__(self, num_classes=8, in_channels=3, input_size=75):
        super(FaceEmotionCNN, self).__init__()

        # ‚úÖ Blocks with SE attention
        self.block1 = ConvBlock(in_channels, 32, use_se=True, reduction=8)   # 75 -> 37
        self.block2 = ConvBlock(32, 64, use_se=True, reduction=8)            # 37 -> 18
        self.block3 = ConvBlock(64, 128, use_se=True, reduction=16)          # 18 -> 9
        self.block4 = ConvBlock(128, 256, use_se=True, reduction=16)         # 9 -> 4

        # Classifier with Global Average Pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)  # ‚úÖ Makes the model size-independent
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),  # ‚úÖ Reduced to avoid underfitting
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def create_model(dataset='affectnet', num_classes=8):
    if dataset == 'affectnet':
        return FaceEmotionCNN(num_classes=num_classes, in_channels=config.IN_CHANNELS, input_size=config.INPUT_SIZE)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

# Create and display the model
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES)
total_params = sum(p.numel() for p in model.parameters())
print(f"üß† Model created with SE Blocks: {total_params:,} parameters")

## 8. üîß Training Utilities

In [None]:
class AverageMeter:
    """Tracks average values."""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate(model, val_loader, criterion, device, per_class=False, use_amp=False):
    """Validation with optional per-class metrics and AMP support."""
    model.eval()

    loss_meter = AverageMeter()
    correct = 0
    total = 0

    if per_class:
        class_correct = defaultdict(int)
        class_total = defaultdict(int)

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # ‚ö° Mixed Precision for validation too
            with torch.amp.autocast('cuda', enabled=use_amp):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            loss_meter.update(loss.item(), inputs.size(0))

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if per_class:
                for pred, label in zip(predicted, labels):
                    class_total[label.item()] += 1
                    if pred == label:
                        class_correct[label.item()] += 1

    accuracy = 100.0 * correct / total

    if per_class:
        emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
        print("\n  üìä Accuracy per class:")
        for i, emo in enumerate(emotions):
            if class_total[i] > 0:
                acc = 100.0 * class_correct[i] / class_total[i]
                print(f"    {emo:10s}: {acc:5.1f}% ({class_correct[i]}/{class_total[i]})")

    return loss_meter.avg, accuracy

print("‚úÖ Utilities defined (with AMP support)")

## 9. üìÇ Data Loading

### Dataset Configuration

Choose the datasets you want to use for training. The notebook will automatically download the selected datasets via `kagglehub`.

**Available datasets:**
- `affectnet`: Main dataset (8 classes, ~50k images, RGB)
- `fer2013`: Classic dataset (7 classes, ~35k images, grayscale 48x48)
- `ferplus`: FER2013 with labels corrected by Microsoft (8 classes)

In [None]:
# Install kagglehub
!pip install -q kagglehub
print("‚úÖ kagglehub installed.")

# ===============================================================================
# üéØ DATASET CONFIGURATION
# ===============================================================================
#
# ‚ö†Ô∏è IMPORTANT: FER2013 and FER+ use the SAME images but with different labels!
#    - fer2013: Original labels (7 classes, noisier)
#    - ferplus: Labels corrected by Microsoft (8 classes, 10 annotators)
#
#    ‚Üí Do not use both at the same time (duplicates)!
#    ‚Üí Prefer FER+ for better quality
#

DATASETS_TO_USE = [
    'affectnet',    # ‚úÖ Main dataset (8 classes, ~50k images)
    # 'fer2013',    # ‚ùå Replaced by FER+ (same images, worse labels)
    'ferplus',      # ‚úÖ FER+ with corrected labels (8 classes, ~35k images)
]

# Mode: 'combined' to merge all datasets, 'single' to use only the first one
DATASET_MODE = 'combined' if len(DATASETS_TO_USE) > 1 else 'single'

print(f"\n{'='*60}")
print(f"üìã DATASET CONFIGURATION")
print(f"{'='*60}")
print(f"  Selected datasets: {DATASETS_TO_USE}")
print(f"  Mode: {DATASET_MODE}")
if 'ferplus' in DATASETS_TO_USE:
    print(f"  ‚ÑπÔ∏è FER+ = FER2013 with labels corrected by Microsoft (better quality)")
if 'fer2013' in DATASETS_TO_USE and 'ferplus' in DATASETS_TO_USE:
    print(f"  ‚ö†Ô∏è WARNING: fer2013 and ferplus use the same images!")
print(f"{'='*60}")

In [None]:
import kagglehub

# ===============================================================================
# 
# ===============================================================================

# Kaggle IDs for each dataset
KAGGLE_IDS = {
    'affectnet': 'dollyprajapati182/balanced-affectnet',
    'fer2013': 'msambare/fer2013',              # Folder version (images directly)
    'fer2013_csv': 'deadskull7/fer2013',        # Original CSV version (for FER+)
}

dataset_paths = {}

print("üì• Downloading datasets...\n")

# ===============================================================================
# 1. Download AffectNet
# ===============================================================================
if 'affectnet' in DATASETS_TO_USE:
    print(f"üì¶ [1/3] AffectNet...")
    try:
        path = kagglehub.dataset_download(KAGGLE_IDS['affectnet'])
        dataset_paths['affectnet'] = str(path)
        print(f"  ‚úì Downloaded: {path}")
    except Exception as e:
        print(f"  ‚úó Error: {e}")
        dataset_paths['affectnet'] = None

# ===============================================================================
# 2. Download FER2013 (folder version)
# ===============================================================================
if 'fer2013' in DATASETS_TO_USE:
    print(f"\nüì¶ [2/3] FER2013...")
    try:
        path = kagglehub.dataset_download(KAGGLE_IDS['fer2013'])
        dataset_paths['fer2013'] = str(path)
        print(f"  ‚úì Downloaded: {path}")
    except Exception as e:
        print(f"  ‚úó Error: {e}")
        dataset_paths['fer2013'] = None

# ===============================================================================
# 3. Configure FER+ (download CSV + generate images)
# ===============================================================================
if 'ferplus' in DATASETS_TO_USE:
    print(f"\nüì¶ [3/3] FER+ (FER2013 with Microsoft labels)...")

    # FER+ requires the original FER2013 CSV
    print(f"  üì• Downloading fer2013.csv...")
    try:
        fer2013_csv_path = kagglehub.dataset_download(KAGGLE_IDS['fer2013_csv'])
        print(f"  ‚úì fer2013.csv downloaded: {fer2013_csv_path}")

        # Configure FER+ (download labels + generate images)
        ferplus_path = setup_ferplus_dataset(fer2013_csv_path, output_folder='/content/ferplus_generated')

        if ferplus_path:
            dataset_paths['ferplus'] = ferplus_path
            print(f"  ‚úì FER+ configured: {ferplus_path}")
        else:
            print(f"  ‚ö†Ô∏è FER+ not configured (see errors above)")
            dataset_paths['ferplus'] = None

    except Exception as e:
        print(f"  ‚úó Error: {e}")
        dataset_paths['ferplus'] = None

# ===============================================================================
# Update config
# ===============================================================================
valid_paths = {k: v for k, v in dataset_paths.items() if v is not None}

if 'affectnet' in valid_paths:
    config.DATASET_ROOT = valid_paths['affectnet']
elif valid_paths:
    config.DATASET_ROOT = list(valid_paths.values())[0]

print(f"\n{'='*60}")
print(f"‚úÖ DATASETS READY")
print(f"{'='*60}")
for name, path in dataset_paths.items():
    status = "‚úì" if path else "‚úó"
    print(f"  {status} {name}: {path if path else 'Not available'}")
print(f"{'='*60}")

In [None]:
# ===============================================================================
# üìÇ DATA LOADING (MULTI-DATASET OR SINGLE)
# ===============================================================================

print("üìÇ Loading datasets...")

train_transform = get_train_transforms()
val_transform = get_val_transforms()

# Filter valid paths
valid_dataset_paths = {k: v for k, v in dataset_paths.items() if v is not None}

# ===============================================================================
# MULTI-DATASET (combined) or SINGLE-DATASET MODE
# ===============================================================================

if DATASET_MODE == 'combined' and len(valid_dataset_paths) > 1:
    print(f"\nüîÄ MULTI-DATASET mode activated!")
    print(f"   Datasets: {list(valid_dataset_paths.keys())}")

    # Use combined dataset
    train_dataset = CombinedEmotionDataset(
        datasets_config=valid_dataset_paths,
        split='train',
        transform=train_transform,
        use_albumentations=HAS_ALBUMENTATIONS,
        target_size=config.INPUT_SIZE
    )

    val_dataset = CombinedEmotionDataset(
        datasets_config=valid_dataset_paths,
        split='val',
        transform=val_transform,
        use_albumentations=HAS_ALBUMENTATIONS,
        target_size=config.INPUT_SIZE
    )

    # Update config with unified classes (8 classes)
    config.NUM_CLASSES = CombinedEmotionDataset.NUM_CLASSES

else:
    print(f"\nüìÅ SINGLE-DATASET Mode")

    # Use the first available dataset
    dataset_name = list(valid_dataset_paths.keys())[0]
    root_path = valid_dataset_paths[dataset_name]
    print(f"   Dataset: {dataset_name}")

    if dataset_name == 'affectnet':
        train_dataset = BalancedAffectNetDataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS
        )
        val_dataset = BalancedAffectNetDataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS
        )
        config.NUM_CLASSES = 8

    elif dataset_name == 'fer2013':
        train_dataset = FER2013Dataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        val_dataset = FER2013Dataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        config.NUM_CLASSES = 7  # FER2013 has no Contempt

    elif dataset_name == 'ferplus':
        train_dataset = FERPlusDataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        val_dataset = FERPlusDataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        config.NUM_CLASSES = 8

# ===============================================================================
# CLASS WEIGHT CALCULATION (adaptive)
# ===============================================================================

def get_class_weights_adaptive(dataset, max_weight=5.0):
    """Calculates weights to balance classes (compatible with all datasets)."""
    counts = dataset.get_class_distribution()
    num_classes = len(counts)
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    weights = weights / weights.sum() * num_classes
    weights = np.clip(weights, 0.3, max_weight)
    weights = weights / weights.sum() * num_classes

    # Get emotion names depending on dataset type
    if hasattr(dataset, 'IDX_TO_EMOTION'):
        idx_to_emotion = dataset.IDX_TO_EMOTION
    elif hasattr(dataset, 'UNIFIED_CLASSES'):
        idx_to_emotion = {i: c for i, c in enumerate(dataset.UNIFIED_CLASSES)}
    else:
        idx_to_emotion = {i: f"Class_{i}" for i in range(num_classes)}

    print(f"\nüìä Class weights ({num_classes} classes):")
    for i, (count, weight) in enumerate(zip(counts, weights)):
        emotion = idx_to_emotion.get(i, f"Class_{i}")
        print(f"    {emotion:10s}: {count:6d} samples, weight: {weight:.3f}")

    return torch.FloatTensor(weights)

class_weights = get_class_weights_adaptive(train_dataset, max_weight=config.MAX_CLASS_WEIGHT).to(config.DEVICE)

# ===============================================================================
# OPTIMIZED DATALOADERS
# ===============================================================================

train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE * 2,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False
)

print(f"\n{'='*60}")
print("‚úÖ DATA LOADED")
print(f"{'='*60}")
print(f"  Train: {len(train_dataset):,} samples ({len(train_loader)} batches)")
print(f"  Val:   {len(val_dataset):,} samples ({len(val_loader)} batches)")
print(f"  Classes: {config.NUM_CLASSES}")
print(f"  Input size: {config.INPUT_SIZE}x{config.INPUT_SIZE}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  ‚ö° Workers: {config.NUM_WORKERS}, Prefetch: {config.PREFETCH_FACTOR}")
print(f"{'='*60}")

## 10. üëÄ Sample Visualization

In [None]:
# Visualize some images from the dataset (multi-dataset compatible)
def show_samples(dataset, n_samples=8):
    """Displays samples from the dataset (compatible with all datasets)."""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()

    # ImageNet Denormalization
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    # Get idx -> emotion mapping depending on dataset type
    if hasattr(dataset, 'IDX_TO_EMOTION'):
        idx_to_emotion = dataset.IDX_TO_EMOTION
    elif hasattr(dataset, 'UNIFIED_CLASSES'):
        idx_to_emotion = {i: c for i, c in enumerate(dataset.UNIFIED_CLASSES)}
    else:
        idx_to_emotion = {i: f"Class_{i}" for i in range(config.NUM_CLASSES)}

    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))

    for i, idx in enumerate(indices):
        img, label = dataset[idx]

        # Convert tensor to numpy and denormalize
        if isinstance(img, torch.Tensor):
            img_np = img.numpy().transpose(1, 2, 0)
        else:
            img_np = img.transpose(1, 2, 0) if img.shape[0] == 3 else img

        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)

        emotion = idx_to_emotion.get(label, f"Class_{label}")

        # Display source if available (multi-dataset)
        if hasattr(dataset, 'sources') and idx < len(dataset.sources):
            source = dataset.sources[idx]
            title = f"{emotion}\n({source})"
        else:
            title = emotion

        axes[i].imshow(img_np)
        axes[i].set_title(title, fontsize=10)
        axes[i].axis('off')

    # Title depending on mode
    if DATASET_MODE == 'combined' and len(valid_dataset_paths) > 1:
        dataset_type = f"Multi-Dataset ({', '.join(valid_dataset_paths.keys())})"
    else:
        dataset_type = list(valid_dataset_paths.keys())[0] if valid_dataset_paths else "Unknown"

    plt.suptitle(f'Samples - {dataset_type}', fontsize=14)
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)

## 11. üöÄ Training Configuration (with SWA)

In [None]:
# ===============================================================================
# üöÄ TRAINING CONFIGURATION
# ===============================================================================

# Model
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES).to(config.DEVICE)

# Model Compilation (PyTorch 2.0+)
if config.USE_COMPILE and hasattr(torch, 'compile'):
    try:
        # ‚ö° max-autotune mode: slower at start but faster afterwards
        model = torch.compile(model, mode=config.COMPILE_MODE)
        print(f"‚ö° Model compiled with torch.compile(mode='{config.COMPILE_MODE}')")
        print("   Note: First epochs will be slower (compilation)")
    except Exception as e:
        print(f"‚ö†Ô∏è torch.compile not available: {e}")

# Loss Function
if config.USE_FOCAL_LOSS:
    criterion = FocalLoss(
        gamma=config.FOCAL_GAMMA,
        alpha=class_weights,
        label_smoothing=config.LABEL_SMOOTHING if config.USE_LABEL_SMOOTHING else 0.0
    )
    print(f"‚úì Focal Loss (gamma={config.FOCAL_GAMMA})")
elif config.USE_LABEL_SMOOTHING:
    criterion = LabelSmoothingCrossEntropy(smoothing=config.LABEL_SMOOTHING)
    print(f"‚úì Label Smoothing (smoothing={config.LABEL_SMOOTHING})")
else:
    criterion = nn.CrossEntropyLoss(weight=class_weights)

val_criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# OneCycleLR Scheduler
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE * 10,
    epochs=config.EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos'
)

# GradScaler for Mixed Precision
scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)

# Display configuration
print(f"\n{'='*60}")
print("üìã Training Configuration:")
print(f"{'='*60}")
print(f"  Dataset: Balanced AffectNet (75x75 RGB, 8 classes)")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE} -> {config.LEARNING_RATE * 10}")
print(f"  Epochs: {config.EPOCHS}, Patience: {config.PATIENCE}")
print(f"  Mixup: {config.USE_MIXUP} (alpha={config.MIXUP_ALPHA})")
print(f"  ‚ö° Mixed Precision (AMP): {config.USE_AMP}")
print(f"  ‚ö° torch.compile: {config.COMPILE_MODE}")
print(f"{'='*60}")

## 12. üèãÔ∏è Training Loop

In [None]:
# ===============================================================================
# üèãÔ∏è UNIFIED TRAINING LOOP
# ===============================================================================

import gc

# Tracking variables
best_val_acc = 0.0
best_val_loss = float('inf')
patience_counter = 0
best_epoch = 0

# History for plots
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'lr': [], 'epoch_time': [], 'gpu_memory': []
}

# SWA Setup (optional)
swa_model = None
swa_scheduler = None
if config.USE_SWA and not config.USE_COMPILE:
    from torch.optim.swa_utils import AveragedModel, SWALR
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=config.SWA_LR)
    print(f"‚úÖ SWA activated (starts at epoch {config.SWA_START_EPOCH})")
elif config.USE_SWA and config.USE_COMPILE:
    print("‚ö†Ô∏è SWA disabled because torch.compile is enabled (incompatible)")

start_time = time.time()

print("\n" + "=" * 70)
print("üöÄ STARTING TRAINING")
print("=" * 70)
print(f"Mixed Precision: {config.USE_AMP}")
print(f"Batch size: {config.BATCH_SIZE}")
print(f"Workers: {config.NUM_WORKERS}")
print(f"Epochs: {config.EPOCHS}, Patience: {config.PATIENCE}")
print("=" * 70 + "\n")

for epoch in range(config.EPOCHS):
    epoch_start = time.time()
    model.train()

    loss_meter = AverageMeter()
    correct = 0
    total = 0

    optimizer.zero_grad(set_to_none=True)

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(config.DEVICE, non_blocking=True), labels.to(config.DEVICE, non_blocking=True)

        # Mixup only (CutMix disabled as it lowers performance)
        use_mixup = config.USE_MIXUP and random.random() > 0.5

        if use_mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, config.MIXUP_ALPHA)

        # Mixed Precision Forward Pass
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            outputs = model(inputs)

            if use_mixup:
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            else:
                loss = criterion(outputs, labels)

            loss = loss / config.ACCUMULATION_STEPS

        # Backward with GradScaler
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            # Scheduler step (not during SWA)
            if swa_model is None or epoch < config.SWA_START_EPOCH:
                scheduler.step()

            optimizer.zero_grad(set_to_none=True)

        # Metrics
        loss_meter.update(loss.item() * config.ACCUMULATION_STEPS, inputs.size(0))

        if not use_mixup:
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    train_acc = 100.0 * correct / max(total, 1)

    # SWA update after SWA_START_EPOCH
    if swa_model is not None and epoch >= config.SWA_START_EPOCH:
        swa_model.update_parameters(model)
        swa_scheduler.step()

    # Validation
    val_loss, val_acc = validate(model, val_loader, val_criterion, config.DEVICE,
                                 per_class=(epoch % 10 == 0), use_amp=config.USE_AMP)

    current_lr = optimizer.param_groups[0]['lr']
    epoch_time = time.time() - epoch_start
    elapsed = time.time() - start_time

    # GPU memory tracking
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.memory_allocated() / 1024**3
        gpu_mem_max = torch.cuda.max_memory_allocated() / 1024**3
    else:
        gpu_mem = 0
        gpu_mem_max = 0

    # Save history
    history['train_loss'].append(loss_meter.avg)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    history['epoch_time'].append(epoch_time)
    history['gpu_memory'].append(gpu_mem_max)

    swa_status = " [SWA]" if swa_model is not None and epoch >= config.SWA_START_EPOCH else ""
    print(f"Epoch {epoch+1:3d}/{config.EPOCHS} | "
          f"Loss: {loss_meter.avg:.4f} | Acc: {train_acc:.1f}% | "
          f"Val: {val_acc:.1f}% | LR: {current_lr:.6f} | "
          f"Time: {epoch_time:.1f}s | GPU: {gpu_mem:.1f}/{gpu_mem_max:.1f}GB{swa_status}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0

        # Get model weights (handle torch.compile)
        model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model

        torch.save({
            'epoch': epoch,
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'history': history,
            'config': {
                'num_classes': config.NUM_CLASSES,
                'in_channels': config.IN_CHANNELS,
                'input_size': config.INPUT_SIZE,
                'dataset': 'affectnet',
            }
        }, config.SAVE_PATH)
        print(f"  ‚úÖ [BEST] New best model! (Val Acc: {val_acc:.2f}%)")
    else:
        patience_counter += 1
        if patience_counter >= config.PATIENCE:
            print(f"\n‚èπÔ∏è Early stopping after {epoch+1} epochs!")
            break

# Memory cleanup
torch.cuda.empty_cache()
gc.collect()

elapsed = time.time() - start_time
avg_epoch_time = np.mean(history['epoch_time'])

print(f"\n{'='*70}")
print("‚úÖ TRAINING COMPLETED!")
print(f"{'='*70}")
print(f"Total time: {elapsed/60:.1f} minutes")
print(f"Average time per epoch: {avg_epoch_time:.1f} seconds")
print(f"Best epoch: {best_epoch}")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Max GPU memory: {max(history['gpu_memory']):.2f} GB")
print(f"Model saved: {config.SAVE_PATH}")
print(f"{'='*70}")

## 13. üìà Results Visualization

In [None]:
# ===============================================================================
# üìà TRAINING RESULTS VISUALIZATION
# ===============================================================================

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# 1. Loss
axes[0, 0].plot(history['train_loss'], label='Train', color='blue')
axes[0, 0].plot(history['val_loss'], label='Validation', color='orange')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('üìâ Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Accuracy
axes[0, 1].plot(history['train_acc'], label='Train', color='blue')
axes[0, 1].plot(history['val_acc'], label='Validation', color='orange')
axes[0, 1].axhline(y=best_val_acc, color='green', linestyle='--', label=f'Best: {best_val_acc:.1f}%')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('üìä Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Learning Rate
axes[0, 2].plot(history['lr'], color='green')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Learning Rate')
axes[0, 2].set_title('üìà Learning Rate (OneCycleLR)')
axes[0, 2].grid(True, alpha=0.3)

# 4. Time per epoch
axes[1, 0].plot(history['epoch_time'], color='purple')
axes[1, 0].axhline(y=np.mean(history['epoch_time']), color='red', linestyle='--',
                   label=f'Average: {np.mean(history["epoch_time"]):.1f}s')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Time (s)')
axes[1, 0].set_title('‚è±Ô∏è Time per Epoch')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 5. GPU Memory
axes[1, 1].plot(history['gpu_memory'], color='red')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Memory (GB)')
axes[1, 1].set_title('üéÆ Max GPU Memory')
axes[1, 1].grid(True, alpha=0.3)

# 6. Summary
axes[1, 2].axis('off')
summary_text = f"""
üìã TRAINING SUMMARY

Best accuracy: {best_val_acc:.2f}%
Best epoch: {best_epoch}

‚öôÔ∏è Configuration:
‚Ä¢ Batch Size: {config.BATCH_SIZE}
‚Ä¢ Epochs: {len(history['train_loss'])}
‚Ä¢ Mixed Precision: {config.USE_AMP}
‚Ä¢ torch.compile: {config.USE_COMPILE}
‚Ä¢ Mixup: {config.USE_MIXUP} (Œ±={config.MIXUP_ALPHA})
‚Ä¢ CutMix: {config.USE_CUTMIX}
‚Ä¢ SE Blocks: {config.USE_SE_BLOCKS}

‚è±Ô∏è Performance:
‚Ä¢ Average time/epoch: {np.mean(history['epoch_time']):.1f}s
"""
axes[1, 2].text(0.1, 0.5, summary_text, fontsize=12, va='center')

plt.tight_layout()
plt.show()

## 14. üîç Final Evaluation (with TTA)

In [None]:
# ===============================================================================
# üîç FINAL EVALUATION WITH TTA (Test-Time Augmentation)
# ===============================================================================
# ‚ö†Ô∏è If OOM: Restart the kernel (Runtime > Restart) before running this cell
# The CUDA Graphs cache of torch.compile cannot be released otherwise.

import gc

# ‚ö° AGGRESSIVE GPU MEMORY CLEANUP
print("üßπ Cleaning GPU memory...")

# Delete all possible models and tensors
for var_name in ['model', 'swa_model', 'optimizer', 'scheduler', 'scaler', 'criterion']:
    if var_name in dir():
        try:
            exec(f'del {var_name}')
        except:
            pass

# Force cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()

# Display available memory
if torch.cuda.is_available():
    gpu_free = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
    gpu_reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"   Free GPU memory: {gpu_free / 1024**3:.2f} GB")
    print(f"   PyTorch reserved memory: {gpu_reserved:.2f} GB")
    if gpu_free < 2 * 1024**3:  # Less than 2GB free
        print("   ‚ö†Ô∏è Low free memory - using small batches")


def validate_with_tta(model, val_loader, criterion, device, n_augmentations=5, use_amp=False):
    """Validation with Test-Time Augmentation - average over multiple augmentations."""
    model.eval()

    correct = 0
    total = 0
    loss_sum = 0

    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            batch_size = inputs.size(0)

            # Collect predictions from multiple augmentations
            all_outputs = []

            with torch.amp.autocast('cuda', enabled=use_amp):
                # 1. Original
                all_outputs.append(model(inputs))

                # 2. Horizontal flip
                all_outputs.append(model(torch.flip(inputs, dims=[3])))

                # 3-5. Slight brightness variations
                if n_augmentations >= 3:
                    all_outputs.append(model(inputs * 0.95))
                if n_augmentations >= 4:
                    all_outputs.append(model(inputs * 1.05))
                if n_augmentations >= 5:
                    all_outputs.append(model(torch.flip(inputs, dims=[3]) * 0.98))

            # Average predictions (soft voting)
            avg_outputs = torch.stack(all_outputs).mean(dim=0)

            loss = criterion(avg_outputs, labels)
            loss_sum += loss.item() * batch_size

            _, predicted = avg_outputs.max(1)
            total += batch_size
            correct += predicted.eq(labels).sum().item()

            for pred, label in zip(predicted, labels):
                class_total[label.item()] += 1
                if pred == label:
                    class_correct[label.item()] += 1

            # Free memory of intermediate outputs
            del all_outputs, avg_outputs, inputs, labels
            torch.cuda.empty_cache()

    accuracy = 100.0 * correct / total
    avg_loss = loss_sum / total

    emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
    print(f"\n  üìä Accuracy per class (TTA x{n_augmentations}):")
    for i, emo in enumerate(emotions):
        if class_total[i] > 0:
            acc = 100.0 * class_correct[i] / class_total[i]
            print(f"    {emo:10s}: {acc:5.1f}% ({class_correct[i]}/{class_total[i]})")

    return avg_loss, accuracy


# Load the best model
print("\nüì• Loading the best model...")
checkpoint = torch.load(config.SAVE_PATH, weights_only=False)

# Create a new model (without compilation) to load weights
eval_model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES).to(config.DEVICE)
eval_model.load_state_dict(checkpoint['model_state_dict'])

# ‚ö° REDUCED BATCH SIZE to avoid OOM (CUDA Graphs cache takes ~13GB)
EVAL_BATCH_SIZE = 256  # Much smaller to leave space
print(f"   ‚ö° Reduced evaluation batch size: {EVAL_BATCH_SIZE} (instead of {config.BATCH_SIZE})")

eval_loader = DataLoader(
    val_dataset,
    batch_size=EVAL_BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

# Recreate validation criterion
val_criterion = nn.CrossEntropyLoss()

print(f"\n{'='*60}")
print(f"üìä STANDARD EVALUATION")
print(f"{'='*60}")
val_loss, val_acc = validate(eval_model, eval_loader, val_criterion, config.DEVICE, per_class=True, use_amp=config.USE_AMP)
print(f"\nüéØ Standard results:")
print(f"   - Global accuracy: {val_acc:.2f}%")
print(f"   - Loss: {val_loss:.4f}")

# Cleanup before TTA (which uses more memory)
torch.cuda.empty_cache()

print(f"\n{'='*60}")
print(f"üìä EVALUATION WITH TTA (Test-Time Augmentation)")
print(f"{'='*60}")
tta_loss, tta_acc = validate_with_tta(eval_model, eval_loader, val_criterion, config.DEVICE,
                                       n_augmentations=5, use_amp=config.USE_AMP)
print(f"\nüéØ Results with TTA:")
print(f"   - Global accuracy: {tta_acc:.2f}%")
print(f"   - Loss: {tta_loss:.4f}")
print(f"   - TTA Improvement: {tta_acc - val_acc:+.2f}%")

## 15. üíæ Final Model Saving

In [None]:
# Save final model (weights only) - lightweight version for deployment
torch.save({
    'model_state_dict': eval_model.state_dict(),  # Uses eval_model (the loaded model)
    'num_classes': config.NUM_CLASSES,
    'in_channels': config.IN_CHANNELS,
    'input_size': config.INPUT_SIZE,
    'dataset': 'affectnet',
    'best_val_acc': checkpoint['val_acc'],  # Uses the checkpoint value
}, 'emotion_model.pth')

print("‚úÖ Model saved in 'emotion_model.pth'")
print(f"   Size: {os.path.getsize('emotion_model.pth') / 1024 / 1024:.2f} MB")
print(f"   Best Val Acc: {checkpoint['val_acc']:.2f}%")

## 16. üß™ Test on Some Images

In [None]:
def predict_emotion(model, image_tensor, device):
    """Predicts emotion for an image."""
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        pred_idx = outputs.argmax(1).item()
        confidence = probs[0, pred_idx].item()
    return pred_idx, confidence, probs[0].cpu().numpy()

# Test on some validation images
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.flatten()

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']

indices = random.sample(range(len(val_dataset)), 8)

for i, idx in enumerate(indices):
    img, true_label = val_dataset[idx]
    pred_idx, confidence, probs = predict_emotion(eval_model, img, config.DEVICE)  # ‚ö° Uses eval_model

    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * std + mean
    img_np = np.clip(img_np, 0, 1)

    true_emotion = emotions[true_label]
    pred_emotion = emotions[pred_idx]

    color = 'green' if pred_idx == true_label else 'red'

    axes[i].imshow(img_np)
    axes[i].set_title(f"True: {true_emotion}\nPred: {pred_emotion} ({confidence*100:.1f}%)",
                      color=color, fontsize=10)
    axes[i].axis('off')

plt.suptitle('üîç Predictions on Validation Set', fontsize=14)
plt.tight_layout()
plt.show()