In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import time

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🔥 Device: {device}")

# Paramètres d'entraînement
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 10
NUM_CLASSES = 5

# Classes et emojis
CLASSES = ['cat', 'dog', 'house', 'car', 'tree']
CLASS_EMOJIS = ['🐱', '🐶', '🏠', '🚗', '🌳']

print(f"⚙️ Configuration:")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Learning rate: {LEARNING_RATE}")
print(f"   Epochs: {EPOCHS}")
print(f"   Classes: {CLASSES}")
print("✅ Imports terminés!")


In [None]:
# Architecture CNN
class DrawingCNN(nn.Module):
    """CNN optimisé pour dessins Quick Draw 28x28"""
    
    def __init__(self, num_classes=5):
        super(DrawingCNN, self).__init__()
        
        # Bloc convolutionnel 1: 28x28 -> 14x14
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        # Bloc convolutionnel 2: 14x14 -> 7x7
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        # Bloc convolutionnel 3: 7x7 -> 7x7
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Couches fully connected
        self.fc1 = nn.Linear(128 * 7 * 7, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)
        
        # Initialisation des poids
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Bloc 1
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        
        # Bloc 2  
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        
        # Bloc 3
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # FC
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    def predict_proba(self, x):
        """Probabilités avec softmax"""
        with torch.no_grad():
            logits = self.forward(x)
            return F.softmax(logits, dim=1)

# Test du modèle
model = DrawingCNN(num_classes=NUM_CLASSES).to(device)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("🧠 ARCHITECTURE CNN")
print("=" * 30)
print(f"🔢 Paramètres: {params:,}")
print(f"💾 Taille estimée: {params * 4 / 1024 / 1024:.1f} MB")

# Test forward pass
test_input = torch.randn(1, 1, 28, 28).to(device)
with torch.no_grad():
    output = model(test_input)
    print(f"📊 Test: {test_input.shape} -> {output.shape}")
    print("✅ Architecture validée!")


In [None]:
# Dataset personnalisé
class QuickDrawDataset(Dataset):
    """Dataset pour les données Quick Draw"""
    
    def __init__(self, data_dir, classes, max_samples_per_class=50000, train=True):
        self.data_dir = Path(data_dir)
        self.classes = classes
        self.class_to_idx = {cls: idx for idx, cls in enumerate(classes)}
        self.train = train
        
        # Charger les données
        print("📊 Chargement des données...")
        self.images = []
        self.labels = []
        
        for class_idx, class_name in enumerate(classes):
            print(f"   {CLASS_EMOJIS[class_idx]} Chargement {class_name}...")
            
            # Charger le fichier .npy
            filepath = self.data_dir / f"{class_name}.npy"
            data = np.load(filepath)
            
            # Limiter le nombre d'échantillons
            if len(data) > max_samples_per_class:
                # Mélanger et prendre les premiers max_samples_per_class
                indices = np.random.permutation(len(data))[:max_samples_per_class]
                data = data[indices]
            
            # Ajouter aux listes
            self.images.extend(data)
            self.labels.extend([class_idx] * len(data))
            
            print(f"     -> {len(data):,} images ajoutées")
        
        # Convertir en numpy arrays
        self.images = np.array(self.images)
        self.labels = np.array(self.labels)
        
        print(f"✅ Dataset créé: {len(self.images):,} images total")
        print(f"   Shape: {self.images.shape}")
        print(f"   Distribution: {np.bincount(self.labels)}")
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Récupérer l'image et le label
        image = self.images[idx].reshape(28, 28).astype(np.float32)
        label = self.labels[idx]
        
        # Normalisation: Quick Draw a noir=255, blanc=0
        # On inverse pour avoir noir=0, blanc=1, puis on normalise [-1, 1]
        image = 255 - image  # Inverser
        image = image / 255.0  # [0, 1]
        image = (image - 0.5) / 0.5  # [-1, 1]
        
        # Ajouter dimension channel
        image = image[np.newaxis, ...]  # (1, 28, 28)
        
        return torch.FloatTensor(image), torch.LongTensor([label])[0]

# Créer le dataset
project_root = Path('.').parent
data_dir = project_root / "data"

# Dataset complet
print("🔄 Création du dataset...")
full_dataset = QuickDrawDataset(
    data_dir=data_dir,
    classes=CLASSES,
    max_samples_per_class=80000,  # Équilibré, mais pas trop pour la vitesse
    train=True
)

# Split train/validation/test
total_size = len(full_dataset)
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size]
)

