# Imports

In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import sys
import re
from tqdm.auto import tqdm
import joblib
import time
import warnings
import math

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler, StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModel, BertTokenizer, BertModel
from transformers.tokenization_utils_base import BatchEncoding

from scipy.special import exp1 # exponential integral (https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.exp1.html)
from scipy.optimize import fsolve
from scipy.signal import argrelextrema


# Load Tokenized Data

In [None]:
# Specify your path to Capstone folder.

main_path = "/content/drive/MyDrive/Capstone_Diana/Capstone/"

In [None]:
if "drive" in main_path:
    from google.colab import drive
    drive.mount("/content/drive")

In [None]:
# Specify your original database.
# mechano - for MechanoProDB;
# protherm - for ProThermDB

database = "mechano"

In [None]:
# The format of saved datasets: numeric_method.tokenization_protein.tokenization_text.dataset_length.database_name
# Example: tokenized_dataset_0_protbert_scibert_98_mechano

mechano = False
if database == "mechano":
    numeric_method = 'regression_imputings'
    dataset_length = 127
    mechano = True
else:
    numeric_method = 'none'
    dataset_length = 14645

sequence_method = "protbert"
text_method = "scibert"

dataset_name = f"tokenized_dataset_{numeric_method}_{sequence_method}_{text_method}_{dataset_length}_{database}"
pickle_path = f"{main_path}Tokenized_results/{dataset_name}.pkl"

In [None]:
# Load preprocessed dataset.

tokenized_df = pd.read_pickle(pickle_path)

print("Pickle successfully loaded!")
tokenized_df.head()

In [None]:
# The original split made for this project.

# tokenized_df = tokenized_df.reset_index(drop=True)

# print("Splitting data into train and validation sets...")
# train_df, val_df = train_test_split(tokenized_df, test_size=0.2, random_state=42, shuffle=True)
# print(f"Training set size: {len(train_df)}")
# print(f"Validation set size: {len(val_df)}")

# save_path = f"/content/drive/MyDrive/Capstone/Tokenized_results/"
# train_df.to_pickle(f"{save_path}train_tokenized_df.pkl")
# val_df.to_pickle(f"{save_path}val_tokenized_df.pkl")

In [None]:
# Loading train and validation data.

train_df = pd.read_pickle(f"{main_path}Tokenized_results/{database}_train_tokenized_df.pkl")
val_df = pd.read_pickle(f"{main_path}Tokenized_results/{database}_val_tokenized_df.pkl")

train_indices = train_df.index
val_indices = val_df.index

In [None]:
# Experimental Case.
# If you want to use only three main columns as an input, put df_state="subset", otherwise left the default one.

df_state = "all" # subset/all

if df_state == 'subset' and mechano:
    numeric_columns_new = ["Pulling Start", "Pulling End"]
    needed_cols_subset = ["Sequence", "tokenized_Sequence", "Pulling Start", "Pulling End", "targets", "Experimental Conditions", "tokenized_Experimental Conditions"]
    train_df = train_df[needed_cols_subset]
    val_df = val_df[needed_cols_subset]
    tokenized_df = tokenized_df[needed_cols_subset]

    train_df.loc[:, "numeric_embeddings"] = train_df[numeric_columns_new].apply(lambda row: torch.tensor(row.values, dtype=torch.float), axis=1)
    val_df.loc[:, "numeric_embeddings"] = val_df[numeric_columns_new].apply(lambda row: torch.tensor(row.values, dtype=torch.float), axis=1)
    tokenized_df.loc[:, "numeric_embeddings"] = tokenized_df[numeric_columns_new].apply(lambda row: torch.tensor(row.values, dtype=torch.float), axis=1)
elif not mechano:
    needed_cols_subset = ['pH',  'Tm_(C)', 'Sequence', 'tokenized_Sequence',  'SEC_STR', 'tokenized_SEC_STR', 'targets', 'numeric_embeddings']
    train_df = train_df[needed_cols_subset]
    val_df = val_df[needed_cols_subset]
    tokenized_df = tokenized_df[needed_cols_subset]
else:
    pass

# TensorFlow Random Forest (TFDF)

This model was previously used only for **MechanoProDB** as is not optimal for ProThermDB usage.

All chunks contain data related to MechanoProDB only.

## Installations and Imports

In [None]:
# Install TFDF and Optuna for this separate section.
#%pip install tensorflow==2.18
%pip install tensorflow tensorflow_decision_forests
%pip install optuna

In [None]:
# Load TFDF
import tensorflow_decision_forests as tfdf
import optuna
import logging

## Functions

In [None]:
def aggregate_tokenized_features(df, columns):
    """
    Aggregates the input_ids obtained after tokenization by ProtBERT and SciBERT, using average
    """
    for col in columns:
        df[col] = df[col].apply(
            lambda x: np.mean(x[0].ids) if hasattr(x[0], "ids") else 0
            )
    return df

def relative_root_mean_squared_error(y_true, y_pred):
    """
    Calculates the RRMSE and returns the single value
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    y_true = y_true.astype(float)
    y_pred = y_pred.astype(float)

    mask = y_true != 0
    if np.sum(mask) == 0:
        return np.nan

    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]

    if len(y_true_filtered) == 0:
         return np.nan

    relative_errors = 1 - (y_pred_filtered / y_true_filtered)
    rrmse = np.mean(np.abs(relative_errors))
    if np.isnan(rrmse) or np.isinf(rrmse):
        return np.nan
    return rrmse

def evaluate_predictions(y_true, y_pred):
    """
    Calculates and returns evaluation metrics for TFDF model
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if y_true.ndim == 1:
        y_true = y_true.reshape(-1, 1)
    if y_pred.ndim == 1:
        y_pred = y_pred.reshape(-1, 1)

    if y_true.shape != y_pred.shape:
         raise ValueError(f"Shape mismatch: y_true {y_true.shape} vs y_pred {y_pred.shape}")

    num_outputs = y_true.shape[1]
    metrics_dict = {"Output": [], "MAPE": [], "SMAPE": [], "MSE": [], "RMSE": [], "RRMSE": [], "R2": []}

    if np.isnan(y_pred).any():
        print("Warning: NaNs found in predictions. Metrics might be affected.")

    for i in range(num_outputs):
        y_t = y_true[:, i]
        y_p = y_pred[:, i]

        mape = mean_absolute_percentage_error(y_t, y_p)
        mse = mean_squared_error(y_t, y_p)
        rmse = np.sqrt(mse)
        smape = 100 * np.mean(2 * np.abs(y_t - y_p) / (np.abs(y_t) + np.abs(y_p) + 1e-8))
        rrmse = relative_root_mean_squared_error(y_t, y_p)
        r2 = r2_score(y_t, y_p)

        metrics_dict["Output"].append(f"Output_{i}")
        metrics_dict["MAPE"].append(mape)
        metrics_dict["SMAPE"].append(smape)
        metrics_dict["MSE"].append(mse)
        metrics_dict["RMSE"].append(rmse)
        metrics_dict["RRMSE"].append(rrmse)
        metrics_dict["R2"].append(r2)

    metrics_dict["Output"].append("Average")
    for metric in ["MAPE", "SMAPE", "MSE", "RMSE", "RRMSE", "R2"]:
        metrics_dict[metric].append(np.mean(metrics_dict[metric]))

    return pd.DataFrame(metrics_dict)

## Dataset Processing

In [None]:
# Prepare your data for usage.
df_tfdf = tokenized_df.copy()
to_drop = ['targets', 'numeric_embeddings']
if mechano:
  to_drop.extend(['tokenized_Joint_Text_Cols', 'tokenized_domain_subsequences'])
  df_tfdf.drop(columns = to_drop, axis = 1, inplace=True)

In [None]:
df_tfdf = df_tfdf.rename(columns=lambda x: x.replace(" ", "_"))
tokenized_columns = [col for col in df_tfdf.columns if col.startswith("tokenized_")]
df_tfdf = aggregate_tokenized_features(df_tfdf, tokenized_columns)

In [None]:
if mechano:
    categorical_columns = [
        "Name", "SCOP_annotation", "Experimental_Conditions", "Organism", "Classification",
        "Technique", "Pulling_Mode", "Unfolding_Pathway", "domain_subsequences",
        "Joint_Text_Cols", "PDB_UniProt", "Sequence"
        ]
else:
    categorical_columns = []

df_tfdf.drop(columns=categorical_columns, axis=1, inplace=True)

In [None]:
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "tfdf-regression-tuning"
storage_name = f"sqlite:///{study_name}.db"

In [None]:
if mechano:
    target_columns = ["ΔG_[kBT]", "Xu_[nm]", "Koff_[s-¹]"]
else:
    target_columns = []

print("{} examples in training, {} examples for testing.".format(
    len(train_df), len(val_df)))
df_tfdf = df_tfdf.rename(columns=lambda x: x.replace(" ", "_").replace("[", "").replace("]", "").replace("Δ", "Delta").replace("-¹", "_1"))
train_df = train_df.rename(columns=lambda x: x.replace(" ", "_").replace("[", "").replace("]", "").replace("Δ", "Delta").replace("-¹", "_1"))
val_df = val_df.rename(columns=lambda x: x.replace(" ", "_").replace("[", "").replace("]", "").replace("Δ", "Delta").replace("-¹", "_1"))

if mechano:
    target_columns = ["DeltaG_kBT", "Xu_nm", "Koff_s_1"]
else:
    target_columns = []

In [None]:
train_df.drop(columns=categorical_columns, axis=1, inplace=True)
val_df.drop(columns=categorical_columns, axis=1, inplace=True)

train_df = aggregate_tokenized_features(train_df, tokenized_columns)
val_df = aggregate_tokenized_features(val_df, tokenized_columns)

train_df.drop(columns=to_drop, axis=1, inplace=True)
val_df.drop(columns=to_drop, axis=1, inplace=True)

feature_names = [x for x in df_tfdf.columns.tolist() if x not in target_columns]

In [None]:
# Defining the Objective function for Optuna Optimization.

def objective(trial):
    """
    Optuna objective function to train and evaluate models.
    """
    max_depth = trial.suggest_int("max_depth", 5, 20)
    min_examples = trial.suggest_int("min_examples", 2, 10)

    all_target_val_rmse = []
    all_target_val_loss = []

    print(f"Optuna Trial {trial.number}")
    print(f"Hyperparameters: max_depth={max_depth}, min_examples={min_examples}")

    for label in target_columns:
        print(f"Training for target: {label}")
        missing = [f for f in feature_names + [label] if f not in train_df.columns]
        if missing:
            print(f"Missing columns in train_df: {missing}")

        train_df_target = train_df[feature_names + [label]].copy()
        val_df_target = val_df[feature_names + [label]].copy()

        if label not in train_df_target.columns or label not in val_df_target.columns:
            print(f"Label '{label}' not found in DataFrame columns.")

        try:
            train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(
                train_df_target, label=label, task=tfdf.keras.Task.REGRESSION, max_num_classes=100000
            )
            val_ds = tfdf.keras.pd_dataframe_to_tf_dataset(
                val_df_target, label=label, task=tfdf.keras.Task.REGRESSION, max_num_classes=100000
            )
        except Exception as e:
             print(f"Error creating TF Dataset for target {label}: {e}")
             return float('inf')


        model = tfdf.keras.RandomForestModel(
            task=tfdf.keras.Task.REGRESSION,
            features=[tfdf.keras.FeatureUsage(name=f) for f in feature_names],
            max_depth=max_depth,
            min_examples=min_examples,
            compute_oob_variable_importances=False,
            verbose=0,
        )

        model.fit(train_ds, verbose=0)

        evaluation_metrics = model.evaluate(val_ds, verbose=0, return_dict=True)
        print(f"Evaluation metrics for {label}: {evaluation_metrics}")

        if 'root_mean_squared_error' in evaluation_metrics:
            val_rmse = evaluation_metrics['root_mean_squared_error']
        elif 'loss' in evaluation_metrics:
             val_rmse = np.sqrt(evaluation_metrics['loss'])
        else:
             print(f"Warning: Could not find 'root_mean_squared_error' or 'loss' in evaluation results for {label}.")
             predictions = model.predict(val_ds)
             y_true_np = val_df_target[label].to_numpy()
             val_rmse = np.sqrt(mean_squared_error(y_true_np, predictions.flatten()))

        all_target_val_rmse.append(val_rmse)
        all_target_val_loss.append(evaluation_metrics.get('loss', val_rmse**2))

    average_rmse = np.mean(all_target_val_rmse)
    print(f"Trial {trial.number} Completed")
    print(f"Average Validation RMSE: {average_rmse:.4f}")
    print(f"Individual RMSEs: {all_target_val_rmse}")

    trial.set_user_attr("individual_rmses", all_target_val_rmse)
    trial.set_user_attr("average_mse", np.mean(all_target_val_loss))

    return average_rmse

