# Plant Doctor - Entra√Ænement du Mod√®le CNN

**Configuration:** AMD RX 7900 XT avec TensorFlow-DirectML (Windows)

Ce notebook entra√Æne un mod√®le EfficientNet-B0 pour la classification des maladies de plantes sur le dataset PlantVillage (38 classes).

## Fonctionnalit√©s:
- **Checkpoints automatiques** - Sauvegarde √† chaque epoch, reprise possible
- **Graphiques en temps r√©el** - Visualisation pendant l'entra√Ænement
- **D√©tection d'overfitting** - Alerte si val_loss diverge trop
- **Estimation du temps** - Temps restant estim√©

## Temps estim√©:
- **Phase 1 (10 epochs):** ~30-45 minutes (GPU AMD)
- **Phase 2 (15 epochs):** ~60-90 minutes (GPU AMD)
- **Total:** ~1h30 √† 2h30 selon la configuration

## 1. Installation des d√©pendances

Pour AMD sur Windows, on utilise **tensorflow-directml** qui supporte les GPU AMD via DirectX 12.

In [None]:
# Installation tensorflow-directml pour AMD GPU sur Windows
# Ex√©cuter UNE SEULE FOIS
# !pip install tensorflow-directml-plugin tensorflow==2.10.0

# Alternative: tensorflow standard (CPU) si DirectML pose probl√®me
# !pip install tensorflow

# Autres d√©pendances
# !pip install pillow matplotlib scikit-learn kaggle

: 

In [3]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from pathlib import Path
import json
import time
from datetime import timedelta
from IPython.display import display, clear_output

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU disponibles: {tf.config.list_physical_devices('GPU')}")

# V√©rifier si DirectML est utilis√©
try:
    devices = tf.config.list_physical_devices()
    gpu_found = any('GPU' in str(d) or 'DML' in str(d) for d in devices)
    print(f"GPU/DirectML d√©tect√©: {'Oui' if gpu_found else 'Non (CPU mode)'}")
except Exception as e:
    print(f"Mode CPU (DirectML non d√©tect√©): {e}")

ModuleNotFoundError: No module named 'numpy'

## 2. T√©l√©chargement du Dataset PlantVillage

Le dataset PlantVillage contient ~54,000 images de feuilles de plantes r√©parties en 38 classes (maladies + plantes saines).

**Option 1:** T√©l√©charger depuis Kaggle (recommand√©)
- Dataset: https://www.kaggle.com/datasets/emmarex/plantdisease

**Option 2:** T√©l√©charger manuellement et placer dans `data/PlantVillage/`

In [None]:
# Configuration des chemins
PROJECT_ROOT = Path("C:/TFE-4")
DATA_DIR = PROJECT_ROOT / "data"
DATASET_DIR = DATA_DIR / "PlantVillage"
MODEL_DIR = PROJECT_ROOT / "models"

# Cr√©er les dossiers si n√©cessaire
DATA_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

print(f"Dataset directory: {DATASET_DIR}")
print(f"Exists: {DATASET_DIR.exists()}")

In [None]:
# Option 1: T√©l√©charger via Kaggle API
# N√©cessite un fichier kaggle.json dans ~/.kaggle/

# !kaggle datasets download -d emmarex/plantdisease -p {DATA_DIR}
# !unzip {DATA_DIR}/plantdisease.zip -d {DATA_DIR}

# Option 2: T√©l√©chargement manuel
# 1. Aller sur https://www.kaggle.com/datasets/emmarex/plantdisease
# 2. T√©l√©charger et extraire dans C:/TFE-4/data/PlantVillage/
# Structure attendue:
#   data/PlantVillage/
#     ‚îú‚îÄ‚îÄ Apple___Apple_scab/
#     ‚îú‚îÄ‚îÄ Apple___Black_rot/
#     ‚îú‚îÄ‚îÄ ...
#     ‚îî‚îÄ‚îÄ Tomato___healthy/

print("T√©l√©chargez le dataset PlantVillage et placez-le dans:")
print(str(DATASET_DIR))

