## 1. üì¶ Imports et V√©rification GPU

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import os
import time
from collections import defaultdict
import random
from PIL import Image
import matplotlib.pyplot as plt

# V√©rification GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

PyTorch version: 2.9.0+cu126
CUDA disponible: True
GPU: Tesla T4
CUDA version: 12.6


In [8]:
# Import albumentations pour l'augmentation avanc√©e
try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    HAS_ALBUMENTATIONS = True
    print("‚úÖ Albumentations disponible pour l'augmentation avanc√©e")
except ImportError:
    HAS_ALBUMENTATIONS = False
    print("‚ö†Ô∏è Installez albumentations: pip install albumentations")

‚úÖ Albumentations disponible pour l'augmentation avanc√©e


## 2. ‚öôÔ∏è Configuration des Hyperparam√®tres

Tous les param√®tres d'entra√Ænement sont centralis√©s ici pour faciliter l'exp√©rimentation.

In [9]:
class Config:
    """Configuration MAXIMALE - Meilleur r√©sultat possible (epoch 1 lent, suite rapide)"""

    # === DONN√âES ===
    DATASET_ROOT = './kaggle/input/balanced-affectnet'

    # === MOD√àLE ===
    NUM_CLASSES = 8
    IN_CHANNELS = 3
    INPUT_SIZE = 75
    USE_SE_BLOCKS = True

    # === ENTRA√éNEMENT ===
    BATCH_SIZE = 1536     # ‚ö° Bon compromis GPU (utilise ~10-11GB)
    ACCUMULATION_STEPS = 1
    LEARNING_RATE = 0.0015  # Ajust√© pour batch 1536
    WEIGHT_DECAY = 1e-4
    EPOCHS = 100
    PATIENCE = 20

    # === TECHNIQUES AVANC√âES ===
    USE_MIXUP = True
    MIXUP_ALPHA = 0.2
    USE_CUTMIX = False
    CUTMIX_ALPHA = 1.0
    CUTMIX_PROB = 0.0

    USE_LABEL_SMOOTHING = True
    LABEL_SMOOTHING = 0.1

    USE_FOCAL_LOSS = False
    FOCAL_GAMMA = 2.0

    # === AUGMENTATION ===
    USE_ADVANCED_AUG = True
    USE_CLAHE = False
    USE_GRID_DISTORTION = False

    # === √âQUILIBRAGE DES CLASSES ===
    USE_OVERSAMPLING = False
    MAX_CLASS_WEIGHT = 3.0

    # === OPTIMISATION GPU MAXIMALE ===
    USE_AMP = True                    # ‚úÖ Mixed Precision
    USE_COMPILE = True                # ‚úÖ torch.compile
    COMPILE_MODE = 'max-autotune'     # ‚ö° MAXIMUM: epoch 1 lent (~2-3min) mais suite tr√®s rapide
    NUM_WORKERS = 2                   # Optimal
    PREFETCH_FACTOR = 4
    PERSISTENT_WORKERS = True

    # === SWA ===
    USE_SWA = False
    SWA_START_EPOCH = 75
    SWA_LR = 0.0001

    # === DEVICE ===
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # === SAUVEGARDE ===
    SAVE_PATH = 'emotion_model_best.pth'

config = Config()

# ‚ö° Optimisations CUDA MAXIMALES
if torch.cuda.is_available():
    # Performance
    torch.backends.cudnn.benchmark = True          # Auto-tune kernels
    torch.backends.cuda.matmul.allow_tf32 = True   # TensorFloat-32 (2x plus rapide)
    torch.backends.cudnn.allow_tf32 = True         # TF32 pour cuDNN
    torch.backends.cudnn.deterministic = False     # Non-d√©terministe = plus rapide
    torch.set_float32_matmul_precision('high')     # Tensor Cores optimis√©s

    # ‚ö° NOUVEAU: Optimisations m√©moire pour batches plus grands
    torch.cuda.set_per_process_memory_fraction(0.95)  # Utilise 95% de la VRAM

    gpu_mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM totale: {gpu_mem_total:.1f} GB")
    print(f"‚ö° Mode: max-autotune (epoch 1 lent, suite tr√®s rapide)")

print(f"\n{'='*60}")
print("üìã CONFIGURATION MAXIMALE (Qualit√© + Vitesse)")
print(f"{'='*60}")
print(f"Device: {config.DEVICE}")
print(f"‚ö° Batch size: {config.BATCH_SIZE}")
print(f"‚ö° Learning rate: {config.LEARNING_RATE}")
print(f"‚ö° torch.compile: {config.COMPILE_MODE}")
print(f"‚ö° Mixed Precision: {config.USE_AMP}")
print(f"‚ö° TF32: Activ√©")
print(f"‚ö†Ô∏è Epoch 1: ~2-3 min (compilation)")
print(f"‚úÖ Epochs 2+: ~20-25s (tr√®s rapide)")
print(f"{'='*60}")

GPU: Tesla T4
VRAM totale: 14.7 GB
‚ö° Mode: max-autotune (epoch 1 lent, suite tr√®s rapide)

üìã CONFIGURATION MAXIMALE (Qualit√© + Vitesse)
Device: cuda
‚ö° Batch size: 1536
‚ö° Learning rate: 0.0015
‚ö° torch.compile: max-autotune
‚ö° Mixed Precision: True
‚ö° TF32: Activ√©
‚ö†Ô∏è Epoch 1: ~2-3 min (compilation)
‚úÖ Epochs 2+: ~20-25s (tr√®s rapide)


## 3. üìâ Fonctions de Perte (Loss Functions)

### Focal Loss
Utile pour les datasets d√©s√©quilibr√©s - r√©duit l'importance des exemples faciles.

### Label Smoothing Cross Entropy
Emp√™che le mod√®le d'√™tre trop confiant sur les pr√©dictions.

In [10]:
class FocalLoss(nn.Module):
    """Focal Loss pour g√©rer le d√©s√©quilibre de classes."""
    def __init__(self, gamma=2.0, alpha=None, reduction='mean', label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.label_smoothing = label_smoothing

    def forward(self, inputs, targets):
        if self.label_smoothing > 0:
            n_classes = inputs.size(-1)
            targets_smooth = torch.zeros_like(inputs)
            targets_smooth.fill_(self.label_smoothing / (n_classes - 1))
            targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.label_smoothing)

            log_probs = F.log_softmax(inputs, dim=-1)
            ce_loss = -(targets_smooth * log_probs).sum(dim=-1)
        else:
            ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        probs = torch.softmax(inputs, dim=-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal_weight = (1 - pt) ** self.gamma

        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            focal_weight = focal_weight * alpha_t

        loss = focal_weight * ce_loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class LabelSmoothingCrossEntropy(nn.Module):
    """Cross Entropy avec label smoothing."""
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, inputs, targets):
        n_classes = inputs.size(-1)
        log_probs = F.log_softmax(inputs, dim=-1)

        targets_smooth = torch.zeros_like(log_probs)
        targets_smooth.fill_(self.smoothing / (n_classes - 1))
        targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)

        loss = -(targets_smooth * log_probs).sum(dim=-1)
        return loss.mean()

print("‚úÖ Fonctions de perte d√©finies")

‚úÖ Fonctions de perte d√©finies


## 4. üîÄ Mixup & CutMix

Techniques d'augmentation qui m√©langent des images pour am√©liorer la g√©n√©ralisation.

In [11]:
def mixup_data(x, y, alpha=0.2):
    """Mixup: m√©lange deux √©chantillons."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]

    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0):
    """CutMix: coupe et colle des patches entre √©chantillons."""
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    _, _, H, W = x.shape
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))

    return x, y, y[index], lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Calcule la loss mix√©e pour mixup/cutmix."""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

print("‚úÖ Fonctions Mixup et CutMix d√©finies")

‚úÖ Fonctions Mixup et CutMix d√©finies


## 5. üñºÔ∏è Transformations et Augmentation de Donn√©es

Utilise Albumentations pour des augmentations avanc√©es (rotation, bruit, flou, etc.)