In [None]:
# Start of Optuna optimization, which is done is 4 trials (can be changed).

study = optuna.create_study(
    direction="minimize",
    study_name=study_name
    )

n_trials = 4
study.optimize(objective, n_trials=n_trials)

print("Optuna Optimization Finished")
print(f"Number of finished trials: {len(study.trials)}")
print("Best trial:")
best_trial = study.best_trial

In [None]:
print(f"Value (Average Validation RMSE): {best_trial.value:.4f}")
print("Params:")
for key, value in best_trial.params.items():
    print(f"{key}: {value}")
print("  User Attributes (Example):")
for key, value in best_trial.user_attrs.items():
    print(f"{key}: {value}")

In [None]:
# Training data with the best hyperparameters after Optuna Optimization.
# Getting the predictions and plotting RMSE training logs to see, which number of trees was the most optimal.

best_params = best_trial.params

final_evaluation_results = {}
all_predictions_dict = {}
all_true_values_dict = {}

for label in target_columns:
    print(f"\nTraining final model for target: {label}\n")

    train_df_target = train_df[feature_names + [label]].copy()
    val_df_target = val_df[feature_names + [label]].copy()
    train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df_target, label=label, task=tfdf.keras.Task.REGRESSION)
    test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(val_df_target, label=label, task=tfdf.keras.Task.REGRESSION)

    final_model = tfdf.keras.RandomForestModel(
        task=tfdf.keras.Task.REGRESSION,
        features=[tfdf.keras.FeatureUsage(name=f) for f in feature_names],
        max_depth=best_params.get('max_depth', 16),
        min_examples=best_params.get('min_examples', 5),
        compute_oob_variable_importances=True,
        verbose=1
        )

    final_model.fit(train_ds)
    print(final_model.summary())

    evaluation = final_model.evaluate(test_ds, return_dict=True)
    print("\nFinal Evaluation Metrics (from model.evaluate):")
    for name, value in evaluation.items():
      print(f"  {name}: {value:.4f}")

    predictions = final_model.predict(test_ds)
    all_predictions_dict[label] = predictions.flatten()
    all_true_values_dict[label] = val_df_target[label].to_numpy()

    try:
        logs = final_model.make_inspector().training_logs()
        if logs:
             plt.figure(figsize=(6, 4))
             metric_to_plot = 'rmse' if any(hasattr(log.evaluation, 'rmse') for log in logs) else 'loss'
             if any(hasattr(log.evaluation, metric_to_plot) for log in logs):
                 plt.plot([log.num_trees for log in logs], [getattr(log.evaluation, metric_to_plot) for log in logs])
                 plt.xlabel("Number of trees")
                 plt.ylabel(f"{metric_to_plot.upper()} (out-of-bag)")
                 plt.title(f"OOB {metric_to_plot.upper()} vs Trees for {label}")
                 plt.show()
             else:
                  print(f"Could not find '{metric_to_plot}' in training logs for {label}.")

        else:
             print("No training logs found.")
    except Exception as e:
        print(f"Could not plot training logs: {e}")

In [None]:
# Redefining renamed target columns for getting evaluation scores.

if mechano:
  target_order = ["DeltaG_kBT", "Xu_nm", "Koff_s_1"]
else:
  target_order = []
y_true_matrix = np.column_stack([all_true_values_dict[t] for t in target_order])
y_pred_matrix = np.column_stack([all_predictions_dict[t] for t in target_order])

final_metrics_df = evaluate_predictions(y_true_matrix, y_pred_matrix)

## Saving the Metrics Dataframe for TFDF

In [None]:
save_dir = f"{main_path}Models_Artifacts_{database}/ML_Models/"
os.makedirs(save_dir, exist_ok=True)

In [None]:
final_metrics_df.to_csv(os.path.join(save_dir, "metrics_df_tfdf.tsv"), sep='\t', index = False)
print(f"Model artifacts are saved to: {os.path.abspath(save_dir)}")

# NNs (Neural Networks)

This section should be run for Neural Networks.


## Functions and Setups

In [None]:
# Define your conditions and parameters.

early_stopping_patience_value = 30
scheduler_patience_value = 2
scheduler_factor = 0.1
learning_rate = 2e-5
min_learning_rate = 1e-7
batch_size = 4
gradient_clip_value = 1.0
num_epochs = 50

In [None]:
def relative_root_mean_squared_error(y_true, y_pred):
    """
    Computing the mean absolute relative error between predictions and non-zero true values.
    """
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    mask = y_true != 0
    y_true = y_true[mask]
    y_pred = y_pred[mask]

    relative_errors = 1 - (y_pred / y_true)
    return np.mean(np.abs(relative_errors))

class MetricsCalculatorUpd:
    """
    Calculates and stores evaluation metrics for training and validation.
    """
    def __init__(self, scaler):
        self.scaler = scaler
        self.num_outputs = 0

        if scaler is None:
            print("Warning: MetricsCalculator initialized without a target scaler."
                  "Inverse transform for interpretable metrics (MSE, RMSE, RRMSE, R2) will not be possible.")
        else:
            try:
                self.num_outputs = scaler.n_features_in_
                print(f"MetricsCalculator initialized for {self.num_outputs} target outputs.")
            except AttributeError:
                print("Warning: Scaler provided does not have 'n_features_in_'. Number of outputs will be inferred later.")

        self.metrics = {
            'epoch': [],
            'train_loss': [],
            'val_loss': [],
            'val_mse_avg': [],
            'val_rmse_avg': [],
            'val_r2_avg': [],
            'val_rrmse_avg': []
        }

        if self.num_outputs > 0:
            self._initialize_per_output_metrics()
        else:
            self._per_output_initialized = False
        self._reset_epoch_accumulators()

    def _initialize_per_output_metrics(self):
        """
        Initializes dictionary keys for per-output metrics.
        """
        if self.num_outputs <= 0:
             print("Warning: Cannot initialize per-output metrics without knowing the number of outputs.")
             return
        for i in range(self.num_outputs):
            self.metrics[f'val_mse_{i}'] = []
            self.metrics[f'val_rmse_{i}'] = []
            self.metrics[f'val_r2_{i}'] = []
            self.metrics[f'val_rrmse_{i}'] = []
        self._per_output_initialized = True
        print(f"Initialized metric keys for {self.num_outputs} individual outputs.")

    def _reset_epoch_accumulators(self):
        self.epoch_train_loss, self.epoch_train_samples = 0.0, 0
        self.epoch_val_loss, self.epoch_val_samples = 0.0, 0
        self.epoch_val_preds_scaled, self.epoch_val_targets_scaled, self.epoch_val_targets_original = [], [], []

    def update_train_batch(self, loss_item, batch_size):
        self.epoch_train_loss += loss_item * batch_size
        self.epoch_train_samples += batch_size

    def update_val_batch(self, loss_item, predictions_scaled, targets_scaled, targets_original, batch_size):
        self.epoch_val_loss += loss_item * batch_size
        self.epoch_val_samples += batch_size

        preds_np = predictions_scaled.detach().cpu().numpy()
        targets_scaled_np = targets_scaled.detach().cpu().numpy()
        if isinstance(targets_original, torch.Tensor):
            targets_original_np = targets_original.detach().cpu().numpy()
        elif isinstance(targets_original, np.ndarray):
             targets_original_np = targets_original
        else:
             try:
                 targets_original_np = np.asarray(targets_original)
             except Exception as e:
                 print(f"Warning: Could not convert targets_original to numpy array: {e}")
                 targets_original_np = None

        self.epoch_val_preds_scaled.append(preds_np)
        self.epoch_val_targets_scaled.append(targets_scaled_np)
        if targets_original_np is not None:
             self.epoch_val_targets_original.append(targets_original_np)

    def calculate_epoch_metrics(self, epoch):
        """
        Calculates average and per-output metrics for the completed epoch and returns a dictionary.
        """
        train_loss = self.epoch_train_loss / self.epoch_train_samples if self.epoch_train_samples > 0 else 0.0
        val_loss = self.epoch_val_loss / self.epoch_val_samples if self.epoch_val_samples > 0 else 0.0

        val_mse_avg, val_rmse_avg, val_r2_avg, val_rrmse_avg = np.nan, np.nan, np.nan, np.nan
        list_val_mse, list_val_rmse, list_val_r2, list_val_rrmse = [], [], [], []

        if self.epoch_val_preds_scaled and self.scaler is not None and self.epoch_val_targets_original:
            try:
                all_preds_scaled = np.concatenate(self.epoch_val_preds_scaled, axis=0)
                all_targets_original = np.concatenate(self.epoch_val_targets_original, axis=0)

                if self.num_outputs == 0:
                    self.num_outputs = all_targets_original.shape[1]
                    print(f"Inferred number of outputs: {self.num_outputs}")
                    if not self._per_output_initialized:
                         self._initialize_per_output_metrics()
                elif all_targets_original.shape[1] != self.num_outputs:
                     print(f"Warning: Data shape mismatch! Expected {self.num_outputs} outputs, got {all_targets_original.shape[1]}. Re-initializing metrics.")
                     self.num_outputs = all_targets_original.shape[1]
                     self._initialize_per_output_metrics()

                all_preds_original = np.full_like(all_targets_original, np.nan)
                if all_preds_scaled.shape[1] == self.scaler.n_features_in_:
                    all_preds_original = self.scaler.inverse_transform(all_preds_scaled)
                else:
                    print(f"Warning: Shape mismatch during inverse transform. Preds shape[1]: {all_preds_scaled.shape[1]}, Scaler expects: {self.scaler.n_features_in_}.")

            except Exception as e:
                 print(f"Error during inverse transform or concatenation: {e}")
                 all_preds_original = np.full_like(all_targets_original, np.nan)

            if not np.isnan(all_preds_original).any() and self.num_outputs > 0:
                try:
                    val_mse_avg = mean_squared_error(all_targets_original, all_preds_original)
                    val_rmse_avg = np.sqrt(val_mse_avg)
                    val_r2_avg = r2_score(all_targets_original, all_preds_original)
                    val_rrmse_avg = relative_root_mean_squared_error(all_targets_original, all_preds_original)

                    for i in range(self.num_outputs):
                        target_i = all_targets_original[:, i]
                        pred_i = all_preds_original[:, i]
                        mse_i = mean_squared_error(target_i, pred_i)
                        rmse_i = np.sqrt(mse_i)
                        rrmse_i = relative_root_mean_squared_error(target_i, pred_i)
                        try:
                            r2_i = r2_score(target_i, pred_i)
                        except ValueError:
                            print(f"Warning: R2 score could not be calculated for output {i} (constant target?).")
                            r2_i = np.nan

                        list_val_mse.append(mse_i)
                        list_val_rmse.append(rmse_i)
                        list_val_r2.append(r2_i)
                        list_val_rrmse.append(rrmse_i)

                except Exception as e:
                    print(f"Error calculating metrics: {e}")
                    val_mse_avg, val_rmse_avg, val_r2_avg, val_rrmse_avg = np.nan, np.nan, np.nan, np.nan
                    list_val_mse, list_val_rmse, list_val_r2, list_val_rrmse = [], [], [], []
            else:
                 print("Metrics (MSE, RMSE, R2, RRMSE) on original scale could not be calculated (NaN predictions or zero outputs).")

        self.metrics['epoch'].append(epoch + 1)
        self.metrics['train_loss'].append(train_loss)
        self.metrics['val_loss'].append(val_loss)
        self.metrics['val_mse_avg'].append(val_mse_avg)
        self.metrics['val_rmse_avg'].append(val_rmse_avg)
        self.metrics['val_r2_avg'].append(val_r2_avg)
        self.metrics['val_rrmse_avg'].append(val_rrmse_avg)

        if self._per_output_initialized:
             for i in range(self.num_outputs):
                 self.metrics[f'val_mse_{i}'].append(list_val_mse[i] if i < len(list_val_mse) else np.nan)
                 self.metrics[f'val_rmse_{i}'].append(list_val_rmse[i] if i < len(list_val_rmse) else np.nan)
                 self.metrics[f'val_r2_{i}'].append(list_val_r2[i] if i < len(list_val_r2) else np.nan)
                 self.metrics[f'val_rrmse_{i}'].append(list_val_rrmse[i] if i < len(list_val_rrmse) else np.nan)
        elif self.num_outputs > 0 and list_val_mse:
            self._initialize_per_output_metrics()
            for i in range(self.num_outputs):
                 self.metrics[f'val_mse_{i}'].append(list_val_mse[i] if i < len(list_val_mse) else np.nan)
                 self.metrics[f'val_rmse_{i}'].append(list_val_rmse[i] if i < len(list_val_rmse) else np.nan)
                 self.metrics[f'val_r2_{i}'].append(list_val_r2[i] if i < len(list_val_r2) else np.nan)
                 self.metrics[f'val_rrmse_{i}'].append(list_val_rrmse[i] if i < len(list_val_rrmse) else np.nan)

        returned_train_loss = train_loss
        returned_val_loss = val_loss
        returned_val_mse_avg = val_mse_avg
        returned_val_rmse_avg = val_rmse_avg
        returned_val_r2_avg = val_r2_avg
        returned_val_rrmse_avg = val_rrmse_avg
        returned_list_mse = list_val_mse
        returned_list_rmse = list_val_rmse
        returned_list_r2 = list_val_r2
        returned_list_rrmse = list_val_rrmse

        self._reset_epoch_accumulators()

        return (returned_train_loss, returned_val_loss,
                returned_val_mse_avg, returned_val_rmse_avg, returned_val_r2_avg, returned_val_rrmse_avg,
                returned_list_mse, returned_list_rmse, returned_list_r2, returned_list_rrmse)

    def get_metrics_df(self):
        """
        Returns the dataframe, containing all calculated metrics.
        """
        if not self.metrics['epoch']: return pd.DataFrame(self.metrics)

        max_len = len(self.metrics['epoch'])
        for key, value in self.metrics.items():
            if len(value) < max_len:
                padding = [np.nan] * (max_len - len(value))
                self.metrics[key] = value + padding

        return pd.DataFrame(self.metrics)

    def get_last_comparison_batch(self):
        """
        Provides the latest batch of model predictions and actual target values.
        """
        if not self.epoch_val_preds_scaled or self.scaler is None: return None, None
        try:
            last_preds_scaled = self.epoch_val_preds_scaled[-1] if self.epoch_val_preds_scaled else None
            last_targets_original = self.epoch_val_targets_original[-1] if self.epoch_val_targets_original else None

            if last_preds_scaled is None or last_targets_original is None: return None, None

            if last_preds_scaled.shape[1] != self.scaler.n_features_in_:
                print(f"Warning: Shape mismatch in get_last_comparison_batch.")
                last_preds_original = np.full_like(last_targets_original, np.nan)
            else:
                last_preds_original = self.scaler.inverse_transform(last_preds_scaled)
            return last_preds_original, last_targets_original
        except Exception as e:
            print(f"Error getting last comparison batch: {e}")
            return None, None

