In [1]:
import os
import sys
import os.path
from sys import platform
from pathlib import Path
import re
import sys
import time
import copy
import math
import scipy
import torch
import pickle
import random
import argparse
import subprocess
import sklearn
import numpy as np
import pandas as pd
from torch import nn
from torch.utils import data
from torch.nn.utils.weight_norm import weight_norm
from sklearn import datasets
from sklearn import preprocessing
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn import metrics
from sklearn import preprocessing
from sklearn.metrics import auc
from sklearn.metrics import r2_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.preprocessing import StandardScaler
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
import seaborn as sns
from scipy import stats
import matplotlib.pyplot as plt
from pathlib import Path
from copy import deepcopy
from ipywidgets import IntProgress
from datetime import datetime
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

In [2]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [3]:
# import sys
# sys.path.append('/content/gdrive/MyDrive/function_predictor/code/')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class StandardScaler:
    """A :class:`StandardScaler` normalizes the features of a dataset.

    When it is fit on a dataset, the :class:`StandardScaler` learns the mean and standard deviation across the 0th axis.
    When transforming a dataset, the :class:`StandardScaler` subtracts the means and divides by the standard deviations.
    """

    def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token: Any = None):
        """
        :param means: An optional 1D numpy array of precomputed means.
        :param stds: An optional 1D numpy array of precomputed standard deviations.
        :param replace_nan_token: A token to use to replace NaN entries in the features.
        """
        self.means = means
        self.stds = stds
        self.replace_nan_token = replace_nan_token

    def fit(self, X: List[List[Optional[float]]]) -> 'StandardScaler':
        """
        Learns means and standard deviations across the 0th axis of the data :code:`X`.

        :param X: A list of lists of floats (or None).
        :return: The fitted :class:`StandardScaler` (self).
        """
        X = np.array(X).astype(float)
        self.means = np.nanmean(X, axis=0)
        self.stds = np.nanstd(X, axis=0)
        self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means)
        self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds)
        self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds)

        return self

    def transform(self, X: List[List[Optional[float]]]) -> np.ndarray:
        """
        Transforms the data by subtracting the means and dividing by the standard deviations.

        :param X: A list of lists of floats (or None).
        :return: The transformed data with NaNs replaced by :code:`self.replace_nan_token`.
        """
        X = np.array(X).astype(float)
        transformed_with_nan = (X - self.means) / self.stds
        transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)

        return transformed_with_none.astype(np.float32)

    def inverse_transform(self, X: List[List[Optional[float]]]) -> np.ndarray:
        """
        Performs the inverse transformation by multiplying by the standard deviations and adding the means.

        :param X: A list of lists of floats.
        :return: The inverse transformed data with NaNs replaced by :code:`self.replace_nan_token`.
        """
        X = np.array(X).astype(float)
        transformed_with_nan = X * self.stds + self.means
        transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)

        return transformed_with_none.astype('float32')


def normalize_targets(y_data) -> StandardScaler:
    # For Future Use.
    """
    Normalizes the targets of the dataset using a :class:`~chemprop.data.StandardScaler`.

    The :class:`~chemprop.data.StandardScaler` subtracts the mean and divides by the standard deviation
    for each task independently.

    This should only be used for regression datasets.

    :return: A :class:`~chemprop.data.StandardScaler` fitted to the targets.
    """
    scaler = StandardScaler().fit(y_data)
    scaled_targets = scaler.transform(y_data).tolist()

    return scaled_targets, scaler

In [6]:
seeds = [0,1,2,42,1234]
seed=seeds[1]
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [7]:
root_path = '/content/gdrive/MyDrive/function_predictor/GB1-Dataset-FewToMore'

In [8]:
embedding_file = root_path+'/'+'emb_residue_level_embedding_GB1_embedding_ESM_2_650_embeddings_tensor.pt'
embeddings = torch.load(embedding_file)

In [9]:
splits = ['low_vs_high', 'one_vs_rest', 'two_vs_rest', 'three_vs_rest']
split_name = splits[-1]
fasta_file = root_path+'/'+split_name+'.fasta'

In [10]:
# [New] Parse FASTA file and create datasets
sequences = {}
with open(fasta_file, 'r') as file:
    for line in file:
        if line.startswith('>'):
            name, target, set_info, _ = line.strip().split(' ')
            name = name[1:]
            target = float(target.split('=')[1])
            set_type = set_info.split('=')[1]
            sequences[name] = {'target': target, 'set': set_type}

