In [None]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import float32
from tensorflow.image import resize, rgb_to_grayscale, convert_image_dtype
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Activation
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, CSVLogger

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

from config import IMG_HEIGHT, IMG_WIDTH
from src.modules import data_loader
from src.face_recognition import ml_models


################## PATH ##################
DB_PATH = r"..\data\gui_database.db"
LFW_DATASET_PATH = r"..\data\dataset-lfw_reconstructed"
ML_OUTPUT = r"..\data\ml_models"
MODEL_SAVE_DIR = f'{ML_OUTPUT}/trained'
LOG_DIR = f'{ML_OUTPUT}/logs'


############# MODEL SETTINGS #############
#####--- prepare_data_train_model ---#####
INPUT_SHAPE = (100, 100, 1)
IMG_WIDTH, IMG_HEIGHT, CHANNELS = INPUT_SHAPE
SPLIT_STRATEGY = 'stratified'
TEST_SPLIT_RATIO = 0.2
VALIDATION_SPLIT_RATIO = 0.15
RANDOM_STATE = 42
N_TRAIN_PER_SUBJECT = 7
#####--- create_model ---#####
MODEL_NAME = 'simple_cnn_lfw_anony_v1'
MODEL_ARCHITECTURE = 'simple_cnn'
LEARNING_RATE = 0.001
EARLY_STOPPING_PATIENCE = 10
TRANSFER_BASE_MODEL_NAME = 'MobileNetV2'
TRANSFER_FREEZE_BASE = True
#####--- train_model ---#####
BATCH_SIZE = 32
EPOCHS = 50
##########################################

# Import Data

### Noised DB dataset

In [None]:
import controller.ml_controller as mlc

X, y, label_encoder = mlc.MLController.get_data_from_db(DB_PATH)
print(f"(nb_image, width, height, channels) : {X.shape}")

### Noised LFW dataset

In [None]:
from src.modules.data_loader import load_anonymized_images_flat
os.makedirs(LFW_DATASET_PATH, exist_ok=True)

X, y, label_encoder = load_anonymized_images_flat(
    data_dir=LFW_DATASET_PATH,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
    color_mode='grayscale'
)

if not X.shape and not y.shape and not label_encoder:
    raise ValueError('Erreur critique lors du chargement des données. Arrêt du script.')
print(f"\n(nb_image, width, height, channels) : {X.shape}")

# Train

In [None]:
##### prepare_data_train_model #####

# --- 2. Préparation des Données ---

def format_ml_image(image: np.array):
    # Convert to grayscale
    if image.ndim == 2:
        image = image[..., np.newaxis]
    elif image.shape[2] == 3:
        image = rgb_to_grayscale(image).numpy()
    elif image.shape[2] != 1:
        raise ValueError("Image must be grayscale or RGB.")
    # Resize & normalize image
    image = convert_image_dtype(image, dtype=float32)
    image = resize(image, [IMG_WIDTH, IMG_HEIGHT], method="area").numpy()
    return image

# Reshape images (size & channel=1)
processed_X = []
for img in X:
    img = format_ml_image(img)
    processed_X.append(img)
X = np.array(processed_X)

num_classes = len(label_encoder.classes_)
print(f"Nombre de classes détectées : {num_classes}")



# --- 3. Division des Données ---
print("\n--- Division des données ---")
X_train, y_train = None, None
X_val, y_val = None, None
X_test, y_test = None, None

if SPLIT_STRATEGY == 'stratified':
    data_splits = data_loader.split_data_stratified(
        X, y,
        test_size=TEST_SPLIT_RATIO,
        validation_size=VALIDATION_SPLIT_RATIO,
        random_state=RANDOM_STATE
    )
    X_train = data_splits.get('X_train')
    y_train = data_splits.get('y_train')
    X_val = data_splits.get('X_val')
    y_val = data_splits.get('y_val')
    X_test = data_splits.get('X_test')
    y_test = data_splits.get('y_test')

    if X_val is None and VALIDATION_SPLIT_RATIO > 0 and X_train is not None and len(X_train) > 0:
         val_ratio_from_train = VALIDATION_SPLIT_RATIO / (1.0 - TEST_SPLIT_RATIO)
         if val_ratio_from_train < 1.0:
             print(f"Création du set de validation depuis l'entraînement (ratio: {val_ratio_from_train:.2f})")
             X_train, X_val, y_train, y_val = train_test_split(
                 X_train, y_train,
                 test_size=val_ratio_from_train,
                 random_state=RANDOM_STATE,
                 stratify=y_train
             )
         else:
             print("Attention: Ratios de split incohérents, pas de données d'entraînement restantes après validation.")

