## 1. üì¶ Imports et V√©rification GPU

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import os
import time
from collections import defaultdict
import random
from PIL import Image
import matplotlib.pyplot as plt

# V√©rification GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Import albumentations pour l'augmentation avanc√©e
try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    HAS_ALBUMENTATIONS = True
    print("‚úÖ Albumentations disponible pour l'augmentation avanc√©e")
except ImportError:
    HAS_ALBUMENTATIONS = False
    print("‚ö†Ô∏è Installez albumentations: pip install albumentations")

## 2. ‚öôÔ∏è Configuration des Hyperparam√®tres

Tous les param√®tres d'entra√Ænement sont centralis√©s ici pour faciliter l'exp√©rimentation.

In [None]:
class Config:
    # === DONN√âES ===
    DATASET_ROOT = './data'
    
    # === MOD√àLE ===
    NUM_CLASSES = 8       # 8 √©motions AffectNet
    IN_CHANNELS = 3       # RGB
    INPUT_SIZE = 75       # 75x75 pixels
    
    # === ENTRA√éNEMENT ===
    BATCH_SIZE = 64       # Ajuster selon la VRAM disponible
    ACCUMULATION_STEPS = 1
    LEARNING_RATE = 0.0005
    WEIGHT_DECAY = 1e-4
    EPOCHS = 80
    PATIENCE = 15         # Early stopping
    
    # === TECHNIQUES AVANC√âES ===
    USE_MIXUP = True
    MIXUP_ALPHA = 0.4
    USE_CUTMIX = False
    CUTMIX_ALPHA = 1.0
    CUTMIX_PROB = 0.5
    
    USE_LABEL_SMOOTHING = True
    LABEL_SMOOTHING = 0.1
    
    USE_FOCAL_LOSS = False
    FOCAL_GAMMA = 2.0
    
    # === AUGMENTATION ===
    USE_ADVANCED_AUG = True
    
    # === √âQUILIBRAGE DES CLASSES ===
    USE_OVERSAMPLING = False
    MAX_CLASS_WEIGHT = 3.0
    
    # === DEVICE ===
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()

print(f"üñ•Ô∏è Device: {config.DEVICE}")
print(f"üìä Batch size: {config.BATCH_SIZE}")
print(f"üìà Learning rate: {config.LEARNING_RATE}")
print(f"üîÑ Epochs: {config.EPOCHS}")

## 3. üìâ Fonctions de Perte (Loss Functions)

### Focal Loss
Utile pour les datasets d√©s√©quilibr√©s - r√©duit l'importance des exemples faciles.

### Label Smoothing Cross Entropy
Emp√™che le mod√®le d'√™tre trop confiant sur les pr√©dictions.

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss pour g√©rer le d√©s√©quilibre de classes."""
    def __init__(self, gamma=2.0, alpha=None, reduction='mean', label_smoothing=0.0):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        
    def forward(self, inputs, targets):
        if self.label_smoothing > 0:
            n_classes = inputs.size(-1)
            targets_smooth = torch.zeros_like(inputs)
            targets_smooth.fill_(self.label_smoothing / (n_classes - 1))
            targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.label_smoothing)
            
            log_probs = F.log_softmax(inputs, dim=-1)
            ce_loss = -(targets_smooth * log_probs).sum(dim=-1)
        else:
            ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        
        probs = torch.softmax(inputs, dim=-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        focal_weight = (1 - pt) ** self.gamma
        
        if self.alpha is not None:
            alpha_t = self.alpha.gather(0, targets)
            focal_weight = focal_weight * alpha_t
        
        loss = focal_weight * ce_loss
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        return loss


class LabelSmoothingCrossEntropy(nn.Module):
    """Cross Entropy avec label smoothing."""
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        
    def forward(self, inputs, targets):
        n_classes = inputs.size(-1)
        log_probs = F.log_softmax(inputs, dim=-1)
        
        targets_smooth = torch.zeros_like(log_probs)
        targets_smooth.fill_(self.smoothing / (n_classes - 1))
        targets_smooth.scatter_(1, targets.unsqueeze(1), 1.0 - self.smoothing)
        
        loss = -(targets_smooth * log_probs).sum(dim=-1)
        return loss.mean()

print("‚úÖ Fonctions de perte d√©finies")

## 4. üîÄ Mixup & CutMix

Techniques d'augmentation qui m√©langent des images pour am√©liorer la g√©n√©ralisation.

In [None]:
def mixup_data(x, y, alpha=0.2):
    """Mixup: m√©lange deux √©chantillons."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    
    return mixed_x, y_a, y_b, lam