In [12]:
def get_train_transforms():
    """Transformations pour l'entra√Ænement - VERSION √âQUILIBR√âE (ni trop ni trop peu)."""
    if HAS_ALBUMENTATIONS and config.USE_ADVANCED_AUG:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Affine(
                translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
                scale=(0.95, 1.05),   # Mod√©r√©
                rotate=(-10, 10),     # Mod√©r√© (pas 15 qui est trop)
                p=0.4
            ),
            # PAS de CLAHE ni GridDistortion (trop agressif sur 75x75)
            A.OneOf([
                A.GaussNoise(std_range=(0.02, 0.08), p=1),
                A.GaussianBlur(blur_limit=(3, 5), p=1),
            ], p=0.2),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=1),
                A.RandomGamma(gamma_limit=(85, 115), p=1),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=15, p=1),
            ], p=0.4),
            A.CoarseDropout(
                num_holes_range=(1, 2),
                hole_height_range=(4, 8),
                hole_width_range=(4, 8),
                fill=0,
                p=0.2
            ),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        # Fallback vers torchvision
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.95, 1.05)),
            transforms.ColorJitter(brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


def get_val_transforms():
    """Transformations pour la validation (juste normalisation)."""
    if HAS_ALBUMENTATIONS:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

print("‚úÖ Transformations d√©finies (version √©quilibr√©e)")

‚úÖ Transformations d√©finies (version √©quilibr√©e)


## 6. üìÅ Dataset AffectNet

In [13]:
from torch.utils.data import Dataset, WeightedRandomSampler

class BalancedAffectNetDataset(Dataset):
    """
    Dataset pour Balanced AffectNet.

    Structure attendue:
    data/
        train/Anger/, Contempt/, Disgust/, Fear/, Happy/, Neutral/, Sad/, Surprise/
        val/...
        test/...
    """

    NUM_CLASSES = 8

    EMOTION_CLASSES = {
        'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3,
        'Sad': 4, 'Surprise': 5, 'Neutral': 6, 'Contempt': 7,
    }

    IDX_TO_EMOTION = {v: k for k, v in EMOTION_CLASSES.items()}

    def __init__(self, root_dir='./kaggle/input/balanced-affectnet', split='train', transform=None, use_albumentations=False):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.use_albumentations = use_albumentations

        self.images = []
        self.labels = []

        split_dir = os.path.join(root_dir, split)

        if not os.path.exists(split_dir):
            raise FileNotFoundError(
                f"Dataset non trouv√©: {split_dir}\n"
                f"T√©l√©chargez depuis: https://www.kaggle.com/datasets/dollyprajapati182/balanced-affectnet"
            )

        # Charger toutes les images
        for emotion_name, emotion_idx in self.EMOTION_CLASSES.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                print(f"‚ö†Ô∏è {emotion_dir} non trouv√©, ignor√©...")
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    self.images.append(os.path.join(emotion_dir, img_name))
                    self.labels.append(emotion_idx)

        print(f"üìÇ Charg√© {len(self.images)} images depuis AffectNet {split}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        image = Image.open(img_path).convert('RGB')
        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


def get_class_weights(dataset, max_weight=5.0):
    """Calcule les poids pour √©quilibrer les classes."""
    counts = dataset.get_class_distribution()
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    weights = weights / weights.sum() * len(weights)
    weights = np.clip(weights, 0.3, max_weight)
    weights = weights / weights.sum() * len(weights)

    print("\nüìä Poids des classes:")
    for i, (count, weight) in enumerate(zip(counts, weights)):
        emotion = BalancedAffectNetDataset.IDX_TO_EMOTION.get(i, f"Class_{i}")
        print(f"    {emotion:10s}: {count:5d} samples, poids: {weight:.3f}")

    return torch.FloatTensor(weights)


def get_balanced_sampler(dataset):
    """Cr√©e un sampler √©quilibr√© pour l'entra√Ænement."""
    labels = dataset.get_labels()
    counts = np.bincount(labels, minlength=BalancedAffectNetDataset.NUM_CLASSES)
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    sample_weights = weights[labels]

    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

print("‚úÖ Classes Dataset d√©finies")

‚úÖ Classes Dataset d√©finies


## 6bis. üìÅ Dataset FER2013 / FER2013+

FER2013 est un dataset classique de reconnaissance d'√©motions avec ~35k images en 48x48 grayscale.
FER2013+ est une version avec des labels corrig√©s et am√©lior√©s par Microsoft.

In [14]:
class FER2013Dataset(Dataset):
    """
    Dataset FER2013 avec conversion grayscale -> RGB.

    Structure Kaggle (msambare/fer2013):
        train/angry/, disgust/, fear/, happy/, neutral/, sad/, surprise/
        test/...

    Les images sont en 48x48 grayscale, automatiquement:
    - Redimensionn√©es vers target_size (75x75 par d√©faut)
    - Converties en RGB (pour compatibilit√© avec le mod√®le)
    """

    NUM_CLASSES = 7  # Pas de Contempt dans FER2013

    EMOTION_CLASSES = {
        'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3,
        'sad': 4, 'surprise': 5, 'neutral': 6
    }

    IDX_TO_EMOTION = {v: k.capitalize() for k, v in EMOTION_CLASSES.items()}

    def __init__(self, root_dir, split='train', transform=None,
                 use_albumentations=False, target_size=75):
        self.root_dir = root_dir
        self.split = 'train' if split == 'train' else 'test'  # FER2013 n'a que train/test
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size

        self.images = []
        self.labels = []

        split_dir = os.path.join(root_dir, self.split)

        if not os.path.exists(split_dir):
            raise FileNotFoundError(f"FER2013 non trouv√©: {split_dir}")

        for emotion_name, emotion_idx in self.EMOTION_CLASSES.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.images.append(os.path.join(emotion_dir, img_name))
                    self.labels.append(emotion_idx)

        print(f"üìÇ FER2013 {self.split}: {len(self.images)} images charg√©es")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Charger l'image (peut √™tre grayscale)
        image = Image.open(img_path)

        # Convertir en RGB si grayscale
        if image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Redimensionner vers target_size (48x48 -> 75x75)
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


# ===============================================================================
# üì¶ FER+ (FER2013 avec labels corrig√©s par Microsoft)
# ===============================================================================

def download_ferplus_labels(dest_folder):
    """
    T√©l√©charge fer2013new.csv depuis le repo Microsoft FERPlus.

    Returns:
        str: Chemin vers le fichier t√©l√©charg√©
    """
    import urllib.request

    url = "https://raw.githubusercontent.com/microsoft/FERPlus/master/fer2013new.csv"
    dest_path = os.path.join(dest_folder, "fer2013new.csv")

    if os.path.exists(dest_path):
        print(f"  ‚úì fer2013new.csv d√©j√† pr√©sent")
        return dest_path

    print(f"  üì• T√©l√©chargement de fer2013new.csv depuis GitHub...")
    try:
        urllib.request.urlretrieve(url, dest_path)
        print(f"  ‚úì T√©l√©charg√©: {dest_path}")
        return dest_path
    except Exception as e:
        print(f"  ‚úó Erreur: {e}")
        return None


def generate_ferplus_images(fer2013_csv_path, ferplus_csv_path, output_folder):
    """
    G√©n√®re les images PNG depuis fer2013.csv avec les labels FER+.

    Structure de sortie:
        output_folder/
            FER2013Train/
                fer0000000.png
                ...
            FER2013Valid/
            FER2013Test/

    Returns:
        bool: True si succ√®s
    """
    import csv

    # Cr√©er les dossiers
    for split in ['FER2013Train', 'FER2013Valid', 'FER2013Test']:
        os.makedirs(os.path.join(output_folder, split), exist_ok=True)

    # Mapping usage -> dossier
    usage_to_folder = {
        'Training': 'FER2013Train',
        'PublicTest': 'FER2013Valid',
        'PrivateTest': 'FER2013Test'
    }

    # Lire fer2013.csv et g√©n√©rer les images
    print(f"  üñºÔ∏è G√©n√©ration des images depuis fer2013.csv...")

    with open(fer2013_csv_path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)  # emotion,pixels,Usage

        for idx, row in enumerate(reader):
            if len(row) < 3:
                continue

            emotion = row[0]
            pixels = row[1]
            usage = row[2]

            # Convertir pixels en image
            pixel_values = [int(p) for p in pixels.split()]
            img_array = np.array(pixel_values, dtype=np.uint8).reshape(48, 48)
            img = Image.fromarray(img_array, mode='L')

            # Sauvegarder
            folder = usage_to_folder.get(usage, 'FER2013Train')
            img_name = f"fer{idx:08d}.png"
            img_path = os.path.join(output_folder, folder, img_name)
            img.save(img_path)

            if idx % 5000 == 0:
                print(f"    Progression: {idx} images...")

    print(f"  ‚úì {idx + 1} images g√©n√©r√©es")
    return True


def setup_ferplus_dataset(fer2013_kaggle_path, output_folder=None):
    """
    Configure le dataset FER+ complet:
    1. T√©l√©charge fer2013new.csv depuis GitHub
    2. Trouve/t√©l√©charge fer2013.csv depuis Kaggle
    3. G√©n√®re les images PNG

    Args:
        fer2013_kaggle_path: Chemin vers le dataset FER2013 Kaggle (msambare/fer2013)
        output_folder: Dossier de sortie (optionnel)

    Returns:
        str: Chemin vers le dataset FER+ pr√™t √† l'emploi
    """
    if output_folder is None:
        output_folder = os.path.join(os.path.dirname(fer2013_kaggle_path), 'ferplus_generated')

    os.makedirs(output_folder, exist_ok=True)

    # V√©rifier si d√©j√† g√©n√©r√©
    train_folder = os.path.join(output_folder, 'FER2013Train')
    if os.path.exists(train_folder) and len(os.listdir(train_folder)) > 1000:
        print(f"  ‚úì FER+ d√©j√† g√©n√©r√© dans {output_folder}")
        # T√©l√©charger quand m√™me les labels si pas pr√©sents
        download_ferplus_labels(output_folder)
        return output_folder

    print("\nüîß Configuration de FER+ (premi√®re utilisation)...")

    # 1. T√©l√©charger les labels FER+ depuis GitHub
    ferplus_csv = download_ferplus_labels(output_folder)
    if ferplus_csv is None:
        return None

    # 2. Trouver fer2013.csv
    # Le dataset Kaggle msambare/fer2013 est en format dossiers, pas CSV
    # On doit utiliser le dataset original: deadskull7/fer2013
    fer2013_csv = os.path.join(fer2013_kaggle_path, 'fer2013.csv')

    if not os.path.exists(fer2013_csv):
        # Chercher dans d'autres emplacements possibles
        for alt_path in [
            os.path.join(fer2013_kaggle_path, 'fer2013', 'fer2013.csv'),
            os.path.join(fer2013_kaggle_path, 'data', 'fer2013.csv'),
        ]:
            if os.path.exists(alt_path):
                fer2013_csv = alt_path
                break

    if not os.path.exists(fer2013_csv):
        print(f"  ‚ö†Ô∏è fer2013.csv non trouv√©. FER+ n√©cessite le dataset CSV original.")
        print(f"     Le dataset Kaggle 'msambare/fer2013' est en format image.")
        print(f"     Pour FER+, utilisez 'deadskull7/fer2013' qui contient le CSV.")
        return None

    # 3. G√©n√©rer les images
    success = generate_ferplus_images(fer2013_csv, ferplus_csv, output_folder)

    if success:
        return output_folder
    return None


class FERPlusDataset(Dataset):
    """
    Dataset FER2013+ (FER+) avec labels corrig√©s par Microsoft.

    FER+ am√©liore FER2013 avec:
    - Labels vot√©s par 10 annotateurs (plus fiables)
    - 8 classes (ajout de Contempt)
    - Possibilit√© d'utiliser les probabilit√©s de vote

    Le dataset est automatiquement configur√© depuis:
    - fer2013new.csv (labels) depuis GitHub Microsoft
    - fer2013.csv (images) depuis Kaggle
    """

    NUM_CLASSES = 8

    # Colonnes du CSV FER+: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF
    FERPLUS_EMOTIONS = ['neutral', 'happiness', 'surprise', 'sadness', 'anger', 'disgust', 'fear', 'contempt']

    # Mapping FER+ order -> Unified order (AffectNet)
    # FER+: neutral(0), happiness(1), surprise(2), sadness(3), anger(4), disgust(5), fear(6), contempt(7)
    # Unified: Anger(0), Disgust(1), Fear(2), Happy(3), Sad(4), Surprise(5), Neutral(6), Contempt(7)
    FERPLUS_TO_UNIFIED = {
        0: 6,  # neutral -> Neutral
        1: 3,  # happiness -> Happy
        2: 5,  # surprise -> Surprise
        3: 4,  # sadness -> Sad
        4: 0,  # anger -> Anger
        5: 1,  # disgust -> Disgust
        6: 2,  # fear -> Fear
        7: 7,  # contempt -> Contempt
    }

    IDX_TO_EMOTION = {
        0: 'Anger', 1: 'Disgust', 2: 'Fear', 3: 'Happy',
        4: 'Sad', 5: 'Surprise', 6: 'Neutral', 7: 'Contempt'
    }

    def __init__(self, root_dir, split='train', transform=None,
                 use_albumentations=False, target_size=75,
                 label_mode='majority', min_votes=1):
        """
        Args:
            root_dir: Chemin vers le dataset FER+ (avec FER2013Train/, etc.)
            split: 'train', 'val' ou 'test'
            label_mode: 'majority' (label le plus vot√©) ou 'probability' (distribution)
            min_votes: Nombre minimum de votes pour inclure une image
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size
        self.label_mode = label_mode
        self.min_votes = min_votes

        self.images = []
        self.labels = []
        self.vote_distributions = []  # Pour le mode probability

        # Mapping split -> folder
        split_to_folder = {
            'train': 'FER2013Train',
            'val': 'FER2013Valid',
            'test': 'FER2013Test'
        }

        folder_name = split_to_folder.get(split, 'FER2013Train')
        split_dir = os.path.join(root_dir, folder_name)

        if not os.path.exists(split_dir):
            print(f"‚ö†Ô∏è FER+ {split} non trouv√©: {split_dir}")
            return

        # Charger les labels depuis fer2013new.csv
        ferplus_csv = os.path.join(root_dir, 'fer2013new.csv')
        if not os.path.exists(ferplus_csv):
            print(f"‚ö†Ô∏è fer2013new.csv non trouv√© dans {root_dir}")
            return

        self._load_data(split_dir, ferplus_csv, split)

        print(f"üìÇ FER+ {split}: {len(self.images)} images charg√©es (mode: {label_mode})")

    def _load_data(self, split_dir, ferplus_csv, split):
        """Charge les images et labels."""
        import csv

        # Mapping usage dans le CSV
        usage_mapping = {
            'train': 'Training',
            'val': 'PublicTest',
            'test': 'PrivateTest'
        }
        target_usage = usage_mapping.get(split, 'Training')

        with open(ferplus_csv, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)  # Skip header

            for idx, row in enumerate(reader):
                if len(row) < 10:
                    continue

                usage = row[0]

                # Filtrer par split
                if usage != target_usage:
                    continue

                # Votes pour chaque √©motion (colonnes 1-8)
                # Format: usage, neutral, happiness, surprise, sadness, anger, disgust, fear, contempt, unknown, NF
                try:
                    votes = [int(v) if v.strip().isdigit() else 0 for v in row[1:9]]
                except:
                    continue

                total_votes = sum(votes)

                # Ignorer si pas assez de votes valides ou si c'est "unknown" / "NF"
                if total_votes < self.min_votes:
                    continue

                # Chemin de l'image
                img_name = f"fer{idx:08d}.png"
                img_path = os.path.join(split_dir, img_name)

                if not os.path.exists(img_path):
                    continue

                # Calculer le label
                ferplus_label = np.argmax(votes)
                unified_label = self.FERPLUS_TO_UNIFIED[ferplus_label]

                self.images.append(img_path)
                self.labels.append(unified_label)

                # Stocker la distribution pour le mode probability
                if self.label_mode == 'probability':
                    vote_dist = np.array(votes, dtype=np.float32)
                    vote_dist = vote_dist / vote_dist.sum()  # Normaliser
                    # R√©ordonner selon l'ordre unifi√©
                    unified_dist = np.zeros(8, dtype=np.float32)
                    for ferplus_idx, unified_idx in self.FERPLUS_TO_UNIFIED.items():
                        unified_dist[unified_idx] = vote_dist[ferplus_idx]
                    self.vote_distributions.append(unified_dist)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]

        # Charger l'image
        image = Image.open(img_path)

        # Convertir en RGB
        if image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Redimensionner
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        # Retourner selon le mode
        if self.label_mode == 'probability' and len(self.vote_distributions) > idx:
            return image, label, torch.tensor(self.vote_distributions[idx])

        return image, label

    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)


print("‚úÖ Datasets FER2013 et FER+ d√©finis")

‚úÖ Datasets FER2013 et FER+ d√©finis


## 6ter. üîÄ Dataset Combin√© Multi-Sources

Ce dataset combine AffectNet, FER2013 et/ou FER+ en unifiant les classes vers 8 √©motions.

In [15]:
class CombinedEmotionDataset(Dataset):
    """
    Dataset combinant plusieurs sources avec mapping unifi√© des classes.

    Combine AffectNet, FER2013 et FER+ avec:
    - Redimensionnement automatique vers target_size
    - Conversion grayscale -> RGB automatique
    - Mapping unifi√© vers 8 classes (ordre AffectNet)

    Classes unifi√©es:
        0: Anger, 1: Disgust, 2: Fear, 3: Happy,
        4: Sad, 5: Surprise, 6: Neutral, 7: Contempt
    """

    # 8 classes unifi√©es (ordre AffectNet)
    UNIFIED_CLASSES = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
    NUM_CLASSES = 8
    IDX_TO_EMOTION = {i: c for i, c in enumerate(UNIFIED_CLASSES)}

    def __init__(self, datasets_config, split='train', transform=None,
                 use_albumentations=False, target_size=75):
        """
        Args:
            datasets_config: dict {nom_dataset: chemin_racine}
                Exemple: {'affectnet': '/path/to/affectnet', 'fer2013': '/path/to/fer2013'}
            split: 'train', 'val', ou 'test'
            target_size: taille de sortie uniforme (75 par d√©faut)
        """
        self.transform = transform
        self.use_albumentations = use_albumentations
        self.target_size = target_size

        self.images = []  # Liste de dicts: {'path': str, 'is_grayscale': bool}
        self.labels = []
        self.sources = []  # Pour tracking/debug

        total_by_source = {}

        for dataset_name, root_dir in datasets_config.items():
            if root_dir is None or not os.path.exists(root_dir):
                print(f"‚ö†Ô∏è {dataset_name} ignor√© (non trouv√©): {root_dir}")
                continue

            count_before = len(self.images)

            if dataset_name == 'affectnet':
                self._load_affectnet(root_dir, split)
            elif dataset_name == 'fer2013':
                self._load_fer2013(root_dir, split)
            elif dataset_name == 'ferplus':
                self._load_ferplus(root_dir, split)
            else:
                print(f"‚ö†Ô∏è Dataset inconnu: {dataset_name}")
                continue

            total_by_source[dataset_name] = len(self.images) - count_before

        print(f"\n{'='*50}")
        print(f"üìä DATASET COMBIN√â ({split})")
        print(f"{'='*50}")
        for src, count in total_by_source.items():
            print(f"  {src:15s}: {count:6d} images")
        print(f"  {'TOTAL':15s}: {len(self.images):6d} images")
        print(f"{'='*50}")
        self._print_class_distribution()

    def _load_affectnet(self, root_dir, split):
        """Charge les images AffectNet."""
        affectnet_mapping = {
            'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3,
            'Sad': 4, 'Surprise': 5, 'Neutral': 6, 'Contempt': 7
        }

        split_dir = os.path.join(root_dir, split)
        if not os.path.exists(split_dir):
            return

        for emotion_name, unified_idx in affectnet_mapping.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    self.images.append({
                        'path': os.path.join(emotion_dir, img_name),
                        'is_grayscale': False
                    })
                    self.labels.append(unified_idx)
                    self.sources.append('affectnet')

    def _load_fer2013(self, root_dir, split):
        """Charge les images FER2013."""
        # FER2013 n'a que train/test, pas de val
        fer_split = 'train' if split == 'train' else 'test'
        split_dir = os.path.join(root_dir, fer_split)

        if not os.path.exists(split_dir):
            return

        # Mapping FER2013 (7 classes) -> unifi√© (8 classes)
        fer_to_unified = {
            'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3,
            'sad': 4, 'surprise': 5, 'neutral': 6
        }

        for emotion_name, unified_idx in fer_to_unified.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                continue

            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.images.append({
                        'path': os.path.join(emotion_dir, img_name),
                        'is_grayscale': True  # FER2013 est en grayscale
                    })
                    self.labels.append(unified_idx)
                    self.sources.append('fer2013')

    def _load_ferplus(self, root_dir, split):
        """Charge les images FER+ (avec labels corrig√©s par Microsoft)."""
        import csv

        # Mapping FER+ order -> Unified order
        ferplus_to_unified = {
            0: 6,  # neutral -> Neutral
            1: 3,  # happiness -> Happy
            2: 5,  # surprise -> Surprise
            3: 4,  # sadness -> Sad
            4: 0,  # anger -> Anger
            5: 1,  # disgust -> Disgust
            6: 2,  # fear -> Fear
            7: 7,  # contempt -> Contempt
        }

        # Mapping split -> folder et usage
        split_mapping = {
            'train': ('FER2013Train', 'Training'),
            'val': ('FER2013Valid', 'PublicTest'),
            'test': ('FER2013Test', 'PrivateTest')
        }

        folder_name, target_usage = split_mapping.get(split, ('FER2013Train', 'Training'))
        split_dir = os.path.join(root_dir, folder_name)

        if not os.path.exists(split_dir):
            print(f"    ‚ö†Ô∏è FER+ {split} non trouv√©: {split_dir}")
            return

        # Trouver fer2013new.csv
        ferplus_csv = os.path.join(root_dir, 'fer2013new.csv')
        if not os.path.exists(ferplus_csv):
            print(f"    ‚ö†Ô∏è fer2013new.csv non trouv√© dans {root_dir}")
            return

        # Charger les donn√©es
        with open(ferplus_csv, 'r') as f:
            reader = csv.reader(f)
            header = next(reader)  # Skip header

            for idx, row in enumerate(reader):
                if len(row) < 10:
                    continue

                usage = row[0]

                # Filtrer par split
                if usage != target_usage:
                    continue

                # Votes pour chaque √©motion (colonnes 1-8)
                try:
                    votes = [int(v.strip()) if v.strip().isdigit() else 0 for v in row[1:9]]
                except:
                    continue

                if sum(votes) == 0:
                    continue

                # Label = √©motion avec le plus de votes
                ferplus_label = np.argmax(votes)
                unified_label = ferplus_to_unified[ferplus_label]

                # Chemin de l'image
                img_name = f"fer{idx:08d}.png"
                img_path = os.path.join(split_dir, img_name)

                if os.path.exists(img_path):
                    self.images.append({
                        'path': img_path,
                        'is_grayscale': True
                    })
                    self.labels.append(unified_label)
                    self.sources.append('ferplus')

    def _print_class_distribution(self):
        """Affiche la distribution des classes."""
        if len(self.labels) == 0:
            return
        counts = self.get_class_distribution()
        print("\n  Distribution par classe:")
        max_count = max(counts) if len(counts) > 0 else 1
        for i, (cls, count) in enumerate(zip(self.UNIFIED_CLASSES, counts)):
            bar_len = int(30 * count / max_count) if max_count > 0 else 0
            bar = '‚ñà' * bar_len
            print(f"    {cls:10s}: {count:6d} {bar}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_info = self.images[idx]
        label = self.labels[idx]

        # Charger l'image
        image = Image.open(img_info['path'])

        # Convertir en RGB si n√©cessaire
        if img_info['is_grayscale'] or image.mode == 'L':
            image = image.convert('RGB')
        else:
            image = image.convert('RGB')

        # Redimensionner vers target_size
        if image.size != (self.target_size, self.target_size):
            image = image.resize((self.target_size, self.target_size), Image.BILINEAR)

        image = np.array(image)

        # Appliquer les transformations
        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0

        return image, label

    def get_class_distribution(self):
        if len(self.labels) == 0:
            return np.zeros(self.NUM_CLASSES, dtype=int)
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)

    def get_labels(self):
        return np.array(self.labels)

    def get_source_distribution(self):
        """Retourne le nombre d'images par source."""
        from collections import Counter
        return Counter(self.sources)


print("‚úÖ CombinedEmotionDataset d√©fini")

‚úÖ CombinedEmotionDataset d√©fini


## 7. üß† Architecture du Mod√®le CNN (avec SE Blocks)

In [16]:
# ‚úÖ NOUVEAU: Squeeze-and-Excitation Block pour am√©liorer l'attention sur les features
class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block - am√©liore la qualit√© des features."""
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ConvBlock(nn.Module):
    """Bloc convolutif avec BatchNorm, ReLU et SE Block optionnel."""
    def __init__(self, in_channels, out_channels, use_se=True, reduction=16):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.se = SEBlock(out_channels, reduction) if use_se else nn.Identity()

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.se(x)  # ‚úÖ Attention par SE Block
        x = self.pool(x)
        return x


class FaceEmotionCNN(nn.Module):
    """CNN am√©lior√© avec SE Blocks pour la reconnaissance d'√©motions."""
    def __init__(self, num_classes=8, in_channels=3, input_size=75):
        super(FaceEmotionCNN, self).__init__()

        # ‚úÖ Blocs avec SE attention
        self.block1 = ConvBlock(in_channels, 32, use_se=True, reduction=8)   # 75 -> 37
        self.block2 = ConvBlock(32, 64, use_se=True, reduction=8)            # 37 -> 18
        self.block3 = ConvBlock(64, 128, use_se=True, reduction=16)          # 18 -> 9
        self.block4 = ConvBlock(128, 256, use_se=True, reduction=16)         # 9 -> 4

        # Classifier avec Global Average Pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)  # ‚úÖ Rend le mod√®le ind√©pendant de la taille
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),  # ‚úÖ R√©duit pour √©viter l'underfitting
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


def create_model(dataset='affectnet', num_classes=8):
    if dataset == 'affectnet':
        return FaceEmotionCNN(num_classes=num_classes, in_channels=config.IN_CHANNELS, input_size=config.INPUT_SIZE)
    else:
        raise ValueError(f"Unknown dataset: {dataset}")

# Cr√©er et afficher le mod√®le
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES)
total_params = sum(p.numel() for p in model.parameters())
print(f"üß† Mod√®le cr√©√© avec SE Blocks: {total_params:,} param√®tres")

