# Train MultiTask Models

In [1]:
import os, sys
import json
import tensorflow as tf

In [2]:
sys.path.append(os.path.abspath(os.path.join('..','data_processing')))
sys.path.append(os.path.abspath(os.path.join('..','models')))

In [None]:
from multitask_preprocessing import train_dataset, val_dataset, test_dataset

In [4]:
from efficientnetb0 import create_efficientnetb0_multi_task, create_efficientnetb0_multi_task_v2, create_efficientnetb3_multi_task, create_efficientnetb3_multi_task_v2, create_efficientnetb3_multi_task_v3

In [5]:
from resnet50 import create_resnet50_multi_task, create_resnet50_multi_task_v2

In [6]:
from mobileNet import create_mobileNet_multi_task, create_mobileNet_multi_task_v2

In [4]:
def train_and_save_multi_task_model(model_name,
                                    model_function,
                                    dropout_rate,
                                    results_dir='../results/Multi_Task',
                                    epochs=10,
                                    learning_rate=1e-4,
                                    patience=5):


    # Setup Directories


    model_dir = os.path.join(results_dir, model_name)
    os.makedirs(model_dir, exist_ok=True)


    # Initialize the Multi-Task Model


    model = model_function(dropout_rate)
    
    model.summary()

    # Compile the Model

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss={
            'face_output': 'binary_crossentropy',
            'age_output': 'categorical_crossentropy',
            'gender_output': 'binary_crossentropy'
        },
        metrics={
            'face_output': 'binary_accuracy',
            'age_output': 'categorical_accuracy',
            'gender_output': 'binary_accuracy'
        },
        weighted_metrics={
            'age_output': 'categorical_accuracy',
            'gender_output': 'binary_accuracy'
        }
    )
    
    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_age_output_categorical_accuracy',
            factor=0.5,
            patience=patience,
            min_lr=1e-6,
            verbose=1
        )
    ]

    # Train the Model

    print(f"Starting training for {epochs} epochs...")
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        verbose=1,
        callbacks=callbacks,
    )
    print("Training completed.")


    # Evaluate on Test Data


    print("Evaluating the model on the test set...")
    test_results = model.evaluate(test_dataset, verbose=1)
    print("Evaluation on test set completed.")


    # Map Test Metrics


    metric_names = model.metrics_names
    test_metrics = dict(zip(metric_names, test_results))

    test_metrics = {k: float(v) for k, v in test_metrics.items()}

    # Initialize Training Info

    training_info = {
        "Model": model_name,
        "epochs": epochs,
        "initial_learning_rate": learning_rate,
        "dropout_rate": dropout_rate,
        "patience": patience,
        "train_metrics": {},
        "val_metrics": {},
        "test_metrics": test_metrics
    }


    # Populate Training Metrics


    for key, values in history.history.items():
        if key.startswith('val_'):
            # Validation metric
            metric_name = key  
            if metric_name not in training_info['val_metrics']:
                training_info['val_metrics'][metric_name] = []
            training_info['val_metrics'][metric_name].extend([float(v) for v in values])
        else:
            metric_name = key
            if metric_name not in training_info['train_metrics']:
                training_info['train_metrics'][metric_name] = []
            training_info['train_metrics'][metric_name].extend([float(v) for v in values])


    # Save the Model
    model_path = os.path.join(model_dir, f'{model_name}.h5')
    print(f"Saving the model to {model_path}...")
    model.save(model_path)
    print("Model saved successfully.")

    # Save the Training Info
    info_path = os.path.join(model_dir, f'{model_name}_training_info.json')
    print(f"Saving the training info to {info_path}...")
    with open(info_path, 'w') as f:
        json.dump(training_info, f, indent=4)
    print("Training info saved successfully.")

    print(f"Model and training info saved in: {model_dir}")
    return history, training_info


