In [1]:
import os
import sys
import os.path
from sys import platform
from pathlib import Path
import re
import sys
import time
import copy
import math
import scipy
import torch
import pickle
import random
import argparse
import subprocess
import sklearn
import numpy as np
import pandas as pd
from torch import nn
from torch.utils import data
from torch.nn.utils.weight_norm import weight_norm
from sklearn import datasets
from sklearn import preprocessing
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import auc
from sklearn.metrics import r2_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
from pathlib import Path
from copy import deepcopy
from ipywidgets import IntProgress
from datetime import datetime
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [2]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [3]:
# import sys
# sys.path.append('/content/gdrive/MyDrive/function_predictor/code/')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class StandardScaler:
    """A :class:`StandardScaler` normalizes the features of a dataset.

    When it is fit on a dataset, the :class:`StandardScaler` learns the mean and standard deviation across the 0th axis.
    When transforming a dataset, the :class:`StandardScaler` subtracts the means and divides by the standard deviations.
    """

    def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token: Any = None):
        """
        :param means: An optional 1D numpy array of precomputed means.
        :param stds: An optional 1D numpy array of precomputed standard deviations.
        :param replace_nan_token: A token to use to replace NaN entries in the features.
        """
        self.means = means
        self.stds = stds
        self.replace_nan_token = replace_nan_token

    def fit(self, X: List[List[Optional[float]]]) -> 'StandardScaler':
        """
        Learns means and standard deviations across the 0th axis of the data :code:`X`.

        :param X: A list of lists of floats (or None).
        :return: The fitted :class:`StandardScaler` (self).
        """
        X = np.array(X).astype(float)
        self.means = np.nanmean(X, axis=0)
        self.stds = np.nanstd(X, axis=0)
        self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means)
        self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds)
        self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds)

        return self

    def transform(self, X: List[List[Optional[float]]]) -> np.ndarray:
        """
        Transforms the data by subtracting the means and dividing by the standard deviations.

        :param X: A list of lists of floats (or None).
        :return: The transformed data with NaNs replaced by :code:`self.replace_nan_token`.
        """
        X = np.array(X).astype(float)
        transformed_with_nan = (X - self.means) / self.stds
        transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)

        return transformed_with_none.astype(np.float32)

    def inverse_transform(self, X: List[List[Optional[float]]]) -> np.ndarray:
        """
        Performs the inverse transformation by multiplying by the standard deviations and adding the means.

        :param X: A list of lists of floats.
        :return: The inverse transformed data with NaNs replaced by :code:`self.replace_nan_token`.
        """
        X = np.array(X).astype(float)
        transformed_with_nan = X * self.stds + self.means
        transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)

        return transformed_with_none.astype('float32')


def normalize_targets(y_data) -> StandardScaler:
    # For Future Use.
    """
    Normalizes the targets of the dataset using a :class:`~chemprop.data.StandardScaler`.

    The :class:`~chemprop.data.StandardScaler` subtracts the mean and divides by the standard deviation
    for each task independently.

    This should only be used for regression datasets.

    :return: A :class:`~chemprop.data.StandardScaler` fitted to the targets.
    """
    scaler = StandardScaler().fit(y_data)
    scaled_targets = scaler.transform(y_data).tolist()

    return scaled_targets, scaler

In [6]:
seeds = [0,1,2,42,1234]
seed=seeds[0]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [7]:
root_path = '/content/gdrive/MyDrive/function_predictor/GB1-Dataset-FewToMore'

In [8]:
embedding_file = root_path+'/'+'emb_residue_level_embedding_GB1_embedding_ESM_2_650_embeddings_tensor.pt'
embeddings = torch.load(embedding_file)

In [9]:
splits = ['low_vs_high', 'one_vs_rest', 'two_vs_rest', 'three_vs_rest']
split_name = splits[-1]
fasta_file = root_path+'/'+split_name+'.fasta'

In [10]:
# [New] Parse FASTA file and create datasets
sequences = {}
with open(fasta_file, 'r') as file:
    for line in file:
        if line.startswith('>'):
            name, target, set_info, _ = line.strip().split(' ')
            name = name[1:]
            target = float(target.split('=')[1])
            set_type = set_info.split('=')[1]
            sequences[name] = {'target': target, 'set': set_type}

