# Notebook 3 - Entraînement Unet + VGG16
- notebooks/03_model_training_unet_vgg16.ipynb
# 1- Importation librairies

In [1]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import joblib
from dotenv import load_dotenv
load_dotenv()

False

## 1.1 - Chemin racine du projet pour les imports relatifs

In [2]:
# Remonter à la racine du projet
project_root = Path("..").resolve()

# Ajouter src/ au PYTHONPATH
src_path = project_root / "src"
sys.path.append(str(src_path))

# Chemins utiles
data_dir      = project_root / "data"
processed_dir = data_dir / "processed"
models_dir    = project_root / "models"
outputs_dir   = project_root / "outputs"

In [3]:
# from model_training.train_unet_vgg16 import train_unet_vgg16
from utils.utils import plot_history
from utils.monitoring import monitor_resources
from model_training.train_unet_vgg16 import train_unet_vgg16




# 2 - Chargement des données

In [4]:
train_path = processed_dir / "train.npz"
val_path   = processed_dir / "val.npz"

train = np.load(train_path)
val   = np.load(val_path)

X_train, y_train = train["X"], train["Y"]
X_val, y_val     = val["X"], val["Y"]

## 2.1 - Paramètres de test rapide (dev agile ou prod, suivant contexte)
### 2.1.1 - Taille réduite pour test rapide

In [5]:
# params_test = {
#     'output_dir': str(models_dir),
#     'model_name': "unet_vgg16_test",
#     'force_retrain': False,
#     'epochs': 10,
#     'batch_size': 4,
#     'loss_function': "sparse_categorical_crossentropy",
#     'test_mode': False
# }

### 2.1.2 - Full training pour mise en production

In [6]:
# params_prod = {
#     'output_dir': str(models_dir),
#     'model_name': "unet_vgg16",
#     'force_retrain': False,
#     'epochs': 40,
#     'batch_size': 8,
#     'loss_function': "sparse_categorical_crossentropy"
# }

In [7]:
# Paramètres à tester (grid)
params_grid = [
    {'output_dir': str(models_dir), 'model_name': "unet_vgg16_a", 'force_retrain': False, 'epochs': 30, 'batch_size': 4, 'loss_function': "sparse_categorical_crossentropy", 'use_early_stopping': True, 'turbo': True},
    {'output_dir': str(models_dir), 'model_name': "unet_vgg16_b", 'force_retrain': False, 'epochs': 40, 'batch_size': 8, 'loss_function': "sparse_categorical_crossentropy", 'use_early_stopping': True, 'turbo': True},
    {'output_dir': str(models_dir), 'model_name': "unet_vgg16_c", 'force_retrain': False, 'epochs': 50, 'batch_size': 8, 'loss_function': "sparse_categorical_crossentropy", 'use_early_stopping': True, 'turbo': True},
    {'output_dir': str(models_dir), 'model_name': "unet_vgg16_d", 'force_retrain': False, 'epochs': 50, 'batch_size': 16, 'loss_function': "sparse_categorical_crossentropy", 'use_early_stopping': True, 'turbo': True}
]

# 3 - Entraînement du modèle avec sélection du meilleur run

In [8]:
# Exécution des runs
results = []

In [None]:
for i, params in enumerate(params_grid):
    print(f"\n🔁 Entraînement {i+1}/{len(params_grid)} : {params}")
    model, history = train_unet_vgg16(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        **params
    )
    val_acc = max(history.history['val_accuracy']) if not isinstance(history, dict) else 0
    results.append({
        'run': i+1,
        'params': params,
        'val_accuracy': val_acc,
        'model': model,
        'history': history
    })


🔁 Entraînement 1/4 : {'output_dir': 'C:\\Users\\motar\\Desktop\\1-openclassrooms\\AI_Engineer\\1-projets\\P08\\2-python\\models', 'model_name': 'unet_vgg16_a', 'force_retrain': False, 'epochs': 30, 'batch_size': 4, 'loss_function': 'sparse_categorical_crossentropy', 'use_early_stopping': True, 'turbo': True}
🔄 Lancement du serveur MLflow local...
✅ Serveur MLflow démarré sur http://127.0.0.1:5000
[LOG] ➤ train_unet_vgg16 appelé
⚡️ Mode TURBO activé : JIT, Mixed Precision, logs réduits
The dtype policy mixed_float16 may run slowly because this machine does not have a GPU. Only Nvidia GPUs with compute capability of at least 7.0 run quickly with mixed_float16.
[INFO] ⟳ Chargement du modèle existant : C:\Users\motar\Desktop\1-openclassrooms\AI_Engineer\1-projets\P08\2-python\models\unet_vgg16_a_TURBO.h5



🔁 Entraînement 2/4 : {'output_dir': 'C:\\Users\\motar\\Desktop\\1-openclassrooms\\AI_Engineer\\1-projets\\P08\\2-python\\models', 'model_name': 'unet_vgg16_b', 'force_retrain': False, 

## 3.1 - Sélection du meilleur run

In [None]:
best_run = sorted(results, key=lambda x: x['val_accuracy'], reverse=True)[0]
best_model = best_run['model']
best_history = best_run['history']
best_params = best_run['params']

In [None]:
print(f"\n✅ Meilleur modèle : {best_params['model_name']} avec val_accuracy = {best_run['val_accuracy']:.4f}")

In [None]:
# best_model_path = Path(models_dir) / "best_unet_vgg16.h5"
# best_history_path = Path(models_dir) / "best_unet_vgg16_history.pkl"
best_model_path = models_dir / "best_unet_vgg16.h5"
best_history_path = models_dir / "best_unet_vgg16_history.pkl"

In [None]:
best_model.save(best_model_path)
joblib.dump(best_history.history, best_history_path)

In [None]:
# # params = params_test
# params = params_prod

# model, history = train_unet_vgg16(
#     X_train=X_train,
#     y_train=y_train,
#     X_val=X_val,
#     y_val=y_val,
#     **params
# )

# 4 - Résumé et courbes du modèle

In [None]:
if isinstance(best_history, dict) and best_history.get("test_mode"):
    print("[INFO] ✅ Test rapide terminé. Pas d'entraînement complet.")
else:
    print("[INFO] 📊 Résumé du modèle champion et affichage des courbes")
    best_model.summary()
    plot_path = outputs_dir / f"plot_{best_params['model_name']}_BEST.png"
    plot_history(best_history, plot_path)

In [None]:
# if isinstance(history, dict) and history.get("test_mode"):
#     print("[INFO] ✅ Test rapide terminé. Pas d'entraînement complet.")
# else:
#     print("[INFO] 📊 Affichage du résumé du modèle et des courbes")
#     model.summary()

#     # Construction du chemin d'enregistrement depuis notebooks/
#     plot_path = Path("..") / "outputs" / f"plot_{params['model_name']}.png"
#     plot_history(history, plot_path)

### === Exemple de prédiction (facultatif pour test rapide) ===
#### y_pred = model.predict(X_val[:1])
#### plt.imshow(np.argmax(y_pred[0], axis=-1))
#### plt.title("Exemple de prédiction")
#### plt.show()