# Finetune a multitask model on concatenated embeddings

This notebook aims at concatenating the ESM and PS embeddings and perform RCF

### Import and initialize

In [None]:
#Import stuff
import os
import re
import time
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sn
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from scipy.stats import spearmanr
from scipy.stats import pearsonr
from scipy.stats import linregress


### Define Paths

In [None]:
#Only use cpu
device = "cpu"

#Assign embedding folder
ESM_EMB_PATH = "../4_FineTuning//PCA_reduced/ESM_embeddings/"
PS_EMB_PATH = "../4_FineTuning//PCA_reduced/PS_embeddings/"

## Load experimental data

In [None]:
#Load experimental values
exp = pd.read_csv("../4_FineTuning/jain_full.csv", sep=";")
exp.dropna(inplace=True)

In [None]:
exp

In [None]:
#Normalize AC-SINS
norm_ac = [(data - min(exp["AC-SINS"])) / (max(exp["AC-SINS"]) - min(exp["AC-SINS"])) for data in exp["AC-SINS"]]
exp["norm_AC-SINS"] = norm_ac

#Normalize HIC
norm_hic = [(data - min(exp["HIC"])) / (max(exp["HIC"]) - min(exp["HIC"])) for data in exp["HIC"]]
exp["norm_HIC"] = norm_hic

#Add fake labels for testing
rng = np.random.default_rng(12345)
rand = rng.random(len(norm_ac))
exp["fake"] = rand

#Binary classifictaion
bc = [0 if val <= 5 else 1 for val in exp["AC-SINS"]]
exp["BC"] = bc

In [None]:
exp

In [None]:
del exp["Name.1"]
del exp["match"]
del exp["fake"]

In [None]:
#Get sequences from fasta file
fastas = {}
with open("../4_FineTuning//antibody_bulk.fsa", "r") as fasta:
    for line in fasta:
        if line.startswith(">"):
            header = line.strip()[1:]
        else:
            seq = line.strip()
            fastas[header] = seq

In [None]:
#Function that calculates amino acid distribution
def aa_dist(seq):
    counter = Counter(seq)
    aas = ["A","R","N","D","B","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","V"]
    dist = []
    for aa in aas:
        if aa in counter:
            dist.append(counter[aa]/len(seq))
        else:
            dist.append(0)
    return dist

In [None]:
# Collect sequences and aa distribution
seqs =  []
dists = []
for i, row in exp.iterrows():
    name = row["Name"]
    seq = fastas[name]
    seqs.append(seq)
    dists.append(aa_dist(fastas[name]))
    
#Add to dataframe
aadf = pd.DataFrame(dists, columns = ["A","R","N","D","B","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","V"])
exp["Sequence"] = seqs
exp = pd.concat([exp, aadf], axis=1)

In [None]:
#Make bc dict
label_dict = {}
for i, row in exp.iterrows():
    label_dict[row["Name"]] = row["BC"]
    
print(len(label_dict))

### Load ESM embeddings

