# Train a multitask model on concatenated embeddings

This notebook aims at concatenating the ESM and PS embeddings and perform multitask learning in order to learn solubility patterns through multitask learning. 

### Import and initialize

In [None]:
#Import stuff
import os
import re
import time
import math
import pandas as pd
import numpy as np
import torch
import torch.autograd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sn
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score
import numpy as np
from sklearn.model_selection import GroupKFold
import pickle

#Set a nice figure size
plt.rcParams['figure.figsize'] = [20, 10]

### Define Paths

In [None]:
#Only use cpu
device = "cpu"

#Assign embedding folder
ESM_EMB_PATH = "./ESM_embeddings/"
PS_EMB_PATH = "./PS_embeddings/"

## Load labels from csv and clean up

In [None]:
#NESG Normalization function
def NESGNormalizeData(data):
    if data != 6:
        return (data - 0) / (5 - 0)
    else:
        return 6

In [None]:
#Load the data
org_df = pd.read_csv("../0_DataPreprocessing/CleanedData.csv", sep=",")


#Get list of emb IDs
emb_id = list(os.listdir(ESM_EMB_PATH))
emb_id = [idx.split(".")[0].split("_")[-1] for idx in emb_id]


#Drop rows without sequence embeddings (should be none)
df = org_df[org_df.ID.isin(emb_id)]
df = df.reset_index(drop=True)

#Replace NaN with 9 for later ignore index
df.NESG_label.fillna(6, inplace=True)
df.PSI_BIO_label.fillna(9, inplace=True)

#Make sure it is integers
df["NESG_label"] = df["NESG_label"].astype(int)
df["PSI_BIO_label"] = df["PSI_BIO_label"].astype(int)

#Reverse the Psi-Bio label
df["PSI_BIO_label"] = [2 if x==0 else x for x in df["PSI_BIO_label"]]
df["PSI_BIO_label"] = [0 if x==1 else x for x in df["PSI_BIO_label"]]
df["PSI_BIO_label"] = [1 if x==2 else x for x in df["PSI_BIO_label"]]

#Load labels 
NESG_label = list(df["NESG_label"])
psi_bio_label = list(df["PSI_BIO_label"])

#Normalize nesg label between 0 and 1
norm_NESG_label = [ NESGNormalizeData(x) for x in NESG_label]

df["norm_NESG_label"] = norm_NESG_label

In [None]:
df

## Load cluster annotation

In [None]:
clusters = pd.read_csv("./DB_clu_50_id.tsv", sep="\t",  header=None)
clusters= clusters.rename(columns={0: 'rep', 1 :'id'})

In [None]:
clusters

In [None]:
#Make a cluster dictionary
cluster_temp_dict = {}
cluster_dict = {}
count = 0
for i, row in clusters.iterrows():
    if row["rep"] in cluster_temp_dict:
        cluster_dict[row["id"]] = cluster_temp_dict[row["rep"]]
    else:
        cluster_temp_dict[row["rep"]] = count 
        count += 1
        cluster_dict[row["id"]] = cluster_temp_dict[row["rep"]]
        
print(f"Total amount of clusters: {count}")
        

In [None]:
#Append cluster info to df
clusters = []
for i, row in df.iterrows():
    name = row["ID"]
    clusters.append(cluster_dict[name])
df["cluster"] = clusters

In [None]:
df

## ESM Data load 