def collate_fn(batch):
    collated = {}
    collated['numeric_features'] = torch.stack([item['numeric_features'] for item in batch])
    collated['targets_scaled'] = torch.stack([item['targets_scaled'] for item in batch])
    collated['targets_original'] = torch.stack([item['targets_original'] for item in batch])

    num_tokenized_cols = len(batch[0]['input_ids'])
    collated['input_ids'] = []
    collated['attention_mask'] = []
    for i in range(num_tokenized_cols):
        collated['input_ids'].append(torch.stack([item['input_ids'][i] for item in batch]))
        collated['attention_mask'].append(torch.stack([item['attention_mask'][i] for item in batch]))

    return collated

## Model Choice

In [None]:
# Specify your model below.
# 1. ESM2
# 2. BERT - for combined ProtBERT and SciBERT and Regression Head
# 3. BiLSTM - for BiLSTM model, using ProtBERT and SciBERT tokenizations

model_chosen = "ESM2"

## Processing Dataset

In [None]:
df = tokenized_df.copy()

In [None]:
# Dropping these columns as they were not utilized in the scope of this project.

if 'tokenized_Joint_Text_Cols' in df.columns:
    df.drop(columns=['tokenized_Joint_Text_Cols', 'tokenized_domain_subsequences'], inplace=True)

In [None]:
dataset_text_columns = [x.replace("tokenized_", "") for x in df.columns if "tokenized_" in x]
if model_chosen == "BiLSTM":
    prot_col_list = ["tokenized_Sequence"]
    print(f"Protein column: {prot_col_list[0]}")
elif model_chosen == "BERT":
    prot_col_list = ['Sequence']
    print(f"Protein column: {prot_col_list[0]}")
sci_col_list = [col.replace("tokenized_", "") for col in df.columns if col.startswith("tokenized_") and col != "tokenized_Sequence"]

original_numeric_col = "numeric_embeddings"
target_col = "targets"

print(f"Scientific text columns: {sci_col_list}")
print(f"Original numeric column: {original_numeric_col}")
print(f"Original target column: {target_col}")

In [None]:
# Getting the lengths for features to input into the models.
# Works both for MechanoProDB and ProThermDB.

len_targets = len(tokenized_df.targets.iloc[0])
len_numeric = len(tokenized_df.numeric_embeddings.iloc[0])

len_prot_cols = 1
len_sci_cols = len(sci_col_list)
len_text_cols_all = len_prot_cols + len_sci_cols

In [None]:
# Instantiating numeric and target scalers for scaling.

numeric_scaler = StandardScaler()
target_scaler = MinMaxScaler(feature_range=(0, 1))

In [None]:
numeric_features_all = np.stack(df[original_numeric_col].values)
print(f"Numeric features shape: {numeric_features_all.shape}")

numeric_features_train = numeric_features_all[train_indices]
numeric_features_val = numeric_features_all[val_indices]

In [None]:
numeric_features_train_scaled = numeric_scaler.fit_transform(numeric_features_train)
numeric_features_val_scaled = numeric_scaler.transform(numeric_features_val)
print("Numeric features are scaled.")

In [None]:
scaled_numeric_col_name = "numeric_features_scaled"
train_df[scaled_numeric_col_name] = [row for row in numeric_features_train_scaled]
val_df[scaled_numeric_col_name] = [row for row in numeric_features_val_scaled]
print(f"Scaled numeric features added to DataFrames under column: {scaled_numeric_col_name}")

In [None]:
print("Scaling target variables using MinMaxScaler...")
train_targets_original = np.stack(train_df[target_col].values)
val_targets_original = np.stack(val_df[target_col].values)
train_targets_scaled = target_scaler.fit_transform(train_targets_original)
val_targets_scaled = target_scaler.transform(val_targets_original)

In [None]:
scaled_target_col_name = 'targets_scaled'
original_target_col_name = 'targets_original'
train_df[scaled_target_col_name] = [torch.tensor(row, dtype=torch.float32) for row in train_targets_scaled]
train_df[original_target_col_name] = [torch.tensor(row, dtype=torch.float32) for row in train_targets_original]
val_df[scaled_target_col_name] = [torch.tensor(row, dtype=torch.float32) for row in val_targets_scaled]
val_df[original_target_col_name] = [torch.tensor(row, dtype=torch.float32) for row in val_targets_original]
print("Scaled and original targets added to DataFrames.")

## Models

In [None]:
# Selecting available GPU or CPU if unavailable.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Loading and configuring pretrained models for chosen options.
# Some models are freezed or partially unfreezed for future training.

if model_chosen == "ESM2":
    model_name = "facebook/esm2_t6_8M_UR50D"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    output_size = model.config.hidden_size
    print(f"Hidden size: {output_size}")

    for param in model.parameters():
        param.requires_grad = False

    # Unfreezing the last two layers of the encoder for ESM2
    for param in model.encoder.layer[-2:].parameters():
        param.requires_grad = True

elif model_chosen in ["BERT", "BiLSTM"]:
    prot_model_name = "Rostlab/prot_bert_bfd"
    sci_model_name = "allenai/scibert_scivocab_uncased"

    prot_tokenizer = AutoTokenizer.from_pretrained(prot_model_name)
    prot_model = AutoModel.from_pretrained(prot_model_name)
    prot_model.to(device)

    sci_tokenizer = AutoTokenizer.from_pretrained(sci_model_name)
    sci_model = AutoModel.from_pretrained(sci_model_name)
    sci_model.to(device)

if model_chosen == "BERT":
    for param in prot_model.parameters():
        param.requires_grad = False
    for param in sci_model.parameters():
        param.requires_grad = False

### Using Mean Pooling in AddedSublayer (Run for ESM2)

In [None]:
class AddedSubLayer(nn.Module):
    """
    Class definition for Regression Head, which takes the sequences, furtherly averaging them through Mean Pooling.
    """
    def __init__(self, pretrained_model, num_numeric_features, output_size=len_targets, num_tokenized_cols=len_text_cols_all):
        super(AddedSubLayer, self).__init__()
        self.esm2 = pretrained_model
        self.esm2_hidden_size = pretrained_model.config.hidden_size
        self.num_tokenized_cols = num_tokenized_cols
        self.fc_input_size = self.esm2_hidden_size * self.num_tokenized_cols + num_numeric_features
        self.fc = nn.Linear(self.fc_input_size, output_size)
        self.dropout = nn.Dropout(p=0.1)

    def _mean_pooling(self, model_output, attention_mask):
        """
        Computes mean-pooled token embeddings by averaging the hidden states while masking out padding tokens.
        """
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids_list, attention_mask_list, numeric_features):
        """
        Processes multiple tokenized inputs and numeric features through the model.
        Outputs predictions using Linear layer and Sigmoid activation function.
        """
        pooled_outputs = []
        assert len(input_ids_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} input sequences, got {len(input_ids_list)}"
        assert len(attention_mask_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} attention masks, got {len(attention_mask_list)}"

        for i in range(self.num_tokenized_cols):
            outputs = self.esm2(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i])
            pooled_output = self._mean_pooling(outputs, attention_mask_list[i])
            pooled_outputs.append(pooled_output)
        concatenated_pooled_output = torch.cat(pooled_outputs, dim=1)

        combined_features = torch.cat((concatenated_pooled_output, numeric_features), dim=1)
        combined_features = self.dropout(combined_features)
        logits = self.fc(combined_features)
        predictions_scaled = torch.sigmoid(logits)
        return predictions_scaled

### Using All sequences in AddedSublayer (Run for ESM2)

In [None]:
# Added SubLayer without Pooling applied for text features (performed worse than Mean Pooling).

