In [None]:
import os
import time
import numpy as np
from PIL import Image

from src.controller import ml_controller as mlc2


################## PATH ##################
LFW_DATASET_PATH = r"..\data\dataset-lfw_reconstructed"
DB_PATH = r"..\data\gui_database.db"
ML_OUTPUT = r"..\data\ml_models"
MODEL_SAVE_DIR = f'{ML_OUTPUT}/trained'
LOG_DIR = f'{ML_OUTPUT}/logs'


############# MODEL SETTINGS #############
#####--- prepare_data_train_model ---#####
INPUT_SHAPE = (100, 100, 1)
IMG_WIDTH, IMG_HEIGHT, CHANNELS = INPUT_SHAPE
SPLIT_STRATEGY = 'stratified'
TEST_SPLIT_RATIO = 0.2
VALIDATION_SPLIT_RATIO = 0.15
RANDOM_STATE = 42
N_TRAIN_PER_SUBJECT = 7
#####--- create_model ---#####
MODEL_NAME = 'simple_cnn_lfw_anony_v1'
MODEL_ARCHITECTURE = 'simple_cnn'
LEARNING_RATE = 0.001
EARLY_STOPPING_PATIENCE = 10
TRANSFER_BASE_MODEL_NAME = 'MobileNetV2'
TRANSFER_FREEZE_BASE = True
#####--- train_model ---#####
BATCH_SIZE = 32
EPOCHS = 50
##########################################

# Import Data

### Noised DB dataset

In [None]:
X, y, label_encoder = mlc2.MLController.get_data_from_db(DB_PATH)
print(f"(nb_image, width, height, channels) : {X.shape}")

### Noised LFW dataset

In [None]:
from src.modules.data_loader import load_anonymized_images_flat
os.makedirs(LFW_DATASET_PATH, exist_ok=True)

X, y, label_encoder = load_anonymized_images_flat(
    data_dir=LFW_DATASET_PATH,
    img_width=IMG_WIDTH,
    img_height=IMG_HEIGHT,
    color_mode='grayscale'
)

if not X.shape and not y.shape and not label_encoder:
    raise ValueError('Critical error while loading data. Script stopped..')
print(f"\n(nb_image, width, height, channels) : {X.shape}")

# Train

In [None]:
res = mlc2.prepare_data_train_model(
    X, y, label_encoder,
    input_shape=INPUT_SHAPE,
    split_strategy=SPLIT_STRATEGY,
    test_split_ratio=TEST_SPLIT_RATIO,
    validation_split_ratio=VALIDATION_SPLIT_RATIO,
    random_state=RANDOM_STATE,
    n_train_per_subject=N_TRAIN_PER_SUBJECT
)
num_classes, X_train, y_train, X_test, y_test, X_val, y_val, validation_data = res
print(X_train.shape)

In [None]:
res = mlc2.create_model(
    num_classes,
    input_shape=INPUT_SHAPE,
    model_save_dir=MODEL_SAVE_DIR,
    log_dir=LOG_DIR,
    model_name=MODEL_NAME,
    model_architecture=MODEL_ARCHITECTURE,
    learning_rate=LEARNING_RATE,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    transfer_base_model_name=TRANSFER_BASE_MODEL_NAME,
    transfer_freeze_base=TRANSFER_FREEZE_BASE
)
model, callbacks, model_filepath, summary_text = res

In [None]:
# Start timer
print("--- Starting the Training Script ---")
start_time = time.time()

# Train Model
res = mlc2.train_model(
    model,
    X_train, y_train, X_test, y_test,
    validation_data, callbacks, label_encoder, model_filepath,
    model_save_dir=MODEL_SAVE_DIR,
    model_name=MODEL_NAME,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,

)
for key, val in res.items(): print(f"\n{key} : {val}")

# End timer
end_time = time.time()
duration = end_time - start_time
print(f"--- Training Script Completed in {duration:.2f} secondes ---")

# Predict noised image (l'image source ne doit pas etre noised normalement)

In [None]:
user = 27
image_path = f"../data/dataset-lfw_reconstructed/reconstructed_{user}_2.png"
image = np.array(Image.open(image_path))

result = mlc2.predict_image(image, MODEL_SAVE_DIR, MODEL_NAME, INPUT_SHAPE)
predicted_label, prediction_confidence = result
print(f"  - Predicted Identity (Subject ID) : {predicted_label}")
print(f"  - Trust : {prediction_confidence:.4f} ({prediction_confidence*100:.2f}%)")