In [11]:
train_data = []
test_data = []
for name, info in sequences.items():
    embedding = embeddings[f'{name}'].numpy()
    if info['set'] == 'train':
        train_data.append((embedding, info['target']))
    elif info['set'] == 'test':
        test_data.append((embedding, info['target']))

In [12]:
from sklearn.preprocessing import StandardScaler as SklearnStandardScaler
from torch.nn.utils.rnn import pad_sequence

class LSTM_dataset(Dataset):
    def __init__(self, data, max_len, X_scaler=None):
        super().__init__()
        self.data = data
        self.max_len = max_len
        self.X_scaler = X_scaler

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        embedding, target = self.data[idx]
        seq_len = len(embedding)

        # Pad on-the-fly
        padded_embedding = np.zeros((self.max_len, embedding.shape[1]), dtype=np.float32)
        padded_embedding[:seq_len, :] = embedding

        return {'embedding': padded_embedding, 'seq_len': seq_len, 'target': np.array(target, dtype=np.float32)}

    def collate_fn(self, batch):
        embeddings = [item['embedding'] for item in batch]
        seq_lens = [item['seq_len'] for item in batch]
        targets = [item['target'] for item in batch]
        # Convert to numpy arrays
        embeddings_array = np.stack(embeddings)
        seq_lens_array = np.array(seq_lens, dtype=np.int64)
        targets_array = np.stack(targets)

        # Reshape for scaling: (batch_size * max_len, emb_dim)
        flat_embeddings = embeddings_array.reshape(-1, embeddings_array.shape[-1])
        if self.X_scaler is not None:
            scaled_flat_embeddings = self.X_scaler.transform(flat_embeddings)
        else:
            scaled_flat_embeddings = flat_embeddings
        scaled_embeddings = scaled_flat_embeddings.reshape(embeddings_array.shape)

        # Convert to tensors
        embeddings_tensor = torch.from_numpy(scaled_embeddings)
        seq_lens_tensor = torch.from_numpy(seq_lens_array)
        targets_tensor = torch.from_numpy(targets_array)

        return {
            'embeddings': embeddings_tensor,
            'seq_lens': seq_lens_tensor,
            'targets': targets_tensor
        }
def generate_LSTM_loader(train_data, test_data, max_len, batch_size, scaler: Optional[StandardScaler] = None):
    # Fit the scaler on all embeddings from the training data
    flat_train_embeddings = np.concatenate([item[0].reshape(-1, 1280) for item in train_data], axis=0)

    if scaler is None:
        scaler = SklearnStandardScaler()
    scaler.fit(flat_train_embeddings)

    train_dataset = LSTM_dataset(train_data, max_len, scaler)
    test_dataset = LSTM_dataset(test_data, max_len, scaler)

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=test_dataset.collate_fn)

    return train_loader, test_loader

In [13]:
y_train_normalized, target_scaler = normalize_targets([item[1] for item in train_data])
y_test_normalized = target_scaler.transform([item[1] for item in test_data])

# Update train and test data with normalized targets
train_data_normalized = [(item[0], y_train_normalized[i]) for i, item in enumerate(train_data)]
test_data_normalized = [(item[0], y_test_normalized[i]) for i, item in enumerate(test_data)]

In [14]:
# Example usage
max_len = 265  # Set your sequence max length here
scaler = StandardScaler()  # Initialize your custom scaler
# train_loader, test_loader = generate_LSTM_loader(train_data_normalized, test_data_normalized, max_len, 256, scaler)
train_loader, test_loader = generate_LSTM_loader(train_data, test_data, max_len, 256, scaler)

In [15]:
# for one_seq_ppt_group in train_loader:
#     seq_rep, seq_lens, target = one_seq_ppt_group["embeddings"], one_seq_ppt_group["seq_lens"], one_seq_ppt_group["targets"]
#     break