# class AddedSubLayer(nn.Module):
#     """
#     Class definition for Regression Head, which takes the sequences directly, without pooling them.
#     """
#     def __init__(self, pretrained_model, num_numeric_features, output_size=len_targets, num_tokenized_cols=len_text_cols_all, max_length=1024):
#         super(AddedSubLayer, self).__init__()
#         self.esm2 = pretrained_model
#         self.esm2_hidden_size = pretrained_model.config.hidden_size
#         self.num_tokenized_cols = num_tokenized_cols
#         self.max_length = max_length

#         flattened_text_size_per_sequence = self.max_length * self.esm2_hidden_size
#         total_flattened_text_size = flattened_text_size_per_sequence * self.num_tokenized_cols
#         self.fc_input_size = total_flattened_text_size + num_numeric_features

#         print(f"WARNING: Final Linear layer input size = {self.fc_input_size} (Flattened Text: {total_flattened_text_size}, Numeric: {num_numeric_features})")
#         self.fc = nn.Linear(self.fc_input_size, output_size)
#         self.dropout = nn.Dropout(p=0.1)

#     def forward(self, input_ids_list, attention_mask_list, numeric_features):
#         """
#         Flattens the full hidden states of each tokenized input, concatenating the results witjh numeric features.
#         Produces Sigoid-scaled predictions.
#         """
#         assert len(input_ids_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} input sequences, got {len(input_ids_list)}"
#         assert len(attention_mask_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} attention masks, got {len(attention_mask_list)}"
#         batch_size = numeric_features.shape[0]
#         flattened_outputs = []

#         for i in range(self.num_tokenized_cols):
#             outputs = self.esm2(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i])
#             flattened_output = outputs.last_hidden_state.contiguous().view(batch_size, -1)
#             expected_flat_size = self.max_length * self.esm2_hidden_size
#             if flattened_output.shape[1] != expected_flat_size:
#                 raise RuntimeError(f"Flattened size mismatch for input {i}. "
#                                     f"Expected {expected_flat_size}, got {flattened_output.shape[1]}. "
#                                     f"Check max_length used in tokenizer vs. model init.")

#             flattened_outputs.append(flattened_output)

#         concatenated_flattened_text = torch.cat(flattened_outputs, dim=1)
#         combined_features = torch.cat((concatenated_flattened_text, numeric_features), dim=1)
#         combined_features = self.dropout(combined_features)
#         logits = self.fc(combined_features)
#         predictions_scaled = torch.sigmoid(logits)
#         return predictions_scaled

### Continue with ESM2 model

In [None]:
class ProteinDatasetESM2(Dataset):
    def __init__(self, df, text_cols, num_feat_col="numeric_embeddings",
                 scaled_target_col="targets_scaled", original_target_col="targets_original",
                 tokenizer=None, max_length=1024):
      self.text_cols = text_cols
      self.numeric_features = torch.stack([torch.tensor(item, dtype=torch.float32) for item in df[num_feat_col].tolist()], dim=0)
      self.targets_scaled = torch.tensor(np.stack(df[scaled_target_col].values), dtype=torch.float32)
      self.targets_original = torch.tensor(np.stack(df[original_target_col].values), dtype=torch.float32)
      self.tokenizer = tokenizer
      self.max_length = max_length
      self.num_samples = len(self.targets_scaled)
      self.num_tokenized_cols = len(self.text_cols)
      self.text_dict = {}
      print("Loading text data...")
      for col in self.text_cols:
          if col not in df.columns:
                raise ValueError(f"Text column '{col}' not found in DataFrame.")
          self.text_dict[col] = df[col].fillna('').astype(str).tolist()

      print("Validating data lengths...")
      if len(self.numeric_features) != self.num_samples:
            raise ValueError(f"Mismatch in length between numeric features ({len(self.numeric_features)}) and targets ({self.num_samples})")
      if len(self.targets_original) != self.num_samples:
             raise ValueError(f"Mismatch in length between original targets ({len(self.targets_original)}) and scaled targets ({self.num_samples})")
      for col in self.text_cols:
          if len(self.text_dict[col]) != self.num_samples:
              raise ValueError(f"Mismatch in length between texts in column '{col}' ({len(self.text_dict[col])}) and targets ({self.num_samples})")

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return self.num_samples

    def __getitem__(self, idx):
        """
        Retrieves a single data sample by index.
        """
        data = {
            "numeric_features": self.numeric_features[idx],
            "targets_scaled": self.targets_scaled[idx],
            "targets_original": self.targets_original[idx],
            "input_ids": [],
            "attention_mask": []
        }

        for col in self.text_cols:
            sequence = self.text_dict[col][idx]
            inputs = self.tokenizer(
                sequence,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors="pt"
            )
            data["input_ids"].append(inputs['input_ids'].squeeze(0))
            data["attention_mask"].append(inputs['attention_mask'].squeeze(0))

        return data


### BERT Model Classes

In [None]:
class ProteinDatasetBERT(Dataset):
    def __init__(self, df,
                 prot_text_cols,
                 sci_text_cols,
                 prot_tokenizer,
                 sci_tokenizer,
                 num_feat_col="numeric_embeddings",
                 scaled_target_col="targets_scaled",
                 original_target_col="targets_original",
                 max_length_prot=1024,
                 max_length_sci=512):
      self.prot_text_cols = prot_text_cols
      self.sci_text_cols = sci_text_cols
      self.prot_tokenizer = prot_tokenizer
      self.sci_tokenizer = sci_tokenizer
      self.max_length_prot = max_length_prot
      self.max_length_sci = max_length_sci

      if len(prot_text_cols) != 1:
          raise ValueError(f"Expected exactly 1 protein text column, got {len(prot_text_cols)}")

      self.all_text_cols = self.prot_text_cols + self.sci_text_cols

      self.numeric_features = torch.stack([torch.tensor(item, dtype=torch.float32) for item in df[num_feat_col].tolist()], dim=0)
      self.targets_scaled = torch.tensor(np.stack(df[scaled_target_col].values), dtype=torch.float32)
      self.targets_original = torch.tensor(np.stack(df[original_target_col].values), dtype=torch.float32)
      self.num_samples = len(self.targets_scaled)

      self.text_dict = {}
      print("Loading text data...")
      for col in self.all_text_cols:
          if col not in df.columns:
                raise ValueError(f"Text column '{col}' not found in DataFrame.")
          self.text_dict[col] = df[col].fillna('').astype(str).tolist()

      print("Validating data lengths...")
      if len(self.numeric_features) != self.num_samples:
            raise ValueError(f"Mismatch in length between numeric features ({len(self.numeric_features)}) and targets ({self.num_samples})")
      if len(self.targets_original) != self.num_samples:
             raise ValueError(f"Mismatch in length between original targets ({len(self.targets_original)}) and scaled targets ({self.num_samples})")
      for col in self.all_text_cols:
          if len(self.text_dict[col]) != self.num_samples:
              raise ValueError(f"Mismatch in length between texts in column '{col}' ({len(self.text_dict[col])}) and targets ({self.num_samples})")

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return self.num_samples

    def __getitem__(self, idx):
        """
        Retrieves a single sample by index.
        """
        data = {
            "numeric_features": self.numeric_features[idx],
            "targets_scaled": self.targets_scaled[idx],
            "targets_original": self.targets_original[idx],
            "input_ids": [],
            "attention_mask": []
        }

        for col in self.prot_text_cols:
            sequence = self.text_dict[col][idx]
            inputs = self.prot_tokenizer(
                sequence,
                max_length=self.max_length_prot,
                padding='max_length',
                truncation=True,
                return_tensors="pt"
            )
            data["input_ids"].append(inputs['input_ids'].squeeze(0))
            data["attention_mask"].append(inputs['attention_mask'].squeeze(0))

        for col in self.sci_text_cols:
            sequence = self.text_dict[col][idx]
            inputs = self.sci_tokenizer(
                sequence,
                max_length=self.max_length_sci,
                padding='max_length',
                truncation=True,
                return_tensors="pt"
            )
            data["input_ids"].append(inputs['input_ids'].squeeze(0))
            data["attention_mask"].append(inputs['attention_mask'].squeeze(0))

        return data

In [None]:
class BertModel(nn.Module):
    def __init__(self, prot_model, sci_model, num_numeric_features, output_size=len_targets,
                 num_prot_cols=len_prot_cols, num_sci_cols=len_sci_cols, dropout_rate=0.1):
        super(BertModel, self).__init__()

        if num_prot_cols != 1:
             raise ValueError("This implementation expects exactly 1 protein column.")

        self.prot_model = prot_model
        self.sci_model = sci_model
        self.num_prot_cols = num_prot_cols
        self.num_sci_cols = num_sci_cols

        try:
             self.prot_hidden_size = prot_model.config.hidden_size
        except AttributeError:
             if hasattr(prot_model, 'embed_dim'):
                 self.prot_hidden_size = prot_model.embed_dim
             elif hasattr(prot_model, 'hidden_size'):
                 self.prot_hidden_size = prot_model.hidden_size
             else:
                 raise AttributeError("Could not determine hidden size for prot_model. Check model structure.")

        try:
             self.sci_hidden_size = sci_model.config.hidden_size
        except AttributeError:
             if hasattr(sci_model, 'hidden_size'):
                 self.sci_hidden_size = sci_model.hidden_size
             else:
                 raise AttributeError("Could not determine hidden size for sci_model. Check model structure.")

        self.fc_input_size = self.prot_hidden_size + self.sci_hidden_size + num_numeric_features
        self.norm_prot = nn.LayerNorm(self.prot_hidden_size)
        # self.norm_sci = nn.LayerNorm(self.sci_hidden_size)
        # self.norm_numeric = nn.LayerNorm(num_numeric_features)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fc = nn.Linear(self.fc_input_size, output_size)

    def _mean_pooling(self, model_output, attention_mask):
        """
        Computes the mean of token embeddings across the sequence length, weighted by the attention mask.
        """
        if hasattr(model_output, 'last_hidden_state'):
            token_embeddings = model_output.last_hidden_state
        else:
             token_embeddings = model_output

        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids_list, attention_mask_list, numeric_features):
        """
        Processes one protein input and multiple scientific text inputs through their models.
        Pools and combines the results with numeric features, then predicts using a linear layer with sigmoid activation.
        """
        total_expected_cols = self.num_prot_cols + self.num_sci_cols
        assert len(input_ids_list) == total_expected_cols, f"Expected {total_expected_cols} input sequences, got {len(input_ids_list)}"
        assert len(attention_mask_list) == total_expected_cols, f"Expected {total_expected_cols} attention masks, got {len(attention_mask_list)}"

        prot_outputs = self.prot_model(input_ids=input_ids_list[0], attention_mask=attention_mask_list[0])
        prot_pooled = self._mean_pooling(prot_outputs, attention_mask_list[0])
        prot_pooled = self.norm_prot(prot_pooled)

        sci_pooled_outputs = []
        for i in range(self.num_prot_cols, total_expected_cols):
            sci_outputs = self.sci_model(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i])
            sci_pooled = self._mean_pooling(sci_outputs, attention_mask_list[i])
            sci_pooled_outputs.append(sci_pooled)

        sci_pooled_stacked = torch.stack(sci_pooled_outputs, dim=0)
        sci_pooled_avg = torch.mean(sci_pooled_stacked, dim=0)

        combined_features = torch.cat((prot_pooled, sci_pooled_avg, numeric_features), dim=1)
        combined_features = self.dropout(combined_features)

        logits = self.fc(combined_features)
        predictions_scaled = torch.sigmoid(logits)

        return predictions_scaled

### BiLSTM Model classes

