In [None]:
import pickle
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, balanced_accuracy_score, classification_report
from collections import defaultdict

# ==========================================
# 1. CONFIGURATION GLOBALE
# ==========================================
SEED = 42
N_SPLITS = 5
BATCH_SIZE = 64
EPOCHS_PER_FOLD = 50
LEARNING_RATE = 1e-3  # CORRIGÉ: Réduit pour Adam
WEIGHT_DECAY = 1e-4   # CORRIGÉ: Moins agressif
EARLY_STOP_PATIENCE = 15
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)

set_seed(SEED)
print(f"Device utilisé : {DEVICE}\n")

# ==========================================
# 2. PRÉ-TRAITEMENT & DATASET
# ==========================================
def enhance_channels(img_tensor):
    r = img_tensor[0]
    b = img_tensor[2]
    diff = (r - b) + 0.5
    diff = torch.clamp(diff, 0, 1)
    return torch.stack([r, diff, b], dim=0)

class RetinaDataset(Dataset):
    def __init__(self, images, labels=None, transform=None, is_test=False):
        self.images = images
        self.labels = labels
        self.transform = transform
        self.is_test = is_test

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_arr = self.images[idx]
        img_pil = Image.fromarray(img_arr.astype('uint8'))
        
        if self.transform:
            img_tensor = self.transform(img_pil)
        else:
            img_tensor = transforms.ToTensor()(img_pil)
            
        if self.is_test:
            return img_tensor
        
        label = self.labels[idx]
        if hasattr(label, 'item'):
            label = int(label.item())
        else:
            label = int(label)
            
        return img_tensor, label

# ==========================================
# 3. ARCHITECTURE SIMPLIFIÉE
# ==========================================
class SimpleNet(nn.Module):
    """Architecture plus simple et stable"""
    def __init__(self, num_classes=5):
        super().__init__()
        
        # Bloc 1 (28x28 -> 14x14)
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        
        # Bloc 2 (14x14 -> 7x7)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        # Bloc 3 (7x7 -> 3x3)
        self.conv5 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        
        # Global Average Pooling
        self.gap = nn.AdaptiveAvgPool2d(1)
        
        # Classifier
        self.fc1 = nn.Linear(128, 128)
        self.dropout = nn.Dropout(0.3)  # RÉDUIT
        self.fc2 = nn.Linear(128, num_classes)
        
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight)
                init.constant_(m.bias, 0)

    def forward(self, x):
        # Bloc 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        
        # Bloc 2
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)
        
        # Bloc 3
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.max_pool2d(x, 2)
        
        # GAP + Classifier
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# ==========================================
# 4. PRÉPARATION DES DONNÉES
# ==========================================
print("Chargement des données...")
TRAIN_PATH = "ift-3395-6390-kaggle-2-competition-fall-2025/train_data.pkl" 
TEST_PATH = "ift-3395-6390-kaggle-2-competition-fall-2025/test_data.pkl"

with open(TRAIN_PATH, "rb") as f:
    train_data_raw = pickle.load(f)

X_all = train_data_raw["images"].astype(np.float32)
y_all = train_data_raw["labels"].reshape(-1)

# Stats globales
X_tmp = X_all / 255.0
IR_MEAN = X_tmp.mean(axis=(0, 1, 2)).tolist()
IR_STD = X_tmp.std(axis=(0, 1, 2)).tolist()
print(f"Stats -> Mean: {[f'{m:.3f}' for m in IR_MEAN]}, Std: {[f'{s:.3f}' for s in IR_STD]}")

unique, counts = np.unique(y_all, return_counts=True)
print(f"Distribution: {dict(zip(unique, counts))}\n")

# TRANSFORMS MODÉRÉS (clé du succès)
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomRotation(15),  # RÉDUIT: 15° au lieu de 90°
    transforms.ToTensor(),
    transforms.Lambda(enhance_channels),
    transforms.Normalize(IR_MEAN, IR_STD)
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(enhance_channels),
    transforms.Normalize(IR_MEAN, IR_STD)
])