üß† Mod√®le cr√©√© avec SE Blocks: 1,321,384 param√®tres


## 8. üîß Utilitaires d'Entra√Ænement

In [17]:
class AverageMeter:
    """Suit les valeurs moyennes."""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate(model, val_loader, criterion, device, per_class=False, use_amp=False):
    """Validation avec m√©triques optionnelles par classe et support AMP."""
    model.eval()

    loss_meter = AverageMeter()
    correct = 0
    total = 0

    if per_class:
        class_correct = defaultdict(int)
        class_total = defaultdict(int)

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # ‚ö° Mixed Precision pour la validation aussi
            with torch.amp.autocast('cuda', enabled=use_amp):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            loss_meter.update(loss.item(), inputs.size(0))

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            if per_class:
                for pred, label in zip(predicted, labels):
                    class_total[label.item()] += 1
                    if pred == label:
                        class_correct[label.item()] += 1

    accuracy = 100.0 * correct / total

    if per_class:
        emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
        print("\n  üìä Pr√©cision par classe:")
        for i, emo in enumerate(emotions):
            if class_total[i] > 0:
                acc = 100.0 * class_correct[i] / class_total[i]
                print(f"    {emo:10s}: {acc:5.1f}% ({class_correct[i]}/{class_total[i]})")

    return loss_meter.avg, accuracy

