# Finetune a multitask model on concatenated embeddings

This notebook aims at concatenating the ESM and PS embeddings and perform multitask learning in order to learn solubility patterns through multitask learning. 

### Import and initialize

In [None]:
#Import stuff
import os
import re
import sys
import time
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from pathlib import Path
import seaborn as sn
import sklearn
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import MultiLabelBinarizer
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from scipy.stats import spearmanr
from scipy.stats import pearsonr
from scipy.stats import linregress
from sklearn.model_selection import GroupKFold
import pickle


### Define Paths

In [None]:
#Only use cpu
device = "cpu"

#Assign embedding folder
ESM_EMB_PATH = "./PCA_reduced/ESM_embeddings/"
PS_EMB_PATH = "./PCA_reduced/PS_embeddings/"

## Load experimental data

In [None]:
#Load experimental values
exp = pd.read_csv("jain_full.csv", sep=";")
exp.dropna(inplace=True)

In [None]:
exp

## Load clusters

In [None]:
#min seq ID for clustering: 0.65
clusters = pd.read_csv("./ABDB_clu_80.tsv", sep="\t",  header=None)
clusters= clusters.rename(columns={0: 'rep', 1 :'id'})

In [None]:
clusters

In [None]:
#Make a cluster dictionary
cluster_temp_dict = {}
cluster_dict = {}
count = 0
for i, row in clusters.iterrows():
    if row["rep"] in cluster_temp_dict:
        cluster_dict[row["id"]] = cluster_temp_dict[row["rep"]]
    else:
        cluster_temp_dict[row["rep"]] = count 
        count += 1
        cluster_dict[row["id"]] = cluster_temp_dict[row["rep"]]
        
print(f"Total amount of clusters: {count}")

In [None]:
#Append cluster info to df
clusters = []
for i, row in exp.iterrows():
    name = row["Name"]
    clusters.append(cluster_dict[name])
exp["cluster"] = clusters

In [None]:
exp

## Check label distribution

In [None]:
#Plot histogram
plt.rcParams['figure.figsize'] = [20, 10]
_ = plt.hist(exp["AC-SINS"], bins = 40)
_ = plt.title("norm AC-SINS values distribution", fontsize=20)
_ = plt.xticks(np.arange(-1, 31, step=1))
plt.show()

In [None]:
exp

In [None]:
#Normalize AC-SINS
norm_ac = [(data - min(exp["AC-SINS"])) / (max(exp["AC-SINS"]) - min(exp["AC-SINS"])) for data in exp["AC-SINS"]]
exp["norm_AC-SINS"] = norm_ac

#Normalize HIC
norm_hic = [(data - min(exp["HIC"])) / (max(exp["HIC"]) - min(exp["HIC"])) for data in exp["HIC"]]
exp["norm_HIC"] = norm_hic

#Add fake labels for testing
rng = np.random.default_rng(12345)
rand = rng.random(len(norm_ac))
exp["fake"] = rand

#Binary classifictaion
bc = [0 if val <= 5 else 1 for val in exp["AC-SINS"]]
exp["BC"] = bc

In [None]:
exp

In [None]:
#Make label dict
label_dict = {}
for i, row in exp.iterrows():
    label_dict[row["Name"]] = [row["BC"], row["Name"]]
    
print(len(label_dict))

### Load ESM embeddings