print(f"📊 Splits:")
print(f"   Train: {len(train_dataset):,} images")
print(f"   Validation: {len(val_dataset):,} images") 
print(f"   Test: {len(test_dataset):,} images")

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"📦 DataLoaders créés:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")
print("✅ Données prêtes!")


In [None]:
# Configuration TensorBoard et optimiseur
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'../runs/drawing_cnn_{timestamp}')

# Réinitialiser le modèle pour un entraînement propre
model = DrawingCNN(num_classes=NUM_CLASSES).to(device)

# Optimiseur et scheduler
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7)
criterion = nn.CrossEntropyLoss()

print(f"⚙️ Configuration entraînement:")
print(f"   TensorBoard: runs/drawing_cnn_{timestamp}")
print(f"   Optimiseur: Adam (lr={LEARNING_RATE})")
print(f"   Scheduler: StepLR (step=3, gamma=0.7)")
print(f"   Loss: CrossEntropyLoss")

# Fonctions d'entraînement et validation
def train_epoch(model, loader, optimizer, criterion, epoch, writer):
    """Entraîne le modèle pour une époque"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # Statistiques
        running_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        # Log TensorBoard toutes les 100 batches
        if batch_idx % 100 == 0:
            step = epoch * len(loader) + batch_idx
            writer.add_scalar('Train/Loss_Batch', loss.item(), step)
            writer.add_scalar('Train/Accuracy_Batch', 100. * correct / total, step)
            
            print(f'   Batch {batch_idx:3d}/{len(loader)} | '
                  f'Loss: {loss.item():.4f} | '
                  f'Acc: {100. * correct / total:.2f}%')
    
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100. * correct / total
    
    # Log époque
    writer.add_scalar('Train/Loss_Epoch', epoch_loss, epoch)
    writer.add_scalar('Train/Accuracy_Epoch', epoch_acc, epoch)
    
    return epoch_loss, epoch_acc

def validate_epoch(model, loader, criterion, epoch, writer):
    """Valide le modèle"""
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    val_loss /= len(loader)
    val_acc = 100. * correct / total
    
    # Log validation
    writer.add_scalar('Val/Loss', val_loss, epoch)
    writer.add_scalar('Val/Accuracy', val_acc, epoch)
    
    print(f'   Validation | Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%')
    
    return val_loss, val_acc

print("✅ Fonctions d'entraînement prêtes!")


In [None]:
# 🚀 ENTRAÎNEMENT PRINCIPAL
print("🔥 DÉBUT DE L'ENTRAÎNEMENT")
print("=" * 50)

# Ajouter graphe du modèle à TensorBoard
sample_input = torch.randn(1, 1, 28, 28).to(device)
writer.add_graph(model, sample_input)

# Historique des métriques
train_losses, train_accs = [], []
val_losses, val_accs = [], []
best_val_acc = 0.0

start_time = time.time()

for epoch in range(EPOCHS):
    epoch_start = time.time()
    
    print(f"\n📅 ÉPOQUE {epoch+1}/{EPOCHS}")
    print("-" * 30)
    
    # Entraînement
    print("🏋️ Entraînement...")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, epoch, writer)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validation
    print("🔍 Validation...")
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, epoch, writer)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    # Learning rate scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    writer.add_scalar('Train/Learning_Rate', current_lr, epoch)
    
    # Sauvegarder le meilleur modèle
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'train_acc': train_acc,
        }, f'../models/best_model_{timestamp}.pth')
        print(f"   💾 Nouveau meilleur modèle sauvé! (Val Acc: {val_acc:.2f}%)")
    
    epoch_time = time.time() - epoch_start
    print(f"   ⏱️ Temps époque: {epoch_time:.1f}s | LR: {current_lr:.6f}")
    print(f"   📊 Train: {train_acc:.2f}% | Val: {val_acc:.2f}% | Best: {best_val_acc:.2f}%")

total_time = time.time() - start_time

print(f"\n🎉 ENTRAÎNEMENT TERMINÉ!")
print("=" * 50)
print(f"⏱️ Temps total: {total_time/60:.1f} minutes")
print(f"🏆 Meilleure accuracy validation: {best_val_acc:.2f}%")
print(f"📈 Accuracy finale train: {train_accs[-1]:.2f}%")
print(f"📉 Loss finale train: {train_losses[-1]:.4f}")


In [None]:
# 📊 ÉVALUATION FINALE ET VISUALISATIONS

# Visualisation des courbes d'apprentissage
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

epochs_range = range(1, EPOCHS + 1)

# Loss
ax1.plot(epochs_range, train_losses, 'b-', label='Train', linewidth=2)
ax1.plot(epochs_range, val_losses, 'r-', label='Validation', linewidth=2)
ax1.set_title('Loss par époque', fontsize=14, fontweight='bold')
ax1.set_xlabel('Époque')
ax1.set_ylabel('Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(epochs_range, train_accs, 'b-', label='Train', linewidth=2)
ax2.plot(epochs_range, val_accs, 'r-', label='Validation', linewidth=2)
ax2.set_title('Accuracy par époque', fontsize=14, fontweight='bold')
ax2.set_xlabel('Époque')
ax2.set_ylabel('Accuracy (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Distribution des classes (subplot 3)
test_predictions = []
test_true_labels = []

model.eval()
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        test_predictions.extend(pred.cpu().numpy())
        test_true_labels.extend(target.cpu().numpy())

# Matrice de confusion simplifiée
from collections import Counter
pred_counter = Counter(test_predictions)
true_counter = Counter(test_true_labels)

class_names_with_emoji = [f"{CLASS_EMOJIS[i]} {cls}" for i, cls in enumerate(CLASSES)]

ax3.bar(range(len(CLASSES)), [pred_counter[i] for i in range(len(CLASSES))], 
        alpha=0.7, label='Prédictions')
ax3.bar(range(len(CLASSES)), [true_counter[i] for i in range(len(CLASSES))], 
        alpha=0.7, label='Vraies classes')
ax3.set_title('Distribution Test Set', fontsize=14, fontweight='bold')
ax3.set_xlabel('Classes')
ax3.set_ylabel('Nombre d\'échantillons')
ax3.set_xticks(range(len(CLASSES)))
ax3.set_xticklabels(class_names_with_emoji, rotation=45)
ax3.legend()
ax3.grid(True, alpha=0.3)

# Accuracy par classe
correct_per_class = [0] * len(CLASSES)
total_per_class = [0] * len(CLASSES)

for true_label, pred_label in zip(test_true_labels, test_predictions):
    total_per_class[true_label] += 1
    if true_label == pred_label:
        correct_per_class[true_label] += 1

class_accuracies = [100 * correct_per_class[i] / max(total_per_class[i], 1) 
                   for i in range(len(CLASSES))]

bars = ax4.bar(range(len(CLASSES)), class_accuracies, 
               color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'])
ax4.set_title('Accuracy par classe', fontsize=14, fontweight='bold')
ax4.set_xlabel('Classes')
ax4.set_ylabel('Accuracy (%)')
ax4.set_xticks(range(len(CLASSES)))
ax4.set_xticklabels(class_names_with_emoji, rotation=45)
ax4.grid(True, alpha=0.3)

# Ajouter les valeurs sur les barres
for bar, acc in zip(bars, class_accuracies):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{acc:.1f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Sauvegarder dans TensorBoard
writer.add_figure('Results/Training_Summary', fig)

# Calcul accuracy test finale
test_correct = sum([1 for true, pred in zip(test_true_labels, test_predictions) if true == pred])
test_total = len(test_true_labels)
test_accuracy = 100. * test_correct / test_total

print("🎯 RÉSULTATS FINAUX")
print("=" * 40)
print(f"📊 Test Accuracy: {test_accuracy:.2f}%")
print(f"🏆 Meilleure Val Accuracy: {best_val_acc:.2f}%")
print(f"⏱️ Temps d'entraînement: {total_time/60:.1f} min")
print(f"💾 Modèle sauvé: models/best_model_{timestamp}.pth")

# Accuracy par classe détaillée
print("\n📈 ACCURACY PAR CLASSE:")
for i, (cls, acc) in enumerate(zip(CLASSES, class_accuracies)):
    print(f"   {CLASS_EMOJIS[i]} {cls.capitalize():8s}: {acc:5.1f}%")

# Évaluation de l'objectif
if test_accuracy >= 85.0:
    print(f"\n🎉 OBJECTIF ATTEINT! Test accuracy {test_accuracy:.1f}% >= 85%")
    status = "SUCCESS ✅"
else:
    print(f"\n⚠️ Objectif manqué. Test accuracy {test_accuracy:.1f}% < 85%")
    status = "NEEDS_IMPROVEMENT ⚠️"

print(f"\n🏁 STATUS: {status}")

# Fermer TensorBoard
writer.close()
print(f"\n🌐 TensorBoard: tensorboard --logdir=../runs")
print(f"📁 Run: drawing_cnn_{timestamp}")
print("✅ Session terminée!")