In [None]:
class MultiInputBiLSTMRegressor(nn.Module):
    def __init__(self, protbert_model, scibert_model, num_lstm_layers=2, lstm_dropout=0.2, final_dropout=0.3, hidden_size=128,
                 num_numeric_features=len_numeric, output_size=len_targets, len_prot_cols = len_prot_cols, len_sci_cols = len_sci_cols):
        super(MultiInputBiLSTMRegressor, self).__init__()
        self.protbert = protbert_model
        self.scibert = scibert_model
        print("Freezing BERT parameters...")
        for param in self.protbert.parameters():
            param.requires_grad = False
        for param in self.scibert.parameters():
            param.requires_grad = False
        print("BERT parameters frozen.")

        self.lstm_hidden_size = hidden_size

        self.protbert_lstm = nn.LSTM(input_size=self.protbert.config.hidden_size,
                                     hidden_size=self.lstm_hidden_size,
                                     batch_first=True, num_layers=num_lstm_layers,
                                     dropout=lstm_dropout if num_lstm_layers > 1 else 0,
                                     bidirectional=True)
        print(f"ProtBERT BiLSTM: input_size={self.protbert.config.hidden_size}, hidden_size_per_direction={self.lstm_hidden_size}")

        self.scibert_lstm = nn.LSTM(input_size=self.scibert.config.hidden_size,
                                    hidden_size=self.lstm_hidden_size,
                                    batch_first=True, num_layers=num_lstm_layers,
                                    dropout=lstm_dropout if num_lstm_layers > 1 else 0,
                                    bidirectional=True)
        print(f"SciBERT BiLSTM: input_size={self.scibert.config.hidden_size}, hidden_size_per_direction={self.lstm_hidden_size}")


        self.numeric_fc_output_dim = self.lstm_hidden_size * 2
        self.numeric_fc = nn.Linear(num_numeric_features, self.numeric_fc_output_dim)

        combined_input_dim = self.numeric_fc_output_dim * (len_prot_cols + len_sci_cols + 1)
        print(combined_input_dim)
        self.combined_fc = nn.Sequential(
            nn.Linear(combined_input_dim, 128),
            nn.ReLU(),
            nn.Dropout(final_dropout),
            nn.Linear(128, output_size)
        )

    def forward(self,
                input_ids_prot: torch.Tensor, attention_mask_prot: torch.Tensor,
                input_ids_sci_list: list[torch.Tensor], attention_mask_sci_list: list[torch.Tensor],
                numeric_features: torch.Tensor
               ) -> torch.Tensor:

        """
        Performs a forward pass by extracting sequence embeddings from BERT models, processing them with BiLSTMs,
        Combines the outputs with numeric features.
        Predicts the target through a fully connected regression head.
        """
        try:
            prot_outputs = self.protbert(input_ids=input_ids_prot, attention_mask=attention_mask_prot)
            prot_embeddings = prot_outputs.last_hidden_state
            _ , (prot_hidden_n, _) = self.protbert_lstm(prot_embeddings)
            prot_pooled = torch.cat((prot_hidden_n[-2, :, :], prot_hidden_n[-1, :, :]), dim=1)
        except Exception as e:
            print(f"!! ERROR during ProtBERT processing: {e}")
            raise e

        sci_pooled_list = []
        if len(input_ids_sci_list) != len_sci_cols or len(attention_mask_sci_list) != len_sci_cols:
            raise ValueError(f"Expected {len_sci_cols} SciBERT inputs, but got {len(input_ids_sci_list)}")

        for i in range(len_sci_cols):
            try:
                input_ids_sci = input_ids_sci_list[i]
                attention_mask_sci = attention_mask_sci_list[i]
                sci_outputs = self.scibert(input_ids=input_ids_sci, attention_mask=attention_mask_sci)
                sci_embeddings = sci_outputs.last_hidden_state
                _ , (sci_hidden_n, _) = self.scibert_lstm(sci_embeddings)
                sci_pooled = torch.cat((sci_hidden_n[-2, :, :], sci_hidden_n[-1, :, :]), dim=1)
                sci_pooled_list.append(sci_pooled)
            except Exception as e:
                print(f"!! ERROR during SciBERT processing (input {i+1}): {e}")
                print(f"   Input ID shape: {input_ids_sci.shape}")
                print(f"   Input Mask shape: {attention_mask_sci.shape}")
                if 'sci_embeddings' in locals(): print(f"   Embeddings shape: {sci_embeddings.shape}")
                raise e

        try:
            numeric_out = self.numeric_fc(numeric_features)
        except Exception as e:
            print(f"!! ERROR during Numeric FC processing: {e}")
            print(f"   Input numeric_features shape: {numeric_features.shape}")
            raise e

        try:
            shapes_ok = prot_pooled.shape[1] == numeric_out.shape[1]
            for idx, sp in enumerate(sci_pooled_list):
                if sp.shape[1] != prot_pooled.shape[1]:
                     shapes_ok = False

            if not shapes_ok: print("!! WARNING: Hidden dimensions seem inconsistent before concatenation!")
            combined = torch.cat([prot_pooled] + sci_pooled_list + [numeric_out], dim=1)
        except Exception as e:
            print(f"!! ERROR during concatenation: {e}")
            print(f"  Shape prot_pooled: {prot_pooled.shape}")
            for idx, sp in enumerate(sci_pooled_list): print(f"  Shape sci_pooled_{idx}: {sp.shape}")
            print(f"  Shape numeric_out: {numeric_out.shape}")
            raise e

        try:
            x = combined
            x = self.combined_fc[0](x)
            x = self.combined_fc[1](x)
            x = self.combined_fc[2](x)
            x = self.combined_fc[3](x)
            output = x
        except Exception as e:
            print(f"!! ERROR during combined_fc pass: {e}")
            print(f"   Input shape to combined_fc was: {combined.shape}")
            if 'x' in locals() and x.shape != combined.shape: print(f"   Shape *after* first linear layer (if reached): {x.shape}")
            raise e

        return output

In [None]:
class MultiInputDataset(Dataset):
    def __init__(self, dataframe,
                 prot_token_col: str,
                 sci_text_cols: list[str],
                 scibert_tokenizer: AutoTokenizer,
                 scibert_max_len: int,
                 numeric_col_scaled: str,
                 target_col_scaled: str,
                 target_col_original: str):

        self.dataframe = dataframe
        self.prot_token_col = prot_token_col

        if len(sci_text_cols) != len_sci_cols:
            raise ValueError(f"Expected {len_sci_cols} sci_token_cols, but got {len(sci_text_cols)}")
        self.sci_text_cols = sci_text_cols
        self.scibert_tokenizer = scibert_tokenizer
        self.scibert_max_len = scibert_max_len

        self.numeric_col_scaled = numeric_col_scaled
        self.target_col_scaled = target_col_scaled
        self.target_col_original = target_col_original

        required_cols = [prot_token_col] + sci_text_cols + [numeric_col_scaled, target_col_scaled, target_col_original]
        missing_cols = [col for col in required_cols if col not in dataframe.columns]
        if missing_cols:
            raise ValueError(f"DataFrame is missing required columns: {missing_cols}")

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        return len(self.dataframe)

    def __getitem__(self, idx):
        """
        Retrieves a single sample by index with all correpsonding data.
        """
        if idx >= len(self.dataframe):
             raise IndexError(f"Index {idx} out of bounds for dataframe with length {len(self.dataframe)}")
        try:
            item = self.dataframe.iloc[idx]
            prot_token_data = item[self.prot_token_col]
            if not isinstance(prot_token_data, (dict, BatchEncoding)):
                 raise TypeError(f"Expected a dict in column '{self.prot_token_col}' at index {idx}, but got {type(prot_token_data)}")
            input_ids_prot = torch.tensor(prot_token_data['input_ids'], dtype=torch.long)
            attention_mask_prot = torch.tensor(prot_token_data['attention_mask'], dtype=torch.long)

            input_ids_sci_list = []
            attention_mask_sci_list = []
            for i, col_name in enumerate(self.sci_text_cols):
                text = item[col_name]
                text = str(text) if pd.notna(text) and text is not None else ""
                sci_token_data = self.scibert_tokenizer(
                    text,
                    add_special_tokens=True,
                    padding='max_length',
                    max_length=self.scibert_max_len,
                    truncation=True,
                    return_attention_mask=True,
                    return_tensors='pt')

                sci_ids = sci_token_data['input_ids'].squeeze(0)
                sci_mask = sci_token_data['attention_mask'].squeeze(0)

                input_ids_sci_list.append(sci_ids)
                attention_mask_sci_list.append(sci_mask)

            numeric_features_data = item[self.numeric_col_scaled]
            if isinstance(numeric_features_data, torch.Tensor):
                numeric_features = numeric_features_data.float()
            else:
                numeric_features = torch.tensor(item[self.numeric_col_scaled], dtype=torch.float)

            targets_scaled_data = item[self.target_col_scaled]
            if isinstance(targets_scaled_data, torch.Tensor):
                targets_scaled = targets_scaled_data.float()
            else:
                targets_scaled = torch.tensor(targets_scaled_data, dtype=torch.float)

            targets_original_data = item[self.target_col_original]
            if isinstance(targets_original_data, torch.Tensor):
                targets_original = targets_original_data.float()
            else:
                targets_original = torch.tensor(targets_original_data, dtype=torch.float)

            return {
                'input_ids_prot': input_ids_prot,
                'attention_mask_prot': attention_mask_prot,
                'input_ids_sci_list': input_ids_sci_list,
                'attention_mask_sci_list': attention_mask_sci_list,
                'numeric_features': numeric_features,
                'targets_scaled': targets_scaled,
                'targets_original': targets_original
            }
        except KeyError as e:
            print(f"Error accessing key in DataFrame or tokenized dict at index {idx}: {e}")
            print(f"Ensure column '{self.prot_token_col}' and columns in {self.sci_token_cols} exist.")
            print(f"Also ensure the dictionaries within these columns contain 'input_ids' and 'attention_mask'.")
            raise
        except TypeError as e:
             print(f"Type error processing item at index {idx}: {e}")
             raise
        except Exception as e:
            print(f"Generic error processing item at index {idx}: {e}")
            raise

### Continue with general part

In [None]:
# Creating Train and Validation Datasets

print("Creating Datasets and Dataloaders...")
if model_chosen == "ESM2":
    train_dataset = ProteinDatasetESM2(train_df, text_cols=dataset_text_columns, tokenizer=tokenizer,
                                   num_feat_col=scaled_numeric_col_name,
                                   scaled_target_col=scaled_target_col_name,
                                   original_target_col=original_target_col_name)

    val_dataset = ProteinDatasetESM2(val_df, text_cols=dataset_text_columns, tokenizer=tokenizer,
                                 num_feat_col=scaled_numeric_col_name,
                                 scaled_target_col=scaled_target_col_name,
                                 original_target_col=original_target_col_name)

if model_chosen == "BERT":
    train_dataset = ProteinDatasetBERT(train_df, prot_text_cols=prot_col_list, sci_text_cols=sci_col_list,
                                   prot_tokenizer=prot_tokenizer, sci_tokenizer=sci_tokenizer, num_feat_col=scaled_numeric_col_name,
                                   scaled_target_col=scaled_target_col_name, original_target_col=original_target_col_name)
    val_dataset = ProteinDatasetBERT(val_df, prot_text_cols=prot_col_list, sci_text_cols=sci_col_list,
                                 prot_tokenizer=prot_tokenizer, sci_tokenizer=sci_tokenizer, num_feat_col=scaled_numeric_col_name,
                                 scaled_target_col=scaled_target_col_name, original_target_col=original_target_col_name)

elif model_chosen == "BiLSTM":
    train_dataset = MultiInputDataset(dataframe=train_df, prot_token_col=prot_col_list[0], sci_text_cols=sci_col_list,
                                      scibert_tokenizer=sci_tokenizer, scibert_max_len=512,
                                      numeric_col_scaled=scaled_numeric_col_name, target_col_scaled=scaled_target_col_name,
                                      target_col_original=original_target_col_name)

    val_dataset = MultiInputDataset(dataframe=val_df, prot_token_col=prot_col_list[0],
                                    sci_text_cols=sci_col_list, scibert_tokenizer=sci_tokenizer,
                                    scibert_max_len=512, numeric_col_scaled=scaled_numeric_col_name,
                                    target_col_scaled=scaled_target_col_name, target_col_original=original_target_col_name)