In [None]:
#Load and format embeddings in a dict
ESM_embs_dict = dict()     
for file in os.listdir(ESM_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt"):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{ESM_EMB_PATH}/{file}')
        ESM_embs_dict[name] = tensor_in

### Load PS embeddings

In [None]:
#Load and format embeddings
PS_embs_dict = dict()
for file in os.listdir(PS_EMB_PATH):
    name = file.split(".")[0].split("_")[-1]
    if file.endswith(".pt"):
        print (f"working with file: {file}", end="\r")
        tensor_in = torch.load(f'{PS_EMB_PATH}/{file}')
        PS_embs_dict[name] = tensor_in

### Concatenate embeddings

In [None]:
#Concatenate PS and ESm embeddings
cat_embs_dict = dict()
count = 0

# Iterate through sequence embeddings
for key, value in ESM_embs_dict.items():
    count += 1
    print(f"Working with {count}/{len(ESM_embs_dict)}", end = "\r")
    
    #if structure embeddings exist - use it , else use zeros
    esm = value
    ps = PS_embs_dict[key]

    #Sanity check dimensions
    assert esm.shape == ps.shape
        
    #Concatenate the embeddings and add to dict
    Xs = torch.cat((esm,ps),1)
    cat_embs_dict[key] = Xs

print(f"Concatenated embeddings from {len(cat_embs_dict)} proteins")  

### Load sequences

In [None]:
#Get sequences from fasta file
fastas = {}
with open("./antibody_bulk.fsa", "r") as fasta:
    for line in fasta:
        if line.startswith(">"):
            header = line.strip()[1:]
        else:
            seq = line.strip()
            fastas[header] = seq

### Prepare data for model

In [None]:
#Function that calculates amino acid distribution
def aa_dist(seq):
    counter = Counter(seq)
    aas = ["A","R","N","D","B","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","V"]
    dist = []
    for aa in aas:
        if aa in counter:
            dist.append(counter[aa]/len(seq))
        else:
            dist.append(0)
    return dist

In [None]:
#Ensure dimensions/info in embeddings
embs_X = []
data_names = []
clusters = []
data_labels = []

for key,embs in cat_embs_dict.items():
    row_num = exp.loc[exp['Name'] == key]

    #Also, add in the extra info
    template = [0] * len(embs)
    extra = aa_dist(fastas[key])
    extra_inf = extra + [len(fastas[key])]
    template = [extra_inf for x in template]
    extra_inf = torch.FloatTensor(template)
    
    #Get proper labels
    if key in list(exp["Name"]):
        data_labels.append(label_dict[key][0])
        data_names.append(label_dict[key][1])
    
        #Append all
        embs_X.append(torch.cat((embs,extra_inf), 1))
        clusters.append(row_num.cluster)

In [None]:
print(embs_X[50].shape)
print(len(embs_X))
print(len(data_labels))
print(len(clusters))

### Dataset, data split and DataLoader

In [None]:
# Define my own group-K-fold splitter
def groupkfold(data_X, data_y, data_cluster, n=10):
    """Split the data for cross fold validation"""
    
    #initialize
    folds = []
    partitions = {}
    total_size = len(data_cluster)
    part_size = total_size/(n+1)
    cluster_dict = {}
    
    #unique clusters
    u_clust = set(data_cluster)
    count = 0
    for cluster in u_clust:
        count += 1
        temp = []
        for i, x in enumerate(data_cluster):
            if x == cluster:
                temp.append(i)
        cluster_dict[count] = temp
        
    #Sort clusters by size 
    s_clust = sorted(cluster_dict.items(), key=lambda x: len(x[1]),reverse=True)
    
    #Split into partitions
    counter = 0
    skips = []
    for cl, part in s_clust:

        counter += 1
        if counter%(n+2) == 0:
            counter = 1
            
        while counter in skips:
            counter += 1
        
        if counter in partitions:
            if counter not in skips:
                partitions[counter] += part
                
                if len(partitions[counter]) > part_size:
                    skips.append(counter)
                
            else:
                print("Something went wrong")
                sys.exit(1)
        else:
            partitions[counter] = part

        
    #Double check that all is good
    tester_size = 0
    tester_sizes = []
    for p, part in partitions.items():
        tester_size += len(part)
        tester_sizes.append(len(part))
    if tester_size != total_size:
        print("not all data is included. Aborting.")
        print(tester_size, total_size)
        sys.exit(1)
    elif (max(tester_sizes)-min(tester_sizes)) > 10:
        print("Partitions does not match")
        print(f"Max: {max(tester_sizes)}, Min: {min(tester_sizes)}")
        sys.exit(1)
    elif len(partitions) != (n+1):
        print("The true length and needed lenght are not identical")
        sys.exit(1)
    
    #Make the folds
    ps = []
    partitions = sorted(partitions.items(), key=lambda x: x[0])
    
    #First get the test partition
    test_idx = partitions[round(n/2)][1]
    del partitions[(round(n/2))]
    for p, part in enumerate(partitions):
        ps.append(p)
        val_idx = partitions[p][1]
        train_idx = [value[1] for i,value in enumerate(partitions) if i not in [p]]
        train_idx = [x for sublist in train_idx for x in sublist]
        folds.append([train_idx,val_idx,test_idx])
        
    return folds
            

In [None]:
#Create dataset function
class ProteinDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.y = Y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        
        return (self.X[idx], torch.tensor(self.y[idx]))

In [None]:
#Create collate function for padding sequences
def pad_collate(batch):
    (xx, yy) = zip(*batch)
    xx_pad = pad_sequence(xx, batch_first=True, padding_value=0) 
    return xx_pad, yy

In [None]:
# Make model for saving models
def save_model(filepath, epoch, model, train_loss_values, train_r, train_p, train_AUC, train_MCC, train_labels,train_pred,
               val_loss_values, val_r, val_p, val_AUC, val_MCC, val_labels,val_pred,
               test_loss_values, test_r, test_p, test_AUC, test_MCC, test_labels,test_pred):
    
    #Save the trained model in various ways to ensure no loss of model
    
    #Create the folder
    isExist = os.path.exists(filepath)
    if not isExist:
        os.makedirs(filepath)

    ### METHOD 1 ###
    torch.save(model.state_dict(), filepath+"/model_conv.state_dict")

    #Later to restore:
    #model.load_state_dict(torch.load(filepath))
    #model.eval()

    ### METHOD 2 ###
    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'train_loss' : train_loss_values,
        'val_loss' : val_loss_values,
        'test_loss' : test_loss_values,
        'train_r': train_r,
        'train_p': train_p,
        'val_r': val_r,
        'val_p': val_p,
        'test_r': test_r,
        'test_p': test_p,
        'train_AUC': train_AUC,
        'val_AUC': val_AUC,
        'test_AUC': test_AUC,
        'train_MCC': train_MCC,
        'val_MCC': val_MCC,
        'test_MCC': test_MCC,
        'train_labels':train_labels,
        'val_labels': val_labels,
        'test_labels':test_labels,
        'train_pred': train_pred, 
        'val_pred': val_pred,
        'test_pred': test_pred,
    }

    torch.save(state, filepath+"/model_conv.state")

    #Later to restore:
    #model.load_state_dict(state['state_dict'])
    #optimizer.load_state_dict(state['optimizer'])


    ### METHOD 3 ###
    torch.save(model, filepath+"/model_conv.full")

    #Later to restore:
    #model = torch.load(filepath)
    