print("‚úÖ Utilitaires d√©finis (avec support AMP)")

‚úÖ Utilitaires d√©finis (avec support AMP)


## 9. üìÇ Chargement des Donn√©es

### Configuration des Datasets √† Utiliser

Choisissez les datasets que vous voulez utiliser pour l'entra√Ænement. Le notebook t√©l√©chargera automatiquement les datasets s√©lectionn√©s via `kagglehub`.

**Datasets disponibles:**
- `affectnet`: Dataset principal (8 classes, ~50k images, RGB)
- `fer2013`: Dataset classique (7 classes, ~35k images, grayscale 48x48)
- `ferplus`: FER2013 avec labels corrig√©s par Microsoft (8 classes)

In [18]:
# Installation de kagglehub
!pip install -q kagglehub
print("‚úÖ kagglehub install√©.")

# ===============================================================================
# üéØ CONFIGURATION DES DATASETS √Ä UTILISER
# ===============================================================================
#
# ‚ö†Ô∏è IMPORTANT: FER2013 et FER+ utilisent les M√äMES images mais avec des labels diff√©rents!
#    - fer2013: Labels originaux (7 classes, plus bruit√©s)
#    - ferplus: Labels corrig√©s par Microsoft (8 classes, 10 annotateurs)
#
#    ‚Üí Ne pas utiliser les deux en m√™me temps (duplicatas)!
#    ‚Üí Pr√©f√©rer FER+ pour une meilleure qualit√©
#