if model_chosen == "BiLSTM":
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=None)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=None)
else:
    train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

In [None]:
num_numeric_features = len_numeric
num_tokenized_cols = len(dataset_text_columns)
output_size = len_targets
print(f"Numeric features: {num_numeric_features}, Tokenized cols: {num_tokenized_cols}, Output size: {output_size}")

In [None]:
# Defining model wrapper for training.

if model_chosen == "ESM2":
    model_wrapper = AddedSubLayer(pretrained_model=model,
                        output_size=output_size,
                        num_numeric_features=num_numeric_features,
                        num_tokenized_cols=num_tokenized_cols)

elif model_chosen == "BERT":
    model_wrapper = BertModel(prot_model=prot_model, sci_model=sci_model,
                              num_numeric_features=num_numeric_features,
                              output_size=output_size, num_prot_cols=len(prot_col_list),
                              num_sci_cols=len(sci_col_list))

elif model_chosen == "BiLSTM":
    model_wrapper = MultiInputBiLSTMRegressor(protbert_model=prot_model, scibert_model=sci_model,
                                              hidden_size=256, num_lstm_layers=2,
                                              lstm_dropout=0.2, final_dropout=0.3,
                                              num_numeric_features=len_numeric, output_size=output_size)

model_wrapper.to(device)
print(f"Model instantiated and moved to {device}.")

In [None]:
# Initializing the best optimizers for each model.

if model_chosen == "ESM2":
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model_wrapper.parameters()), lr=learning_rate, weight_decay=1e-4)
elif model_chosen == "BERT":
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model_wrapper.parameters()), lr=learning_rate, weight_decay=1e-4)
elif model_chosen == "BiLSTM":
    optimizer = optim.AdamW(model_wrapper.parameters(), lr=learning_rate, weight_decay=0.01)

loss_fn = nn.MSELoss()
metrics_calculator = MetricsCalculatorUpd(scaler=target_scaler)

In [None]:
# Setting by Learning Rate Scheduler and evaluation parameters.

scheduler = ReduceLROnPlateau(optimizer = optimizer, mode='min', factor=scheduler_factor,
                              patience=scheduler_patience_value, min_lr=min_learning_rate)
print_prediction_comparison_epoch = True
best_val_loss = float('inf')
epochs_no_improve = 0

In [None]:
# Stating correct saving paths for metrics and models parameters.

if model_chosen == "ESM2":
    save_dir = f"{main_path}Models_Artifacts_{database}/ESM2_{df_state}_text_50_epochs_Mean_Pooling"

    os.makedirs(save_dir, exist_ok=True)
    tokenizer_path = os.path.join(save_dir, 'tokenizer') # Tokenizer saves to a directory
    metrics_path = os.path.join(save_dir, 'metrics_df_esm')

elif model_chosen == "BERT":
    save_dir = f'{main_path}Models_Artifacts_{database}/BERT_{df_state}_text_50_epochs_Mean_Pooling'
    os.makedirs(save_dir, exist_ok=True)
    prot_tokenizer_path = os.path.join(save_dir, 'prot_tokenizer')
    sci_tokenizer_path = os.path.join(save_dir, 'sci_tokenizer')
    metrics_path = os.path.join(save_dir, 'metrics_df_bert')

elif model_chosen == "BiLSTM":
    save_dir = f'{main_path}Models_Artifacts_{database}/MultiInputBiLSTM_{df_state}_text_50_epochs_Mean_Pooling'
    os.makedirs(save_dir, exist_ok=True)
    prot_tokenizer_path = os.path.join(save_dir, 'prot_tokenizer')
    sci_tokenizer_path = os.path.join(save_dir, 'sci_tokenizer')
    metrics_path = os.path.join(save_dir, 'metrics_df_lstm')

numeric_scaler_path = os.path.join(save_dir, 'numeric_feature_scaler.pkl')
target_scaler_path = os.path.join(save_dir, 'target_scaler.pkl')

print(f"Model artifacts will be saved to: {os.path.abspath(save_dir)}")

### Training and Saving

In [None]:
# Training and Evaluation are performed in this chunk.
# The best model is saved during the validation if new loss < old loss (for validation).
# Gradient clipping is applied along with early to prevent overfitting and exploding gradients.
# Metrics are gathered in the end to be saved further on to corresponging save_dir directories.

start_time = time.time()
print("Starting training...")

for epoch in range(num_epochs):
    print(f"\n--- Epoch {epoch + 1}/{num_epochs} ---")
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current Learning Rate: {current_lr:.2e}")

    model_wrapper.train()
    train_loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1} Training", leave=False)
    for batch_num, batch in enumerate(train_loop):
        if model_chosen == "BiLSTM":
            input_ids_prot = batch['input_ids_prot'].to(device)
            attention_mask_prot = batch['attention_mask_prot'].to(device)
            input_ids_sci_list = [tensor.to(device) for tensor in batch['input_ids_sci_list']]
            attention_mask_sci_list = [tensor.to(device) for tensor in batch['attention_mask_sci_list']]
        else:
            input_ids_list = [tensor.to(device) for tensor in batch['input_ids']]
            attention_mask_list = [tensor.to(device) for tensor in batch['attention_mask']]

        numeric_features = batch['numeric_features'].to(device)
        targets_scaled = batch['targets_scaled'].to(device)
        targets_original = batch['targets_original'].to(device)

        optimizer.zero_grad()
        if model_chosen == "BiLSTM":
            predictions_scaled = model_wrapper(input_ids_prot=input_ids_prot,
                                       attention_mask_prot=attention_mask_prot,
                                       input_ids_sci_list=input_ids_sci_list,
                                       attention_mask_sci_list=attention_mask_sci_list,
                                       numeric_features=numeric_features)
        else:
            predictions_scaled = model_wrapper(input_ids_list, attention_mask_list, numeric_features)

        loss = loss_fn(predictions_scaled, targets_scaled)

        loss.backward()
        if model_chosen == "ESM2":
            torch.nn.utils.clip_grad_norm_(model_wrapper.parameters(), max_norm=gradient_clip_value)
        else:
            torch.nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model_wrapper.parameters()), max_norm=gradient_clip_value)

        optimizer.step()

        metrics_calculator.update_train_batch(loss.item(), numeric_features.size(0))
        train_loop.set_postfix(loss=f"{loss.item():.4f}")

        if batch_num % 100 == 0:
             current_avg_loss = metrics_calculator.epoch_train_loss / metrics_calculator.epoch_train_samples if metrics_calculator.epoch_train_samples else 0
             print(f"  Epoch {epoch+1}, Batch {batch_num}/{len(train_dataloader)}, Current Avg Train Loss (scaled): {current_avg_loss:.6f}")


    model_wrapper.eval()
    with torch.no_grad():
        val_loop = tqdm(val_dataloader, desc=f"Epoch {epoch+1} Validation", leave=False, unit="batch")
        for batch_num, batch in enumerate(val_loop):
            if model_chosen == "BiLSTM":
                input_ids_prot = batch['input_ids_prot'].to(device)
                attention_mask_prot = batch['attention_mask_prot'].to(device)
                input_ids_sci_list = [tensor.to(device) for tensor in batch['input_ids_sci_list']]
                attention_mask_sci_list = [tensor.to(device) for tensor in batch['attention_mask_sci_list']]
            else:
                input_ids_list = [tensor.to(device) for tensor in batch['input_ids']]
                attention_mask_list = [tensor.to(device) for tensor in batch['attention_mask']]

            numeric_features = batch['numeric_features'].to(device)
            targets_scaled = batch['targets_scaled'].to(device)
            targets_original = batch['targets_original'].to(device)

            if model_chosen == "BiLSTM":
                predictions_scaled = model_wrapper(input_ids_prot=input_ids_prot,
                                           attention_mask_prot=attention_mask_prot,
                                           input_ids_sci_list=input_ids_sci_list,
                                           attention_mask_sci_list=attention_mask_sci_list,
                                           numeric_features=numeric_features)
            else:
                predictions_scaled = model_wrapper(input_ids_list, attention_mask_list, numeric_features)

            val_loss_batch = loss_fn(predictions_scaled, targets_scaled)
            metrics_calculator.update_val_batch(val_loss_batch.item(),
                                                predictions_scaled,
                                                targets_scaled,
                                                targets_original,
                                                numeric_features.size(0))
            val_loop.set_postfix(val_loss=f"{val_loss_batch.item():.4f}")

            if print_prediction_comparison_epoch and batch_num == 0:
                predictions_original = target_scaler.inverse_transform(predictions_scaled.cpu().numpy())
                targets_original_np = targets_original.cpu().numpy()

                print("\n--- Prediction vs. Actual (Original Scale) ---")
                print("First few examples from the first validation batch:")
                for i in range(min(5, len(predictions_original))):
                    print(f"Example {i+1}:")
                    print(f"Prediction: {[f'{p:.2f}' for p in predictions_original[i]]}")
                    print(f"Actual:     {[f'{t:.2f}' for t in targets_original_np[i]]}")
                    diff = predictions_original[i] - targets_original_np[i]
                    print(f"Difference: {[f'{d:.2f}' for d in diff]}")
                print("--------------------------------------------")

    train_loss_epoch, val_loss_epoch, \
    val_mse_avg_epoch, val_rmse_avg_epoch, val_r2_avg_epoch, val_rrmse_avg_epoch, \
    list_val_mse_epoch, list_val_rmse_epoch, list_val_r2_epoch, list_val_rrmse_epoch = \
        metrics_calculator.calculate_epoch_metrics(epoch)

    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"Train Loss (scaled): {train_loss_epoch:.6f}")
    val_loss_str = f"{val_loss_epoch:.6f}" if not np.isnan(val_loss_epoch) else "N/A"
    val_mse_avg_str = f"{val_mse_avg_epoch:.6f}" if not np.isnan(val_mse_avg_epoch) else "N/A"
    val_rmse_avg_str = f"{val_rmse_avg_epoch:.6f}" if not np.isnan(val_rmse_avg_epoch) else "N/A"
    val_r2_avg_str = f"{val_r2_avg_epoch:.6f}" if not np.isnan(val_r2_avg_epoch) else "N/A"
    val_rrmse_avg_str = f"{val_rrmse_avg_epoch:.6f}" if not np.isnan(val_rrmse_avg_epoch) else "N/A"
    print(f"  Val Loss (scaled):   {val_loss_str}")
    print(f"  Val MSE (avg):       {val_mse_avg_str}")
    print(f"  Val RMSE (avg):      {val_rmse_avg_str}")
    print(f"  Val R^2 (avg):       {val_r2_avg_str}")
    print(f"  Val RRMSE (avg):       {val_rrmse_avg_str}")

    if list_val_mse_epoch:
        num_outputs_calc = len(list_val_mse_epoch)
        for i in range(num_outputs_calc):
              mse_i_str = f"{list_val_mse_epoch[i]:.6f}" if not np.isnan(list_val_mse_epoch[i]) else "N/A"
              rmse_i_str = f"{list_val_rmse_epoch[i]:.6f}" if not np.isnan(list_val_rmse_epoch[i]) else "N/A"
              r2_i_str = f"{list_val_r2_epoch[i]:.6f}" if not np.isnan(list_val_r2_epoch[i]) else "N/A"
              rrmse_i_str = f"{list_val_rrmse_epoch[i]:.6f}" if not np.isnan(list_val_rrmse_epoch[i]) else "N/A"
    else:
        print("Per-output metrics could not be calculated.")

    scheduler.step(val_loss_epoch)

    if val_loss_epoch < best_val_loss:
        print(f"Validation loss improved ({best_val_loss:.6f} --> {val_loss_epoch:.6f}). Saving model...")
        best_val_loss = val_loss_epoch
        # epochs_no_improve = 0
        # if model_chosen == "ESM2":
        #     torch.save(model_wrapper.state_dict(), os.path.join(save_dir, "ESM_best_model.pth"))
        # elif model_chosen == "BERT":
        #     torch.save(model_wrapper.state_dict(), os.path.join(save_dir, "BERT_best_model.pth"))
        # elif model_chosen == "BiLSTM":
        #     torch.save(model_wrapper.state_dict(), os.path.join(save_dir, "LSTM_best_model.pth"))

    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve for {epochs_no_improve} epoch(s). Best so far: {best_val_loss:.6f}")

    if epochs_no_improve >= early_stopping_patience_value:
        print(f"\nEarly stopping triggered after {epoch + 1} epochs.")
        break