def cutmix_data(x, y, alpha=1.0):
    """CutMix: coupe et colle des patches entre √©chantillons."""
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)
    
    _, _, H, W = x.shape
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
    
    return x, y, y[index], lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Calcule la loss mix√©e pour mixup/cutmix."""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

print("‚úÖ Fonctions Mixup et CutMix d√©finies")

## 5. üñºÔ∏è Transformations et Augmentation de Donn√©es

Utilise Albumentations pour des augmentations avanc√©es (rotation, bruit, flou, etc.)

In [None]:
def get_train_transforms():
    """Transformations pour l'entra√Ænement (images RGB 75x75)."""
    if HAS_ALBUMENTATIONS and config.USE_ADVANCED_AUG:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Affine(
                translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)},
                scale=(0.9, 1.1),
                rotate=(-10, 10),
                p=0.5
            ),
            A.OneOf([
                A.GaussNoise(std_range=(0.02, 0.1), p=1),
                A.GaussianBlur(blur_limit=(3, 5), p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ], p=0.3),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),
                A.RandomGamma(gamma_limit=(80, 120), p=1),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=20, p=1),
            ], p=0.5),
            A.CoarseDropout(
                num_holes_range=(1, 4),
                hole_height_range=(6, 12),
                hole_width_range=(6, 12),
                fill=0,
                p=0.3
            ),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        # Fallback vers torchvision
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05), scale=(0.9, 1.1)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


def get_val_transforms():
    """Transformations pour la validation (juste normalisation)."""
    if HAS_ALBUMENTATIONS:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

print("‚úÖ Transformations d√©finies")

## 6. üìÅ Dataset AffectNet

In [None]:
from torch.utils.data import Dataset, WeightedRandomSampler

class BalancedAffectNetDataset(Dataset):
    """
    Dataset pour Balanced AffectNet.
    
    Structure attendue:
    data/
        train/Anger/, Contempt/, Disgust/, Fear/, Happy/, Neutral/, Sad/, Surprise/
        val/...
        test/...
    """
    
    NUM_CLASSES = 8
    
    EMOTION_CLASSES = {
        'Anger': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3,
        'Sad': 4, 'Surprise': 5, 'Neutral': 6, 'Contempt': 7,
    }
    
    IDX_TO_EMOTION = {v: k for k, v in EMOTION_CLASSES.items()}
    
    def __init__(self, root_dir='./data', split='train', transform=None, use_albumentations=False):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.use_albumentations = use_albumentations
        
        self.images = []
        self.labels = []
        
        split_dir = os.path.join(root_dir, split)
        
        if not os.path.exists(split_dir):
            raise FileNotFoundError(
                f"Dataset non trouv√©: {split_dir}\n"
                f"T√©l√©chargez depuis: https://www.kaggle.com/datasets/dollyprajapati182/balanced-affectnet"
            )
        
        # Charger toutes les images
        for emotion_name, emotion_idx in self.EMOTION_CLASSES.items():
            emotion_dir = os.path.join(split_dir, emotion_name)
            if not os.path.exists(emotion_dir):
                print(f"‚ö†Ô∏è {emotion_dir} non trouv√©, ignor√©...")
                continue
            
            for img_name in os.listdir(emotion_dir):
                if img_name.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
                    self.images.append(os.path.join(emotion_dir, img_name))
                    self.labels.append(emotion_idx)
        
        print(f"üìÇ Charg√© {len(self.images)} images depuis AffectNet {split}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            if self.use_albumentations:
                augmented = self.transform(image=image)
                image = augmented['image']
            else:
                image = self.transform(image)
        else:
            image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
        
        return image, label
    
    def get_class_distribution(self):
        return np.bincount(self.labels, minlength=self.NUM_CLASSES)
    
    def get_labels(self):
        return np.array(self.labels)


def get_class_weights(dataset, max_weight=5.0):
    """Calcule les poids pour √©quilibrer les classes."""
    counts = dataset.get_class_distribution()
    counts = np.maximum(counts, 1)
    
    weights = 1.0 / counts
    weights = weights / weights.sum() * len(weights)
    weights = np.clip(weights, 0.3, max_weight)
    weights = weights / weights.sum() * len(weights)
    
    print("\nüìä Poids des classes:")
    for i, (count, weight) in enumerate(zip(counts, weights)):
        emotion = BalancedAffectNetDataset.IDX_TO_EMOTION.get(i, f"Class_{i}")
        print(f"    {emotion:10s}: {count:5d} samples, poids: {weight:.3f}")
    
    return torch.FloatTensor(weights)


def get_balanced_sampler(dataset):
    """Cr√©e un sampler √©quilibr√© pour l'entra√Ænement."""
    labels = dataset.get_labels()
    counts = np.bincount(labels, minlength=BalancedAffectNetDataset.NUM_CLASSES)
    counts = np.maximum(counts, 1)
    
    weights = 1.0 / counts
    sample_weights = weights[labels]
    
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

print("‚úÖ Classes Dataset d√©finies")

## 7. üß† Architecture du Mod√®le CNN

In [None]:
# Import du mod√®le depuis le fichier existant
from model import FaceEmotionCNN, create_model

# Cr√©er et afficher le mod√®le
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES)
total_params = sum(p.numel() for p in model.parameters())
print(f"üß† Mod√®le cr√©√© avec {total_params:,} param√®tres")