In [None]:
# V√©rifier le dataset
if DATASET_DIR.exists():
    classes = sorted([d.name for d in DATASET_DIR.iterdir() if d.is_dir()])
    print(f"Nombre de classes: {len(classes)}")
    print(f"\nClasses trouv√©es:")
    for i, cls in enumerate(classes):
        count = len(list((DATASET_DIR / cls).glob('*')))
        print(f"  {i:2d}. {cls}: {count} images")
else:
    print("Dataset non trouv√©! T√©l√©chargez-le d'abord.")

## 3. Pr√©paration des Donn√©es

In [None]:
# Param√®tres
IMG_SIZE = 224  # EfficientNet-B0 input size
BATCH_SIZE = 32  # R√©duire si manque de VRAM (16 ou 8)
VALIDATION_SPLIT = 0.2
SEED = 42

# Data augmentation pour l'entra√Ænement
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=VALIDATION_SPLIT
)

# Pas d'augmentation pour la validation
val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=VALIDATION_SPLIT
)

In [None]:
# Cr√©er les g√©n√©rateurs
train_generator = train_datagen.flow_from_directory(
    DATASET_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    seed=SEED
)

val_generator = val_datagen.flow_from_directory(
    DATASET_DIR,
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    seed=SEED
)

# Sauvegarder les labels
class_indices = train_generator.class_indices
class_names = list(class_indices.keys())
num_classes = len(class_names)

print(f"\nNombre de classes: {num_classes}")
print(f"Images d'entra√Ænement: {train_generator.samples}")
print(f"Images de validation: {val_generator.samples}")

In [None]:
# Sauvegarder le mapping des classes
class_labels = {v: k for k, v in class_indices.items()}

# Cr√©er un fichier JSON avec les labels
labels_file = DATA_DIR / "class_labels.json"
with open(labels_file, 'w', encoding='utf-8') as f:
    json.dump(class_labels, f, indent=2, ensure_ascii=False)

print(f"Labels sauvegard√©s dans: {labels_file}")

## 4. Construction du Mod√®le (EfficientNet-B0 + Transfer Learning)

In [None]:
def create_model(num_classes, fine_tune_at=100):
    """
    Cr√©er un mod√®le EfficientNet-B0 avec transfer learning.
    
    Args:
        num_classes: Nombre de classes de sortie
        fine_tune_at: Couche √† partir de laquelle fine-tuner (0 = tout geler)
    """
    # Charger EfficientNet-B0 pr√©-entra√Æn√© sur ImageNet
    base_model = EfficientNetB0(
        weights='imagenet',
        include_top=False,
        input_shape=(IMG_SIZE, IMG_SIZE, 3)
    )
    
    # Geler les couches de base
    base_model.trainable = False
    
    # Construire le mod√®le complet
    model = keras.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model, base_model

# Cr√©er le mod√®le
model, base_model = create_model(num_classes)
model.summary()

In [None]:
# Compiler le mod√®le
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

## 5. Entra√Ænement - Phase 1 (Feature Extraction)

In [None]:
# ============================================================
# CALLBACKS AVANC√âS AVEC GRAPHIQUES EN TEMPS R√âEL
# ============================================================