elif SPLIT_STRATEGY == 'fixed_per_subject':
    X_train_full, X_test, y_train_full, y_test = data_loader.split_data_fixed_per_subject(
        X, y,
        n_train_per_class=N_TRAIN_PER_SUBJECT,
        random_state=RANDOM_STATE
    )
    if VALIDATION_SPLIT_RATIO > 0 and X_train_full is not None and len(X_train_full) > 0:
        print(f"Création du set de validation depuis l'entraînement (ratio: {VALIDATION_SPLIT_RATIO})")
        X_train, X_val, y_train, y_val = train_test_split(
            X_train_full, y_train_full,
            test_size=VALIDATION_SPLIT_RATIO,
            random_state=RANDOM_STATE,
            stratify=y_train_full
        )
    else:
        X_train, y_train = X_train_full, y_train_full
        X_val, y_val = None, None

else:
    print(f"Erreur: Stratégie de split '{SPLIT_STRATEGY}' non reconnue.")

if X_train is None or len(X_train) == 0:
    print("Erreur: Aucune donnée d'entraînement disponible après la division.")

if X_val is None or len(X_val) == 0:
    print("Attention: Aucune donnée de validation disponible. L'entraînement se fera sans suivi de validation.")
    validation_data = None # `fit` utilisera pas de validation
else:
    validation_data = (X_val, y_val)
    print(f"Taille finale - Entraînement: {len(X_train)}, Validation: {len(X_val)}, Test: {len(X_test) if X_test is not None else 0}")

In [None]:
print("--- Démarrage du Script d'Entraînement ---")
start_time = time.time()

# --- 1. Chargement de la Configuration ---
print("Configuration chargée depuis config.py:")
print(f"  - Dossier Données: {LFW_DATASET_PATH}")
print(f"  - Dossier Sauvegarde Modèles: {MODEL_SAVE_DIR}")
print(f"  - Architecture Modèle: {MODEL_ARCHITECTURE}")
print(f"  - Nom Modèle: {MODEL_NAME}")
print(f"  - Dimensions Image: {INPUT_SHAPE}")
print(f"  - Stratégie Split: {SPLIT_STRATEGY}")
print(f"  - Époques: {EPOCHS}, Batch Size: {BATCH_SIZE}")

os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
print(f"  - Dossier Logs TensorBoard: {LOG_DIR}")


In [None]:
# --- 4. Construction du Modèle ---
print("\n--- Construction du modèle ---")
model = None
if MODEL_ARCHITECTURE == 'simple_cnn':
    model = ml_models.build_simple_cnn(input_shape=INPUT_SHAPE, num_classes=num_classes)
elif MODEL_ARCHITECTURE.startswith('transfer_'):
    print(f"Utilisation du modèle de base: {TRANSFER_BASE_MODEL_NAME}, Freeze: {TRANSFER_FREEZE_BASE}")
    model = ml_models.build_transfer_model(input_shape=INPUT_SHAPE,
                                           num_classes=num_classes,
                                           base_model_name=TRANSFER_BASE_MODEL_NAME,
                                           freeze_base=TRANSFER_FREEZE_BASE)
else:
    print(f"Erreur: Architecture de modèle non reconnue dans config: {MODEL_ARCHITECTURE}")

if model is None:
    print("Erreur critique lors de la construction du modèle. Arrêt.")

# --- 5. Compilation du Modèle ---
print("\n--- Compilation du modèle ---")
optimizer = Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
print("Modèle compilé avec Adam optimizer.")
model.summary()