DATASETS_TO_USE = [
    'affectnet',    # ‚úÖ Dataset principal (8 classes, ~50k images)
    # 'fer2013',    # ‚ùå Remplac√© par FER+ (m√™mes images, labels moins bons)
    'ferplus',      # ‚úÖ FER+ avec labels corrig√©s (8 classes, ~35k images)
]

# Mode: 'combined' pour fusionner tous les datasets, 'single' pour utiliser le premier uniquement
DATASET_MODE = 'combined' if len(DATASETS_TO_USE) > 1 else 'single'

print(f"\n{'='*60}")
print(f"üìã CONFIGURATION DES DATASETS")
print(f"{'='*60}")
print(f"  Datasets s√©lectionn√©s: {DATASETS_TO_USE}")
print(f"  Mode: {DATASET_MODE}")
if 'ferplus' in DATASETS_TO_USE:
    print(f"  ‚ÑπÔ∏è FER+ = FER2013 avec labels corrig√©s par Microsoft (meilleure qualit√©)")
if 'fer2013' in DATASETS_TO_USE and 'ferplus' in DATASETS_TO_USE:
    print(f"  ‚ö†Ô∏è ATTENTION: fer2013 et ferplus utilisent les m√™mes images!")
print(f"{'='*60}")

‚úÖ kagglehub install√©.

üìã CONFIGURATION DES DATASETS
  Datasets s√©lectionn√©s: ['affectnet', 'ferplus']
  Mode: combined
  ‚ÑπÔ∏è FER+ = FER2013 avec labels corrig√©s par Microsoft (meilleure qualit√©)


In [None]:
import kagglehub

# ===============================================================================
# üì• T√âL√âCHARGEMENT DES DATASETS
# ===============================================================================

# IDs Kaggle pour chaque dataset
KAGGLE_IDS = {
    'affectnet': 'dollyprajapati182/balanced-affectnet',
    'fer2013': 'msambare/fer2013',              # Version en dossiers (images directement)
    'fer2013_csv': 'deadskull7/fer2013',        # Version CSV originale (pour FER+)
}

dataset_paths = {}

print("üì• T√©l√©chargement des datasets...\n")

# ===============================================================================
# 1. T√©l√©charger AffectNet
# ===============================================================================
if 'affectnet' in DATASETS_TO_USE:
    print(f"üì¶ [1/3] AffectNet...")
    try:
        path = kagglehub.dataset_download(KAGGLE_IDS['affectnet'])
        dataset_paths['affectnet'] = str(path)
        print(f"  ‚úì T√©l√©charg√©: {path}")
    except Exception as e:
        print(f"  ‚úó Erreur: {e}")
        dataset_paths['affectnet'] = None

# ===============================================================================
# 2. T√©l√©charger FER2013 (version dossiers)
# ===============================================================================
if 'fer2013' in DATASETS_TO_USE:
    print(f"\nüì¶ [2/3] FER2013...")
    try:
        path = kagglehub.dataset_download(KAGGLE_IDS['fer2013'])
        dataset_paths['fer2013'] = str(path)
        print(f"  ‚úì T√©l√©charg√©: {path}")
    except Exception as e:
        print(f"  ‚úó Erreur: {e}")
        dataset_paths['fer2013'] = None

# ===============================================================================
# 3. Configurer FER+ (t√©l√©charge CSV + g√©n√®re images)
# ===============================================================================
if 'ferplus' in DATASETS_TO_USE:
    print(f"\nüì¶ [3/3] FER+ (FER2013 avec labels Microsoft)...")

    # FER+ n√©cessite le CSV original de FER2013
    print(f"  üì• T√©l√©chargement de fer2013.csv...")
    try:
        fer2013_csv_path = kagglehub.dataset_download(KAGGLE_IDS['fer2013_csv'])
        print(f"  ‚úì fer2013.csv t√©l√©charg√©: {fer2013_csv_path}")

        # Configurer FER+ (t√©l√©charge labels + g√©n√®re images)
        ferplus_path = setup_ferplus_dataset(fer2013_csv_path)

        if ferplus_path:
            dataset_paths['ferplus'] = ferplus_path
            print(f"  ‚úì FER+ configur√©: {ferplus_path}")
        else:
            print(f"  ‚ö†Ô∏è FER+ non configur√© (voir erreurs ci-dessus)")
            dataset_paths['ferplus'] = None

    except Exception as e:
        print(f"  ‚úó Erreur: {e}")
        dataset_paths['ferplus'] = None

# ===============================================================================
# Mise √† jour de la config
# ===============================================================================
valid_paths = {k: v for k, v in dataset_paths.items() if v is not None}

if 'affectnet' in valid_paths:
    config.DATASET_ROOT = valid_paths['affectnet']
elif valid_paths:
    config.DATASET_ROOT = list(valid_paths.values())[0]

print(f"\n{'='*60}")
print(f"‚úÖ DATASETS PR√äTS")
print(f"{'='*60}")
for name, path in dataset_paths.items():
    status = "‚úì" if path else "‚úó"
    print(f"  {status} {name}: {path if path else 'Non disponible'}")
print(f"{'='*60}")

üì• T√©l√©chargement des datasets...

üì¶ [1/3] AffectNet...
Downloading from https://www.kaggle.com/api/v1/datasets/download/dollyprajapati182/balanced-affectnet?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 273M/273M [00:07<00:00, 39.2MB/s]

Extracting files...





  ‚úì T√©l√©charg√©: /root/.cache/kagglehub/datasets/dollyprajapati182/balanced-affectnet/versions/1

üì¶ [3/3] FER+ (FER2013 avec labels Microsoft)...
  üì• T√©l√©chargement de fer2013.csv...
Downloading from https://www.kaggle.com/api/v1/datasets/download/deadskull7/fer2013?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 96.6M/96.6M [00:03<00:00, 33.8MB/s]

Extracting files...





  ‚úì fer2013.csv t√©l√©charg√©: /root/.cache/kagglehub/datasets/deadskull7/fer2013/versions/1

