In [None]:
# -*- Deep_Ensemble_Train.py -*-
"""
Created Oct 2020
Amended Nov 2023 for Deep Ensemble

@author: Timothy E H Allen / Alistair M Middleton / Adapted by Arndt Wallmann
"""
#%%

# Import modules
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf
# import tensorflow_probability as tfp # No longer needed
import tqdm
import pandas as pd
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import random
import seaborn as sns
from sklearn.metrics import mean_absolute_error
import os # For saving ensemble models

tf.enable_v2_behavior()
# tfd = tfp.distributions # No longer needed

# Define inputs and variables
receptor = "AR"
input_data_a = "/content/drive/My Drive/Training_Sets/" + receptor + " fingerprint.csv"
input_data_b = "/content/drive/My Drive/Test_Sets/" + receptor + " fingerprint.csv"
rng_1 = 1989
rng_2 = 2020
validation_proportion = 0.25
neurons = 10
hidden_layers = 2
LR = 0.01
epochs = 100
batch_size= 100
ensemble_size = 5 # <<< Number of models in the ensemble
model_base_path = "/content/drive/My Drive/Models/" + receptor + "_ensemble_predictor" # Base path for saving

# Ensure model directory exists
os.makedirs(model_base_path, exist_ok=True)

def read_dataset(input_data):
    df = pd.read_csv(input_data)
    # Assuming fingerprints are first 10000 columns, adjust if different
    X = df[df.columns[0:10000]].values.astype(np.float32) # Ensure float32 for TF
    Y = df[df.columns[10000]].values.astype(np.float32) # Ensure float32 for TF
    return X, Y

# Load and shuffle data and make training and validation sets
X, Y = read_dataset(input_data_a)
test_x, test_y = read_dataset(input_data_b)
X, Y = shuffle(X, Y, random_state=rng_1)
train_x, validation_x, train_y, validation_y = train_test_split(X, Y, test_size=validation_proportion, random_state=rng_2)

# --- BNN Specific kl_loss_weight removed ---
# kl_loss_weight = 1 / train_x.shape[0] # Not needed for Deep Ensemble

# Inspect the shape of the training and test data
print("Dimensionality of data:")
print("Train x shape =", train_x.shape, "dtype =", train_x.dtype)
print("Train y shape =", train_y.shape, "dtype =", train_y.dtype)
print("Validation x shape =", validation_x.shape, "dtype =", validation_x.dtype)
print("Validation y shape =", validation_y.shape, "dtype =", validation_y.dtype)
print("Test x shape =", test_x.shape, "dtype =", test_x.dtype)
print("Test y shape =", test_y.shape, "dtype =", test_y.dtype)

# --- BNN posterior/prior functions removed ---
# def posterior_mean_field(...): ...
# def prior_not_trainable(...): ...

# --- BNN negloglik removed, using standard 'mse' loss ---
# def negloglik(y, rv_y): ...

# --- Define the standard deterministic model architecture ---
def create_deterministic_model(input_shape, neurons, hidden_layers):
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.InputLayer(input_shape=input_shape)) # Explicit Input Layer

    if hidden_layers == 1:
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(1)) # Output layer for regression

    elif hidden_layers == 2:
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(1)) # Output layer

    elif hidden_layers == 3:
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(neurons, activation='relu'))
        model.add(tf.keras.layers.Dense(1)) # Output layer

    else:
        raise ValueError("Number of hidden layers outside this model scope, please choose 1, 2, or 3")

    model.compile(optimizer=tf.optimizers.Adam(learning_rate=LR), loss='mse', metrics=['mse', 'mae']) # Using MSE loss
    return model

# --- Train the ensemble ---
ensemble_models = []
histories = [] # Store history of each model

input_shape = (train_x.shape[1],) # Define input shape based on data