if epoch == num_epochs - 1:
     print("\nTraining finished after reaching max epochs.")
print(f"Best validation loss achieved: {best_val_loss:.6f}")

print("\nTraining finished.")
end_time = time.time()

In [None]:
# Time Elapsed during training.

elapsed_time = end_time - start_time
total_epochs_run = epoch + 1
print(f"Training took: {elapsed_time / 3600:.2f} hours ({elapsed_time:.2f} seconds)")

if total_epochs_run > 0:
    print(f"Average time per epoch: {elapsed_time / total_epochs_run:.2f} seconds")
else:
    print("No epochs were completed.")

In [None]:
# Printing final_metrics to see the trend.

final_metrics_df = metrics_calculator.get_metrics_df()
print("\n--- Final Metrics Summary ---")
print(final_metrics_df.to_string())

In [None]:
# Saving all collected data to corresponding folders based on model choice.

final_metrics_df.to_csv(metrics_path, sep='\t', index = False)

if model_chosen == "ESM2":
    print(f"Saving tokenizer to directory: {tokenizer_path}")
    tokenizer.save_pretrained(tokenizer_path)
elif model_chosen == "BERT":
    print(f"Saving Prot tokenizer to directory: {prot_tokenizer_path}")
    prot_tokenizer.save_pretrained(prot_tokenizer_path)
    print(f"Saving Sci tokenizer to directory: {sci_tokenizer_path}")
    sci_tokenizer.save_pretrained(sci_tokenizer_path)

print(f"Saving numeric feature scaler to: {numeric_scaler_path}")
joblib.dump(numeric_scaler, numeric_scaler_path)

print(f"Saving target value scaler to: {target_scaler_path}")
joblib.dump(target_scaler, target_scaler_path)

print("\nInference Artifacts are saved successfully")

# Inference (for the BEST model) and visualizations

In [None]:
# Specify columns you want (at least, MUST specify 'Sequence' for inference to work!)
# Leave "<UNK>" for unspecified text columns.
# Leave None for unspecified numeric columns (do not specify 'Total Length (AA)' and 'Contour Length [nm]', leave None).

text_inputs_dict = {
    'Sequence': "RLDAPSQIEVKDVTDTTALITWFKPLAEIDGIELTYGIKDVPGDRTTIDLTEDENQYSIGNLKPDTEYEVSLISRRGDMSSNPAKETFTT",
    'Name': "FIBRONECTIN TYPE III DOMAIN FROM TENASCIN",
    'SCOP annotation': "<UNK>",
    'Experimental Conditions': "pH = 7",
    'Organism': "Homo sapiens",
    'Classification': "CELL ADHESION PROTEIN",
    'Technique': "<UNK>",
    'Pulling Mode': "<UNK>",
    'Unfolding Pathway': "<UNK>",
    'PDB_UniProt': "1TEN"
}

numeric_inputs_dict = {
    'Highest unfolding forces/ Clamp forces [pN]': np.nan,
    'Standard Deviation of force [pN]':np.nan,
    'Total Length (AA)': np.nan,
    'Pulling Start': np.nan,
    'Pulling End': np.nan,
    'Velocity [nm/s]': np.nan,
    'Contour Length [nm]': np.nan
}

In [None]:
# Specifying target variables to predict

target_columns = ["ΔG [kBT]", "Xu [nm]", "Koff [s-¹]"]

In [None]:
# Redefining ESM-2 AddedSublayer (for this section to work independently from Models section).

class AddedSubLayer(nn.Module):
    def __init__(self, pretrained_model, num_numeric_features, output_size, num_tokenized_cols):
        super(AddedSubLayer, self).__init__()
        self.esm2 = pretrained_model
        self.esm2_hidden_size = pretrained_model.config.hidden_size
        self.num_tokenized_cols = num_tokenized_cols
        self.fc_input_size = self.esm2_hidden_size * self.num_tokenized_cols + num_numeric_features

        self.fc = nn.Linear(self.fc_input_size, output_size)
        self.dropout = nn.Dropout(p=0.1)

    def _mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def forward(self, input_ids_list, attention_mask_list, numeric_features):
        pooled_outputs = []
        assert len(input_ids_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} input sequences, got {len(input_ids_list)}"
        assert len(attention_mask_list) == self.num_tokenized_cols, f"Expected {self.num_tokenized_cols} attention masks, got {len(attention_mask_list)}"

        for i in range(self.num_tokenized_cols):
            outputs = self.esm2(input_ids=input_ids_list[i], attention_mask=attention_mask_list[i])
            pooled_output = self._mean_pooling(outputs, attention_mask_list[i])
            pooled_outputs.append(pooled_output)
        concatenated_pooled_output = torch.cat(pooled_outputs, dim=1)

        combined_features = torch.cat((concatenated_pooled_output, numeric_features), dim=1)
        combined_features = self.dropout(combined_features)
        logits = self.fc(combined_features)
        predictions_scaled = torch.sigmoid(logits)
        return predictions_scaled

In [None]:
# Inference part, defined only for MechanoProDB due to incomplete received dataset of ProThermDB and only for ESM-2.

database = 'mechano'
artifacts_path = f"{main_path}Models_Artifacts_{database}/ESM2_all_text_50_epochs_Mean_Pooling/"
target_scaler_path = f"{artifacts_path}target_scaler.pkl"
numeric_scaler_path = f"{artifacts_path}numeric_feature_scaler.pkl"
model_state_path = f"{artifacts_path}ESM_best_model.pth"
tokenizer_path = f"{artifacts_path}tokenizer"

loaded_target_scaler = joblib.load(target_scaler_path)
output_size = loaded_target_scaler.n_features_in_
print(f"Scaler expects {output_size} target variables.")
loaded_numeric_scaler = joblib.load(numeric_scaler_path)
total_num_features = loaded_numeric_scaler.n_features_in_
print(f"Scaler expects {total_num_features} numeric features.")
loaded_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

max_token_length = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_name = "facebook/esm2_t6_8M_UR50D"
num_expected_text_cols = 10

try:
    print(f"Loading base model: {base_model_name}...")
    base_esm_model = AutoModel.from_pretrained(base_model_name).to(device)
    print("Base model loaded.")

    model_inference = AddedSubLayer(
        pretrained_model=base_esm_model,
        num_numeric_features=total_num_features,
        output_size=output_size,
        num_tokenized_cols=num_expected_text_cols
    ).to(device)

    model_inference.load_state_dict(torch.load(model_state_path, map_location=device))
    model_inference.eval()
    print("Wrapper model instantiated and state loaded successfully.")
except Exception as e:
    print(f"Error loading model state from {model_state_path}: {e}")
    print("Check base_model_name, AddedSubLayer definition, and saved state compatibility.")
    exit()

In [None]:
def predict_single_instance(text_inputs, numeric_inputs, model, tokenizer, numeric_scaler, target_scaler, max_len, device, num_text_cols):
    """
    Makes a prediction for a single instance of data.
    """
    if len(text_inputs) != num_text_cols:
        raise ValueError(f"Expected {num_text_cols} text inputs, but got {len(text_inputs)}")
    if len(numeric_inputs) != numeric_scaler.n_features_in_:
         raise ValueError(f"Expected {numeric_scaler.n_features_in_} numeric inputs, but got {len(numeric_inputs)}")

    model.eval()

    input_ids_list = []
    attention_mask_list = []
    for text in text_inputs:
        encoding = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids_list.append(encoding['input_ids'].to(device))
        attention_mask_list.append(encoding['attention_mask'].to(device))

    numeric_inputs_np = np.array(numeric_inputs).reshape(1, -1)
    scaled_numeric_features = numeric_scaler.transform(numeric_inputs_np)
    numeric_features_tensor = torch.tensor(scaled_numeric_features, dtype=torch.float32).to(device)

    with torch.no_grad():
        scaled_predictions = model(
            input_ids_list=input_ids_list,
            attention_mask_list=attention_mask_list,
            numeric_features=numeric_features_tensor
        )

    scaled_predictions_np = scaled_predictions.cpu().numpy()
    if scaled_predictions_np.ndim == 1:
        scaled_predictions_np = scaled_predictions_np.reshape(1, -1)

    final_predictions = target_scaler.inverse_transform(scaled_predictions_np)

    return final_predictions.flatten()

In [None]:
# Filling non-specified numeric values by means of corresponding columns from MechanoProDB.
# Automatic filling of 'Total Length' and 'Contour Length' from 'Sequence'

numeric_inputs_dict['Total Length (AA)'] = len(text_inputs_dict['Sequence'])
numeric_inputs_dict['Contour Length [nm]'] = 0.35 * numeric_inputs_dict['Total Length (AA)']
numeric_inputs_dict['Pulling Start'] = 1 if np.isnan(numeric_inputs_dict['Pulling Start']) else numeric_inputs_dict['Pulling Start']
numeric_inputs_dict['Pulling End'] = numeric_inputs_dict['Total Length (AA)'] if np.isnan(numeric_inputs_dict['Pulling End']) \
                                     else numeric_inputs_dict['Pulling End']
numeric_inputs_dict['Highest unfolding forces/ Clamp forces [pN]'] = tokenized_df['Highest unfolding forces/ Clamp forces [pN]'].mean() \
                                                                     if np.isnan(numeric_inputs_dict['Highest unfolding forces/ Clamp forces [pN]']) \
                                                                     else numeric_inputs_dict['Highest unfolding forces/ Clamp forces [pN]']
numeric_inputs_dict['Standard Deviation of force [pN]'] = tokenized_df['Standard Deviation of force [pN]'].mean() \
                                                          if np.isnan(numeric_inputs_dict['Standard Deviation of force [pN]']) \
                                                          else numeric_inputs_dict['Standard Deviation of force [pN]']
numeric_inputs_dict['Velocity [nm/s]'] = tokenized_df['Velocity [nm/s]'].mean() if np.isnan(numeric_inputs_dict['Velocity [nm/s]']) \
                                         else numeric_inputs_dict['Velocity [nm/s]']

text_inputs_dict_values = list(text_inputs_dict.values())
temp_numeric_inputs_dict_values = list(numeric_inputs_dict.values())
numeric_inputs_dict_values = [round(float(val), 2) for val in temp_numeric_inputs_dict_values]

assert len(text_inputs_dict_values) == num_expected_text_cols, "Incorrect number of text inputs"
assert len(numeric_inputs_dict_values) == total_num_features, "Incorrect number of numeric inputs"

In [None]:
# Instantiating the predictions and saving predicted values to corerpsonding variables for future visualizations.

try:
    predictions = predict_single_instance(text_inputs=text_inputs_dict_values, numeric_inputs=numeric_inputs_dict_values,
                                          model=model_inference, tokenizer=loaded_tokenizer, numeric_scaler=loaded_numeric_scaler,
                                          target_scaler=loaded_target_scaler, max_len=max_token_length,
                                          device=device, num_text_cols=num_expected_text_cols)

    print("\n - CHECK YOUR INPUTS -\n")
    print("\n Text Inputs:")
    for key, value in text_inputs_dict.items():
      print(f"{key}: {value[:50]}..." if value != "<UNK>" else f"{key}: {value}")
    print()
    print("\n Numeric Inputs:")
    for key, val in zip(numeric_inputs_dict.keys(), numeric_inputs_dict_values):
      print(f"{key}: {val}")


    print("\n - PREDICTIONS - \n")
    print(f"{target_columns[0]}: {predictions[0]}")
    print(f"{target_columns[1]}: {predictions[1]}")
    print(f"{target_columns[2]}: {predictions[2]}")

    delta_G = predictions[0]
    Xu = predictions[1]
    Koff = predictions[2]
    contour_length = numeric_inputs_dict['Contour Length [nm]']
    kBT = 4.1

