# Mod√©lisation CNN pour la D√©tection de Somnolence

## Construction et Entra√Ænement de R√©seaux de Neurones Convolutifs

Ce notebook couvre:
1. **Construction du CNN from scratch**
2. **Entra√Ænement avec callbacks**
3. **Transfer Learning** avec MobileNetV2
4. **√âvaluation et Visualisation des r√©sultats**

## 1. Configuration et Imports

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, callbacks, optimizers

# Configuration
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
%matplotlib inline

# V√©rification GPU
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU disponible: {tf.config.list_physical_devices('GPU')}")

# Ajout du path source
sys.path.append('../src')
from models.cnn import EyeCNN, YawnCNN
from models.transfer_learning import TransferLearningModel
from utils.metrics import ModelMetrics

## 2. Chargement des Donn√©es

In [None]:
# Chargement des donn√©es pr√©trait√©es
DATA_DIR = '../data/processed'

X_train = np.load(f'{DATA_DIR}/X_train.npy')
X_val = np.load(f'{DATA_DIR}/X_val.npy')
X_test = np.load(f'{DATA_DIR}/X_test.npy')
y_train = np.load(f'{DATA_DIR}/y_train.npy')
y_val = np.load(f'{DATA_DIR}/y_val.npy')
y_test = np.load(f'{DATA_DIR}/y_test.npy')

print("Donn√©es charg√©es:")
print(f"  Train: {X_train.shape}, Labels: {y_train.shape}")
print(f"  Val:   {X_val.shape}, Labels: {y_val.shape}")
print(f"  Test:  {X_test.shape}, Labels: {y_test.shape}")
print(f"\nDistribution - Train: {np.bincount(y_train)}, Val: {np.bincount(y_val)}, Test: {np.bincount(y_test)}")

## 3. Construction du CNN (From Scratch)

### 3.1 Architecture du Mod√®le

Le CNN pour la d√©tection des yeux suit l'architecture classique:
- **Input**: Images 48x48 en niveaux de gris
- **Conv + Pool**: Extraction de features
- **Flatten**: Aplatissement
- **Dense + Dropout**: Classification

In [None]:
# Cr√©ation du mod√®le CNN
cnn_model = EyeCNN(config_path='../config.yaml')
model = cnn_model.build_model()

# Affichage du r√©sum√©
print("="*60)
print("ARCHITECTURE DU CNN POUR D√âTECTION DES YEUX")
print("="*60)
model.summary()

### 3.2 Visualisation de l'Architecture

In [None]:
# Visualisation de l'architecture
tf.keras.utils.plot_model(
    model,
    to_file='../reports/figures/cnn_architecture.png',
    show_shapes=True,
    show_layer_names=True,
    rankdir='TB',
    expand_nested=True
)

print("‚úì Sch√©ma de l'architecture sauvegard√©")

# Affichage des dimensions de sortie de chaque couche
print("\nDimensions des tenseurs par couche:")
print("-"*60)
for layer in model.layers:
    print(f"{layer.name:20s} -> {layer.output_shape}")

## 4. Entra√Ænement du Mod√®le

### 4.1 Callbacks pour l'Entra√Ænement

In [None]:
# Configuration des callbacks
callbacks_list = [
    # Early Stopping: arr√™te si pas d'am√©lioration apr√®s 10 √©poques
    callbacks.EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    
    # Model Checkpoint: sauvegarde le meilleur mod√®le
    callbacks.ModelCheckpoint(
        '../models/cnn_eye_best.h5',
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    ),
    
    # Reduce LR: r√©duit le learning rate si plateau
    callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    
    # TensorBoard: visualisation
    callbacks.TensorBoard(
        log_dir='../reports/logs',
        histogram_freq=1
    )
]

print("Callbacks configur√©s:")
for cb in callbacks_list:
    print(f"  - {cb.__class__.__name__}")

### 4.2 Lancement de l'Entra√Ænement

In [None]:
# Param√®tres d'entra√Ænement
BATCH_SIZE = 32
EPOCHS = 50