class TrainingMonitor(keras.callbacks.Callback):
    """
    Callback personnalis√© pour:
    - Afficher la progression avec barre
    - Graphiques en temps r√©el
    - D√©tection d'overfitting
    - Estimation du temps restant
    """
    
    def __init__(self, total_epochs, phase_name="Training"):
        super().__init__()
        self.total_epochs = total_epochs
        self.phase_name = phase_name
        self.history = {'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': []}
        self.epoch_times = []
        self.start_time = None
        self.fig = None
        self.axes = None
        
    def on_train_begin(self, logs=None):
        self.start_time = time.time()
        print(f"\n{'='*60}")
        print(f"  {self.phase_name} - {self.total_epochs} epochs")
        print(f"{'='*60}\n")
        
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch_start = time.time()
        
    def on_epoch_end(self, epoch, logs=None):
        # Enregistrer le temps
        epoch_time = time.time() - self.epoch_start
        self.epoch_times.append(epoch_time)
        
        # Enregistrer les m√©triques
        self.history['loss'].append(logs.get('loss', 0))
        self.history['accuracy'].append(logs.get('accuracy', 0))
        self.history['val_loss'].append(logs.get('val_loss', 0))
        self.history['val_accuracy'].append(logs.get('val_accuracy', 0))
        
        # Calculer le temps restant
        avg_epoch_time = np.mean(self.epoch_times)
        remaining_epochs = self.total_epochs - (epoch + 1)
        eta = timedelta(seconds=int(avg_epoch_time * remaining_epochs))
        
        # Barre de progression
        progress = (epoch + 1) / self.total_epochs
        bar_length = 30
        filled = int(bar_length * progress)
        bar = '‚ñà' * filled + '‚ñë' * (bar_length - filled)
        
        # D√©tecter l'overfitting
        overfitting_warning = ""
        if len(self.history['val_loss']) > 3:
            recent_val_loss = self.history['val_loss'][-3:]
            recent_train_loss = self.history['loss'][-3:]
            gap = np.mean(recent_val_loss) - np.mean(recent_train_loss)
            if gap > 0.3:
                overfitting_warning = " ‚ö†Ô∏è OVERFITTING POSSIBLE"
        
        # Afficher la progression
        clear_output(wait=True)
        print(f"\n{self.phase_name}")
        print(f"[{bar}] {epoch+1}/{self.total_epochs} ({progress*100:.0f}%)")
        print(f"")
        print(f"üìä Epoch {epoch+1} Results:")
        print(f"   Loss:     {logs.get('loss', 0):.4f} (train) | {logs.get('val_loss', 0):.4f} (val)")
        print(f"   Accuracy: {logs.get('accuracy', 0)*100:.2f}% (train) | {logs.get('val_accuracy', 0)*100:.2f}% (val)")
        print(f"")
        print(f"‚è±Ô∏è  Temps epoch: {timedelta(seconds=int(epoch_time))} | ETA: {eta}{overfitting_warning}")
        
        # Afficher le graphique
        self._plot_progress()
        
    def _plot_progress(self):
        """Afficher les graphiques de progression."""
        if len(self.history['loss']) < 1:
            return
            
        epochs = range(1, len(self.history['loss']) + 1)
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 4))
        
        # Loss
        axes[0].plot(epochs, self.history['loss'], 'b-', label='Train', linewidth=2)
        axes[0].plot(epochs, self.history['val_loss'], 'r-', label='Validation', linewidth=2)
        axes[0].fill_between(epochs, self.history['loss'], self.history['val_loss'], 
                             alpha=0.2, color='gray')
        axes[0].set_title('Loss', fontsize=12, fontweight='bold')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Accuracy
        axes[1].plot(epochs, [a*100 for a in self.history['accuracy']], 'b-', 
                     label='Train', linewidth=2)
        axes[1].plot(epochs, [a*100 for a in self.history['val_accuracy']], 'r-', 
                     label='Validation', linewidth=2)
        axes[1].set_title('Accuracy', fontsize=12, fontweight='bold')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('Accuracy (%)')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3)
        axes[1].set_ylim([0, 100])
        
        plt.tight_layout()
        plt.show()
        
    def on_train_end(self, logs=None):
        total_time = time.time() - self.start_time
        print(f"\n{'='*60}")
        print(f"  {self.phase_name} termin√©!")
        print(f"  Temps total: {timedelta(seconds=int(total_time))}")
        print(f"  Meilleure val_accuracy: {max(self.history['val_accuracy'])*100:.2f}%")
        print(f"{'='*60}\n")


# Callbacks standards
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=str(MODEL_DIR / 'checkpoint_phase1_epoch{epoch:02d}.keras'),
    monitor='val_accuracy',
    save_best_only=False,  # Sauvegarder chaque epoch pour reprise
    save_weights_only=False,
    verbose=0
)