except ValueError as ve:
    print(f"Input data validation error: {ve}")
except Exception as e:
    print(f"An error occurred during prediction: {e}")

In [None]:
def calculate_G0(x, delta_G, Xu, mu=2/3, theta=0):
    """
    Calculates the free energy G0(x) based on the provided formula
    with constants mu = 2/3 and theta = 0.
    These parameters are used as a convention.
    """
    if Xu <= 0:
        raise ValueError("Xu (distance x‡) must be positive.")

    term1 = (delta_G * x) / (mu * Xu)
    term2 = (2 * x**2) / Xu
    term3 = np.abs(x) + x * np.sin(theta)
    term4 = mu / (1 - mu)
    term5 = (term2 / term3)**term4
    term6 = 1 - (1 - mu) * term5

    g0 = term1 * term6

    return g0

def plot_energy_landscape(delta_G=delta_G, Xu=Xu):
    """
    Visualizes the one-dimensional free energy profile along a reaction coordinate based on given parameters.
    Highlights the minimum (Gn) and maximum (Gt) points on the plot.
    """
    x_vals = np.linspace(-2, 2, 1000)
    g_vals = calculate_G0(x_vals, delta_G=delta_G, Xu=Xu)

    min_idx = argrelextrema(g_vals, np.less)[0]
    max_idx = argrelextrema(g_vals, np.greater)[0]

    plt.plot(x_vals, g_vals, color='black')
    plt.xlabel("Reaction Coordinate (nm)", fontsize=14)
    plt.ylabel("G₀(x)", fontsize=14)
    plt.title("Free Energy Profile", fontsize=14)
    plt.grid(False)
    plt.ylim(-1.5 * delta_G, 1.5 * delta_G)
    plt.xlim(-1.3 * Xu, 1.3 * Xu)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)

    if len(min_idx) > 0:
        x_min = x_vals[min_idx[0]]
        y_min = g_vals[min_idx[0]]
        plt.plot(x_min, y_min, 'o', color='blue', alpha = 0.5)
        plt.text(x_min, y_min - 2, "$G_n$", ha='center', va='top', fontsize=14)

    if len(max_idx) > 0:
        x_max = x_vals[max_idx[0]]
        y_max = g_vals[max_idx[0]]
        plt.plot(x_max, y_max, 'o', color='red', alpha=0.5)
        plt.text(x_max, y_max + 2, "$G_t$", ha='center', va='bottom', fontsize=14)

    plt.show()


In [None]:
plot_energy_landscape()

In [None]:
def heaviside(x):
    if x > 0:
      return 1
    else:
      return 0

def compute_F_cFirst(delta_G, Xu, Koff, gamma=0.5772, kBT=kBT):
    """
    Computes Ḟ_cFirst using the Kramers’ approximation.
    γ ≈ 0.5772 is the Euler–Mascheroni constant from paper.
    """
    numerator = (Koff * kBT * np.exp((delta_G / kBT) + gamma)) / Xu
    exponent = -delta_G / kBT
    return numerator * (1 - np.exp(exponent))

def find_diffusion_coeff(delta_G, Xu, Koff, kBT=kBT):
    """
    Zeta is the diffusion coefficient (ζ)
    Computes zeta from known K0 = Koff.
    """
    numerator = 3 * delta_G
    denominator = np.pi * Koff * Xu**2
    exp_term = np.exp(-delta_G / kBT)
    zeta = (numerator / denominator) * exp_term
    return zeta

def F_first(delta_G, xu, Koff, nu=2/3, kBT=kBT, F_dot=100000):
    """
    Calculates F_first from the paper, needed for the next calculations.
    F_dot is the loading rate, can be changed by user if needed (by default is set to 10^5 pN/s).
    """
    F_cFirst = compute_F_cFirst(delta_G, xu, Koff)
    zeta = find_diffusion_coeff(delta_G, Xu, Koff)

    term_1_prefactor = delta_G / (nu * Xu)
    e1_arg = (Koff * kBT) / (Xu * F_dot)
    E1 = exp1(e1_arg)
    partial_term = (kBT / delta_G) * np.exp(e1_arg)

    inner_term_1 = 1 - (np.abs(1 - (partial_term * E1))**nu)
    term_1 = term_1_prefactor * (inner_term_1) * heaviside(F_cFirst - F_dot)

    sqrt_2F_first = np.sqrt(2 * F_cFirst * Xu * zeta)
    sqrt_2F_dot = np.sqrt(2 * F_dot * Xu * zeta)
    inner_term_2 = (delta_G / (nu * xu)) - sqrt_2F_first + sqrt_2F_dot
    term_2 = inner_term_2 * heaviside(F_dot - F_cFirst)

    final_result = term_1 + term_2

    return final_result

def x_of_f(f, L):
    """
    Calculates the extension of a polymer under a given force, using a semi-empirical force-extension relation, scaled by the contour length.
    """
    if f <= 0:
        return np.nan
    try:
      term1 = 4 / 3
      term2 = 4 / (3 * np.sqrt(f + 1))
      exp_term = np.exp(np.power(900 / f, 0.25))
      term3 = (10 * exp_term) / (np.sqrt(f) * (exp_term - 1)**2)
      term4 = f**1.62 / (3.55 + 3.8 * f**2.2)
      final_value = (term1 - term2 - term3 + term4) * L
      return final_value
    except Exception:
        return np.nan

def find_force_extension_parameters(delta_G=delta_G, Xu=Xu, Koff=Koff, L=contour_length, kBT=kBT, P=0.4):
    """
    Computes the expected extension of a polymer based on molecular energy parameters and worm-like chain (WLC) model approximations.
    P is the persistence length in nanometers, conventionally is 0.4
    Contour length is taken from the input data.
    """
    F_final_value = F_first(delta_G, Xu, Koff)
    f_target = (F_final_value * P) / kBT

    x_f = x_of_f(f_target, L)
    return x_f

def inverse_force_extension_formula(x_target, delta_G=delta_G, Xu=Xu, Koff=Koff, L=contour_length, kBT=kBT, P=0.4):
    """
    Finds the force that corresponds to a given extension by numerically solving the inverse of the force-extension relationship.
    """
    F_final_value = F_first(delta_G, Xu, Koff)
    f_target = (F_final_value * P) / kBT

    def equation(f):
        return x_of_f(f, L) - x_target

    f_guess = 0.1
    try:
        f_solution = fsolve(equation, f_guess, xtol=1e-6)[0]
        return f_solution
    except Exception:
        return np.nan

def plot_force_extension_curve():
    """
    Plots the theoretical force-extension curve of a polymer.
    """
    x_val = find_force_extension_parameters()
    x_vals = np.linspace(0, x_val, 200)
    f_vals = [inverse_force_extension_formula(x) for x in x_vals]

    noise_strength = 0.2
    f_vals_noisy = f_vals + np.random.normal(0, noise_strength, size=len(f_vals))
    plt.plot(x_vals, f_vals_noisy, label='f(x)', color='black', linewidth=1.4)
    x_last = x_vals[-1]
    y_last = f_vals_noisy[-1]
    plt.plot([x_last, x_last], [y_last, 0], color='black', linewidth=1.4)
    tail_x = np.linspace(x_last, x_last + 5, 50)
    tail_noise = np.random.normal(0, noise_strength, size=tail_x.shape)
    plt.plot(tail_x, tail_noise, color='black', linewidth=1.4)

    plt.axhline(y=0, color='red', linestyle='--')
    plt.title('Force-Extension Curve', fontsize=14)
    plt.xlabel('Extension (nm)', fontsize=14)
    plt.ylabel('Force (pN)', fontsize=14)
    # plt.legend()
    plt.grid(False)
    plt.show()

In [None]:
plot_force_extension_curve()

# Comparison of results (3 Neural Network Models)

In [None]:
# Loading metrics for MechanoProDB and ProThermDB from their corresponding folders.

artifacts_path = f"{main_path}Models_Artifacts_{database}/"
main_dir = f"_{df_state}_text_50_epochs_Mean_Pooling/metrics_df_"

metrics_esm = pd.read_csv(artifacts_path + "ESM2" + main_dir + "esm.tsv", sep='\t')
if mechano:
  metrics_bert = pd.read_csv(artifacts_path + "BERT" + main_dir + "bert.tsv", sep='\t')
  metrics_lstm = pd.read_csv(artifacts_path + "MultiInputBiLSTM" + main_dir + "lstm.tsv", sep='\t')

In [None]:
# Due to non-changing nature of BiLSTM model atfer first epochs, to save the time, last 15 epochs vere average by the last run epoch parameters.

if mechano:
  last_row = metrics_lstm.iloc[[-1]]
  last_15_rows = pd.concat([last_row] * (50 - len(metrics_lstm)) , ignore_index=True)
  last_15_rows['epoch'] = range(len(metrics_lstm), 50)
  metrics_lstm_extended = pd.concat([metrics_lstm, last_15_rows], ignore_index=True)

  epochs = metrics_lstm_extended.index if 'epoch' not in metrics_lstm_extended.columns else metrics_lstm_extended['epoch']
else:
  epochs = metrics_esm.index if 'epoch' not in metrics_esm.columns else metrics_esm['epoch']

In [None]:
# Plotting Scores for Evaluation Metrics of each model for target variables (both for MechanoProDB and ProThermDB).

if mechano:
  target_names = ["ΔG", "Xu", "Koff"]
else:
  target_names = ["Tm_(C)"]

metrics_names = ['MSE','RMSE','R^2', 'Loss', 'RRMSE']
metrics_names_small = ['mse', 'rmse', 'r2', 'loss', 'rrmse']

for i in range(len(target_names)):
  for j in range(5):
    if j != 3:
      print(f'{i} --> {j}')
      print(metrics_names_small[j])
      plt.figure(figsize=(9, 5))
      print(f'val_{metrics_names_small[j]}_{i}')
      plt.plot(epochs, metrics_esm[f'val_{metrics_names_small[j]}_{i}'], label=f'ESM2', color='red')
      if mechano:
        plt.plot(epochs, metrics_bert[f'val_{metrics_names_small[j]}_{i}'], label=f'BERT', color='blue')
        plt.plot(epochs, metrics_lstm_extended[f'val_{metrics_names_small[j]}_{i}'], label=f'BiLSTM', color='green')
      plt.title(f"Validation Scores for {target_names[i]} ")
      plt.xlabel("Epoch")
      plt.ylabel(metrics_names[j])
      plt.legend()
      plt.grid(True, alpha=0.5)
      plt.tight_layout()
      plt.show()
    else:
      continue

for i in range(5):
  if i!= 5:
    part = "_avg"
    if i == 3:
      part = ""
    plt.figure(figsize=(9, 5))
    plt.plot(epochs, metrics_esm[f'val_{metrics_names_small[i]}{part}'], label=f'ESM2', color='red')
    if mechano:
      plt.plot(epochs, metrics_bert[f'val_{metrics_names_small[i]}{part}'], label=f'BERT', color='blue')
      plt.plot(epochs, metrics_lstm_extended[f'val_{metrics_names_small[i]}{part}'], label=f'BiLSTM', color='green')
    plt.title(f"Average Validation Scores")
    plt.xlabel("Epoch")
    plt.ylabel(metrics_names[i])
    plt.legend()
    plt.grid(True, alpha=0.5)
    plt.tight_layout()
    plt.show()

***

# End of Notebook