In [11]:
train_data = []
test_data = []
for name, info in sequences.items():
    embedding = embeddings[f'{name}'].numpy()
    if info['set'] == 'train':
        train_data.append((embedding, info['target']))
    elif info['set'] == 'test':
        test_data.append((embedding, info['target']))

In [12]:
from sklearn.preprocessing import StandardScaler as SklearnStandardScaler

class ATT_dataset(Dataset):
    def __init__(self, data, max_len, X_scaler=None):
        super().__init__()
        self.data = data
        self.max_len = max_len
        self.X_scaler = X_scaler

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        embedding, target = self.data[idx]
        seq_len = len(embedding)

        # Pad on-the-fly
        padded_embedding = np.zeros((self.max_len, embedding.shape[1]), dtype=np.float32)
        padded_embedding[:seq_len, :] = embedding

        # Generate mask
        mask = np.zeros(self.max_len, dtype=np.float32)
        mask[:seq_len] = 1

        return {'embedding': padded_embedding, 'mask': mask, 'target': np.array(target, dtype=np.float32)}

    def collate_fn(self, batch):
        embeddings = [item['embedding'] for item in batch]
        masks = [item['mask'] for item in batch]
        targets = [item['target'] for item in batch]

        # Convert to numpy arrays
        embeddings_array = np.stack(embeddings)
        masks_array = np.stack(masks)
        targets_array = np.stack(targets)

        # Reshape for scaling: (batch_size * max_len, emb_dim)
        flat_embeddings = embeddings_array.reshape(-1, embeddings_array.shape[-1])
        if self.X_scaler is not None:
            scaled_flat_embeddings = self.X_scaler.transform(flat_embeddings)
        else:
            scaled_flat_embeddings = flat_embeddings
        scaled_embeddings = scaled_flat_embeddings.reshape(embeddings_array.shape)

        # Convert to tensors
        embeddings_tensor = torch.from_numpy(scaled_embeddings)
        masks_tensor = torch.from_numpy(masks_array)
        targets_tensor = torch.from_numpy(targets_array)

        return {
            'embeddings': embeddings_tensor,
            'masks': masks_tensor,
            'targets': targets_tensor
        }

def generate_ATT_loader(train_data, test_data, max_len, batch_size, scaler: Optional[StandardScaler] = None):
    # Flatten all embeddings from train data for scaler fitting
    flat_train_embeddings = np.concatenate([item[0].reshape(-1, 1280) for item in train_data], axis=0)

    if scaler is None:
        scaler = SklearnStandardScaler()
    scaler.fit(flat_train_embeddings)

    train_dataset = ATT_dataset(train_data, max_len, scaler)
    test_dataset = ATT_dataset(test_data, max_len, scaler)

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=train_dataset.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=test_dataset.collate_fn)

    return train_loader, test_loader

In [13]:
# Example usage
max_len = 265  # Set your sequence max length here
scaler = StandardScaler()  # Initialize your custom scaler
train_loader, test_loader = generate_ATT_loader(train_data, test_data, max_len, batch_size=256, scaler=scaler)

In [14]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1,-2))
        scores.masked_fill_(attn_mask,-1e9)
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