best_model_callback = keras.callbacks.ModelCheckpoint(
    filepath=str(MODEL_DIR / 'best_model_phase1.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    verbose=0
)

early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

print("‚úÖ Callbacks configur√©s:")
print("   - Sauvegarde checkpoint chaque epoch")
print("   - Sauvegarde meilleur mod√®le")
print("   - Early stopping (patience=5)")
print("   - R√©duction learning rate automatique")

In [None]:
# Phase 1: Entra√Æner uniquement les couches ajout√©es (base gel√©e)
EPOCHS_PHASE1 = 10

# Cr√©er le moniteur d'entra√Ænement
monitor_phase1 = TrainingMonitor(EPOCHS_PHASE1, "üå± Phase 1: Feature Extraction")

# Liste des callbacks
callbacks_phase1 = [
    monitor_phase1,
    checkpoint_callback,
    best_model_callback,
    early_stopping,
    reduce_lr
]

print("üöÄ D√©marrage Phase 1...")
print(f"   Epochs: {EPOCHS_PHASE1}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Steps par epoch: {train_generator.samples // BATCH_SIZE}")
print("")

history1 = model.fit(
    train_generator,
    epochs=EPOCHS_PHASE1,
    validation_data=val_generator,
    callbacks=callbacks_phase1,
    verbose=0  # On utilise notre propre affichage
)

In [None]:
# ============================================================
# REPRISE D'ENTRA√éNEMENT (si interrompu)
# ============================================================
# D√©commentez cette cellule si vous devez reprendre l'entra√Ænement

"""
# Trouver le dernier checkpoint
import glob
checkpoints = sorted(glob.glob(str(MODEL_DIR / 'checkpoint_phase1_epoch*.keras')))
if checkpoints:
    latest_checkpoint = checkpoints[-1]
    print(f"üîÑ Reprise depuis: {latest_checkpoint}")
    
    # Extraire le num√©ro d'epoch
    import re
    match = re.search(r'epoch(\d+)', latest_checkpoint)
    start_epoch = int(match.group(1)) if match else 0
    
    # Charger le mod√®le
    model = keras.models.load_model(latest_checkpoint)
    print(f"   Mod√®le charg√©, reprise √† l'epoch {start_epoch + 1}")
    
    # Ajuster les epochs restants
    EPOCHS_PHASE1 = 10
    remaining_epochs = EPOCHS_PHASE1 - start_epoch
    print(f"   Epochs restants: {remaining_epochs}")
else:
    print("Aucun checkpoint trouv√©, d√©marrage depuis le d√©but")
"""
print("üí° Pour reprendre un entra√Ænement interrompu, d√©commentez la cellule ci-dessus")

## 6. Entra√Ænement - Phase 2 (Fine-Tuning)

In [None]:
# Recompiler avec un learning rate plus bas
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-5),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks pour phase 2
checkpoint_phase2 = keras.callbacks.ModelCheckpoint(
    filepath=str(MODEL_DIR / 'checkpoint_phase2_epoch{epoch:02d}.keras'),
    monitor='val_accuracy',
    save_best_only=False,
    verbose=0
)

best_model_phase2 = keras.callbacks.ModelCheckpoint(
    filepath=str(MODEL_DIR / 'best_model_phase2.keras'),
    monitor='val_accuracy',
    save_best_only=True,
    verbose=0
)

early_stopping_phase2 = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_phase2 = keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=3,
    min_lr=1e-8,
    verbose=1
)

In [None]:
# Phase 2: Fine-tuning
EPOCHS_PHASE2 = 15

# Cr√©er le moniteur d'entra√Ænement
monitor_phase2 = TrainingMonitor(EPOCHS_PHASE2, "üåø Phase 2: Fine-Tuning")

# Liste des callbacks
callbacks_phase2 = [
    monitor_phase2,
    checkpoint_phase2,
    best_model_phase2,
    early_stopping_phase2,
    reduce_lr_phase2
]

print("üöÄ D√©marrage Phase 2 (Fine-Tuning)...")
print(f"   Epochs: {EPOCHS_PHASE2}")
print(f"   Learning rate: 1e-5 (r√©duit)")
print("")

history2 = model.fit(
    train_generator,
    epochs=EPOCHS_PHASE2,
    validation_data=val_generator,
    callbacks=callbacks_phase2,
    verbose=0
)