## 8. üîß Utilitaires d'Entra√Ænement

In [None]:
class AverageMeter:
    """Suit les valeurs moyennes."""
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate(model, val_loader, criterion, device, per_class=False):
    """Validation avec m√©triques optionnelles par classe."""
    model.eval()
    
    loss_meter = AverageMeter()
    correct = 0
    total = 0
    
    if per_class:
        class_correct = defaultdict(int)
        class_total = defaultdict(int)
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss_meter.update(loss.item(), inputs.size(0))
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            if per_class:
                for pred, label in zip(predicted, labels):
                    class_total[label.item()] += 1
                    if pred == label:
                        class_correct[label.item()] += 1
    
    accuracy = 100.0 * correct / total
    
    if per_class:
        emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']
        print("\n  üìä Pr√©cision par classe:")
        for i, emo in enumerate(emotions):
            if class_total[i] > 0:
                acc = 100.0 * class_correct[i] / class_total[i]
                print(f"    {emo:10s}: {acc:5.1f}% ({class_correct[i]}/{class_total[i]})")
    
    return loss_meter.avg, accuracy

print("‚úÖ Utilitaires d√©finis")

## 9. üìÇ Chargement des Donn√©es

In [None]:
print("üìÇ Chargement du dataset Balanced AffectNet...\n")

train_transform = get_train_transforms()
val_transform = get_val_transforms()

train_dataset = BalancedAffectNetDataset(
    root_dir=config.DATASET_ROOT,
    split='train',
    transform=train_transform,
    use_albumentations=HAS_ALBUMENTATIONS
)

val_dataset = BalancedAffectNetDataset(
    root_dir=config.DATASET_ROOT,
    split='val',
    transform=val_transform,
    use_albumentations=HAS_ALBUMENTATIONS
)

# Poids des classes
class_weights = get_class_weights(train_dataset, max_weight=config.MAX_CLASS_WEIGHT).to(config.DEVICE)

# DataLoaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

print(f"\n‚úÖ Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
print(f"   Batches - Train: {len(train_loader)}, Val: {len(val_loader)}")

## 10. üëÄ Visualisation d'√âchantillons

In [None]:
# Visualiser quelques images du dataset
def show_samples(dataset, n_samples=8):
    """Affiche des √©chantillons du dataset."""
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.flatten()
    
    # D√©normalisation
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    indices = random.sample(range(len(dataset)), n_samples)
    
    for i, idx in enumerate(indices):
        img, label = dataset[idx]
        
        # Convertir tensor en numpy et d√©normaliser
        img_np = img.numpy().transpose(1, 2, 0)
        img_np = img_np * std + mean
        img_np = np.clip(img_np, 0, 1)
        
        emotion = BalancedAffectNetDataset.IDX_TO_EMOTION[label]
        
        axes[i].imshow(img_np)
        axes[i].set_title(emotion, fontsize=12)
        axes[i].axis('off')
    
    plt.suptitle('√âchantillons du Dataset AffectNet', fontsize=14)
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)