# Make model for saving models

## START FROM HERE

In [None]:
#splits for testing
test_embs = embs_X
test_label = data_labels
test_clusters = [intg.tolist()[0] for intg in clusters]

folds = groupkfold(test_embs, test_label, test_clusters, n=10)
foldperf = {}

In [None]:
# Validate/test model
def test_model(model, optimizer, data_loader, loss_fn):
    "Run model in evaluation mode on a dataset"
    val_running_loss = 0.0
    val_pred = []
    val_labels = []
    model.eval()
    with torch.no_grad():
        for i, (embs,labels) in enumerate(data_loader):
            labels = torch.tensor(labels)
            labels = labels.float()
            optimizer.zero_grad()
            y_pred = model(embs)
            y_pred = y_pred.squeeze()
            loss = loss_fn(y_pred, labels)
            val_running_loss += loss.item() * embs.size(0)
            acsins_pred = [pred.item() for pred in y_pred]
            val_pred.append(acsins_pred)
            val_labels.append(labels)
    model.train() 
    return val_pred, val_labels, val_running_loss

### Define model

In [None]:
#Hyper parameters
input_size = 60
hidden_size = 64
num_layers = 3
num_classes_nesg = 6 #7
num_classes_psibio = 2 #3
batch_size = 17
n_epochs = 100 
lr = 0.0001 #0.0001
dropout = 0.4
weight_decay = 1e-6

In [None]:
#Define Bi_LSTM model
class Bi_LSTM(nn.Module) :
    def __init__(self, input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout) :
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes_nesg = num_classes_nesg
        self.num_classes_psibio = num_classes_psibio
        self.dropout = dropout
            
        #Initialize the LSTM layer 
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional = True, batch_first=True)
        
        #Initialize ReLU layer
        self.relu = nn.ReLU()
        
        ## FOR FINETUNING; SKIP THE LINEAR LAYER (no need to lower complexity yet) ##
        
        #Initilize the linear layers for nesg labels 
        self.linear1 = nn.Linear((hidden_size * 2)+21, 1)
        
        #Initilize the linear layers for psibio labels
        self.linear2 = nn.Linear((hidden_size * 2)+21, num_classes_psibio)
        
        #Initialize sigmoid activation function 
        self.sigmoid = nn.Sigmoid()
        
        ## THIS IS THE NEW LAST LINEAR LAYERS ###
        #Initialize last layer
        self.last1 = nn.Linear(((hidden_size * 2)+21), 32)
        self.last2 = nn.Linear(32,1)
        self.last_activation = nn.Sigmoid()
        
        
    def forward(self, x):
        #Split embeddings and extra info for last dense layer
        embs, extra = torch.split(x, [60,21], dim=2)
        extra = torch.squeeze(extra)
        extra = extra.mean(1)
        #print(f"extra shape: {extra.shape}")
        #extra shape: torch.Size([128, 21])
        
        #batch normalize data
        self.bnorm = nn.BatchNorm1d(num_features=embs.shape[1])
        norm_data = self.bnorm(embs)
        
        #forward through the lstm layer
        #print(f"initial shape: {norm_data.shape}")
        #initial shape: torch.Size([128, 804, 60])
        lstm_out,(ht, ct) = self.lstm(norm_data)
        
        
        #concatenate states from both directions
        lstm_ht = torch.cat([ht[-1,:, :], ht[-2,:,:]], dim=1)
        #print(f"after lstm shape: {lstm_ht.shape}")
        #after lstm shape: torch.Size([128, 128])
        
        #Add the extra information before going through last dense layers
        collect = torch.cat((lstm_ht, extra), dim=1)

        #forward through relu layer
        #print(f"with_collection shape: {collect.shape}")
        #with_collection shape: torch.Size([128, 149])
        relu_nesg = self.relu(collect)
        relu_psibio = self.relu(collect)
        
        #Add last layer 
        #print(f"with_collection shape: {collect.shape}")
        last_out1 = self.last1(collect)
        last_out2 = self.last2(last_out1)
        out = self.last_activation(last_out2)


        return out

#Define the model, optimizer and loss function (removed  weight_decay = weight_decay) (removed weight=class_weights_nesg,)
model = Bi_LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout)

In [None]:
#Load model
state = torch.load("../3_PreTraining/model/model_8/model_conv.state")
model.load_state_dict(torch.load("../3_PreTraining/model/model_8/model_conv.state_dict"), strict=False)
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

## Manually check out that the model is ready to learn

In [None]:
print(state["state_dict"]['lstm.weight_hh_l0'])

In [None]:
full = 0
catch = 0
for param in model.parameters():
    full += 1
    if param.requires_grad:
        catch += 1
print(f"{catch}/{full} parameters have gradients")