In [None]:
# --- 6. Configuration des Callbacks ---
print("\n--- Configuration des Callbacks ---")
callbacks = []

model_filename = f"{MODEL_NAME}.h5"
model_filepath = os.path.join(MODEL_SAVE_DIR, model_filename)
print(f"  - ModelCheckpoint: Sauvegarde du meilleur modèle dans {model_filepath}")
checkpoint_callback = ModelCheckpoint(
    filepath=model_filepath,
    monitor='val_accuracy',
    save_best_only=True,
    save_weights_only=False,
    mode='max',
    verbose=1
)
callbacks.append(checkpoint_callback)

if EARLY_STOPPING_PATIENCE and EARLY_STOPPING_PATIENCE > 0:
    print(f"  - EarlyStopping: Activé avec patience={EARLY_STOPPING_PATIENCE}")
    early_stopping_callback = EarlyStopping(
        monitor='val_accuracy',
        patience=EARLY_STOPPING_PATIENCE,
        mode='max',
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stopping_callback)
else:
    print("  - EarlyStopping: Désactivé.")


if LOG_DIR and LOG_DIR:
    tensorboard_log_dir = os.path.join(LOG_DIR, MODEL_NAME + "_" + time.strftime("%Y%m%d-%H%M%S"))
    print(f"  - TensorBoard: Logs dans {tensorboard_log_dir}")
    tensorboard_callback = TensorBoard(
        log_dir=tensorboard_log_dir,
        histogram_freq=1
    )
    callbacks.append(tensorboard_callback)
else:
    print("  - TensorBoard: Désactivé.")

csv_log_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}_training_log.csv")
print(f"  - CSVLogger: Logs dans {csv_log_path}")
csv_logger_callback = CSVLogger(csv_log_path, append=False)
callbacks.append(csv_logger_callback)

In [8]:
# --- 7. Entraînement du Modèle ---
print("\n--- Démarrage de l'entraînement ---")
history = None
try:
    history = model.fit(
        X_train, y_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=validation_data,
        callbacks=callbacks,
        verbose=1
    )
    print("--- Entraînement terminé ---")

except Exception as e:
    print(f"\nErreur pendant l'entraînement : {e}")
    encoder_save_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}_label_encoder.joblib")
    print("\nSauvegarde de l'encodeur de labels (même si l'entraînement a échoué)...")
    data_loader.save_label_encoder(label_encoder, encoder_save_path)

KeyboardInterrupt: 

In [None]:
# --- 8. Post-Entraînement ---
encoder_save_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}_label_encoder.joblib")
print("\n--- Sauvegarde de l'encodeur de labels ---")
data_loader.save_label_encoder(label_encoder, encoder_save_path)

if history is not None:
    print("\n--- Affichage des courbes d'apprentissage ---")
    try:
        acc = history.history['accuracy']
        loss = history.history['loss']
        epochs_range = range(len(acc))

        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(epochs_range, acc, label='Training Accuracy')
        if validation_data: # Seulement si validation existe
             val_acc = history.history['val_accuracy']
             plt.plot(epochs_range, val_acc, label='Validation Accuracy')
        plt.legend(loc='lower right')
        plt.title('Training and Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')

        plt.subplot(1, 2, 2)
        plt.plot(epochs_range, loss, label='Training Loss')
        if validation_data: # Seulement si validation existe
            val_loss = history.history['val_loss']
            plt.plot(epochs_range, val_loss, label='Validation Loss')
        plt.legend(loc='upper right')
        plt.title('Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')

        plot_save_path = os.path.join(MODEL_SAVE_DIR, f"{MODEL_NAME}_training_curves.pdf")
        plt.savefig(plot_save_path, format='pdf', bbox_inches='tight')
        print(f"Courbes sauvegardées dans : {plot_save_path}")

    except Exception as plot_e:
        print(f"Erreur lors de la génération/sauvegarde des courbes: {plot_e}")


