In [None]:
# Imports
import os
import utils
import pandas as pd
import cv2
import numpy as np
import concurrent.futures
from sklearn.utils.class_weight import compute_class_weight
from functools import partial

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, concatenate
from tensorflow.keras import Model  

In [None]:
# Training hyperparameters

METRICS_CLASSIFICATION_BINARY = [
    tf.keras.metrics.Precision(),
    tf.keras.metrics.Recall(),
    tf.keras.metrics.BinaryAccuracy()
]

METRICS_CLASSIFICATION_MULTICLASS = [
    tf.keras.metrics.Precision(),
    tf.keras.metrics.Recall(),
    tf.keras.metrics.CategoricalAccuracy()
]

METRICS_REGRESSION = [
    tf.keras.metrics.MeanSquaredError(),
    tf.keras.metrics.MeanAbsoluteError()
]

BATCH_SIZE = 64
EPOCHS = 250

In [None]:
# Helper functions
def load_image(img_path, image_method_folder):
    return cv2.imread(os.path.join(image_method_folder, img_path.replace("\\", "/")))

def train_model(model, X_train, X_val, y_train, y_val, classes_weight):
    model.fit(
        x = X_train,
        y = y_train,
        validation_data = (X_val, y_val),
        batch_size = BATCH_SIZE,
        epochs = EPOCHS,
        callbacks = [
            tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 25, restore_best_weights = True, start_from_epoch = 10)
        ],
        class_weight = classes_weight
    )

def build_model(img_input_shape, mlp_input_shape, is_classification_dataset, n_classes=None):
    """Builds the model for a NN with 2 CNN branches"""
    if is_classification_dataset:
        assert n_classes is not None
        assert n_classes > 1
    else:
        assert n_classes is None

    # Define the model architecture
    input_cnn, output_cnn = utils.get_cnn_branch(img_input_shape)
    input_ffnn, output_ffnn = utils.get_mlp_branch(mlp_input_shape)

    # Combine the branches
    combined_outputs = concatenate([output_ffnn,output_cnn])

    # Hidden layer after concatenation
    out = Dense(16, activation="relu")(combined_outputs)

    # Final layer
    if is_classification_dataset:
        if n_classes == 2:
            out = Dense(1, activation='sigmoid')(out)
        else:
            out = Dense(n_classes, activation='softmax')(out)
    else:
        out = Dense(1, activation="linear")(out)

    # Create the model
    model = Model(
        inputs=[input_cnn, input_ffnn],
        outputs = out
    )

    # Compile the model
    model.compile(
        optimizer = "adam",
        metrics = METRICS_REGRESSION if not is_classification_dataset else METRICS_CLASSIFICATION_BINARY if n_classes == 2 else METRICS_CLASSIFICATION_MULTICLASS,
        loss = "mean_squared_error" if not is_classification_dataset else "binary_crossentropy" if n_classes == 2 else "categorical_crossentropy"
    )

    return model

In [None]:
# Main function
def train(dataset_name, image_method_name):
    dataset_folder = utils.get_cnnmlp_models_path(dataset_name)
    if any(f.split("_")[0] for f in os.listdir(dataset_folder) if f == image_method_name):
        print("The mode has already been trained!")
        return
    
    image_method_folder = utils.get_images_path_for_dataset(dataset_name, image_method_name)
    try:
        csv_file_name = next(f for f in os.listdir(image_method_folder) if f.endswith(".csv"))
    except StopIteration:
        raise Exception(f"There are no images for dataset {dataset_name} using {image_method_name} method")
    
    # Get the dataset
    print("Loading dataset ...")
    X,y = utils.get_X_y(dataset_name)

    # Get the indices for train and validation split
    indices_train,indices_val = utils.get_indices_train_eval(dataset_name)

    # Get X data splits
    X_data_train = X[indices_train]
    X_data_val = X[indices_val]
    y_train = y[indices_train]
    y_val = y[indices_val]

    del X

    # Get the class weights
    classes = np.unique(y_train)
    class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
    class_weight_dict = {classes[i]: weight for i, weight in enumerate(class_weights)}
    del y

    if utils.is_dataset_classification(dataset_name):
        is_classification_dataset = True
        n_classes = utils.get_number_of_classes(dataset_name)
        assert n_classes > 1
        if utils.is_dataset_multiclass_classification(dataset_name):
            y_train = to_categorical(y_train, num_classes=n_classes)
            y_val = to_categorical(y_val, num_classes=n_classes)
    else:
        is_classification_dataset = False
        n_classes = None

    # Get the routes to the images
    image_paths_np = pd.read_csv(os.path.join(image_method_folder, csv_file_name))["images"].to_numpy()
    train_paths = image_paths_np[indices_train]
    val_paths = image_paths_np[indices_val]

    # Load train and validation images
    print("Loading images ...")
    func_load_image = partial(load_image, image_method_folder=image_method_folder)
    with concurrent.futures.ThreadPoolExecutor() as executor:
        X_img_train = np.array(list(executor.map(func_load_image, train_paths)))
        X_img_val = np.array(list(executor.map(func_load_image, val_paths)))

    # Build the model
    print("Building model ...")
    model = build_model(
        img_input_shape = X_img_train[0].shape,
        mlp_input_shape = X_data_train[0].shape,
        is_classification_dataset = is_classification_dataset,
        n_classes = n_classes,
    )
    
    # Train the model
    print("Training model ...")
    train_model(
        model = model,
        X_train = [X_img_train, X_data_train],
        X_val = [X_img_val, X_data_val],
        y_train = y_train,
        y_val = y_val,
        classes_weight = class_weight_dict if is_classification_dataset else None
    )

    # Save the model
    print("Saving the model ...")
    model.save(os.path.join(dataset_folder, f"{image_method_name}.keras"))

In [None]:
train(utils.HELOC_NAME, utils.BARGRAPH_NAME)