In [16]:
class LSTM(nn.Module):
    def __init__(self,
                 in_dim: int,
                 hid_dim: int,
                 latent_dim: int,
                 out_dim: int,
                 num_layers: int,
                 max_len: int,
                 last_hid: int,
                 dropout: float = 0.
                 ):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_layers = num_layers
        self.hid_dim = hid_dim
        #--------------------------------------------------#
        self.encoder_LSTM = nn.LSTM(in_dim, hid_dim, batch_first=True, num_layers=num_layers)
        self.mean = nn.Linear(in_features=hid_dim*num_layers, out_features=latent_dim)
        self.log_variance = nn.Linear(in_features=hid_dim*num_layers, out_features = latent_dim)
        self.dropout = nn.Dropout(dropout, inplace=True)
        #--------------------------------------------------#
        self.fc_1 = nn.Linear(int(self.latent_dim),last_hid)
        self.fc_2 = nn.Linear(last_hid,last_hid)
        self.fc_3 = nn.Linear(last_hid,1)
        self.cls = nn.Sigmoid()

    def encoder(self, x_inputs, padding_length, hidden_encoder):
        # Pad the packed input (already done when input into the NN)
        packed_output_encoder, hidden_encoder = self.encoder_LSTM(x_inputs, hidden_encoder)
        output_encoder, _ = nn.utils.rnn.pad_packed_sequence(packed_output_encoder, batch_first=True, total_length=padding_length)
        output_encoder.cuda()

        # Estimate the mean and the variance
        mean = self.mean(hidden_encoder[0]).cuda()
        log_var = self.log_variance(hidden_encoder[0])
        std = torch.exp(0.5*log_var).cuda()

        output_encoder = output_encoder.contiguous().cuda()

        # Generating unit Gaussian noise
        batch_size = output_encoder.shape[0]
        seq_len = output_encoder.shape[1]
        noise = torch.randn(batch_size, self.latent_dim).cuda()

        z = noise*std + mean
        #print("Z DIMENSION:", z.shape)
        return z, mean, log_var, hidden_encoder

    def initial_hidden_vars(self, batch_size):
        hidden_cell = torch.zeros(self.num_layers, batch_size, self.hid_dim).float().cuda()
        state_cell = torch.zeros(self.num_layers, batch_size, self.hid_dim).float().cuda()

        return (hidden_cell, state_cell)


    def forward(self, x, lengths, hidden_encoder):
        max_length = x.shape[1]
        #hidden_cell = torch.zeros(self.num_layers, x.shape[0], self.hid_dim)
        #state_cell = torch.zeros(self.num_layers, x.shape[0], self.hid_dim)
        lengths = lengths.cpu()
        x = nn.utils.rnn.pack_padded_sequence(input=x, lengths=lengths, batch_first=True, enforce_sorted=False)
        z, mean, log_var, hidden_encoder = self.encoder(x, max_length, hidden_encoder)
        #print("Z DIMENSION:", z.shape)
        #--------------------------------------------------#
        #output = nn.functional.relu(z)
        #output = self.dropout1(output)
        #--------------------------------------------------#
        output = self.fc_1(z)
        output = nn.functional.relu(output)
        output = self.fc_2(output)
        output = nn.functional.relu(output)
        output = self.fc_3(output)
        return output

In [17]:
hid_dim    = 256    # 256
latent_dim   = 1024      # 5
out_dim    = 1      # 2
num_layers   = 1      # 3
last_hid   = 1280  # 1024
dropout    = 0.0     # 0
seqs_max_len = 265
epoch_num      = 100
batch_size     = 256
learning_rate  =  [0.01        , # 0
                   0.005       , # 1
                   0.002       , # 2
                   0.001       , # 3
                   0.0005      , # 4
                   0.0002      , # 5
                   0.0001      , # 6
                   0.00005     , # 7
                   0.00002     , # 8
                   0.00001     , # 9
                   0.000005    , # 10
                   0.000002    , # 11
                   0.000001    , # 12
                   ][6]

In [18]:
model = LSTM(
            in_dim    =  1280         ,
            hid_dim   =  hid_dim      ,
            latent_dim  =  latent_dim ,
            out_dim   =  out_dim      ,
            num_layers  =  num_layers ,
            max_len   =  seqs_max_len ,
            last_hid  =  last_hid     ,
            dropout   =  dropout      ,
            )

model.float()
model.cuda()
#--------------------------------------------------#
print("#"*50)
print(model)
#model.float()
#print( summary( model,[(seqs_max_len, NN_input_dim),] )  )
#model.float()
print("#"*50)
#--------------------------------------------------#
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Optional: Define a learning rate scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50], gamma=0.5)