üîß Configuration de FER+ (premi√®re utilisation)...
  üì• T√©l√©chargement de fer2013new.csv depuis GitHub...
  ‚úì T√©l√©charg√©: /root/.cache/kagglehub/datasets/deadskull7/fer2013/versions/ferplus_generated/fer2013new.csv
  üñºÔ∏è G√©n√©ration des images depuis fer2013.csv...
    Progression: 0 images...


  img = Image.fromarray(img_array, mode='L')


    Progression: 5000 images...


In [None]:
# ===============================================================================
# üìÇ CHARGEMENT DES DONN√âES (MULTI-DATASET OU SINGLE)
# ===============================================================================

print("üìÇ Chargement des datasets...")

train_transform = get_train_transforms()
val_transform = get_val_transforms()

# Filtrer les paths valides
valid_dataset_paths = {k: v for k, v in dataset_paths.items() if v is not None}

# ===============================================================================
# MODE MULTI-DATASET (combin√©) ou SINGLE-DATASET
# ===============================================================================

if DATASET_MODE == 'combined' and len(valid_dataset_paths) > 1:
    print(f"\nüîÄ Mode MULTI-DATASET activ√©!")
    print(f"   Datasets: {list(valid_dataset_paths.keys())}")

    # Utiliser le dataset combin√©
    train_dataset = CombinedEmotionDataset(
        datasets_config=valid_dataset_paths,
        split='train',
        transform=train_transform,
        use_albumentations=HAS_ALBUMENTATIONS,
        target_size=config.INPUT_SIZE
    )

    val_dataset = CombinedEmotionDataset(
        datasets_config=valid_dataset_paths,
        split='val',
        transform=val_transform,
        use_albumentations=HAS_ALBUMENTATIONS,
        target_size=config.INPUT_SIZE
    )

    # Mettre √† jour la config avec les classes unifi√©es (8 classes)
    config.NUM_CLASSES = CombinedEmotionDataset.NUM_CLASSES

else:
    print(f"\nüìÅ Mode SINGLE-DATASET")

    # Utiliser le premier dataset disponible
    dataset_name = list(valid_dataset_paths.keys())[0]
    root_path = valid_dataset_paths[dataset_name]
    print(f"   Dataset: {dataset_name}")

    if dataset_name == 'affectnet':
        train_dataset = BalancedAffectNetDataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS
        )
        val_dataset = BalancedAffectNetDataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS
        )
        config.NUM_CLASSES = 8

    elif dataset_name == 'fer2013':
        train_dataset = FER2013Dataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        val_dataset = FER2013Dataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        config.NUM_CLASSES = 7  # FER2013 n'a pas Contempt

    elif dataset_name == 'ferplus':
        train_dataset = FERPlusDataset(
            root_dir=root_path,
            split='train',
            transform=train_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        val_dataset = FERPlusDataset(
            root_dir=root_path,
            split='val',
            transform=val_transform,
            use_albumentations=HAS_ALBUMENTATIONS,
            target_size=config.INPUT_SIZE
        )
        config.NUM_CLASSES = 8

# ===============================================================================
# CALCUL DES POIDS DE CLASSES (adaptatif)
# ===============================================================================

def get_class_weights_adaptive(dataset, max_weight=5.0):
    """Calcule les poids pour √©quilibrer les classes (compatible tous datasets)."""
    counts = dataset.get_class_distribution()
    num_classes = len(counts)
    counts = np.maximum(counts, 1)

    weights = 1.0 / counts
    weights = weights / weights.sum() * num_classes
    weights = np.clip(weights, 0.3, max_weight)
    weights = weights / weights.sum() * num_classes

    # R√©cup√©rer les noms d'√©motions selon le type de dataset
    if hasattr(dataset, 'IDX_TO_EMOTION'):
        idx_to_emotion = dataset.IDX_TO_EMOTION
    elif hasattr(dataset, 'UNIFIED_CLASSES'):
        idx_to_emotion = {i: c for i, c in enumerate(dataset.UNIFIED_CLASSES)}
    else:
        idx_to_emotion = {i: f"Class_{i}" for i in range(num_classes)}

    print(f"\nüìä Poids des classes ({num_classes} classes):")
    for i, (count, weight) in enumerate(zip(counts, weights)):
        emotion = idx_to_emotion.get(i, f"Class_{i}")
        print(f"    {emotion:10s}: {count:6d} samples, poids: {weight:.3f}")

    return torch.FloatTensor(weights)

class_weights = get_class_weights_adaptive(train_dataset, max_weight=config.MAX_CLASS_WEIGHT).to(config.DEVICE)

# ===============================================================================
# DATALOADERS OPTIMIS√âS
# ===============================================================================

train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE * 2,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True,
    prefetch_factor=config.PREFETCH_FACTOR if config.NUM_WORKERS > 0 else None,
    persistent_workers=config.PERSISTENT_WORKERS if config.NUM_WORKERS > 0 else False
)

print(f"\n{'='*60}")
print("‚úÖ DONN√âES CHARG√âES")
print(f"{'='*60}")
print(f"  Train: {len(train_dataset):,} samples ({len(train_loader)} batches)")
print(f"  Val:   {len(val_dataset):,} samples ({len(val_loader)} batches)")
print(f"  Classes: {config.NUM_CLASSES}")
print(f"  Input size: {config.INPUT_SIZE}x{config.INPUT_SIZE}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  ‚ö° Workers: {config.NUM_WORKERS}, Prefetch: {config.PREFETCH_FACTOR}")
print(f"{'='*60}")

## 10. üëÄ Visualisation d'√âchantillons

In [None]:
# Visualiser quelques images du dataset (compatible multi-dataset)
def show_samples(dataset, n_samples=8):
    """Affiche des √©chantillons du dataset (compatible avec tous les datasets)."""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()

    # D√©normalisation ImageNet
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    # R√©cup√©rer le mapping idx -> emotion selon le type de dataset
    if hasattr(dataset, 'IDX_TO_EMOTION'):
        idx_to_emotion = dataset.IDX_TO_EMOTION
    elif hasattr(dataset, 'UNIFIED_CLASSES'):
        idx_to_emotion = {i: c for i, c in enumerate(dataset.UNIFIED_CLASSES)}
    else:
        idx_to_emotion = {i: f"Class_{i}" for i in range(config.NUM_CLASSES)}

    indices = random.sample(range(len(dataset)), min(n_samples, len(dataset)))

    for i, idx in enumerate(indices):
        img, label = dataset[idx]

        # Convertir tensor en numpy et d√©normaliser
        if isinstance(img, torch.Tensor):
            img_np = img.numpy().transpose(1, 2, 0)
        else:
            img_np = img.transpose(1, 2, 0) if img.shape[0] == 3 else img

        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)

        emotion = idx_to_emotion.get(label, f"Class_{label}")

        # Afficher la source si disponible (multi-dataset)
        if hasattr(dataset, 'sources') and idx < len(dataset.sources):
            source = dataset.sources[idx]
            title = f"{emotion}\n({source})"
        else:
            title = emotion

        axes[i].imshow(img_np)
        axes[i].set_title(title, fontsize=10)
        axes[i].axis('off')

    # Titre selon le mode
    if DATASET_MODE == 'combined' and len(valid_dataset_paths) > 1:
        dataset_type = f"Multi-Dataset ({', '.join(valid_dataset_paths.keys())})"
    else:
        dataset_type = list(valid_dataset_paths.keys())[0] if valid_dataset_paths else "Unknown"

    plt.suptitle(f'√âchantillons - {dataset_type}', fontsize=14)
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)

## 11. üöÄ Configuration de l'Entra√Ænement (avec SWA)

In [None]:
# ===============================================================================
# üöÄ CONFIGURATION DE L'ENTRA√éNEMENT
# ===============================================================================

# Mod√®le
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES).to(config.DEVICE)

# Compilation du mod√®le (PyTorch 2.0+)
if config.USE_COMPILE and hasattr(torch, 'compile'):
    try:
        # ‚ö° Mode max-autotune: plus lent au d√©but mais plus rapide apr√®s
        model = torch.compile(model, mode=config.COMPILE_MODE)
        print(f"‚ö° Mod√®le compil√© avec torch.compile(mode='{config.COMPILE_MODE}')")
        print("   Note: Les premi√®res √©poques seront plus lentes (compilation)")
    except Exception as e:
        print(f"‚ö†Ô∏è torch.compile non disponible: {e}")

# Fonction de perte
if config.USE_FOCAL_LOSS:
    criterion = FocalLoss(
        gamma=config.FOCAL_GAMMA,
        alpha=class_weights,
        label_smoothing=config.LABEL_SMOOTHING if config.USE_LABEL_SMOOTHING else 0.0
    )
    print(f"‚úì Focal Loss (gamma={config.FOCAL_GAMMA})")
elif config.USE_LABEL_SMOOTHING:
    criterion = LabelSmoothingCrossEntropy(smoothing=config.LABEL_SMOOTHING)
    print(f"‚úì Label Smoothing (smoothing={config.LABEL_SMOOTHING})")