#====================================================================================================#
class MultiHeadAttentionwithonekey(nn.Module):
    def __init__(self,d_model,d_k,n_heads,d_v,out_dim):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.out_dim = out_dim
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc_att = nn.Linear(n_heads * d_v, out_dim, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        Q = self.W_Q(input_Q).view(input_Q.size(0),-1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(input_K).view(input_Q.size(0),-1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(input_V).view(input_V.size(0),-1, self.n_heads, self.d_k).transpose(1, 2)
        #print(Q.size(), K.size())
        attn_mask = attn_mask.unsqueeze(1).repeat(1,self.n_heads,1,1)
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(input_Q.size(0), -1, self.n_heads * self.d_v)
        output = self.fc_att(context) # [batch_size, len_q, out_dim]
        return output, attn

#====================================================================================================#
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, 1)
            )
        #--------------------------------------------------#
        self.weights = nn.Parameter(torch.from_numpy(np.array([0.0,])), requires_grad = True)
        self.fc_1    = nn.Linear(d_model, d_ff)
        self.fc_2    = nn.Linear(d_ff, d_ff)
        self.fc_3    = nn.Linear(d_ff, 1)

    def forward(self, inputs, input_emb):
        '''
        inputs: [batch_size, src_len, out_dim]
        '''
        output = torch.flatten(inputs, start_dim = 1)
        #output += input_emb.mean(dim = 1)
        output = output*self.weights + input_emb.mean(dim = 1) * (1-self.weights)
        #print(self.weights)
        output = self.fc_1(output)
        output = nn.functional.relu(output)
        last_layer = self.fc_2(output)
        output = nn.functional.relu(last_layer)
        output = self.fc_3(output)

        return output, last_layer

#====================================================================================================#
class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, n_heads, d_v, out_dim, d_ff): #out_dim = 1, n_head = 4, d_k = 256
        super(EncoderLayer, self).__init__()
        self.emb_self_attn = MultiHeadAttentionwithonekey(d_model, d_k, n_heads, d_v, out_dim)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, input_emb, emb_self_attn_mask, input_mask):
        '''
        input_emb: [batch_size, src_len, d_model]
        emb_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # output_emb: [batch_size, src_len, 1], attn: [batch_size, n_heads, src_len, src_len]
        output_emb, attn = self.emb_self_attn(input_emb, input_emb, input_emb, emb_self_attn_mask) # input_emb to same Q,K,V
        batch_mask = input_mask.unsqueeze(2)
        output_emb = output_emb * batch_mask
        pos_weights = nn.Softmax(dim = 1)(output_emb.masked_fill_(input_mask.unsqueeze(2).data.eq(0), -1e9)).permute(0,2,1) # [ batch_size, 1, src_len]
        output_emb = torch.matmul(pos_weights, input_emb)
        output_emb, last_layer = self.pos_ffn(output_emb, input_emb) # output_emb: [batch_size, d_model]
        return output_emb, last_layer

#====================================================================================================#
# X05B
class SQembSAtt_Model(nn.Module):
    def __init__(self, d_model, d_k, n_heads, d_v, out_dim, d_ff):
        super(SQembSAtt_Model, self).__init__()
        self.layers = EncoderLayer(d_model, d_k, n_heads, d_v, out_dim, d_ff)

    def get_attn_pad_mask(self, seq_mask):
        batch_size, len_q = seq_mask.size()
        _, len_k = seq_mask.size()
        # eq(zero) is PAD token
        pad_attn_mask = seq_mask.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], True is masked
        return pad_attn_mask.expand(batch_size, len_q, len_k)

    def forward(self, input_emb, input_mask):
        '''
        input_emb  : [batch_size, src_len, embedding_dim]
        input_mask : [batch_size, src_len]
        '''
        emb_self_attn_mask = self.get_attn_pad_mask(input_mask) # [batch_size, src_len, src_len]
        # output_emb: [batch_size, src_len, out_dim], emb_self_attn: [batch_size, n_heads, src_len, src_len]
        output_emb, last_layer = self.layers(input_emb, emb_self_attn_mask, input_mask)
        return output_emb, last_layer

In [15]:
d_k = 256
n_heads = 4
d_v = 256
out_dim = 1
d_ff = 1280
dropout = 0.
epoch_num      = 100
batch_size     = 256
learning_rate  =  [0.01        , # 0
                   0.005       , # 1
                   0.002       , # 2
                   0.001       , # 3
                   0.0005      , # 4
                   0.0002      , # 5
                   0.0001      , # 6
                   0.00005     , # 7
                   0.00002     , # 8
                   0.00001     , # 9
                   0.000005    , # 10
                   0.000002    , # 11
                   0.000001    , # 12
                   ][5]

In [16]:
model = SQembSAtt_Model(
    d_model=1280,
    d_k=d_k,
    n_heads=n_heads,
    d_v=d_v,
    out_dim=out_dim,
    d_ff=d_ff
)

model.float()
model.cuda()
#--------------------------------------------------#
print("#"*50)
print(model)
#model.float()
#--------------------------------------------------#

##################################################
SQembSAtt_Model(
  (layers): EncoderLayer(
    (emb_self_attn): MultiHeadAttentionwithonekey(
      (W_Q): Linear(in_features=1280, out_features=1024, bias=False)
      (W_K): Linear(in_features=1280, out_features=1024, bias=False)
      (W_V): Linear(in_features=1280, out_features=1024, bias=False)
      (fc_att): Linear(in_features=1024, out_features=1, bias=False)
    )
    (pos_ffn): PoswiseFeedForwardNet(
      (fc): Sequential(
        (0): Linear(in_features=1280, out_features=1280, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1280, out_features=1, bias=True)
      )
      (fc_1): Linear(in_features=1280, out_features=1280, bias=True)
      (fc_2): Linear(in_features=1280, out_features=1280, bias=True)
      (fc_3): Linear(in_features=1280, out_features=1, bias=True)
    )
  )
)


In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Optional: Define a learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50], gamma=0.5)

# Loss function
criterion = nn.MSELoss()

In [18]:
model.to(device)

SQembSAtt_Model(
  (layers): EncoderLayer(
    (emb_self_attn): MultiHeadAttentionwithonekey(
      (W_Q): Linear(in_features=1280, out_features=1024, bias=False)
      (W_K): Linear(in_features=1280, out_features=1024, bias=False)
      (W_V): Linear(in_features=1280, out_features=1024, bias=False)
      (fc_att): Linear(in_features=1024, out_features=1, bias=False)
    )
    (pos_ffn): PoswiseFeedForwardNet(
      (fc): Sequential(
        (0): Linear(in_features=1280, out_features=1280, bias=True)
        (1): ReLU()
        (2): Linear(in_features=1280, out_features=1, bias=True)
      )
      (fc_1): Linear(in_features=1280, out_features=1280, bias=True)
      (fc_2): Linear(in_features=1280, out_features=1280, bias=True)
      (fc_3): Linear(in_features=1280, out_features=1, bias=True)
    )
  )
)

In [19]:
def save_plot(y_pred, y_real, epoch, folder_path, file_name, fig_size=(12, 10), marker_size=10, fit_line_color="skyblue", distn_color_1="lightgreen", distn_color_2="salmon"):
    sns.set(style="whitegrid")

    # Create a jointplot
    g = sns.jointplot(
        x=y_real,
        y=y_pred,
        kind="reg",
        height=fig_size[1] - 1,
        color=fit_line_color,
        scatter_kws={"s": marker_size},
        marginal_kws={'color': distn_color_1}
    )

    # Set axis labels
    g.ax_joint.set_xlabel("Actual Values", fontsize=12)
    g.ax_joint.set_ylabel("Predictions", fontsize=12)

    # Set the title with a bit more space at the top
    g.fig.suptitle(f"Predictions vs. Actual Values\n R = {np.round(stats.pearsonr(y_pred, y_real)[0], 3)} | Epoch: {epoch}", fontsize=14, y=1.05)

    # Plot the histograms
    sns.histplot(y_real, color=distn_color_1, alpha=0.6, ax=g.ax_marg_x, fill=True, kde=True)
    sns.histplot(y=y_pred, color=distn_color_2, alpha=0.6, ax=g.ax_marg_y, fill=True, kde=True)

    # Adjust the plot margins and layout
    plt.subplots_adjust(left=0.15, right=0.85, top=0.85, bottom=0.15)
    g.fig.tight_layout()

    # Save the plot
    plt.savefig(Path(folder_path) / file_name, bbox_inches='tight')
    plt.close(g.fig)

In [20]:
# Training
model_name = 'att'
save_path = f'{root_path}/{model_name}_results/split_name_{split_name}/seed_{seed}'
test_interval = 5
best_r_score = -1
training_losses = []  # To store training losses for plotting

# Variables to keep track of the best file names
best_model_file = ""
best_plot_file = ""
best_metrics_file = ""
for epoch in range(epoch_num):
    begin_time = time.time()

    # Training Phase
    model.train()
    total_loss = 0
    for one_seq_ppt_group in train_loader:
        embeddings, mask, target = one_seq_ppt_group['embeddings'], one_seq_ppt_group['masks'], one_seq_ppt_group['targets']
        embeddings, mask, target = embeddings.float().to(device), mask.float().to(device), target.float().to(device)

        optimizer.zero_grad()
        output, _ = model(embeddings, mask)
        loss = criterion(output, target.view(-1,1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    training_losses.append(avg_loss)

    # Test Phase
    if epoch % test_interval == 0 or epoch == epoch_num - 1:
        model.eval()
        y_pred_test = []
        y_real_test = []
        for one_seq_ppt_group in test_loader:
            embeddings, mask, target = one_seq_ppt_group['embeddings'], one_seq_ppt_group['masks'], one_seq_ppt_group['targets']
            embeddings, mask, target = embeddings.float().to(device), mask.float().to(device), target.float().to(device)

            output, _ = model(embeddings, mask)
            output = output.cpu().detach().numpy().reshape(-1)
            target = target.cpu().numpy()

            y_pred_test.extend(output)
            y_real_test.extend(target)

        _, _, r_value_test, _, _ = scipy.stats.linregress(y_pred_test, y_real_test)

        # Save best model
        if r_value_test >= best_r_score:
            best_r_score = r_value_test
            os.makedirs(save_path, exist_ok=True)
            if best_model_file:
                os.remove(best_model_file)
                os.remove(best_plot_file)
                os.remove(best_metrics_file)

            # Update file names
            best_model_file = f'{save_path}/best_model_epoch_{epoch}_seed_{seed}.pt'
            best_plot_file = f'{save_path}/test_plot_epoch_{epoch}.png'
            best_metrics_file = f'{save_path}/metrics_epoch_{epoch}.txt'
            torch.save(model.state_dict(), best_model_file)
            save_plot(y_pred_test, y_real_test, epoch, save_path, best_plot_file.split('/')[-1])
            with open(best_metrics_file, 'w') as f:
                f.write(f'Epoch: {epoch}\n')
                f.write(f'Test R-value: {r_value_test}\n')
                f.write(f'Test MAE: {mean_absolute_error(y_pred_test, y_real_test)}\n')
                f.write(f'Test MSE: {mean_squared_error(y_pred_test, y_real_test)}\n')
                f.write(f'Test RMSE: {np.sqrt(mean_squared_error(y_pred_test, y_real_test))}\n')
                f.write(f'Test R2: {r2_score(y_real_test, y_pred_test)}\n')
                f.write(f'Test Spearman rho: {scipy.stats.spearmanr(y_pred_test, y_real_test)[0]}\n')
                # Additional metrics can be added here
            print(f'Best R-value: {best_r_score:.4f}')

    print(f'Epoch: {epoch} | Training Loss: {avg_loss:.4f}')
    scheduler.step()

Best R-value: 0.3796
Epoch: 0 | Training Loss: 2.4545
Epoch: 1 | Training Loss: 1.4754
Epoch: 2 | Training Loss: 1.0019
Epoch: 3 | Training Loss: 0.8687
Epoch: 4 | Training Loss: 0.7971
Best R-value: 0.6693
Epoch: 5 | Training Loss: 0.7437
Epoch: 6 | Training Loss: 0.6939
Epoch: 7 | Training Loss: 0.6500
Epoch: 8 | Training Loss: 0.6067
Epoch: 9 | Training Loss: 0.5600
Best R-value: 0.7259
Epoch: 10 | Training Loss: 0.5157
Epoch: 11 | Training Loss: 0.4846
Epoch: 12 | Training Loss: 0.4446
Epoch: 13 | Training Loss: 0.4154
Epoch: 14 | Training Loss: 0.3966
Best R-value: 0.7534
Epoch: 15 | Training Loss: 0.3631
Epoch: 16 | Training Loss: 0.3434
Epoch: 17 | Training Loss: 0.3219
Epoch: 18 | Training Loss: 0.3096
Epoch: 19 | Training Loss: 0.3042
Best R-value: 0.7746
Epoch: 20 | Training Loss: 0.2774
Epoch: 21 | Training Loss: 0.2628
Epoch: 22 | Training Loss: 0.2590
Epoch: 23 | Training Loss: 0.2450
Epoch: 24 | Training Loss: 0.2307
Best R-value: 0.7899
Epoch: 25 | Training Loss: 0.2185


In [21]:
# Save the training plot
plt.figure()
plt.plot(training_losses)
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Over Epochs')
plt.savefig(f'{save_path}/training_loss_plot_seed_{seed}.png')
plt.close()

print(f'Training Completed. Best Model Saved with R-score: {best_r_score}')

Training Completed. Best Model Saved with R-score: 0.8283012306340282