In [None]:
#Train model function
def train_model(model, loss_func1, optimizer, n_epochs):
    """Return trained model"""
    
    #Early stopping 
    patience = 10 #pre-training has a patience of 3
    triggers = 0
    total_triggers = 0
    min_val_loss = 1000
    
    #Keep track of loss
    train_loss_values = []
    val_loss_values = []
    test_loss_values = []
    
    #Keep track of AUC
    train_AUC = []
    val_AUC = []
    test_AUC = []
    
    #Keep track of MCC
    train_MCC = []
    val_MCC = []
    test_MCC = []
    
    #Keep track of r and p values
    train_r = []
    train_p = []
    val_r = []
    val_p = []
    test_r = []
    test_p = []
    
    #Test if weights are updated
    w0 = []
    w1 = []
    w2 = []
    w3 = []
    w4 = []
    w5 = []
    w6 = []
    w7 = []
    w8 = []
    w9 = []
    w10 = []
    w11 = []
    w12 = []
    w13 = []
    
    #Train network
    for epoch in range(1,n_epochs+1):
        
        #Keep track of train loss
        train_running_loss = 0.0
        train_pred = []
        train_labels = []
        
        #Keep track of val loss
        val_running_loss = 0.0
        val_pred = []
        val_labels = []
        
        #Chekck if the weights are updated
        params = model.state_dict()
        w0.append(params["last1.weight"].clone())
        w1.append(params["last2.weight"].clone())
        w2.append(params['lstm.weight_ih_l0'].clone())
        w3.append(params['lstm.weight_hh_l0'].clone())
        w4.append(params['lstm.weight_ih_l0_reverse'].clone())
        w5.append(params['lstm.weight_hh_l0_reverse'].clone())
        w6.append(params['lstm.weight_ih_l1'].clone())
        w7.append(params['lstm.weight_hh_l1'].clone())
        w8.append(params['lstm.weight_ih_l1_reverse'].clone())
        w9.append(params['lstm.weight_hh_l1_reverse'].clone())
        w10.append(params['lstm.weight_ih_l2'].clone())
        w11.append(params['lstm.weight_hh_l2'].clone())
        w12.append(params['lstm.weight_ih_l2_reverse'].clone())
        w13.append(params['lstm.weight_hh_l2_reverse'].clone())
        print(f"Epoch: {epoch}", end="\n")
        
        
        #Iterate through batches
        for i, (embs, labels) in enumerate(train_loader):
            labels = torch.tensor(labels)
            labels = labels.float()
            
            #reset optimizer
            optimizer.zero_grad()
            
            #Print sceen output
            str_epoch = format(epoch, '03d')
            #print(f"Epoch: {str_epoch} batch: {i}", end="\n")     
 
            #Predict labels (forward)
            y_pred = model(embs)
            y_pred = y_pred.squeeze()

            #Calculate MSE loss 
            loss = loss_func1(y_pred, labels)
            loss.backward()
            train_running_loss += loss.item() * embs.size(0)
            
            #optimize
            optimizer.step()
            
            #collect prediction and labels for comparison (overwrites every epoch)
            acsins = [pred.item() for pred in y_pred]
            train_pred.append(acsins)
            train_labels.append(labels)
            
        #Run model in evaluation mode on validation set
        val_pred, val_labels, val_running_loss = test_model(model, optimizer, val_loader, loss_fn)
        
        #Run model in evaluation mode on test set
        test_pred, test_labels, test_running_loss = test_model(model, optimizer, test_loader, loss_fn)
          
        #Collect loss after each epoch
        train_loss_values.append(train_running_loss / len(train_X))
        val_loss_values.append(val_running_loss / len(val_X))
        test_loss_values.append(test_running_loss/len(test_X))
        
        #Format predictions 
        train_pred = [item for sublist in train_pred for item in sublist]
        train_labels = [item for sublist in train_labels for item in sublist]
        val_pred = [item for sublist in val_pred for item in sublist]
        val_labels = [item for sublist in val_labels for item in sublist]
        test_pred = [item for sublist in test_pred for item in sublist]
        test_labels = [item for sublist in test_labels for item in sublist]
        
        #Collect spearman corelation
        #train_correlation, p_value = spearmanr(train_labels, train_pred)
        #train_SCC.append(train_correlation)
        #val_correlation, p_value = spearmanr(val_labels, val_pred)
        #val_SCC.append(val_correlation)
        #test_correlation, p_value = spearmanr(test_labels, test_pred)
        #test_SCC.append(test_correlation)
        
        #Collect AUC
        auc_train = roc_auc_score(train_labels,train_pred)
        auc_val = roc_auc_score(val_labels,val_pred)
        auc_test = roc_auc_score(test_labels,test_pred)
        train_AUC.append(auc_train)
        val_AUC.append(auc_val)
        test_AUC.append(auc_test)
        
        #Collect pearson corelation
        #train_correlation, p_value = pearsonr(train_labels, train_pred)
        #train_PCC.append(train_correlation)
        #val_correlation, p_value = pearsonr(val_labels, val_pred)
        #val_PCC.append(val_correlation)
        
        #Collect MCC
        train_pred_round = [0 if x <= 0.35 else 1 for x in train_pred]
        val_pred_round = [0 if x <= 0.35 else 1 for x in val_pred]
        test_pred_round = [round(x) for x in test_pred]
        train_MCC.append(matthews_corrcoef(train_labels, train_pred_round))
        val_MCC.append(matthews_corrcoef(val_labels, val_pred_round))
        test_MCC.append(matthews_corrcoef(test_labels, test_pred_round))
        
        #Collect R^2 and p-value
        train_r.append(accuracy_score(train_labels,train_pred_round))
        val_r.append(accuracy_score(val_labels,val_pred_round))
        test_r.append(accuracy_score(test_labels,test_pred_round))
        
        slope, intercept, train_r_value, train_p_value, std_err = linregress(train_labels, train_pred)
        train_p.append(train_p_value)
        slope, intercept, val_r_value, val_p_value, std_err = linregress(val_labels, val_pred)
        val_p.append(val_p_value)
        slope, intercept, test_r_value, test_p_value, std_err = linregress(test_labels, test_pred)
        test_p.append(test_p_value)
            
        if epoch%1 == 0:
            filepath = f"./models/fold{fold}/finetune_{epoch}_epochs"
            save_model(filepath,  epoch, model, train_loss_values, train_r, train_p, train_AUC, train_MCC, train_labels,train_pred, val_loss_values, val_r, val_p, val_AUC, val_MCC, val_labels,val_pred,test_loss_values, test_r, test_p, test_AUC, test_MCC, test_labels,test_pred)            
            
            print(f"triggers: {triggers}")
            
        #Check for early stopping
        current_loss = val_loss_values[-1]
        if current_loss < min_val_loss:
            min_val_loss = current_loss
            triggers = 0
        else:
            triggers += 1
            total_triggers += 1

        if triggers == patience or total_triggers == 20:  #30
            return w0,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13, epoch, model,train_loss_values, train_r, train_p, train_AUC, train_MCC, train_labels,train_pred, val_loss_values, val_r, val_p, val_AUC, val_MCC, val_labels,val_pred, test_loss_values, test_r, test_p, test_AUC, test_MCC, test_labels,test_pred
        
    #Return if model runs through all epochs
    return w0,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13, epoch, model, train_loss_values, train_r, train_p, train_AUC, train_MCC, train_labels,train_pred, val_loss_values, val_r, val_p, val_AUC, val_MCC, val_labels,val_pred, test_loss_values, test_r, test_p, test_AUC, test_MCC, test_labels,test_pred