else:
    criterion = nn.CrossEntropyLoss(weight=class_weights)

val_criterion = nn.CrossEntropyLoss()

# Optimiseur
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Scheduler OneCycleLR
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE * 10,
    epochs=config.EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos'
)

# GradScaler pour Mixed Precision
scaler = torch.amp.GradScaler('cuda', enabled=config.USE_AMP)

# Affichage de la configuration
print(f"\n{'='*60}")
print("üìã Configuration d'entra√Ænement:")
print(f"{'='*60}")
print(f"  Dataset: Balanced AffectNet (75x75 RGB, 8 classes)")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE} -> {config.LEARNING_RATE * 10}")
print(f"  Epochs: {config.EPOCHS}, Patience: {config.PATIENCE}")
print(f"  Mixup: {config.USE_MIXUP} (alpha={config.MIXUP_ALPHA})")
print(f"  ‚ö° Mixed Precision (AMP): {config.USE_AMP}")
print(f"  ‚ö° torch.compile: {config.COMPILE_MODE}")
print(f"{'='*60}")

## 12. üèãÔ∏è Boucle d'Entra√Ænement

In [None]:
# ===============================================================================
# üèãÔ∏è BOUCLE D'ENTRA√éNEMENT UNIFI√âE
# ===============================================================================

import gc

# Variables de suivi
best_val_acc = 0.0
best_val_loss = float('inf')
patience_counter = 0
best_epoch = 0

# Historique pour les graphiques
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'lr': [], 'epoch_time': [], 'gpu_memory': []
}

# SWA Setup (optionnel)
swa_model = None
swa_scheduler = None
if config.USE_SWA and not config.USE_COMPILE:
    from torch.optim.swa_utils import AveragedModel, SWALR
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=config.SWA_LR)
    print(f"‚úÖ SWA activ√© (d√©marre √† l'√©poque {config.SWA_START_EPOCH})")
elif config.USE_SWA and config.USE_COMPILE:
    print("‚ö†Ô∏è SWA d√©sactiv√© car torch.compile est activ√© (incompatible)")

start_time = time.time()

print("\n" + "=" * 70)
print("üöÄ D√âMARRAGE DE L'ENTRA√éNEMENT")
print("=" * 70)
print(f"Mixed Precision: {config.USE_AMP}")
print(f"Batch size: {config.BATCH_SIZE}")
print(f"Workers: {config.NUM_WORKERS}")
print(f"Epochs: {config.EPOCHS}, Patience: {config.PATIENCE}")
print("=" * 70 + "\n")

for epoch in range(config.EPOCHS):
    epoch_start = time.time()
    model.train()

    loss_meter = AverageMeter()
    correct = 0
    total = 0

    optimizer.zero_grad(set_to_none=True)

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(config.DEVICE, non_blocking=True), labels.to(config.DEVICE, non_blocking=True)

        # Mixup uniquement (CutMix d√©sactiv√© car baisse les performances)
        use_mixup = config.USE_MIXUP and random.random() > 0.5

        if use_mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, config.MIXUP_ALPHA)

        # Mixed Precision Forward Pass
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            outputs = model(inputs)

            if use_mixup:
                loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            else:
                loss = criterion(outputs, labels)

            loss = loss / config.ACCUMULATION_STEPS

        # Backward avec GradScaler
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            # Scheduler step (pas pendant SWA)
            if swa_model is None or epoch < config.SWA_START_EPOCH:
                scheduler.step()

            optimizer.zero_grad(set_to_none=True)

        # M√©triques
        loss_meter.update(loss.item() * config.ACCUMULATION_STEPS, inputs.size(0))

        if not use_mixup:
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    train_acc = 100.0 * correct / max(total, 1)

    # Mise √† jour SWA apr√®s SWA_START_EPOCH
    if swa_model is not None and epoch >= config.SWA_START_EPOCH:
        swa_model.update_parameters(model)
        swa_scheduler.step()

    # Validation
    val_loss, val_acc = validate(model, val_loader, val_criterion, config.DEVICE,
                                 per_class=(epoch % 10 == 0), use_amp=config.USE_AMP)

    current_lr = optimizer.param_groups[0]['lr']
    epoch_time = time.time() - epoch_start
    elapsed = time.time() - start_time

    # Suivi m√©moire GPU
    if torch.cuda.is_available():
        gpu_mem = torch.cuda.memory_allocated() / 1024**3
        gpu_mem_max = torch.cuda.max_memory_allocated() / 1024**3
    else:
        gpu_mem = 0
        gpu_mem_max = 0

    # Sauvegarder historique
    history['train_loss'].append(loss_meter.avg)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    history['epoch_time'].append(epoch_time)
    history['gpu_memory'].append(gpu_mem_max)

    swa_status = " [SWA]" if swa_model is not None and epoch >= config.SWA_START_EPOCH else ""
    print(f"Epoch {epoch+1:3d}/{config.EPOCHS} | "
          f"Loss: {loss_meter.avg:.4f} | Acc: {train_acc:.1f}% | "
          f"Val: {val_acc:.1f}% | LR: {current_lr:.6f} | "
          f"Time: {epoch_time:.1f}s | GPU: {gpu_mem:.1f}/{gpu_mem_max:.1f}GB{swa_status}")

    # Sauvegarder le meilleur mod√®le
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0

        # R√©cup√©rer les poids du mod√®le (g√©rer torch.compile)
        model_to_save = model._orig_mod if hasattr(model, '_orig_mod') else model

        torch.save({
            'epoch': epoch,
            'model_state_dict': model_to_save.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'history': history,
            'config': {
                'num_classes': config.NUM_CLASSES,
                'in_channels': config.IN_CHANNELS,
                'input_size': config.INPUT_SIZE,
                'dataset': 'affectnet',
            }
        }, config.SAVE_PATH)
        print(f"  ‚úÖ [BEST] Nouveau meilleur mod√®le! (Val Acc: {val_acc:.2f}%)")
    else:
        patience_counter += 1
        if patience_counter >= config.PATIENCE:
            print(f"\n‚èπÔ∏è Early stopping apr√®s {epoch+1} √©poques!")
            break

# Nettoyage m√©moire
torch.cuda.empty_cache()
gc.collect()

elapsed = time.time() - start_time
avg_epoch_time = np.mean(history['epoch_time'])

print(f"\n{'='*70}")
print("‚úÖ ENTRA√éNEMENT TERMIN√â!")
print(f"{'='*70}")
print(f"Temps total: {elapsed/60:.1f} minutes")
print(f"Temps moyen par epoch: {avg_epoch_time:.1f} secondes")
print(f"Meilleure √©poque: {best_epoch}")
print(f"Meilleure pr√©cision validation: {best_val_acc:.2f}%")
print(f"Meilleure loss validation: {best_val_loss:.4f}")
print(f"M√©moire GPU max: {max(history['gpu_memory']):.2f} GB")
print(f"Mod√®le sauv√©: {config.SAVE_PATH}")
print(f"{'='*70}")

## 13. üìà Visualisation des R√©sultats