print(f"\n--- Training Deep Ensemble ({ensemble_size} members) ---")
for i in range(ensemble_size):
    print(f"\nTraining Model {i+1}/{ensemble_size}")
    # Create a new instance for each ensemble member
    model = create_deterministic_model(input_shape, neurons, hidden_layers)

    # Train the model
    history = model.fit(train_x, train_y,
                        epochs=epochs,
                        batch_size=batch_size,
                        validation_data=(validation_x, validation_y),
                        verbose=1) # Verbose set to 1 to see progress per model

    histories.append(history)
    ensemble_models.append(model)

    # Save each model individually
    member_model_path = os.path.join(model_base_path, f"member_{i}")
    model.save(member_model_path, save_format="tf")
    print(f"Model {i+1} saved to {member_model_path}")

print("\n--- Ensemble Training Complete ---")

# --- Plot history of loss/metrics for the *last* trained model as an example ---
# (Alternatively, average histories or plot all)
last_history = histories[-1]
print("\n--- Plotting metrics for the last trained model ---")

# Plot history of loss values
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(last_history.history['loss'])
plt.plot(last_history.history['val_loss'])
plt.title(f'Model {ensemble_size} Loss')
plt.ylabel('Loss (MSE)')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')

# Plot history of MAE values (using 'mae' metric added during compile)
plt.subplot(1, 2, 2)
plt.plot(last_history.history['mae'])
plt.plot(last_history.history['val_mae'])
plt.title(f'Model {ensemble_size} MAE')
plt.ylabel('Mean Absolute Error')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.tight_layout()
plt.show()


# --- Ensemble Prediction and Evaluation ---

def predict_ensemble(models, x_data):
    """Generates predictions from each model in the ensemble."""
    all_preds = []
    for model in tqdm.tqdm(models, desc="Predicting with Ensemble"):
        # Predict returns shape (n_samples, 1), squeeze to (n_samples,) if needed,
        # but stacking handles (n_samples, 1) correctly for hstack.
        y_pred = model.predict(x_data)
        all_preds.append(y_pred)
    # Stack predictions along a new axis (axis=1): shape becomes (n_samples, n_models)
    return np.hstack(all_preds) # Use hstack if preds are (N,1) -> (N, n_models)

# Calculate test data outputs
print("\n--- Evaluating on Test Set ---")
y_preds_ensemble_test = predict_ensemble(ensemble_models, test_x)
y_mean_test = np.mean(y_preds_ensemble_test, axis=1)
y_sigma_test = np.std(y_preds_ensemble_test, axis=1) # Uncertainty estimate

# Plot experimental vs predicted values for test data
plt.figure()
plt.scatter(test_y, y_mean_test, marker='o', alpha=0.5)
# Add y=x line for reference
min_val = min(test_y.min(), y_mean_test.min())
max_val = max(test_y.max(), y_mean_test.max())
plt.plot([min_val, max_val], [min_val, max_val], 'r--', label='y=x')
plt.title('Test Set: Experimental vs. Predicted (Ensemble Mean)')
plt.xlabel('Experimental Value')
plt.ylabel('Predicted Value (Mean)')
plt.grid(True)
plt.legend()
plt.show()

# Plot prediction uncertainty (optional)
plt.figure()
plt.scatter(y_mean_test, y_sigma_test, marker='o', alpha=0.5)
plt.title('Test Set: Prediction Uncertainty (Ensemble Std Dev)')
plt.xlabel('Predicted Value (Mean)')
plt.ylabel('Prediction Uncertainty (Std Dev)')
plt.grid(True)
plt.show()


# Calculate and print mean absolute error values for train, validation, and test sets
print("\n--- Evaluating on Training Set ---")
y_preds_ensemble_train = predict_ensemble(ensemble_models, train_x)
y_mean_train = np.mean(y_preds_ensemble_train, axis=1)

print("\n--- Evaluating on Validation Set ---")
y_preds_ensemble_validation = predict_ensemble(ensemble_models, validation_x)
y_mean_validation = np.mean(y_preds_ensemble_validation, axis=1)

print("\n--- MAE Results ---")
print(f"Training Set MAE:   {mean_absolute_error(train_y, y_mean_train):.4f}")
print(f"Validation Set MAE: {mean_absolute_error(validation_y, y_mean_validation):.4f}")
print(f"Test Set MAE:       {mean_absolute_error(test_y, y_mean_test):.4f}")


# End the cycle (optional, good practice in scripts)
tf.keras.backend.clear_session()

# Endgame
print("\nEND")