In [None]:
import os, sys
import json
import tensorflow as tf

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print("GPUs available:", gpus)
else:
    print("No GPUs found. Using CPU.")

In [None]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

In [None]:
from multitask_preprocessing import train_data_multiTask, validation_data_multiTask, test_data_multiTask

In [None]:
from efficientnetb0 import create_efficientnetb0_multi_task

In [None]:
def train_and_save_multi_task_model(model_name, model_function, epochs=10, batch_size=32, learning_rate=0.001,
                                    dropout_rate=0.2, patience=5):
    # Define results directory for saving models and training info
    results_dir = '../results/Multi Task'
    model_dir = os.path.join(results_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)

    # Initialize the multi-task model
    model = model_function(dropout_rate)

    # Compile the model with separate losses and metrics for each output
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss={
            'face_output': 'binary_crossentropy',
            'age_output': 'sparse_categorical_crossentropy',
            'gender_output': 'binary_crossentropy'
        },
        metrics={
            'face_output': 'accuracy',
            'age_output': 'accuracy',
            'gender_output': 'accuracy'
        }
    )

    # Define callbacks: Reduce learning rate if validation accuracy for any output doesn’t improve within 'patience' epochs
    reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_face_output_accuracy',  # Monitor one of the outputs
        factor=0.5,
        patience=patience,
        min_lr=1e-6,
        verbose=1
    )

    # Train the model with callbacks
    history = model.fit(
        train_data_multiTask,
        validation_data=validation_data_multiTask,
        epochs=epochs,
        callbacks=[reduce_lr_callback],
        verbose=1,
        batch_size=batch_size,
        workers=6
    )

    # Evaluate on test data for all outputs
    test_results = model.evaluate(test_data_multiTask, verbose=1)

    # Save model
    model_path = os.path.join(model_dir, f'{model_name}.h5')
    model.save(model_path)

    # Gather training info and hyperparameters
    training_info = {
        "Model": model_name,
        "epochs": epochs,
        "batch_size": batch_size,
        "initial_learning_rate": learning_rate,
        "dropout_rate": dropout_rate,
        "patience": patience,
        "train_accuracy_face": history.history['face_output_accuracy'],
        "val_accuracy_face": history.history['val_face_output_accuracy'],
        "train_loss_face": history.history['face_output_loss'],
        "val_loss_face": history.history['val_face_output_loss'],
        "train_accuracy_age": history.history['age_output_accuracy'],
        "val_accuracy_age": history.history['val_age_output_accuracy'],
        "train_loss_age": history.history['age_output_loss'],
        "val_loss_age": history.history['val_age_output_loss'],
        "train_accuracy_gender": history.history['gender_output_accuracy'],
        "val_accuracy_gender": history.history['val_gender_output_accuracy'],
        "train_loss_gender": history.history['gender_output_loss'],
        "val_loss_gender": history.history['val_gender_output_loss'],
        "test_accuracy_face": test_results[3], 
        "test_loss_face": test_results[0],
        "test_accuracy_age": test_results[5],
        "test_loss_age": test_results[1],
        "test_accuracy_gender": test_results[7],
        "test_loss_gender": test_results[2]
    }

    # Save training info to a JSON file
    info_path = os.path.join(model_dir, f'{model_name}_training_info.json')
    with open(info_path, 'w') as f:
        json.dump(training_info, f, indent=4)

    print(f"Model and training info saved in: {model_dir}")
    return history, training_info


In [None]:
train_and_save_multi_task_model('EfficientNetB0_MultiTask', create_efficientnetb0_multi_task)