## 11. üöÄ Configuration de l'Entra√Ænement

In [None]:
# Mod√®le
model = create_model(dataset='affectnet', num_classes=config.NUM_CLASSES).to(config.DEVICE)

# Fonction de perte
if config.USE_FOCAL_LOSS:
    criterion = FocalLoss(
        gamma=config.FOCAL_GAMMA,
        alpha=class_weights,
        label_smoothing=config.LABEL_SMOOTHING if config.USE_LABEL_SMOOTHING else 0.0
    )
    print(f"‚úì Focal Loss (gamma={config.FOCAL_GAMMA})")
elif config.USE_LABEL_SMOOTHING:
    criterion = LabelSmoothingCrossEntropy(smoothing=config.LABEL_SMOOTHING)
    print(f"‚úì Label Smoothing (smoothing={config.LABEL_SMOOTHING})")
else:
    criterion = nn.CrossEntropyLoss(weight=class_weights)

val_criterion = nn.CrossEntropyLoss()

# Optimiseur
optimizer = optim.AdamW(
    model.parameters(),
    lr=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY
)

# Scheduler OneCycleLR
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config.LEARNING_RATE * 10,
    epochs=config.EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos'
)

print(f"\n{'='*60}")
print("üìã Configuration d'entra√Ænement:")
print(f"  Dataset: Balanced AffectNet (75x75 RGB, 8 classes)")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Epochs: {config.EPOCHS}, Patience: {config.PATIENCE}")
print(f"  Mixup: {config.USE_MIXUP} (alpha={config.MIXUP_ALPHA})")
print(f"  Label Smoothing: {config.USE_LABEL_SMOOTHING} ({config.LABEL_SMOOTHING})")
print(f"  Advanced Aug: {HAS_ALBUMENTATIONS and config.USE_ADVANCED_AUG}")
print(f"{'='*60}")

## 12. üèãÔ∏è Boucle d'Entra√Ænement

In [None]:
# Variables de suivi
best_val_acc = 0.0
best_val_loss = float('inf')
patience_counter = 0
best_epoch = 0

# Historique pour les graphiques
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': [],
    'lr': []
}

start_time = time.time()

print("\nüöÄ D√©marrage de l'entra√Ænement...\n")

for epoch in range(config.EPOCHS):
    model.train()
    
    loss_meter = AverageMeter()
    correct = 0
    total = 0
    
    optimizer.zero_grad()
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
        
        # Mixup ou CutMix al√©atoire
        use_mixup = config.USE_MIXUP and random.random() > 0.5
        use_cutmix = config.USE_CUTMIX and random.random() < config.CUTMIX_PROB and not use_mixup
        
        if use_mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels, config.MIXUP_ALPHA)
        elif use_cutmix:
            inputs, labels_a, labels_b, lam = cutmix_data(inputs, labels, config.CUTMIX_ALPHA)
        
        # Forward
        outputs = model(inputs)
        
        # Loss
        if use_mixup or use_cutmix:
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        else:
            loss = criterion(outputs, labels)
        
        loss = loss / config.ACCUMULATION_STEPS
        loss.backward()
        
        # Gradient accumulation
        if (batch_idx + 1) % config.ACCUMULATION_STEPS == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # M√©triques
        loss_meter.update(loss.item() * config.ACCUMULATION_STEPS, inputs.size(0))
        
        if not (use_mixup or use_cutmix):
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    train_acc = 100.0 * correct / max(total, 1)
    
    # Validation
    val_loss, val_acc = validate(model, val_loader, val_criterion, config.DEVICE, 
                                 per_class=(epoch % 10 == 0))
    
    current_lr = optimizer.param_groups[0]['lr']
    elapsed = time.time() - start_time
    
    # Sauvegarder historique
    history['train_loss'].append(loss_meter.avg)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['lr'].append(current_lr)
    
    print(f"Epoch {epoch+1:3d}/{config.EPOCHS} | "
          f"Train Loss: {loss_meter.avg:.4f} | Train Acc: {train_acc:.1f}% | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.1f}% | "
          f"LR: {current_lr:.6f} | Time: {elapsed/60:.1f}min")
    
    # Sauvegarder le meilleur mod√®le
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
            'config': {
                'num_classes': config.NUM_CLASSES,
                'in_channels': config.IN_CHANNELS,
                'input_size': config.INPUT_SIZE,
                'dataset': 'affectnet',
            }
        }, 'emotion_model_best.pth')
        print(f"  ‚úÖ Nouveau meilleur mod√®le! (Val Acc: {val_acc:.2f}%)")
    else:
        patience_counter += 1
        if patience_counter >= config.PATIENCE:
            print(f"\n‚èπÔ∏è Early stopping apr√®s {epoch+1} √©poques!")
            break