In [None]:
#Perform 5-fold cross validation
test_parts = []
for fold, (train_idx,val_idx,test_idx) in enumerate(folds):
    
    #initialize
    start = time.time()
    fold = fold +1
    
    #Initialize models folder
    if not os.path.exists('./models'):
        os.makedirs('./models')
    with open("./models/logfile.txt", 'a+') as f:
        f.write(f'####  Fold {fold}  ####\n')
    print(f'####  Fold {fold}  ####')
    print(f"Train: {len(train_idx)}")
    print(f"val: {len(val_idx)}")
    print(f"Test: {len(test_idx)}")
      
    #Get the proper fold
    train_X, val_X, test_X = [test_embs[i] for i in train_idx], [test_embs[i] for i in val_idx], [test_embs[i] for i in test_idx]
    train_y, val_y, test_y = [test_label[i] for i in train_idx], [test_label[i] for i in val_idx], [test_label[i] for i in test_idx]
    train_names = [data_names[i] for i in train_idx]
    val_names = [data_names[i] for i in val_idx]
    test_names = [data_names[i] for i in test_idx]
    test_parts += test_idx
    
    #Make dataset
    train = ProteinDataset(train_X,train_y)
    val = ProteinDataset(val_X,val_y)
    test = ProteinDataset(test_X,test_y)
 
    #Make data loaders
    batch_size = 17 
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, collate_fn=pad_collate, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, collate_fn=pad_collate, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, collate_fn=pad_collate, shuffle=True)
    

    #Define model
    model = Bi_LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, num_classes_nesg = num_classes_nesg, num_classes_psibio = num_classes_psibio, dropout = dropout)
    state = torch.load("../3_PreTraining/model/model_8/model_conv.state")
    model.load_state_dict(torch.load("../3_PreTraining/model/model_8/model_conv.state_dict"), strict=False)
    loss_fn = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay = weight_decay)

     
    #Train the model
    w0,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13, epoch, model, train_loss_values, train_r, train_p, train_AUC, train_MCC, train_labels,train_pred, val_loss_values, val_r, val_p, val_AUC, val_MCC, val_labels,val_pred, test_loss_values, test_r, test_p, test_AUC, test_MCC, test_labels,test_pred = train_model(model, loss_fn, optimizer, n_epochs)
    
    #Save in a dictionary
    history = {'train_loss': train_loss_values, 'val_loss': val_loss_values,'test_loss': test_loss_values,
               'train_AUC':train_AUC,'val_AUC':val_AUC,'test_AUC':test_AUC,
               'train_MCC':train_MCC, 'val_MCC':val_MCC, 'test_MCC':test_MCC, 
               'train_labels':train_labels, 'val_labels': val_labels, 'test_labels': test_labels,
               'train_r': train_r, 'train_p': train_p, 'val_r': val_r, 'val_p': val_p,'test_r': test_r, 'test_p': test_p,
               'train_pred':train_pred, 'val_pred':val_pred, 'test_pred':test_pred, 
               'epochs': epoch, 'train_names':train_names, 'val_names':val_names, 'test_names':test_names}
    foldperf[f'fold{fold}'] = history
    end = time.time()
    
    #Screen output/logfile
    with open("./models/logfile.txt", 'a+') as f:
        f.write(f"---------------- Finished fold {fold} ----------------\n")
        f.write(f"Epoch: {epoch} \nTrain loss: {train_loss_values[-1]} \t Validation loss: {val_loss_values[-1]} \t Test loss: {test_loss_values[-1]}\n")
        f.write(f"Train AUC: {train_AUC[-1]} \t Validation AUC: {val_AUC[-1]} \t Test AUC: {test_AUC[-1]}\n")
        f.write(f"Train MCC: {train_MCC[-1]} \t Validation MCC: { val_MCC[-1]} \t Test MCC: {test_MCC[-1]}\n")
        f.write(f"Train Accurcay: {train_r[-1]} \t Validation Accurcay: { val_r[-1]} \t Test Accurcay: { test_r[-1]}\n")
        f.write("Elapsed time: {:.2f} min\n".format((end - start)/60))
        f.write("\n")

