<a href="https://colab.research.google.com/github/ZeynaDieng/cifar10-cnn-classification/blob/main/cifar10_cnn1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
# Projet - Classification CIFAR-10 avec CNN
# Auteur : SEYNABOU DIENG
# Enseignant : Mr LY
# Date : 27/09/2025
# ============================================================

# --------------------------
# 1) Imports et configuration
# --------------------------
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.datasets import cifar10

# Réglages pour reproductibilité
seed = 42
np.random.seed(seed)
tf.random.set_seed(seed)

# Hyperparamètres
BATCH_SIZE = 64
EPOCHS = 50
LEARNING_RATE = 1e-3
VALIDATION_SPLIT = 0.1
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)

# --------------------------
# 2) Chargement et prétraitement
# --------------------------
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = y_train.flatten()
y_test = y_test.flatten()

# Normalisation
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

print(f"x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"x_test : {x_test.shape}, y_test : {y_test.shape}")

# Étiquettes CIFAR-10
class_names = ['airplane','automobile','bird','cat','deer',
               'dog','frog','horse','ship','truck']

# --------------------------
# 3) Data augmentation
# --------------------------
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    validation_split=VALIDATION_SPLIT
)

train_gen = datagen.flow(x_train, y_train, batch_size=BATCH_SIZE, subset='training', seed=seed)
val_gen   = datagen.flow(x_train, y_train, batch_size=BATCH_SIZE, subset='validation', seed=seed)

# --------------------------
# 4) Construction du modèle CNN
# --------------------------
def build_cnn(input_shape=(32,32,3), num_classes=10):
    inp = layers.Input(shape=input_shape)

    # Bloc 1
    x = layers.Conv2D(32, (3,3), padding='same', activation='relu')(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(32, (3,3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2,2))(x)

    # Bloc 2
    x = layers.Conv2D(64, (3,3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, (3,3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2,2))(x)

    # Bloc 3
    x = layers.Conv2D(128, (3,3), padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2,2))(x)

    # Flatten et dense
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inp, outputs=out)
    return model

model = build_cnn()
model.summary()

# --------------------------
# 5) Compilation et callbacks
# --------------------------
opt = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

checkpoint_path = os.path.join(MODEL_DIR, "best_model.keras")
cb_early = callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
cb_ckpt  = callbacks.ModelCheckpoint(checkpoint_path, monitor='val_loss', save_best_only=True)

# --------------------------
# 6) Entraînement
# --------------------------
steps_per_epoch = train_gen.n // BATCH_SIZE
history = model.fit(
    train_gen,
    epochs=EPOCHS,
    validation_data=val_gen,
    callbacks=[cb_early, cb_ckpt]
)


# --------------------------
# 7) Évaluation sur le jeu de test
# --------------------------
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\nTest loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")

# --------------------------
# 8) Courbes d'apprentissage
# --------------------------
def plot_history(history):
    fig, ax = plt.subplots(1,2, figsize=(12,4))

    # Accuracy
    ax[0].plot(history.history['accuracy'], label='train_acc')
    ax[0].plot(history.history['val_accuracy'], label='val_acc')
    ax[0].set_title('Accuracy')
    ax[0].legend()

    # Loss
    ax[1].plot(history.history['loss'], label='train_loss')
    ax[1].plot(history.history['val_loss'], label='val_loss')
    ax[1].set_title('Loss')
    ax[1].legend()

    plt.show()

plot_history(history)

# --------------------------
# 9) Matrice de confusion et rapport
# --------------------------
y_pred_probs = model.predict(x_test)
y_pred = np.argmax(y_pred_probs, axis=1)

cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(10,8))
disp.plot(ax=ax, xticks_rotation='vertical', cmap='Blues')
plt.title("Matrice de confusion - CIFAR-10")
plt.show()

print("Classification report :\n")
print(classification_report(y_test, y_pred, target_names=class_names))

# --------------------------
# 10) Exemples corrects et incorrects
# --------------------------
def show_examples(x, y_true, y_pred, class_names, n=8):
    correct_idx = np.where(y_true == y_pred)[0]
    wrong_idx = np.where(y_true != y_pred)[0]

    # Corrects
    plt.figure(figsize=(12,3))
    for i, idx in enumerate(correct_idx[:n]):
        plt.subplot(1,n,i+1)
        plt.imshow(x[idx])
        plt.title(class_names[y_true[idx]])
        plt.axis('off')
    plt.suptitle("Exemples correctement classés")
    plt.show()

    # Incorrects
    plt.figure(figsize=(12,3))
    for i, idx in enumerate(wrong_idx[:n]):
        plt.subplot(1,n,i+1)
        plt.imshow(x[idx])
        plt.title(f"true:{class_names[y_true[idx]]}\npred:{class_names[y_pred[idx]]}")
        plt.axis('off')
    plt.suptitle("Exemples incorrectement classés")
    plt.show()

show_examples((x_test*255).astype('uint8'), y_test, y_pred, class_names, n=8)

# --------------------------
# 11) Sauvegarde du modèle final
# --------------------------
final_path = os.path.join(MODEL_DIR, "final_model.keras")
model.save(final_path)
print(f"Modèle sauvegardé : {final_path}")


x_train: (50000, 32, 32, 3), y_train: (50000,)
x_test : (10000, 32, 32, 3), y_test : (10000,)


Epoch 1/50
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m415s[0m 581ms/step - accuracy: 0.3691 - loss: 1.9942 - val_accuracy: 0.5590 - val_loss: 1.2324
Epoch 2/50
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m411s[0m 584ms/step - accuracy: 0.5984 - loss: 1.1354 - val_accuracy: 0.6318 - val_loss: 1.0179
Epoch 3/50
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m416s[0m 591ms/step - accuracy: 0.6678 - loss: 0.9489 - val_accuracy: 0.7026 - val_loss: 0.8433
Epoch 4/50
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m421s[0m 598ms/step - accuracy: 0.7069 - loss: 0.8392 - val_accuracy: 0.6742 - val_loss: 0.9250
Epoch 5/50
[1m602/704[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m59s[0m 583ms/step - accuracy: 0.7338 - loss: 0.7706