In [None]:
# ===============================================================================
# üìà VISUALISATION DES R√âSULTATS D'ENTRA√éNEMENT
# ===============================================================================

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# 1. Loss
axes[0, 0].plot(history['train_loss'], label='Train', color='blue')
axes[0, 0].plot(history['val_loss'], label='Validation', color='orange')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('üìâ Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Accuracy
axes[0, 1].plot(history['train_acc'], label='Train', color='blue')
axes[0, 1].plot(history['val_acc'], label='Validation', color='orange')
axes[0, 1].axhline(y=best_val_acc, color='green', linestyle='--', label=f'Best: {best_val_acc:.1f}%')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('üìä Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Learning Rate
axes[0, 2].plot(history['lr'], color='green')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Learning Rate')
axes[0, 2].set_title('üìà Learning Rate (OneCycleLR)')
axes[0, 2].grid(True, alpha=0.3)

# 4. Temps par epoch
axes[1, 0].plot(history['epoch_time'], color='purple')
axes[1, 0].axhline(y=np.mean(history['epoch_time']), color='red', linestyle='--',
                   label=f'Moyenne: {np.mean(history["epoch_time"]):.1f}s')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Temps (s)')
axes[1, 0].set_title('‚è±Ô∏è Temps par Epoch')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 5. M√©moire GPU
axes[1, 1].plot(history['gpu_memory'], color='red')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('M√©moire (GB)')
axes[1, 1].set_title('üéÆ M√©moire GPU Max')
axes[1, 1].grid(True, alpha=0.3)

# 6. R√©sum√©
axes[1, 2].axis('off')
summary_text = f"""
üìã R√âSUM√â DE L'ENTRA√éNEMENT

Meilleure accuracy: {best_val_acc:.2f}%
Meilleure √©poque: {best_epoch}

‚öôÔ∏è Configuration:
‚Ä¢ Batch Size: {config.BATCH_SIZE}
‚Ä¢ Epochs: {len(history['train_loss'])}
‚Ä¢ Mixed Precision: {config.USE_AMP}
‚Ä¢ torch.compile: {config.USE_COMPILE}
‚Ä¢ Mixup: {config.USE_MIXUP} (Œ±={config.MIXUP_ALPHA})
‚Ä¢ CutMix: {config.USE_CUTMIX}
‚Ä¢ SE Blocks: {config.USE_SE_BLOCKS}

‚è±Ô∏è Performance:
‚Ä¢ Temps moyen/epoch: {np.mean(history['epoch_time']):.1f}s
‚Ä¢ M√©moire GPU max: {max(history['gpu_memory']):.2f} GB
"""
axes[1, 2].text(0.05, 0.5, summary_text, fontsize=10, verticalalignment='center',
                fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))

plt.suptitle('üìà M√©triques d\'Entra√Ænement', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

print("\nüìä Graphiques sauvegard√©s dans 'training_curves.png'")

## 14. üîç √âvaluation Finale (avec TTA)

In [None]:
# ===============================================================================
# üîç √âVALUATION FINALE AVEC TTA (Test-Time Augmentation)
# ===============================================================================
# ‚ö†Ô∏è Si OOM: Red√©marrer le kernel (Runtime > Restart) avant de lancer cette cellule
# Le cache CUDA Graphs de torch.compile ne peut pas √™tre lib√©r√© autrement.

import gc

# ‚ö° NETTOYAGE M√âMOIRE GPU AGRESSIF
print("üßπ Nettoyage m√©moire GPU...")

# Supprimer tous les mod√®les et tenseurs possibles
for var_name in ['model', 'swa_model', 'optimizer', 'scheduler', 'scaler', 'criterion']:
    if var_name in dir():
        try:
            exec(f'del {var_name}')
        except:
            pass

# Forcer le nettoyage
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()

# Afficher m√©moire disponible
if torch.cuda.is_available():
    gpu_free = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
    gpu_reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"   M√©moire GPU libre: {gpu_free / 1024**3:.2f} GB")
    print(f"   M√©moire r√©serv√©e PyTorch: {gpu_reserved:.2f} GB")
    if gpu_free < 2 * 1024**3:  # Moins de 2GB libre
        print("   ‚ö†Ô∏è Peu de m√©moire libre - utilisation de petits batches")


def validate_with_tta(model, val_loader, criterion, device, n_augmentations=5, use_amp=False):
    """Validation avec Test-Time Augmentation - moyenne sur plusieurs augmentations."""
    model.eval()

    correct = 0
    total = 0
    loss_sum = 0

    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            batch_size = inputs.size(0)

            # Collecter les pr√©dictions de plusieurs augmentations
            all_outputs = []

            with torch.amp.autocast('cuda', enabled=use_amp):
                # 1. Original
                all_outputs.append(model(inputs))

                # 2. Flip horizontal
                all_outputs.append(model(torch.flip(inputs, dims=[3])))

                # 3-5. L√©g√®res variations de luminosit√©
                if n_augmentations >= 3:
                    all_outputs.append(model(inputs * 0.95))
                if n_augmentations >= 4:
                    all_outputs.append(model(inputs * 1.05))
                if n_augmentations >= 5:
                    all_outputs.append(model(torch.flip(inputs, dims=[3]) * 0.98))

            # Moyenne des pr√©dictions (soft voting)
            avg_outputs = torch.stack(all_outputs).mean(dim=0)

            loss = criterion(avg_outputs, labels)
            loss_sum += loss.item() * batch_size

            _, predicted = avg_outputs.max(1)
            total += batch_size
            correct += predicted.eq(labels).sum().item()

            for pred, label in zip(predicted, labels):
                class_total[label.item()] += 1
                if pred == label:
                    class_correct[label.item()] += 1

            # Lib√©rer la m√©moire des outputs interm√©diaires
            del all_outputs, avg_outputs, inputs, labels
            torch.cuda.empty_cache()

    accuracy = 100.0 * correct / total
    avg_loss = loss_sum / total

    emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
    print(f"\n  üìä Pr√©cision par classe (TTA x{n_augmentations}):")
    for i, emo in enumerate(emotions):
        if class_total[i] > 0:
            acc = 100.0 * class_correct[i] / class_total[i]
            print(f"    {emo:10s}: {acc:5.1f}% ({class_correct[i]}/{class_total[i]})")

    return avg_loss, accuracy


# Charger le meilleur mod√®le
print("\nüì• Chargement du meilleur mod√®le...")
checkpoint = torch.load(config.SAVE_PATH, weights_only=False)

# Cr√©er un nouveau mod√®le (sans compilation) pour charger les poids
eval_model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES).to(config.DEVICE)
eval_model.load_state_dict(checkpoint['model_state_dict'])

# ‚ö° BATCH SIZE R√âDUIT pour √©viter OOM (le cache CUDA Graphs prend ~13GB)
EVAL_BATCH_SIZE = 256  # Beaucoup plus petit pour laisser de la place
print(f"   ‚ö° Batch size √©valuation r√©duit: {EVAL_BATCH_SIZE} (au lieu de {config.BATCH_SIZE})")

eval_loader = DataLoader(
    val_dataset,
    batch_size=EVAL_BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

# Recr√©er le crit√®re de validation
val_criterion = nn.CrossEntropyLoss()

print(f"\n{'='*60}")
print(f"üìä √âVALUATION STANDARD")
print(f"{'='*60}")
val_loss, val_acc = validate(eval_model, eval_loader, val_criterion, config.DEVICE, per_class=True, use_amp=config.USE_AMP)
print(f"\nüéØ R√©sultats standards:")
print(f"   - Pr√©cision globale: {val_acc:.2f}%")
print(f"   - Loss: {val_loss:.4f}")

# Nettoyage avant TTA (qui utilise plus de m√©moire)
torch.cuda.empty_cache()

print(f"\n{'='*60}")
print(f"üìä √âVALUATION AVEC TTA (Test-Time Augmentation)")
print(f"{'='*60}")
tta_loss, tta_acc = validate_with_tta(eval_model, eval_loader, val_criterion, config.DEVICE,
                                       n_augmentations=5, use_amp=config.USE_AMP)
print(f"\nüéØ R√©sultats avec TTA:")
print(f"   - Pr√©cision globale: {tta_acc:.2f}%")
print(f"   - Loss: {tta_loss:.4f}")
print(f"   - Am√©lioration TTA: {tta_acc - val_acc:+.2f}%")

## 15. üíæ Sauvegarde du Mod√®le Final

In [None]:
# Sauvegarder le mod√®le final (poids uniquement) - version l√©g√®re pour d√©ploiement
torch.save({
    'model_state_dict': eval_model.state_dict(),  # Utilise eval_model (le mod√®le charg√©)
    'num_classes': config.NUM_CLASSES,
    'in_channels': config.IN_CHANNELS,
    'input_size': config.INPUT_SIZE,
    'dataset': 'affectnet',
    'best_val_acc': checkpoint['val_acc'],  # Utilise la valeur du checkpoint
}, 'emotion_model.pth')

print("‚úÖ Mod√®le sauvegard√© dans 'emotion_model.pth'")
print(f"   Taille: {os.path.getsize('emotion_model.pth') / 1024 / 1024:.2f} MB")
print(f"   Best Val Acc: {checkpoint['val_acc']:.2f}%")

## 16. üß™ Test sur Quelques Images

In [None]:
def predict_emotion(model, image_tensor, device):
    """Pr√©dit l'√©motion pour une image."""
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        with torch.amp.autocast('cuda', enabled=config.USE_AMP):
            outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        pred_idx = outputs.argmax(1).item()
        confidence = probs[0, pred_idx].item()
    return pred_idx, confidence, probs[0].cpu().numpy()

# Test sur quelques images de validation
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.flatten()

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']

indices = random.sample(range(len(val_dataset)), 8)

for i, idx in enumerate(indices):
    img, true_label = val_dataset[idx]
    pred_idx, confidence, probs = predict_emotion(eval_model, img, config.DEVICE)  # ‚ö° Utilise eval_model

    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * std + mean
    img_np = np.clip(img_np, 0, 1)

    true_emotion = emotions[true_label]
    pred_emotion = emotions[pred_idx]

    color = 'green' if pred_idx == true_label else 'red'

    axes[i].imshow(img_np)
    axes[i].set_title(f"Vrai: {true_emotion}\nPr√©d: {pred_emotion} ({confidence*100:.1f}%)",
                      color=color, fontsize=10)
    axes[i].axis('off')

plt.suptitle('üîç Pr√©dictions sur le Set de Validation', fontsize=14)
plt.tight_layout()
plt.show()