# Loss function
criterion = nn.MSELoss()

##################################################
LSTM(
  (encoder_LSTM): LSTM(1280, 256, batch_first=True)
  (mean): Linear(in_features=256, out_features=1024, bias=True)
  (log_variance): Linear(in_features=256, out_features=1024, bias=True)
  (dropout): Dropout(p=0.0, inplace=True)
  (fc_1): Linear(in_features=1024, out_features=1280, bias=True)
  (fc_2): Linear(in_features=1280, out_features=1280, bias=True)
  (fc_3): Linear(in_features=1280, out_features=1, bias=True)
  (cls): Sigmoid()
)
##################################################


In [19]:
model.to(device)

LSTM(
  (encoder_LSTM): LSTM(1280, 256, batch_first=True)
  (mean): Linear(in_features=256, out_features=1024, bias=True)
  (log_variance): Linear(in_features=256, out_features=1024, bias=True)
  (dropout): Dropout(p=0.0, inplace=True)
  (fc_1): Linear(in_features=1024, out_features=1280, bias=True)
  (fc_2): Linear(in_features=1280, out_features=1280, bias=True)
  (fc_3): Linear(in_features=1280, out_features=1, bias=True)
  (cls): Sigmoid()
)

In [20]:
def save_plot(y_pred, y_real, epoch, folder_path, file_name, fig_size=(12, 10), marker_size=10, fit_line_color="skyblue", distn_color_1="lightgreen", distn_color_2="salmon"):
    sns.set(style="whitegrid")

    # Create a jointplot
    g = sns.jointplot(
        x=y_real,
        y=y_pred,
        kind="reg",
        height=fig_size[1] - 1,
        color=fit_line_color,
        scatter_kws={"s": marker_size},
        marginal_kws={'color': distn_color_1}
    )

    # Set axis labels
    g.ax_joint.set_xlabel("Actual Values", fontsize=12)
    g.ax_joint.set_ylabel("Predictions", fontsize=12)

    # Set the title with a bit more space at the top
    g.fig.suptitle(f"Predictions vs. Actual Values\n R = {np.round(stats.pearsonr(y_pred, y_real)[0], 3)} | Epoch: {epoch}", fontsize=14, y=1.05)

    # Plot the histograms
    sns.histplot(y_real, color=distn_color_1, alpha=0.6, ax=g.ax_marg_x, fill=True, kde=True)
    sns.histplot(y=y_pred, color=distn_color_2, alpha=0.6, ax=g.ax_marg_y, fill=True, kde=True)

    # Adjust the plot margins and layout
    plt.subplots_adjust(left=0.15, right=0.85, top=0.85, bottom=0.15)
    g.fig.tight_layout()

    # Save the plot
    plt.savefig(Path(folder_path) / file_name, bbox_inches='tight')
    plt.close(g.fig)

In [21]:
# Training
model_name = 'lstm'
save_path = f'{root_path}/{model_name}_results/split_name_{split_name}/seed_{seed}'
test_interval = 5
best_r_score = -1
training_losses = []  # To store training losses for plotting