# ==========================================
# 5. BOUCLE CROSS-VALIDATION
# ==========================================
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
fold_scores = []
fold_reports = []
model_paths = []

for fold, (train_idx, val_idx) in enumerate(skf.split(X_all, y_all)):
    print(f"{'='*60}")
    print(f"FOLD {fold+1}/{N_SPLITS}")
    print(f"{'='*60}")
    
    X_train, y_train = X_all[train_idx], y_all[train_idx]
    X_val, y_val = X_all[val_idx], y_all[val_idx]
    
    print(f"Train: {len(X_train)} | Val: {len(X_val)}")
    
    train_ds = RetinaDataset(X_train, y_train, transform=train_transform)
    val_ds = RetinaDataset(X_val, y_val, transform=val_transform)
    
    # SIMPLE DATALOADERS (PAS de sampler complexe)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Modèle simplifié
    model = SimpleNet(num_classes=5).to(DEVICE)
    
    # Optimizer simple
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=7, min_lr=1e-6, verbose=False
    )
    
    # LOSS AVEC CLASS WEIGHTS UNIQUEMENT (pas de label smoothing au début)
    from sklearn.utils.class_weight import compute_class_weight
    loss_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(loss_weights).to(DEVICE))
    
    # Training
    best_bal_acc = 0.0
    best_model_name = f"model_fold_{fold}.pth"
    patience_counter = 0
    
    for epoch in range(EPOCHS_PER_FOLD):
        # TRAIN
        model.train()
        train_loss, train_correct, train_total = 0, 0, 0
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, lbls)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # RÉDUIT
            optimizer.step()
            
            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_correct += (preds == lbls).sum().item()
            train_total += lbls.size(0)
        
        # VALIDATION
        model.eval()
        all_preds, all_targets = [], []
        val_loss = 0
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
                outputs = model(imgs)
                loss = criterion(outputs, lbls)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(lbls.cpu().numpy())
        
        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        train_acc = train_correct / train_total
        val_bal_acc = balanced_accuracy_score(all_targets, all_preds)
        val_acc = accuracy_score(all_targets, all_preds)
        
        # Print
        if (epoch+1) % 5 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:02d} | TrLoss: {avg_train:.3f} TrainAcc: {train_acc:.3f} | "
                  f"ValLoss: {avg_val:.3f} ValAcc: {val_acc:.3f} BalAcc: {val_bal_acc:.4f}")
        
        scheduler.step(val_bal_acc)
        
        # Save best
        if val_bal_acc > best_bal_acc:
            best_bal_acc = val_bal_acc
            patience_counter = 0
            torch.save(model.state_dict(), best_model_name)
            
            if epoch >= 10:
                report = classification_report(all_targets, all_preds, output_dict=True, zero_division=0)
                print(f"  ✓ Meilleur | BalAcc: {best_bal_acc:.4f} | "
                      f"C4 Recall: {report.get('4', {}).get('recall', 0):.2f}")
        else:
            patience_counter += 1
            if patience_counter >= EARLY_STOP_PATIENCE:
                print(f"  Early stop @ epoch {epoch+1}")
                break
    
    # Évaluation finale
    model.load_state_dict(torch.load(best_model_name, map_location=DEVICE))
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for imgs, lbls in val_loader:
            imgs, lbls = imgs.to(DEVICE), lbls.to(DEVICE)
            outputs = model(imgs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(lbls.cpu().numpy())
    
    final_report = classification_report(all_targets, all_preds, output_dict=True, zero_division=0)
    
    print(f"\n{'─'*60}")
    print(f"FOLD {fold+1} FINAL")
    print(f"{'─'*60}")
    print(f"Balanced Acc: {best_bal_acc:.4f} | Accuracy: {accuracy_score(all_targets, all_preds):.4f}")
    print("Per-Class Recall:")
    for cls in range(5):
        recall = final_report.get(str(cls), {}).get('recall', 0)
        support = final_report.get(str(cls), {}).get('support', 0)
        print(f"  Class {cls}: {recall:.3f} (n={int(support)})")
    print()
    
    fold_scores.append(best_bal_acc)
    fold_reports.append(final_report)
    model_paths.append(best_model_name)
    
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ==========================================
# RÉSUMÉ CV
# ==========================================
print(f"{'='*60}")
print("RÉSUMÉ CROSS-VALIDATION")
print(f"{'='*60}")
print(f"Balanced Acc: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")
print(f"Scores: {[f'{s:.4f}' for s in fold_scores]}\n")

avg_recalls = defaultdict(list)
for report in fold_reports:
    for cls in range(5):
        recall = report.get(str(cls), {}).get('recall', 0)
        avg_recalls[cls].append(recall)

print("Recall moyen par classe:")
for cls in range(5):
    recalls = avg_recalls[cls]
    print(f"  Classe {cls}: {np.mean(recalls):.3f} ± {np.std(recalls):.3f}")

# ==========================================
# 6. INFERENCE
# ==========================================
print(f"\n{'='*60}")
print("GÉNÉRATION SOUMISSION")
print(f"{'='*60}")

with open(TEST_PATH, "rb") as f:
    test_data = pickle.load(f)

test_ds = RetinaDataset(test_data['images'], labels=None, transform=val_transform, is_test=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

models = []
for path in model_paths:
    m = SimpleNet(num_classes=5).to(DEVICE)
    m.load_state_dict(torch.load(path, map_location=DEVICE))
    m.eval()
    models.append(m)

final_preds = []
with torch.no_grad():
    for imgs in test_loader:
        imgs = imgs.to(DEVICE)
        ensemble_logits = torch.zeros(imgs.size(0), 5).to(DEVICE)
        
        for m in models:
            # TTA: Normal + Flip H
            out_norm = m(imgs)
            out_h = m(torch.flip(imgs, [3]))
            ensemble_logits += (out_norm + out_h) / 2.0
        
        ensemble_logits /= len(models)
        _, predicted = torch.max(ensemble_logits, 1)
        final_preds.extend(predicted.cpu().numpy())

df = pd.DataFrame({"ID": np.arange(1, len(final_preds) + 1), "Label": final_preds})
df.to_csv("submission_ensemble_kfold.csv", index=False)
print("✓ Fichier 'submission_ensemble_kfold.csv' généré!")

Device utilisé : mps
Chargement des données...
Stats -> Mean: ['0.211', '0.005', '0.229'], Std: ['0.189', '0.017', '0.170']
Distribution: {np.uint8(0): np.int64(486), np.uint8(1): np.int64(128), np.uint8(2): np.int64(206), np.uint8(3): np.int64(194), np.uint8(4): np.int64(66)}

FOLD 1/5
Train: 864 | Val: 216
Train dist: {np.uint8(0): np.int64(388), np.uint8(1): np.int64(103), np.uint8(2): np.int64(165), np.uint8(3): np.int64(155), np.uint8(4): np.int64(53)}
Epoch 01 | Train: 1.713 | Val: 23.382 | Acc: 0.454 | Bal Acc: 0.2000
Epoch 05 | Train: 1.536 | Val: 1.673 | Acc: 0.171 | Bal Acc: 0.2476
Epoch 10 | Train: 1.515 | Val: 1.718 | Acc: 0.282 | Bal Acc: 0.2746
✓ Sauvegarde | Bal Acc: 0.3616 | Class 4 Recall: 0.92
Epoch 15 | Train: 1.511 | Val: 1.627 | Acc: 0.250 | Bal Acc: 0.3643
✓ Sauvegarde | Bal Acc: 0.3643 | Class 4 Recall: 0.85
✓ Sauvegarde | Bal Acc: 0.3696 | Class 4 Recall: 0.92
Epoch 20 | Train: 1.507 | Val: 1.639 | Acc: 0.208 | Bal Acc: 0.3175
✓ Sauvegarde | Bal Acc: 0.3798 | Cl