In [5]:
def load_and_continue_training_multi_task_model(model_name,
                                                results_dir='../results/Multi_Task',
                                                additional_epochs=10,
                                                learning_rate=None,
                                                patience=5):

    # Define Paths
    model_dir = os.path.join(results_dir, model_name)
    model_path = os.path.join(model_dir, f'{model_name}.h5')
    info_path = os.path.join(model_dir, f'{model_name}_training_info.json')

    # Load the Model
    print(f"Loading model from {model_path}...")
    model = tf.keras.models.load_model(model_path)
    print("Model loaded successfully.")


    # Load Existing Training Info
    print(f"Loading training info from {info_path}...")
    with open(info_path, 'r') as f:
        training_info = json.load(f)
    print("Training info loaded successfully.")

    current_lr = training_info.get("initial_learning_rate", 1e-4)
    new_lr = learning_rate if learning_rate is not None else current_lr
    if learning_rate is not None:
        print(f"Updating learning rate from {current_lr} to {new_lr}")


    #  Recompile the Model with Updated Learning Rate
    
    print("Re-compiling the model with updated learning rate and weighted metrics...")
    optimizer = tf.keras.optimizers.Adam(learning_rate=new_lr)
    model.compile(
        optimizer=optimizer,
        loss={
            'face_output': 'binary_crossentropy',
            'age_output': 'categorical_crossentropy',
            'gender_output': 'binary_crossentropy'
        },
        metrics={
            'face_output': 'binary_accuracy',
            'age_output': 'categorical_accuracy',
            'gender_output': 'binary_accuracy'
        },
        weighted_metrics={
            'age_output': 'categorical_accuracy',
            'gender_output': 'binary_accuracy'
        }
    )
    print("Model re-compiled successfully.")


    # Define Callbacks

    callbacks = [
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_age_output_categorical_accuracy',
            factor=0.5,
            patience=patience,
            min_lr=1e-6,
            verbose=1
        )
    ]


    # Continue Training
    print(f"Starting training for {additional_epochs} additional epochs...")
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=additional_epochs,
        verbose=1,
        callbacks=callbacks,
    )
    print("Additional training completed.")

    # Evaluate on Test Data
    print("Evaluating the updated model on the test set...")
    test_results = model.evaluate(test_dataset, verbose=1)
    print("Evaluation on test set completed.")

    metric_names = model.metrics_names
    test_metrics = dict(zip(metric_names, test_results))

    # Convert test_metrics values to native Python types
    test_metrics = {k: float(v) for k, v in test_metrics.items()}


    # Append New Training History to Training Info
    print("Appending new training metrics to training_info...")
    
    training_info.setdefault('train_metrics', {})
    training_info.setdefault('val_metrics', {})
    training_info.setdefault('epochs', 0)

    for key, values in history.history.items():
        if key.startswith('val_'):
            metric_name = key 
            if metric_name not in training_info['val_metrics']:
                training_info['val_metrics'][metric_name] = []
            training_info['val_metrics'][metric_name].extend([float(v) for v in values])
        else:
            metric_name = key
            if metric_name not in training_info['train_metrics']:
                training_info['train_metrics'][metric_name] = []
            training_info['train_metrics'][metric_name].extend([float(v) for v in values])
    print("Training metrics appended successfully.")

    #  Update Epoch Count
    print("Updating epoch count in training_info...")
    training_info["epochs"] += additional_epochs
    print(f"Total epochs after update: {training_info['epochs']}")


    # Replace Test Metrics
    print("Replacing test metrics in training_info...")
    training_info["test_metrics"] = test_metrics
    print("Test metrics replaced successfully.")


    # Save the Updated Model
    print(f"Saving the updated model to {model_path}...")
    model.save(model_path)
    print("Model saved successfully.")


    # Save the Updated Training Info
    print(f"Saving the updated training info to {info_path}...")
    with open(info_path, 'w') as f:
        json.dump(training_info, f, indent=4)
    print("Training info saved successfully.")

    print(f"Model and training info updated and saved in: {model_dir}")
    return history, training_info