print("="*60)
print("D√âBUT DE L'ENTRA√éNEMENT")
print("="*60)
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs max: {EPOCHS}")
print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print("="*60)

# Entra√Ænement
history = cnn_model.train(
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS
)

print("\n‚úì Entra√Ænement termin√©!")

### 4.3 Visualisation de l'Entra√Ænement

In [None]:
# Visualisation de l'historique d'entra√Ænement
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history.history['loss'], label='Train', linewidth=2)
axes[0, 0].plot(history.history['val_loss'], label='Validation', linewidth=2)
axes[0, 0].set_title('Loss (Binary Cross-Entropy)', fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0, 1].plot(history.history['val_accuracy'], label='Validation', linewidth=2)
axes[0, 1].set_title('Accuracy', fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[1, 0].plot(history.history['precision'], label='Train', linewidth=2)
axes[1, 0].plot(history.history['val_precision'], label='Validation', linewidth=2)
axes[1, 0].set_title('Precision', fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Precision')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# AUC
axes[1, 1].plot(history.history['auc'], label='Train', linewidth=2)
axes[1, 1].plot(history.history['val_auc'], label='Validation', linewidth=2)
axes[1, 1].set_title('AUC-ROC', fontweight='bold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('AUC')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle('Historique d\'Entra√Ænement du CNN', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('../reports/figures/training_history.png', dpi=150, bbox_inches='tight')
plt.show()

# Affichage des meilleurs scores
print("\nMeilleurs scores sur validation:")
print("-"*40)
best_epoch = np.argmax(history.history['val_accuracy'])
print(f"Epoch: {best_epoch + 1}")
print(f"Accuracy: {history.history['val_accuracy'][best_epoch]:.4f}")
print(f"Precision: {history.history['val_precision'][best_epoch]:.4f}")
print(f"AUC: {history.history['val_auc'][best_epoch]:.4f}")

## 5. √âvaluation sur le Jeu de Test

In [None]:
# √âvaluation
print("="*60)
print("√âVALUATION SUR LE JEU DE TEST")
print("="*60)

test_loss, test_acc, test_prec, test_rec, test_auc = model.evaluate(
    X_test, y_test, verbose=0
)

print(f"Loss: {test_loss:.4f}")
print(f"Accuracy: {test_acc:.4f}")
print(f"Precision: {test_prec:.4f}")
print(f"Recall: {test_rec:.4f}")
print(f"AUC: {test_auc:.4f}")

# Pr√©dictions
y_pred_proba = model.predict(X_test)
y_pred = (y_pred_proba > 0.5).astype(int).flatten()

# Rapport de classification
print("\nRapport de Classification:")
print("="*60)
print(classification_report(y_test, y_pred, 
                           target_names=['Ouvert (0)', 'Ferm√© (1)']))

### 5.1 Matrice de Confusion

In [None]:
# Matrice de confusion
cm = confusion_matrix(y_test, y_pred)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
           xticklabels=['Ouvert', 'Ferm√©'],
           yticklabels=['Ouvert', 'Ferm√©'],
           ax=axes[0])
axes[0].set_title('Matrice de Confusion', fontweight='bold')
axes[0].set_xlabel('Pr√©dit')
axes[0].set_ylabel('R√©el')

# Normalis√©e
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
           xticklabels=['Ouvert', 'Ferm√©'],
           yticklabels=['Ouvert', 'Ferm√©'],
           ax=axes[1])
axes[1].set_title('Matrice de Confusion (Normalis√©e)', fontweight='bold')
axes[1].set_xlabel('Pr√©dit')
axes[1].set_ylabel('R√©el')

plt.tight_layout()
plt.savefig('../reports/figures/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

# Calcul des m√©triques d√©taill√©es
tn, fp, fn, tp = cm.ravel()
print(f"\nD√©tail des pr√©dictions:")
print(f"  Vrais N√©gatifs (TN): {tn}")
print(f"  Faux Positifs (FP):  {fp}")
print(f"  Faux N√©gatifs (FN):  {fn}")
print(f"  Vrais Positifs (TP): {tp}")
print(f"\nSpecificity (TNR): {tn/(tn+fp):.4f}")
print(f"Sensitivity (TPR): {tp/(tp+fn):.4f}")

### 5.2 Courbe ROC

In [None]:
# Courbe ROC
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
        label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', 
        label='Random classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Taux de Faux Positifs (FPR)', fontsize=12)
plt.ylabel('Taux de Vrais Positifs (TPR)', fontsize=12)
plt.title('Courbe ROC - Classification des Yeux', fontsize=14, fontweight='bold')
plt.legend(loc='lower right', fontsize=11)
plt.grid(True, alpha=0.3)
plt.savefig('../reports/figures/roc_curve.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"AUC-ROC: {roc_auc:.4f}")

## 6. Transfer Learning avec MobileNetV2

### 6.1 Pr√©paration des Donn√©es RGB

In [None]:
# Conversion en RGB 224x224 pour Transfer Learning
def prepare_for_transfer(X, y):
    """Pr√©pare les donn√©es pour MobileNetV2."""
    X_rgb = []
    for img in X:
        # Convertir en 3 canaux
        img_3ch = np.repeat(img, 3, axis=-1)
        # Redimensionner
        img_resized = tf.image.resize(img_3ch, [224, 224])
        # Normalisation [-1, 1] pour MobileNetV2
        img_norm = tf.keras.applications.mobilenet_v2.preprocess_input(
            img_resized * 255
        )
        X_rgb.append(img_norm)
    return np.array(X_rgb), y

print("Pr√©paration des donn√©es pour Transfer Learning...")
X_train_tl, y_train_tl = prepare_for_transfer(X_train, y_train)
X_val_tl, y_val_tl = prepare_for_transfer(X_val, y_val)
X_test_tl, y_test_tl = prepare_for_transfer(X_test, y_test)

print(f"Train: {X_train_tl.shape}")
print(f"Val: {X_val_tl.shape}")
print(f"Test: {X_test_tl.shape}")

### 6.2 Construction du Mod√®le

In [None]:
# Cr√©ation du mod√®le de Transfer Learning
print("="*60)
print("TRANSFER LEARNING - MobileNetV2")
print("="*60)

transfer_model = TransferLearningModel(
    base_model_name="MobileNetV2",
    input_shape=(224, 224, 3)
)

# Phase 1: Feature Extraction
model_tl = transfer_model.build_feature_extractor(trainable=False)

print("\nArchitecture du mod√®le de Transfer Learning:")
model_tl.summary()

### 6.3 Entra√Ænement (Phase 1 - Feature Extraction)

In [None]:
# Entra√Ænement Phase 1
history_tl = transfer_model.train(
    X_train=X_train_tl,
    y_train=y_train_tl,
    X_val=X_val_tl,
    y_val=y_val_tl,
    batch_size=32,
    epochs=20,
    fine_tune=False
)

print("\n‚úì Phase 1 (Feature Extraction) termin√©e!")

### 6.4 Fine-Tuning (Phase 2)

In [None]:
# Phase 2: Fine-tuning
print("\n="*60)
print("PHASE 2: FINE-TUNING")
print("="*60)

history_fine = transfer_model.train(
    X_train=X_train_tl,
    y_train=y_train_tl,
    X_val=X_val_tl,
    y_val=y_val_tl,
    batch_size=32,
    epochs=10,
    fine_tune=True,
    fine_tune_epochs=10
)

print("\n‚úì Fine-tuning termin√©!")

### 6.5 Comparaison des Mod√®les

In [None]:
# √âvaluation du mod√®le Transfer Learning
test_loss_tl, test_acc_tl, test_prec_tl, test_rec_tl, test_auc_tl = model_tl.evaluate(
    X_test_tl, y_test_tl, verbose=0
)

# Comparaison
comparison = {
    'M√©trique': ['Accuracy', 'Precision', 'Recall', 'AUC'],
    'CNN from Scratch': [test_acc, test_prec, test_rec, test_auc],
    'MobileNetV2 (TL)': [test_acc_tl, test_prec_tl, test_rec_tl, test_auc_tl]
}

import pandas as pd
df_comparison = pd.DataFrame(comparison)

print("\nComparaison des Mod√®les:")
print("="*60)
print(df_comparison.to_string(index=False))

# Visualisation
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(df_comparison))
width = 0.35

bars1 = ax.bar(x - width/2, df_comparison['CNN from Scratch'], width, 
               label='CNN from Scratch', color='skyblue', edgecolor='black')
bars2 = ax.bar(x + width/2, df_comparison['MobileNetV2 (TL)'], width,
               label='MobileNetV2 (TL)', color='lightcoral', edgecolor='black')

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Comparaison des Mod√®les', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(df_comparison['M√©trique'])
ax.legend()
ax.set_ylim([0, 1])
ax.grid(axis='y', alpha=0.3)

# Ajouter les valeurs
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{height:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('../reports/figures/model_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Visualisation des Pr√©dictions

In [None]:
# Visualisation des pr√©dictions
n_samples = 12
indices = np.random.choice(len(X_test), n_samples, replace=False)

fig, axes = plt.subplots(3, 4, figsize=(15, 12))
axes = axes.flatten()

for i, idx in enumerate(indices):
    img = X_test[idx].squeeze()
    true_label = y_test[idx]
    pred_proba = y_pred_proba[idx][0]
    pred_label = int(pred_proba > 0.5)
    
    # Couleur selon la pr√©diction
    if pred_label == true_label:
        color = 'green'
        border_color = '#2ecc71'
    else:
        color = 'red'
        border_color = '#e74c3c'
    
    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(
        f'R√©el: {"Ferm√©" if true_label else "Ouvert"}\n'
        f'Pr√©dit: {"Ferm√©" if pred_label else "Ouvert"} ({pred_proba:.2f})',
        color=color, fontweight='bold'
    )
    axes[i].axis('off')
    
    # Bordure color√©e
    for spine in axes[i].spines.values():
        spine.set_edgecolor(border_color)
        spine.set_linewidth(3)

plt.suptitle('Pr√©dictions sur le Jeu de Test', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('../reports/figures/predictions.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Sauvegarde des Mod√®les

In [None]:
# Sauvegarde des mod√®les
print("Sauvegarde des mod√®les...")

# CNN from scratch
model.save('../models/cnn_eye_final.h5')
print("‚úì CNN sauvegard√©: models/cnn_eye_final.h5")

# Transfer Learning
model_tl.save('../models/mobilenet_fatigue_final.h5')
print("‚úì MobileNetV2 sauvegard√©: models/mobilenet_fatigue_final.h5")

# Historiques
import pickle
with open('../models/history_cnn.pkl', 'wb') as f:
    pickle.dump(history.history, f)
with open('../models/history_transfer.pkl', 'wb') as f:
    pickle.dump(history_tl.history, f)
print("‚úì Historiques sauvegard√©s")

print("\n" + "="*60)
print("MOD√àLES ENTRAIN√âS ET SAUVEGARD√âS AVEC SUCC√àS!")
print("="*60)

## 9. R√©sum√© et Prochaines √âtapes

### Concepts du Cours Appliqu√©s

- ‚úÖ **Chapitre 1**: Fonction de perte, optimisation
- ‚úÖ **Chapitre 2**: MLP, r√©gularisation (Dropout)
- ‚úÖ **Chapitre 3-4**: CNN, Transfer Learning, Data Augmentation

### Prochaine √âtape

Le notebook suivant (03_evaluation_et_tests.ipynb) couvrira:
- Tests sur donn√©es r√©elles
- Optimisation pour d√©ploiement
- Int√©gration temps r√©el

---

**Notebook 02 termin√©!** üéâ