# Variables to keep track of the best file names
best_model_file = ""
best_plot_file = ""
best_metrics_file = ""
for epoch in range(epoch_num):
    begin_time = time.time()

    # Training Phase
    model.train()
    total_loss = 0
    for one_seq_ppt_group in train_loader:
        len_train_loader=batch_size
        seq_rep, seq_lens, target = one_seq_ppt_group["embeddings"], one_seq_ppt_group["seq_lens"], one_seq_ppt_group["targets"]
        seq_rep, seq_lens, target = seq_rep.float().to(device), seq_lens.to(device), target.float().to(device)
        states = model.initial_hidden_vars(len_train_loader)
        input_vars = [seq_rep, seq_lens, states]
        output = model(*input_vars)
        loss = criterion(torch.squeeze(output),torch.squeeze(target))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    training_losses.append(avg_loss)

    # Test Phase
    if epoch % test_interval == 0 or epoch == epoch_num - 1:
        model.eval()
        y_pred_test = []
        y_real_test = []
        for one_seq_ppt_group in test_loader:
            len_val_loader = batch_size
            seq_rep, seq_lens, target = one_seq_ppt_group["embeddings"], one_seq_ppt_group["seq_lens"], one_seq_ppt_group["targets"]
            seq_rep, seq_lens, target = seq_rep.float().to(device), seq_lens.to(device), target.float().to(device)
            states = model.initial_hidden_vars(len_val_loader)
            input_vars = [seq_rep, seq_lens, states]
            output = model(*input_vars)
            output = output.cpu().detach().numpy().reshape(-1)
            target = target.cpu().numpy()

            y_pred_test.extend(output)
            y_real_test.extend(target)

        _, _, r_value_test, _, _ = scipy.stats.linregress(y_pred_test, y_real_test)

        # Save best model
        if r_value_test >= best_r_score:
            best_r_score = r_value_test
            os.makedirs(save_path, exist_ok=True)
            if best_model_file:
                os.remove(best_model_file)
                os.remove(best_plot_file)
                os.remove(best_metrics_file)

            # Update file names
            best_model_file = f'{save_path}/best_model_epoch_{epoch}_seed_{seed}.pt'
            best_plot_file = f'{save_path}/test_plot_epoch_{epoch}.png'
            best_metrics_file = f'{save_path}/metrics_epoch_{epoch}.txt'
            torch.save(model.state_dict(), best_model_file)
            save_plot(y_pred_test, y_real_test, epoch, save_path, best_plot_file.split('/')[-1])
            with open(best_metrics_file, 'w') as f:
                f.write(f'Epoch: {epoch}\n')
                f.write(f'Test R-value: {r_value_test}\n')
                f.write(f'Test MAE: {mean_absolute_error(y_pred_test, y_real_test)}\n')
                f.write(f'Test MSE: {mean_squared_error(y_pred_test, y_real_test)}\n')
                f.write(f'Test RMSE: {np.sqrt(mean_squared_error(y_pred_test, y_real_test))}\n')
                f.write(f'Test R2: {r2_score(y_real_test, y_pred_test)}\n')
                f.write(f'Test Spearman rho: {scipy.stats.spearmanr(y_pred_test, y_real_test)[0]}\n')
                # Additional metrics can be added here
            print(f'Best R-value: {best_r_score:.4f}')

    print(f'Epoch: {epoch} | Training Loss: {avg_loss:.4f}')
    # scheduler.step()

Best R-value: 0.0000
Epoch: 0 | Training Loss: 1.5116
Epoch: 1 | Training Loss: 1.2417
Epoch: 2 | Training Loss: 1.2156
Epoch: 3 | Training Loss: 1.2018
Epoch: 4 | Training Loss: 1.2140
Best R-value: 0.0178
Epoch: 5 | Training Loss: 1.1876
Epoch: 6 | Training Loss: 1.2042
Epoch: 7 | Training Loss: 1.2072
Epoch: 8 | Training Loss: 1.2039
Epoch: 9 | Training Loss: 1.2056
Best R-value: 0.0651
Epoch: 10 | Training Loss: 1.2088
Epoch: 11 | Training Loss: 1.2054
Epoch: 12 | Training Loss: 1.1899
Epoch: 13 | Training Loss: 1.1866
Epoch: 14 | Training Loss: 1.1769
Best R-value: 0.1898
Epoch: 15 | Training Loss: 1.1839
Epoch: 16 | Training Loss: 1.1690
Epoch: 17 | Training Loss: 1.1454
Epoch: 18 | Training Loss: 1.1812
Epoch: 19 | Training Loss: 1.1406
Best R-value: 0.3191
Epoch: 20 | Training Loss: 1.1261
Epoch: 21 | Training Loss: 1.1463
Epoch: 22 | Training Loss: 1.1336
Epoch: 23 | Training Loss: 1.1137
Epoch: 24 | Training Loss: 1.1031
Epoch: 25 | Training Loss: 1.1348
Epoch: 26 | Training 

In [22]:
# Save the training plot
plt.figure()
plt.plot(training_losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Over Epochs')
plt.savefig(f'{save_path}/training_loss_plot_seed_{seed}.png')
plt.close()

print(f'Training Completed. Best Model Saved with R-score: {best_r_score}')

Training Completed. Best Model Saved with R-score: 0.5488440681193696