elapsed = time.time() - start_time

print(f"\n{'='*60}")
print("üéâ Entra√Ænement termin√©!")
print(f"{'='*60}")
print(f"Temps total: {elapsed/60:.1f} minutes")
print(f"Meilleure √©poque: {best_epoch}")
print(f"Meilleure pr√©cision validation: {best_val_acc:.2f}%")
print(f"Meilleure loss validation: {best_val_loss:.4f}")

## 13. üìà Visualisation des R√©sultats

In [None]:
# Graphiques d'entra√Ænement
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train', color='blue')
axes[0].plot(history['val_loss'], label='Validation', color='orange')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('üìâ Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(history['train_acc'], label='Train', color='blue')
axes[1].plot(history['val_acc'], label='Validation', color='orange')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('üìä Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning Rate
axes[2].plot(history['lr'], color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('üìà Learning Rate (OneCycleLR)')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()

print("\nüìä Graphiques sauvegard√©s dans 'training_curves.png'")

## 14. üîç √âvaluation Finale

In [None]:
# Charger le meilleur mod√®le
print("üì• Chargement du meilleur mod√®le...")
checkpoint = torch.load('emotion_model_best.pth', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])

print(f"\nüìä √âvaluation finale sur le set de validation:")
val_loss, val_acc = validate(model, val_loader, val_criterion, config.DEVICE, per_class=True)

print(f"\nüéØ R√©sultats finaux:")
print(f"   - Pr√©cision globale: {val_acc:.2f}%")
print(f"   - Loss: {val_loss:.4f}")

## 15. üíæ Sauvegarde du Mod√®le Final

In [None]:
# Sauvegarder le mod√®le final (poids uniquement)
torch.save({
    'model_state_dict': model.state_dict(),
    'num_classes': config.NUM_CLASSES,
    'in_channels': config.IN_CHANNELS,
    'input_size': config.INPUT_SIZE,
    'dataset': 'affectnet',
    'best_val_acc': best_val_acc,
}, 'emotion_model.pth')

print("‚úÖ Mod√®le sauvegard√© dans 'emotion_model.pth'")
print(f"   Taille: {os.path.getsize('emotion_model.pth') / 1024 / 1024:.2f} MB")

## 16. üß™ Test sur Quelques Images

In [None]:
def predict_emotion(model, image_tensor, device):
    """Pr√©dit l'√©motion pour une image."""
    model.eval()
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0).to(device)
        outputs = model(image_tensor)
        probs = F.softmax(outputs, dim=1)
        pred_idx = outputs.argmax(1).item()
        confidence = probs[0, pred_idx].item()
    return pred_idx, confidence, probs[0].cpu().numpy()

# Test sur quelques images de validation
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
axes = axes.flatten()

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
emotions = ['Anger', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']

indices = random.sample(range(len(val_dataset)), 8)

for i, idx in enumerate(indices):
    img, true_label = val_dataset[idx]
    pred_idx, confidence, probs = predict_emotion(model, img, config.DEVICE)
    
    img_np = img.numpy().transpose(1, 2, 0)
    img_np = img_np * std + mean
    img_np = np.clip(img_np, 0, 1)
    
    true_emotion = emotions[true_label]
    pred_emotion = emotions[pred_idx]
    
    color = 'green' if pred_idx == true_label else 'red'
    
    axes[i].imshow(img_np)
    axes[i].set_title(f"Vrai: {true_emotion}\nPr√©d: {pred_emotion} ({confidence*100:.1f}%)", 
                      color=color, fontsize=10)
    axes[i].axis('off')

plt.suptitle('üîç Pr√©dictions sur le Set de Validation', fontsize=14)
plt.tight_layout()
plt.show()