In [None]:
#Load and format embeddings in a dict
ESM_embs_dict = dict()     
for file in os.listdir(ESM_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt"):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{ESM_EMB_PATH}/{file}')
        ESM_embs_dict[name] = tensor_in

### Load PS embeddings

In [None]:
#Load and format embeddings
PS_embs_dict = dict()
for file in os.listdir(PS_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt"):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{PS_EMB_PATH}/{file}')
        PS_embs_dict[name] = tensor_in

### Concatenate embeddings

In [None]:
#Concatenate PS and ESm embeddings
cat_embs_dict = dict()
count = 0

# Iterate through sequence embeddings
for key, value in ESM_embs_dict.items():
    count += 1
    print(f"Working with {count}/{len(ESM_embs_dict)}", end = "\r")
    
    #if structure embeddings exist - use it , else use zeros
    esm = value
    ps = PS_embs_dict[key]

    #Sanity check dimensions
    assert esm.shape == ps.shape
        
    #Concatenate the embeddings and add to dict
    Xs = torch.cat((esm,ps),1)
    cat_embs_dict[key] = Xs

print(f"Concatenated embeddings from {len(cat_embs_dict)} proteins")  

### Load sequences

In [None]:
#Get sequences from fasta file
fastas = {}
with open("../4_FineTuning//antibody_bulk.fsa", "r") as fasta:
    for line in fasta:
        if line.startswith(">"):
            header = line.strip()[1:]
        else:
            seq = line.strip()
            fastas[header] = seq

### Prepare data for model

In [None]:
#Function that calculates amino acid distribution
def aa_dist(seq):
    counter = Counter(seq)
    aas = ["A","R","N","D","B","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","V"]
    dist = []
    for aa in aas:
        if aa in counter:
            dist.append(counter[aa]/len(seq))
        else:
            dist.append(0)
    return dist

In [None]:
#Ensure dimensions/info in embeddings
embs_X = []
data_names = {}
data_nums = []
data_labels = []
count = 0
for key,embs in cat_embs_dict.items():
    count += 1
    data_names[count] = key
    data_nums.append(torch.tensor(count))
    
    #Also, add in the extra info
    template = [0] * len(embs)
    extra = aa_dist(fastas[key])
    extra_inf = extra + [len(fastas[key])]
    template = [extra_inf for x in template]
    extra_inf = torch.FloatTensor(template)
    
    #Get proper labels
    if key in list(exp["Name"]):
        data_labels.append(label_dict[key])
    
        #Append all
        embs_X.append(torch.cat((embs,extra_inf), 1).numpy())

In [None]:
#Pad the embeddings
padded_embs = []
for emb in embs_X:
    shape = np.shape(emb)
    padded_array = np.zeros((500,81))
    padded_array[:shape[0],:shape[1]] = emb
    flat_array = padded_array.flatten()
    padded_embs.append(flat_array)

### Dataset, data split and DataLoader

In [None]:
#Train test split
X_train, X_test, y_train, y_test = train_test_split(padded_embs, data_labels, test_size=0.3, random_state=42)
#X_val, X_test, y_val, y_test = train_test_split(X_other, y_other, test_size=0.5, random_state=42)

print(f"Train size: {len(X_train)}\nTest size: {len(X_test)}")
print(f"Train labels {Counter([label for label in y_train])}")
#print(f"Validation labels {Counter([label[0] for label in y_val])}")
print(f"Test labels {Counter([label for label in y_test])}")

### Train the RFC model

In [None]:
from sklearn.ensemble import RandomForestClassifier
AUC = []
MCC = []
for i in range(2000):
    print(i, end="\r")
    clf=RandomForestClassifier(n_estimators=100)
    clf.fit(X_train,y_train)
    y_pred=clf.predict(X_test)
    mcc_running = matthews_corrcoef(y_test, y_pred)
    auc_running = roc_auc_score(y_test, y_pred)
    AUC.append(auc_running)
    MCC.append(mcc_running)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as stats

#Make plot
plt.rcParams['figure.figsize'] = [35, 20]   
plt.rcParams['font.size']=30

densityx = stats.gaussian_kde(MCC)
n, x, _ = plt.hist(MCC, bins=np.linspace(-0.4, 0.8, 50), 
                   histtype=u'step', density=True, lw=5, label="MCC histogram", color = "blue")  
densityy = stats.gaussian_kde(AUC)
n, y, _ = plt.hist(AUC, bins=np.linspace(-0.4, 0.8, 50), 
                   histtype=u'step', density=True, lw = 5, label="AUC histogram", color = "red")  
plt.plot(x,densityx(x), lw = 5, label="MCC probability density function", c = "cornflowerblue")
plt.plot(y, densityy(y), lw = 5, label="AUC probability density function", c = "coral")
plt.grid()
plt.xticks(np.arange(-0.4, 0.81, step=0.1))
plt.title("Performance of 2000 Random Forest Classifiers ", fontsize = 40)
plt.legend(loc="upper left")
plt.savefig("RFCL_hist.png")

In [None]:
#Check if this outperforms LSTM model
above_MCC = [x for x in MCC if x>=0.4]
above_AUC = [x for x in AUC if x>=0.7]
print(f"MCC >= 0.4: {(len(above_MCC)/len(MCC))*100}%")
print(f"AUC >= 0.7: {(len(above_AUC)/len(AUC))*100}%")