end_time = time.time()
duration = end_time - start_time
print(f"\n--- Script d'Entraînement Terminé en {duration:.2f} secondes ---")
print(f"Le meilleur modèle devrait être sauvegardé dans : {model_filepath}")
print(f"L'encodeur de labels est sauvegardé dans : {encoder_save_path}")

In [None]:
# Evaluation
eval_loss, eval_acc = model.evaluate(X_test, y_test)
y_pred = np.argmax(model.predict(X_test), axis=1)
cm = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)
print(f"Evaluation Loss: {eval_loss:.4f}")
print(f"Evaluation Accuracy: {eval_acc*100:.2f}%")
print(report)

# Predict noised image

In [None]:
import os
import numpy as np
import time
from typing import Optional
from PIL import Image

user = 9
image_path = f"../data/dataset-lfw_reconstructed/reconstructed_{user}_2.png"

In [None]:
print("--- Démarrage du Script de Prédiction ---")
start_time = time.time()

# --- 1. Charger Configuration et Chemins ---
print("Chargement de la configuration...")
model_filename = f"{MODEL_NAME}.h5" # ou .keras
model_filepath = os.path.join(MODEL_SAVE_DIR, model_filename)
encoder_filename = f"{MODEL_NAME}_label_encoder.joblib"
encoder_filepath = os.path.join(MODEL_SAVE_DIR, encoder_filename)

print(f"  - Modèle utilisé: {model_filepath}")
print(f"  - Encodeur utilisé: {encoder_filepath}")
print(f"  - Image à prédire: {image_path}")

In [None]:
# --- 2. Charger Modèle et Encodeur ---
print("\n--- Chargement du modèle et de l'encodeur ---")
if not os.path.exists(model_filepath):
    print(f"Erreur: Fichier modèle non trouvé: {model_filepath}")
try:
    model = load_model(model_filepath)
    print("Modèle chargé avec succès.")
except Exception as e:
    print(f"Erreur lors du chargement du modèle Keras: {e}")

# Charger l'encodeur de labels
label_encoder = data_loader.load_label_encoder(encoder_filepath)
if label_encoder is None:
    print("Erreur critique : Impossible de charger l'encodeur de labels.")

In [None]:
# --- 3. Prétraiter l'Image d'Entrée ---

def preprocess_single_image(
    image_path: str,
    img_width: int,
    img_height: int,
) -> Optional[np.ndarray]:
    """
    Charge, redimensionne, normalise et formate une image unique pour la prédiction.
    """
    try:
        image = Image.open(image_path)
        image = np.array(image)
        image = format_ml_image(image)
        image = np.expand_dims(image, axis=-1)
        image = np.expand_dims(image, axis=0)
        print(f"Image prétraitée, shape final: {image.shape}")
        return image

    except FileNotFoundError:
        print(f"Erreur: Fichier image introuvable : {image_path}")
        return None
    except Exception as e:
        print(f"Erreur lors du prétraitement de l'image {image_path}: {e}")
        return None

print("\n--- Prétraitement de l'image d'entrée ---")
preprocessed_image = preprocess_single_image(
    image_path=image_path,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
)

if preprocessed_image is None:
    print("Échec du prétraitement de l'image.")

In [None]:
# --- 4. Faire la Prédiction ---
print("\n--- Prédiction ---")
try:
    prediction_probabilities = model.predict(preprocessed_image)

    predicted_index = np.argmax(prediction_probabilities[0])
    prediction_confidence = prediction_probabilities[0][predicted_index]

    predicted_label = label_encoder.inverse_transform([predicted_index])[0]

    print("\n--- Résultat de la Prédiction ---")
    print(f"  - Image : {os.path.basename(image_path)}")
    print(f"  - Identité Prédite (Subject ID) : {predicted_label}")
    print(f"  - Confiance : {prediction_confidence:.4f} ({prediction_confidence*100:.2f}%)")

except Exception as e:
    print(f"Erreur lors de la prédiction: {e}")

end_time = time.time()
duration = end_time - start_time
print(f"\n--- Script de Prédiction Terminé en {duration:.2f} secondes ---")

print(predicted_label)