#Final save  
a_file = open("./models/foldperf.pkl", "wb")
pickle.dump(foldperf, a_file)
a_file.close()
print(f"For testing {len(set(test_parts))} unique protiens were used. A total of {len(test_parts)} proteins")


In [None]:
#import  pickle
#pickleFile = open("./models_clustering3/foldperf.pkl", 'rb')
#foldperf = pickle.load(pickleFile)

In [None]:

##################
#### Make plot ###
##################
#Average output over epochs
import warnings
tlf, vlf, ttlf, tpcc, vpcc, ttpcc, tscc, vscc, ttscc, tr, vr, ttr = [],[],[],[],[],[],[],[],[],[],[],[]
max_epoch = 0
min_epoch = n_epochs
for n in range(n_epochs):
    tlfe, vlfe, ttlfe, tpcce, vpcce, ttpcce, tscce, vscce, ttscce, tre, vre, ttre = [],[],[],[],[],[],[],[],[],[],[],[]
    for f in range(1,11):
        try:
            max_epoch = (max(max_epoch,foldperf[f'fold{f}']['epochs']))
            min_epoch = (min(min_epoch,foldperf[f'fold{f}']['epochs']))
            tlfe.append(foldperf[f'fold{f}']['train_loss'][n])
            vlfe.append(foldperf[f'fold{f}']['val_loss'][n])
            ttlfe.append(foldperf[f'fold{f}']['test_loss'][n])
            tpcce.append(foldperf[f'fold{f}']['train_AUC'][n])
            vpcce.append(foldperf[f'fold{f}']['val_AUC'][n])
            ttpcce.append(foldperf[f'fold{f}']['test_AUC'][n])
            tscce.append(foldperf[f'fold{f}']['train_MCC'][n])
            vscce.append(foldperf[f'fold{f}']['val_MCC'][n])
            ttscce.append(foldperf[f'fold{f}']['test_MCC'][n])
            tre.append((foldperf[f'fold{f}']['train_r'][n]))
            vre.append((foldperf[f'fold{f}']['val_r'][n]))
            ttre.append((foldperf[f'fold{f}']['test_r'][n]))
            
            
        except IndexError as error:
            tlfe.append(np.nan)
            vlfe.append(np.nan)
            ttlfe.append(np.nan)
            tpcce.append(np.nan)
            vpcce.append(np.nan)
            ttpcce.append(np.nan)
            tscce.append(np.nan)
            vscce.append(np.nan)
            ttscce.append(np.nan)
            tre.append(np.nan)
            vre.append(np.nan)
            ttre.append(np.nan)
            foldperf[f'fold{f}']['train_loss'] = foldperf[f'fold{f}']['train_loss']+[np.nan]
            foldperf[f'fold{f}']['val_loss'] = foldperf[f'fold{f}']['val_loss']+[np.nan]
            foldperf[f'fold{f}']['test_loss'] = foldperf[f'fold{f}']['test_loss']+[np.nan]
            foldperf[f'fold{f}']['train_AUC'] = foldperf[f'fold{f}']['train_AUC']+[np.nan]
            foldperf[f'fold{f}']['val_AUC'] = foldperf[f'fold{f}']['val_AUC'] +[np.nan]
            foldperf[f'fold{f}']['test_AUC'] = foldperf[f'fold{f}']['test_AUC']+[np.nan]
            foldperf[f'fold{f}']['train_MCC'] = foldperf[f'fold{f}']['train_MCC']+[np.nan]
            foldperf[f'fold{f}']['val_MCC']  = foldperf[f'fold{f}']['val_MCC'] +[np.nan]
            foldperf[f'fold{f}']['test_MCC'] = foldperf[f'fold{f}']['test_MCC']+[np.nan]
            foldperf[f'fold{f}']['train_r']  = foldperf[f'fold{f}']['train_r'] +[np.nan]
            foldperf[f'fold{f}']['val_r']  = foldperf[f'fold{f}']['val_r'] +[np.nan]
            foldperf[f'fold{f}']['test_r']  = foldperf[f'fold{f}']['test_r'] +[np.nan]
            
    # Catch warnings for making means with nan        
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        tlf.append(np.nanmean(tlfe))
        vlf.append(np.nanmean(vlfe))
        ttlf.append(np.nanmean(ttlfe))
        tpcc.append(np.nanmean(tpcce))
        vpcc.append(np.nanmean(vpcce))
        ttpcc.append(np.nanmean(ttpcce))
        tscc.append(np.nanmean(tscce))
        vscc.append(np.nanmean(vscce))
        ttscc.append(np.nanmean(ttscce))
        tr.append(np.nanmean(tre))
        vr.append(np.nanmean(vre))
        ttr.append(np.nanmean(ttre))