In [None]:
# Phase 2: Fine-tuning
EPOCHS_PHASE2 = 15

print("\nPhase 2: Fine-Tuning")
print("="*50)

history2 = model.fit(
    train_generator,
    epochs=EPOCHS_PHASE2,
    validation_data=val_generator,
    callbacks=callbacks_phase2
)

## 7. √âvaluation et Sauvegarde

In [None]:
# ============================================================
# GRAPHIQUES FINAUX ET ANALYSE D'OVERFITTING
# ============================================================

def plot_final_history(history1, history2, save_path):
    """Graphique complet avec analyse d'overfitting."""
    
    # Combiner les historiques
    acc = history1.history['accuracy'] + history2.history['accuracy']
    val_acc = history1.history['val_accuracy'] + history2.history['val_accuracy']
    loss = history1.history['loss'] + history2.history['loss']
    val_loss = history1.history['val_loss'] + history2.history['val_loss']
    
    epochs = range(1, len(acc) + 1)
    phase1_end = len(history1.history['accuracy'])
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Loss
    axes[0, 0].plot(epochs, loss, 'b-', label='Train', linewidth=2)
    axes[0, 0].plot(epochs, val_loss, 'r-', label='Validation', linewidth=2)
    axes[0, 0].axvline(x=phase1_end, color='green', linestyle='--', alpha=0.7, label='Fine-tuning start')
    axes[0, 0].fill_between(epochs, loss, val_loss, alpha=0.2, color='orange')
    axes[0, 0].set_title('Loss Evolution', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Accuracy
    axes[0, 1].plot(epochs, [a*100 for a in acc], 'b-', label='Train', linewidth=2)
    axes[0, 1].plot(epochs, [a*100 for a in val_acc], 'r-', label='Validation', linewidth=2)
    axes[0, 1].axvline(x=phase1_end, color='green', linestyle='--', alpha=0.7, label='Fine-tuning start')
    axes[0, 1].set_title('Accuracy Evolution', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim([0, 100])
    
    # 3. Overfitting Gap (Train - Val)
    gap_acc = [t - v for t, v in zip(acc, val_acc)]
    gap_loss = [v - t for t, v in zip(loss, val_loss)]
    
    axes[1, 0].plot(epochs, [g*100 for g in gap_acc], 'purple', linewidth=2, label='Accuracy Gap')
    axes[1, 0].axhline(y=0, color='gray', linestyle='-', alpha=0.5)
    axes[1, 0].axhline(y=5, color='orange', linestyle='--', alpha=0.7, label='Warning threshold (5%)')
    axes[1, 0].axhline(y=10, color='red', linestyle='--', alpha=0.7, label='Danger threshold (10%)')
    axes[1, 0].fill_between(epochs, 0, [g*100 for g in gap_acc], alpha=0.3, 
                            color=['green' if g < 0.05 else 'orange' if g < 0.1 else 'red' for g in gap_acc])
    axes[1, 0].set_title('Overfitting Detection (Train - Val Accuracy)', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Gap (%)')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Learning Rate Effect
    axes[1, 1].plot(epochs, val_acc, 'r-', linewidth=2, label='Val Accuracy')
    axes[1, 1].axvline(x=phase1_end, color='green', linestyle='--', alpha=0.7, label='LR: 0.001 ‚Üí 0.00001')
    
    # Marquer le meilleur epoch
    best_epoch = np.argmax(val_acc) + 1
    best_val_acc = max(val_acc)
    axes[1, 1].scatter([best_epoch], [best_val_acc], color='gold', s=200, zorder=5, 
                       marker='‚òÖ', label=f'Best: {best_val_acc*100:.2f}% (epoch {best_epoch})')
    
    axes[1, 1].set_title('Best Model Selection', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Validation Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    # Rapport d'overfitting
    print("\n" + "="*60)
    print("üìä RAPPORT D'ANALYSE")
    print("="*60)
    
    final_gap = gap_acc[-1] * 100
    if final_gap < 5:
        status = "‚úÖ EXCELLENT"
        msg = "Pas d'overfitting d√©tect√©"
    elif final_gap < 10:
        status = "‚ö†Ô∏è ATTENTION"
        msg = "L√©ger overfitting, consid√©rez plus de r√©gularisation"
    else:
        status = "‚ùå OVERFITTING"
        msg = "Overfitting significatif, utilisez early stopping ou plus de donn√©es"
    
    print(f"\n{status}: {msg}")
    print(f"   Gap final (train-val): {final_gap:.2f}%")
    print(f"   Meilleure val_accuracy: {max(val_acc)*100:.2f}% (epoch {best_epoch})")
    print(f"   Loss finale: {val_loss[-1]:.4f}")
    print("="*60)

# G√©n√©rer les graphiques
plot_final_history(history1, history2, MODEL_DIR / 'training_history_complete.png')

In [None]:
# Visualiser l'historique d'entra√Ænement
def plot_history(history1, history2):
    # Combiner les historiques
    acc = history1.history['accuracy'] + history2.history['accuracy']
    val_acc = history1.history['val_accuracy'] + history2.history['val_accuracy']
    loss = history1.history['loss'] + history2.history['loss']
    val_loss = history1.history['val_loss'] + history2.history['val_loss']
    
    epochs = range(1, len(acc) + 1)
    phase1_end = len(history1.history['accuracy'])
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Accuracy
    ax1.plot(epochs, acc, 'b-', label='Training')
    ax1.plot(epochs, val_acc, 'r-', label='Validation')
    ax1.axvline(x=phase1_end, color='g', linestyle='--', label='Fine-tuning start')
    ax1.set_title('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True)
    
    # Loss
    ax2.plot(epochs, loss, 'b-', label='Training')
    ax2.plot(epochs, val_loss, 'r-', label='Validation')
    ax2.axvline(x=phase1_end, color='g', linestyle='--', label='Fine-tuning start')
    ax2.set_title('Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(str(MODEL_DIR / 'training_history.png'), dpi=150)
    plt.show()

plot_history(history1, history2)

In [None]:
# Sauvegarder le mod√®le final
final_model_path = MODEL_DIR / 'efficientnet_plant_disease.keras'
model.save(final_model_path)
print(f"\nMod√®le sauvegard√©: {final_model_path}")

# Aussi sauvegarder en format .h5 (compatibilit√©)
h5_model_path = MODEL_DIR / 'efficientnet_plant_disease.h5'
model.save(h5_model_path)
print(f"Mod√®le H5 sauvegard√©: {h5_model_path}")

## 8. Test du Mod√®le

In [None]:
from tensorflow.keras.preprocessing import image

def predict_image(model, img_path, class_labels):
    """
    Pr√©dire la classe d'une image.
    """
    # Charger et pr√©traiter l'image
    img = image.load_img(img_path, target_size=(IMG_SIZE, IMG_SIZE))
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = img_array / 255.0
    
    # Pr√©diction
    predictions = model.predict(img_array, verbose=0)
    predicted_class = np.argmax(predictions[0])
    confidence = predictions[0][predicted_class]
    
    # R√©cup√©rer le nom de la classe
    class_name = class_labels[str(predicted_class)]
    
    return class_name, confidence, predictions[0]

# Test avec une image du dataset
# test_img = list(DATASET_DIR.glob('*/*.jpg'))[0]
# class_name, confidence, _ = predict_image(model, test_img, class_labels)
# print(f"Image: {test_img.name}")
# print(f"Pr√©diction: {class_name}")
# print(f"Confiance: {confidence:.2%}")

## 9. R√©sum√©

### Fichiers g√©n√©r√©s:
- `models/efficientnet_plant_disease.keras` - Mod√®le final
- `models/efficientnet_plant_disease.h5` - Format H5
- `data/class_labels.json` - Mapping index ‚Üí nom de classe
- `models/training_history.png` - Courbes d'entra√Ænement

### Prochaines √©tapes:
1. Copier le mod√®le dans `C:/TFE-4/models/`
2. Cr√©er le fichier `disease_info.json` avec les descriptions des maladies
3. Impl√©menter le service de chargement du mod√®le (Epic 2)