# üé® AI Pictionary - CNN Model Training (Colab Version)

**Quick Draw Dataset Classification with TensorFlow/Keras**

Ce notebook est adapt√© pour tourner sur **Google Colab**. Il t√©l√©charge automatiquement les donn√©es brutes depuis les serveurs de Google, les traite, entra√Æne le mod√®le et vous permet de t√©l√©charger le fichier `.h5` final.

---

In [None]:
# @title ‚öôÔ∏è 1. Configuration de l'environnement Colab
import os

# Cr√©ation de l'arborescence de dossiers n√©cessaire
os.makedirs("logs", exist_ok=True)
os.makedirs("backend/models", exist_ok=True)
os.makedirs("data", exist_ok=True)

print("‚úÖ Dossiers cr√©√©s : logs/, backend/models/, data/")

In [None]:
# @title 2Ô∏è‚É£ Import des Librairies
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
import os
import urllib.request
from datetime import datetime

# TensorFlow/Keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {len(tf.config.list_physical_devices('GPU')) > 0}")

# Style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

In [None]:
# @title 3Ô∏è‚É£ T√©l√©chargement et Pr√©paration des Donn√©es
# Cette cellule remplace le chargement HDF5 local.
# Elle t√©l√©charge les fichiers .npy directement depuis Google Cloud.

# Liste de 20 cat√©gories pour la d√©mo
CATEGORIES = [
    "apple", "banana", "baseball", "book", "bucket",
    "camera", "car", "clock", "cloud", "cup",
    "door", "eye", "face", "fan", "flower",
    "ladder", "lightning", "star", "sword", "tree"
]
NUM_CLASSES = len(CATEGORIES)
MAX_SAMPLES_PER_CLASS = 12000 # Limite pour √©viter de saturer la RAM Colab

print(f"üì• T√©l√©chargement des donn√©es pour {NUM_CLASSES} cat√©gories...")
base_url = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/"

X_data = []
y_data = []

for idx, category in enumerate(CATEGORIES):
    # Gestion des espaces dans les noms de fichiers url (ex: ice cream)
    filename = category.replace(" ", "%20") + ".npy"
    url = base_url + filename
    local_path = f"data/{category}.npy"
    
    if not os.path.exists(local_path):
        try:
            urllib.request.urlretrieve(url, local_path)
        except Exception as e:
            print(f"‚ùå Erreur t√©l√©chargement {category}: {e}")
            continue
    
    # Charger les donn√©es
    data = np.load(local_path)
    # On garde seulement un sous-ensemble pour la m√©moire
    data = data[:MAX_SAMPLES_PER_CLASS] 
    
    X_data.append(data)
    y_data.append(np.full(data.shape[0], idx))
    print(f"   ‚úÖ {category}: {data.shape[0]} images charg√©es")

# Concat√©nation
X = np.concatenate(X_data, axis=0)
y = np.concatenate(y_data, axis=0)

# Normalisation [0, 1] et Reshape (N, 28, 28, 1)
print("\nüîÑ Normalisation et Reshape...")
X = X.astype('float32') / 255.0
X = X.reshape(-1, 28, 28, 1)