#Get only the relevant length (max_epoch)        
tlf = tlf[:max_epoch]
vlf = vlf[:max_epoch]
ttlf = ttlf[:max_epoch]
tpcc = tpcc[:max_epoch]
vpcc = vpcc[:max_epoch]
ttpcc = ttpcc[:max_epoch]
tscc = tscc[:max_epoch]
vscc = vscc[:max_epoch]
ttscc = ttscc[:max_epoch]
tr = tr[:max_epoch]
vr = vr[:max_epoch]
ttr = ttr[:max_epoch]

for f in range(1,11):
    foldperf[f'fold{f}']['train_loss'] = foldperf[f'fold{f}']['train_loss'][:max_epoch]
    foldperf[f'fold{f}']['val_loss'] = foldperf[f'fold{f}']['val_loss'][:max_epoch]
    foldperf[f'fold{f}']['test_loss'] = foldperf[f'fold{f}']['test_loss'][:max_epoch]
    foldperf[f'fold{f}']['train_AUC'] = foldperf[f'fold{f}']['train_AUC'][:max_epoch]
    foldperf[f'fold{f}']['val_AUC'] = foldperf[f'fold{f}']['val_AUC'][:max_epoch]
    foldperf[f'fold{f}']['test_AUC'] = foldperf[f'fold{f}']['test_AUC'][:max_epoch]
    foldperf[f'fold{f}']['train_MCC'] = foldperf[f'fold{f}']['train_MCC'][:max_epoch]
    foldperf[f'fold{f}']['val_MCC']  = foldperf[f'fold{f}']['val_MCC'][:max_epoch]
    foldperf[f'fold{f}']['test_MCC'] = foldperf[f'fold{f}']['test_MCC'][:max_epoch]
    foldperf[f'fold{f}']['train_r']  = foldperf[f'fold{f}']['train_r'][:max_epoch]
    foldperf[f'fold{f}']['val_r']  = foldperf[f'fold{f}']['val_r'][:max_epoch]
    foldperf[f'fold{f}']['test_r']  = foldperf[f'fold{f}']['test_r'][:max_epoch]

    
#Make pretty plot
plt.rcParams['figure.figsize'] = [25, 25]   
plt.rcParams['font.size']=20

#Initialize plot
fig, ((ax1, ax3), (ax2,ax4)) = plt.subplots(2, 2)
fig.patch.set_facecolor('#FAFAFA')
fig.patch.set_alpha(0.7)
x = list(range(1,max_epoch+1))

#base * round(a_number/base)

###### Plot loss ######
ax1.plot(x,vlf, label = "Average loss of Validation data", c="blue", lw = 5 )
ax1.plot(x,tlf, label = "Average loss of Training data", c="red", lw = 5)
ax1.plot(x,ttlf, label = "Average loss of Testing data", c="cornflowerblue", lw = 5)
ax1.axvline(min_epoch,ls = '--', c = "grey", label = "First early stopping of a fold", lw = 3)

# Add each fold
ax1.plot(x,foldperf[f'fold1']['train_loss'], label = "Fold 1-10 Training data", c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold2']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold3']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold4']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold5']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold6']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold7']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold8']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold9']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold10']['train_loss'], c="palevioletred", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold1']['val_loss'], label = "Fold 1-10 Validation data", c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold2']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold3']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold4']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold5']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold6']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold7']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold8']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold9']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold10']['val_loss'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold1']['test_loss'], label = "Fold 1-10 Test data", c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold2']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold3']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold4']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold5']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold6']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold7']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold8']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold9']['test_loss'], c="plum", lw = 3, alpha = 0.2)
ax1.plot(x,foldperf[f'fold10']['test_loss'], c="plum", lw = 3, alpha = 0.2)

# Make plot pretty
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training loss')
ax1.set_title("Loss curve")
ax1.legend(loc='upper center')
#ax1.set_xticks(np.arange(1,max_epoch+1,1))
#ax1.set_yticks(np.arange(0.26,0.37,0.02))
ax1.grid(True)

##### Plot AUC ####
ax2.plot(x,vpcc, label = "Average AUC of Validation data", c = "blue", lw = 5)
ax2.plot(x,tpcc, label = "Average AUC of Training data", c = "red", lw = 5)
ax2.plot(x,ttpcc, label = "Average AUC of Testing data", c = "cornflowerblue", lw = 5)
ax2.axvline(min_epoch,ls = '--', c = "grey", label = "First early stopping of a fold", lw = 3)
# Add each fold
ax2.plot(x,foldperf[f'fold1']['train_AUC'], label = "Fold 1-10 Training data", c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold2']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold3']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold4']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold5']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold6']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold7']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold8']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold9']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold10']['train_AUC'], c="palevioletred", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold1']['val_AUC'], label = "Fold 1-10 Validation data", c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold2']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold3']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold4']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold5']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold6']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold7']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold8']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold9']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold10']['val_AUC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold1']['test_AUC'], label = "Fold 1-10 Test data", c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold2']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold3']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold4']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold5']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold6']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold7']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold8']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold9']['test_AUC'], c="plum", lw = 3, alpha = 0.2)
ax2.plot(x,foldperf[f'fold10']['test_AUC'], c="plum", lw = 3, alpha = 0.2)