In [None]:
#Load and format embeddings in a dict
ESM_embs_dict = dict()     
for file in os.listdir(ESM_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt") and name in list(df["ID"]):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{ESM_EMB_PATH}/{file}')
        ESM_embs_dict[name] = tensor_in

In [None]:
#Sanity check
assert len(ESM_embs_dict) == len(os.listdir(ESM_EMB_PATH))

## Proteinsolver Data load and preprocessing

In [None]:
#Load and format embeddings
PS_embs_dict = dict()
for file in os.listdir(PS_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt") and name in list(df["ID"]):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{PS_EMB_PATH}/{file}')
        PS_embs_dict[name] = tensor_in

In [None]:
#Sanity check
assert len(PS_embs_dict) == len(os.listdir(PS_EMB_PATH))

## Concatenate embeddings

In [None]:
#Concatenate PS and ESm embeddings
cat_embs_dict = dict()

# Iterate through sequence embeddings
for key, value in ESM_embs_dict.items():
    
    #print(f"Working with {count}/{len(ESM_embs_dict)}", end = "\r")
    
    #if structure embeddings exist - use it , else use zeros
    esm = value
    ps = PS_embs_dict[key]

    #Sanity check dimensions
    assert esm.shape == ps.shape
        
    #Concatenate the embeddings and add to dict
    Xs = torch.cat((esm,ps),1)
    cat_embs_dict[key] = Xs

print(f"Concatenated embeddings from {len(cat_embs_dict)} proteins")  

## Check distribution of data labels

In [None]:
#NESG labels
nesg_counts = Counter(NESG_label)
df_label1 = pd.DataFrame.from_dict(nesg_counts, orient='index')
plt1 = df_label1.plot(kind='bar')
plt1.legend(["NESG"])

#Psi_bio labels
psi_bio_counts = Counter(psi_bio_label)
df_label2 = pd.DataFrame.from_dict(psi_bio_counts, orient='index')
plt2 = df_label2.plot(kind='bar')
plt2.legend(["Psi-Bio"])

In [None]:
#Function that calculates amino acid distribution
def aa_dist(seq):
    counter = Counter(seq)
    aas = ["A","R","N","D","B","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","V"]
    dist = []
    for aa in aas:
        if aa in counter:
            dist.append(counter[aa]/len(seq))
        else:
            dist.append(0)
    return dist

In [None]:
#Ensure proper label/emb/clust order
data_labels = []
embs_X = []
clusters = []
count = 0
total = len(cat_embs_dict)
for key,embs in cat_embs_dict.items():
    count += 1
    print(f"working with {count}/{total}", end = "\r")
    row_num = df.loc[df['ID'] == key]
    label = [row_num.norm_NESG_label.item(),int(row_num.PSI_BIO_label)]
    
    #Also, add in the extra info
    template = [0] * len(embs)
    extra = aa_dist(row_num.sequence.item())
    extra_inf = extra + [len(row_num.sequence.item())]
    template = [extra_inf for x in template]
    extra_inf = torch.FloatTensor(template)
    
    #Append all in correct order
    clusters.append(row_num.cluster)
    data_labels.append(label)
    embs_X.append(torch.cat((embs,extra_inf), 1))

### Create DataSet and DataLoader functions

In [None]:
#Create dataset function
class ProteinDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.y = Y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        
        return (self.X[idx], torch.tensor(self.y[idx]))

In [None]:
#Create collate function for padding sequences
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    xx_pad = pad_sequence(xx, batch_first=True, padding_value=0) 
    return xx_pad, yy

# Make functions for easier model assesment

In [None]:
# Make model for saving models
def save_model(filepath, n_epochs, model_conv):
    #Save the trained model in various ways to ensure no loss of model
    
    #Create the folder
    isExist = os.path.exists(filepath)
    if not isExist:
        os.makedirs(filepath)

    ### METHOD 1 ###
    torch.save(model_conv.state_dict(), filepath+"/model_conv.state_dict")

    #Later to restore:
    #model.load_state_dict(torch.load(filepath))
    #model.eval()

    ### METHOD 2 ###
    state = {
        'epoch': n_epochs,
        'state_dict': model_conv.state_dict(),
        'optimizer': optimizer.state_dict(),
    }

    torch.save(state, filepath+"/model_conv.state")

    #Later to restore:
    #model.load_state_dict(state['state_dict'])
    #optimizer.load_state_dict(state['optimizer'])


    ### METHOD 3 ###
    torch.save(model_conv, filepath+"/model_conv.full")

    #Later to restore:
    #model = torch.load(filepath)
    


## Define model

In [None]:
#Hyper parameters
input_size = 60
hidden_size = 64
num_layers = 3
num_classes_nesg = 6 #7
num_classes_psibio = 2 #3
batch_size = 96 # Is defined in the data loader
n_epochs = 8 #51
lr = 0.001  #0.01
dropout = 0.4
weight_decay = 1e-6

In [None]:
#Define Bi_LSTM model

class Bi_LSTM(nn.Module) :
    def __init__(self, input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout) :
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes_nesg = num_classes_nesg
        self.num_classes_psibio = num_classes_psibio
        self.dropout = dropout
            
        #Initialize the LSTM layer 
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional = True, batch_first=True, dropout = dropout)
        
        #Initialize ReLU layer
        self.relu = nn.ReLU()
        
        #Initilize the linear layers for nesg labels 
        self.linear1 = nn.Linear((hidden_size * 2)+21, 1)
        
        #Initilize the linear layers for psibio labels
        self.linear2 = nn.Linear((hidden_size * 2)+21, num_classes_psibio)
        
        #Initialize softmax activation function 
        #self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        
        #Initialize the hidden state with random numbers
        self.hidden = (torch.randn(1, 1, self.hidden_size),torch.randn(1, 1, self.hidden_size))
        
        
    def forward(self, x):
        #Split embeddings and extra info for last dense layer
        embs, extra = torch.split(x, [60,21], dim=2)
        extra = torch.squeeze(extra)
        extra = extra.mean(1)
        #print(f"extra shape: {extra.shape}")
        #extra shape: torch.Size([128, 21])
        
        #batch normalize data
        self.bnorm = nn.BatchNorm1d(num_features=embs.shape[1])
        norm_data = self.bnorm(embs)
        
        #Initialize the hidden states and cell states
        h_0 = torch.autograd.Variable(torch.zeros(self.num_layers*2, norm_data.size(0), self.hidden_size)) #hidden state
        c_0 = torch.autograd.Variable(torch.zeros(self.num_layers*2, norm_data.size(0), self.hidden_size)) #internal state
    
        #forward through the lstm layer
        #print(f"initial shape: {norm_data.shape}")
        #initial shape: torch.Size([128, 804, 60])
        lstm_out,(ht, ct) = self.lstm(norm_data,(h_0, c_0))
        
        
        #concatenate states from both directions
        lstm_ht = torch.cat([ht[-1,:, :], ht[-2,:,:]], dim=1)
        #print(f"after lstm shape: {lstm_ht.shape}")
        #after lstm shape: torch.Size([128, 128])
        
        #Add the extra information before going through last dense layers
        collect = torch.cat((lstm_ht, extra), dim=1)
        
        #forward through relu layer
        #print(f"with_collection shape: {collect.shape}")
        #with_collection shape: torch.Size([128, 149])
        relu_nesg = self.relu(collect)
        relu_psibio = self.relu(collect)
        
        #forward through linear layer 1
        nesg_linear = self.linear1(relu_nesg)
        psibio_linear = self.linear2(relu_psibio)
        
        #Add sigmoid activation function
        sigmoid_nesg = self.sigmoid(nesg_linear)
        
        #Define output
        out1 = sigmoid_nesg
        out2 = psibio_linear

        return [out1, out2]

#Define the model, optimizer and loss function (removed  weight_decay = weight_decay) (removed weight=class_weights_nesg,)
model = Bi_LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout)
model.to(device)
loss_nesg = nn.MSELoss(reduction='none')
loss_psibio = nn.CrossEntropyLoss(ignore_index=9, reduction = "mean")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