# Split Train/Val/Test (80% / 10% / 10%)
print("‚úÇÔ∏è  Splitting dataset...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

IMAGE_SHAPE = (28, 28, 1)

# One-hot encoding
y_train_cat = to_categorical(y_train, NUM_CLASSES)
y_val_cat = to_categorical(y_val, NUM_CLASSES)
y_test_cat = to_categorical(y_test, NUM_CLASSES)

print(f"\n‚úÖ Dataset pr√™t !")
print(f"   Train: {X_train.shape} - Label shape: {y_train_cat.shape}")
print(f"   Val:   {X_val.shape}")
print(f"   Test:  {X_test.shape}")

In [None]:
# @title 4Ô∏è‚É£ Visualisation des √©chantillons
# Display random samples (one per category)
fig, axes = plt.subplots(4, 5, figsize=(15, 12))
fig.suptitle('Sample Images from Each Category', fontsize=16, fontweight='bold')

for i, category in enumerate(CATEGORIES):
    if i >= 20: break # Safety check
    # Find first image of this category in train set
    idx = np.where(y_train == i)[0][0]
    
    # Plot image
    ax = axes[i // 5, i % 5]
    ax.imshow(X_train[idx].squeeze(), cmap='gray')
    ax.set_title(category, fontsize=12, fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# @title 5Ô∏è‚É£ Architecture CNN Simple
# Build model
model = keras.Sequential([
    layers.Input(shape=IMAGE_SHAPE),
    
    # Conv Block 1
    layers.Conv2D(32, kernel_size=(3, 3), activation="relu", name="conv2d_1"),
    layers.MaxPooling2D(pool_size=(2, 2), name="maxpool_1"),
    
    # Conv Block 2
    layers.Conv2D(64, kernel_size=(3, 3), activation="relu", name="conv2d_2"),
    layers.MaxPooling2D(pool_size=(2, 2), name="maxpool_2"),
    
    # Head
    layers.Flatten(name="flatten"),
    layers.Dropout(0.5, name="dropout"),
    layers.Dense(NUM_CLASSES, activation="softmax", name="output")
], name="QuickDraw_SimpleCNN")

model.summary()

# Visualize model architecture (Chemin adapt√© pour Colab)
keras.utils.plot_model(
    model,
    to_file="logs/model_architecture.png",
    show_shapes=True,
    show_layer_names=True,
    rankdir="TB",
    dpi=150
)
print("‚úÖ Architecture sauvegard√©e dans logs/model_architecture.png")

In [None]:
# @title 6Ô∏è‚É£ Compilation
model.compile(
    loss="categorical_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    metrics=["accuracy"]
)
print("‚úÖ Mod√®le compil√© avec Adam & Categorical Crossentropy")

In [None]:
# @title 7Ô∏è‚É£ Entra√Ænement
BATCH_SIZE = 128
EPOCHS = 15

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=3, restore_best_weights=True, verbose=1
    ),
    keras.callbacks.ModelCheckpoint(
        filepath='logs/best_model.h5', # Chemin adapt√©
        monitor='val_accuracy', save_best_only=True, verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, patience=2, min_lr=0.00001, verbose=1
    )
]

print(f"üöÄ D√©marrage de l'entra√Ænement sur {len(X_train):,} images...")
history = model.fit(
    X_train, y_train_cat,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val_cat),
    callbacks=callbacks,
    verbose=1
)

In [None]:
# @title 8Ô∏è‚É£ Historique d'entra√Ænement
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
axes[0].plot(history.history['accuracy'], label='Train Acc')
axes[0].plot(history.history['val_accuracy'], label='Val Acc')
axes[0].set_title('Model Accuracy')
axes[0].legend()

# Loss
axes[1].plot(history.history['loss'], label='Train Loss')
axes[1].plot(history.history['val_loss'], label='Val Loss')
axes[1].set_title('Model Loss')
axes[1].legend()

plt.savefig('logs/training_history.png')
plt.show()

In [None]:
# @title 9Ô∏è‚É£ √âvaluation sur le Test Set
test_loss, test_accuracy = model.evaluate(X_test, y_test_cat, verbose=0)

print(f"\nüéØ Test Accuracy: {test_accuracy*100:.2f}%")
print(f"üéØ Test Loss:     {test_loss:.4f}")

In [None]:
# @title üîü Matrice de Confusion
y_pred_probs = model.predict(X_test, verbose=0)
y_pred = np.argmax(y_pred_probs, axis=1)

cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CATEGORIES, yticklabels=CATEGORIES)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('logs/confusion_matrix.png')
plt.show()

In [None]:
# @title 1Ô∏è‚É£1Ô∏è‚É£ Sauvegarde et M√©tadonn√©es
import json

MODEL_VERSION = "v1.0.0"
# Chemins adapt√©s pour Colab
MODEL_SAVE_PATH = f"backend/models/quickdraw_{MODEL_VERSION}.h5"
METADATA_PATH = f"backend/models/quickdraw_{MODEL_VERSION}_metadata.json"

# Save Model
model.save(MODEL_SAVE_PATH)
print(f"‚úÖ Mod√®le sauvegard√© : {MODEL_SAVE_PATH}")

# Save Metadata
metadata = {
    "version": MODEL_VERSION,
    "created_at": datetime.now().isoformat(),
    "test_accuracy": float(test_accuracy),
    "categories": CATEGORIES,
    "num_classes": NUM_CLASSES
}

with open(METADATA_PATH, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"‚úÖ M√©tadonn√©es sauvegard√©es : {METADATA_PATH}")

In [None]:
# @title ‚¨áÔ∏è 1Ô∏è‚É£2Ô∏è‚É£ T√©l√©charger le mod√®le sur votre PC
from google.colab import files

print("Pr√©paration du t√©l√©chargement...")
try:
    files.download(MODEL_SAVE_PATH)
    files.download(METADATA_PATH)
    print("‚úÖ T√©l√©chargement lanc√© dans le navigateur !")
except Exception as e:
    print(f"‚ùå Erreur lors du t√©l√©chargement: {e}")