# Make plot pretty
ax2.set_xlabel('Epoch')
ax2.set_ylabel('AUC')
ax2.set_title("AUC")
ax2.legend(loc='lower right')
#ax2.set_xticks(np.arange(1,max_epoch+1,1))
#ax2.set_yticks(np.arange(0.5,0.7,0.05))
ax2.grid(True)

    
#### Plot MCC ####
ax3.plot(x,vscc, label = "Average MCC of Validation data", c = "blue", lw = 5)
ax3.plot(x,tscc, label = "Average MCC of Training data", lw = 5, c="red")
ax3.plot(x,ttscc, label = "Average MCC of Testing data", lw = 5, c="cornflowerblue")
ax3.axvline(min_epoch,ls = '--', c = "grey", label = "First early stopping of a fold", lw = 3)
# Add each fold
ax3.plot(x,foldperf[f'fold1']['train_MCC'], label = "Fold 1-10 Training data", c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold2']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold3']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold4']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold5']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold6']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold7']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold8']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold9']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold10']['train_MCC'], c="palevioletred", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold1']['val_MCC'], label = "Fold 1-10 Validation data", c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold2']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold3']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold4']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold5']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold6']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold7']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold8']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold9']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold10']['val_MCC'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold1']['test_MCC'], label = "Fold 1-10 Test data", c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold2']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold3']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold4']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold5']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold6']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold7']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold8']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold9']['test_MCC'], c="plum", lw = 3, alpha = 0.2)
ax3.plot(x,foldperf[f'fold10']['test_MCC'], c="plum", lw = 3, alpha = 0.2)

# Make plot pretty
ax3.set_xlabel('Epoch')
ax3.set_ylabel('MCC')
ax3.set_title("MCC")
ax3.legend(loc='lower right')
#ax3.set_xticks(np.arange(1,max_epoch+1,1))
#ax3.set_yticks(np.arange(0,0.6,0.1))
ax3.grid(True)


#### Plot Accuracy ####
ax4.plot(x,vr, label = "Average Accurcay of Validation data", c = "blue", lw = 5)
ax4.plot(x,tr, label = "Average Accurcay of Training data", lw = 5, c="red")
ax4.plot(x,ttr, label = "Average Accurcay of Testing data", lw = 5, c="cornflowerblue")
ax4.axvline(min_epoch,ls = '--', c = "grey", label = "First early stopping of a fold", lw = 3)
# Add each fold
ax4.plot(x,foldperf[f'fold1']['train_r'], label = "Fold 1-10 Training data", c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold2']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold3']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold4']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold5']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold6']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold7']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold8']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold9']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold10']['train_r'], c="palevioletred", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold1']['val_r'], label = "Fold 1-10 Validation data", c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold2']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold3']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold4']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold5']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold6']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold7']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold8']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold9']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold10']['val_r'], c="cornflowerblue", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold1']['test_r'], label = "Fold 1-10 Test data", c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold2']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold3']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold4']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold5']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold6']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold7']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold8']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold9']['test_r'], c="plum", lw = 3, alpha = 0.2)
ax4.plot(x,foldperf[f'fold10']['test_r'], c="plum", lw = 3, alpha = 0.2)

# Make plot pretty
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accurcay')
ax4.set_title("Accurcay")
ax4.legend(loc='lower right')
#ax4.set_xticks(np.arange(1,max_epoch+1,1))
#ax3.set_yticks(np.arange(0,0.6,0.1))
ax4.grid(True)


fig.tight_layout(pad = 1)
fig.savefig(f'./models/Loss_AUC_MCC_pretty.png', facecolor=fig.get_facecolor(), edgecolor='none')
#plt.show()

#save_model("./final", n_epochs, model, train_loss, val_loss, train_nesg_PCC,  val_nesg_PCC, train_psibio_MCC, val_psibio_MCC, train_psibio_AUC, val_psibio_AUC, labels_out, predictions_out)   
#plot_performance("./final",n_epochs,train_loss, val_loss,train_psibio_AUC,val_psibio_AUC,train_nesg_PCC,val_nesg_PCC,train_psibio_MCC,val_psibio_MCC)        

In [None]:
#Get specific last output
print(f"Test Loss: {ttlf[-1]}")
print(f"Test AUC: {ttpcc[-1]}")
print(f"Test MCC: {ttscc[-1]}")
print(f"Test Accuracy: {ttr[-1]}")

## Check if the weights are updated throughout training

In [None]:
all_w = [w0,w1,w2,w3,w4,w5,w6,w7,w8,w9,w10,w11,w12,w13]

In [None]:
#Compare weights between first and last epoch
total = 0
static = 0
cause = []
count = 0
for i in all_w:
    count += 1
    eq = torch.eq(i[0], i[-1])
    summa = sum(eq)
    total += len(i[0])
    summa = sum(summa)
    static += summa
    if summa > 0:
        cause.append(f"w{count}")
        
        
print(f"{static} static weights out of {total} total weights")
print(f"Causes: {cause}")