In [None]:
model.eval()

In [None]:
#Train model function
def train_model(model, loss_func1, loss_func2, optimizer, n_epochs):
    """Return trained model"""
    
    #Train network
    for epoch in range(1,n_epochs+1):
        
        #Iterate through batches
        for i, (embs, labels) in enumerate(train_loader):
            
            #reset optimizer
            optimizer.zero_grad()
            
            #Print sceen output
            str_epoch = format(epoch, '03d')
            str_batch = format(i+1, '03d')
            print(f"Epoch: {str_epoch}, batch: {str_batch}", end="\r")     

            #Format labels 
            nesg_labels = torch.tensor([label[0] for label in labels])
            psibio_labels = torch.tensor([int(label[1]) for label in labels], dtype = torch.long)           
            
            #Predict labels (forward)
            y_pred = model(embs)
            y_pred1 = torch.squeeze(y_pred[0])
            y_pred2 = y_pred[1]   
            
            #Make a mask vector
            multiply = torch.tensor([0 if x == 6 else 1 for x in nesg_labels])

            #Calculate MSE loss using masking
            loss1 = loss_func1(y_pred1, nesg_labels)
            non_zero_elements = multiply.sum()
            masked_loss = (loss1*multiply).sum()/non_zero_elements 
            
            #Calculate Cross Entropy loss
            loss2 = loss_func2(y_pred2, psibio_labels)
            
            #Combine loss (backward)
            combined_loss = masked_loss + loss2*0.25
            combined_loss.backward()
            
            #optimize
            optimizer.step()
        
        
        #Save model for each epoch
        filepath = f"./model/model_{epoch}"
        save_model(filepath, epoch, model)
        
        
    #Return model, loss values and MCC
    return model
            

In [None]:
# Prepare data loading and model 
train = ProteinDataset(embs_X,data_labels)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, collate_fn=pad_collate, shuffle = True)
model = Bi_LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout)
loss_nesg = nn.MSELoss(reduction='none')
loss_psibio = nn.CrossEntropyLoss(ignore_index=9, reduction = "mean")
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

#Train the model
model = train_model(model,loss_nesg, loss_psibio, optimizer, n_epochs)
