   # Group 10 Traditional Machine Learning Notebook
Within each cycle of active learning, you can:

1. Collect training data (original training data + your query data).

2. Train a prediction model to predict the DMS_score for each mutant (e.g., M0A).

3. Use the trained model to predict the score for all mutant in the test set.

4. Select query mutants for next round based on certain criteria. You may want to make sure you don't query the same mutant twice as you only have a limited chances of making queries in total.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, Dataset
import random
from copy import deepcopy
import pandas as pd
from scipy.stats import spearmanr
from scipy.stats import rankdata
import argparse
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import GradientBoostingRegressor
import time

## 1. collect training data

Upload `sequence.fasta`, `train.csv`, and `test.csv` to the current runtime:

1. click the folder icon on the left

2. click the upload icon and upload the files to the current directory

In [3]:
with open('sequence.fasta', 'r') as f:
  data = f.readlines()

sequence_wt = data[1].strip()
sequence_wt

'MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLREKMRRRLESGDKWFSLEFFPPRTAEGAVNLISRFDRMAAGGPLYIDVTWHPAGDPGSDKETSSMMIASTAVNYCGLETILHMTCCRQRLEEITGHLHKAKQLGLKNIMALRGDPIGDQWEEEEGGFNYAVDLVKHIRSEFGDYFDICVAGYPKGHPEAGSFEADLKHLKEKVSAGADFIITQLFFEADTFFRFVKACTDMGITCPIVPGIFPIQGYHSLRQLVKLSKLEVPQEIKDVIEPIKDNDAAIRNYGIELAVSLCQELLASGLVPGLHFYTLNREMATTEVLKRLGMWTEDPRRPLPWALSAHPKRREEDVRPIFWASRPKSYIYRTQEWDEFPNGRWGNSSSPAFGELKDYYLFYLKSKSPKEELLKMWGEELTSEESVFEVFVLYLSGEPNRNGHKVTCLPWNDEPLAAETSLLKEELLRVNRQGILTINSQPNINGKPSSDPIVGWGPSGGYVFQKAYLEFFTSRETAEALLQVLKKYELRVNYHLVNVKGENITNAPELQPNAVTWGIFPGREIIQPTVVDPVSFMFWKDEAFALWIERWGKLYEEESPSRTIIQYIHDNYFLVNLVDNDFPLDNCLWQVVEDTLELLNRPTQNARETEAP'

In [4]:
len(sequence_wt)

656

In [5]:
def get_mutated_sequence(mut, sequence_wt):
  wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]

  sequence = deepcopy(sequence_wt)

  return sequence[:pos]+mt+sequence[pos+1:]

In [6]:
df_train = pd.read_csv('train.csv')
df_train['sequence'] = df_train.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_train

Unnamed: 0,mutant,DMS_score,sequence
0,M0Y,0.273000,YVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,M0W,0.285700,WVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,M0V,0.215300,VVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,M0T,0.312200,TVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,M0S,0.218000,SVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...,...
1335,R593P,0.806500,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1336,K596A,0.879648,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1337,Y610A,0.721494,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1338,Y610T,0.783082,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [7]:
df_test = pd.read_csv('test.csv')
df_test['sequence'] = df_test.mutant.apply(lambda x: get_mutated_sequence(x, sequence_wt))
df_test

Unnamed: 0,mutant,sequence
0,V1D,MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,V1Y,MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,V1C,MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,V1A,MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,V1E,MENEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...
11319,P655S,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11320,P655T,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11321,P655V,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11322,P655A,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [None]:
# TODO: integrate the query data that you acquired each round into df_train

## 2. Train a prediction model

Here, we provided a linear regression model and used one-hot encoding to encode each variant. You would need to build your own model to achieve better performances.

Hint: you can perform cross-validation on the training set to evaluate your predictor before making predictions on the test set.

In [8]:
'''hyperparameters'''

seq_length = 656
seed = 0 # seed for splitting the validation set
val_ratio = 0.2 # proportion of validation set

In [9]:
class ProteinDataset(Dataset):
    def __init__(self, df, istrain=True):

        alphabet = 'ACDEFGHIKLMNPQRSTVWY'
        map_a2i = {j:i for i,j in enumerate(alphabet)}
        map_i2a = {i:j for i,j in enumerate(alphabet)}

        self.df = df

        self.num_samples = len(self.df)
        self.seq_length = len(self.df.sequence.values[0])
        self.num_channels = 20

        # TODO: replace one-hot encodings with your own encodings
        self.encodings = np.zeros((self.num_samples, self.num_channels, self.seq_length)).astype(np.float32)
        self.targets = np.zeros(self.num_samples).astype(np.float32)

        if istrain:
          for it, (seq,target) in enumerate(self.df[['sequence', 'DMS_score']].values):
              for i,aa in enumerate(seq):
                  self.encodings[it,map_a2i[aa],i] = 1
              self.targets[it] = target

          self.encodings = self.encodings.astype(np.float32)
          self.targets = self.targets.astype(np.float32)
        else:
          for it, seq in enumerate(self.df['sequence'].values):
              for i,aa in enumerate(seq):
                  self.encodings[it,map_a2i[aa],i] = 1

          self.encodings = self.encodings.astype(np.float32)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.tensor(self.encodings[idx]), torch.tensor(self.targets[idx])

## Deep Learning with MSELoss

In [8]:

#OHE with the entire sequence on one line(every 20 positions the AA position changes)
class ProteinDataset(Dataset):
    def __init__(self, df, istrain=True):

        alphabet = 'ACDEFGHIKLMNPQRSTVWY'
        map_a2i = {j:i for i,j in enumerate(alphabet)}
        map_i2a = {i:j for i,j in enumerate(alphabet)}

        self.df = df

        self.num_samples = len(self.df)
        self.seq_length = len(self.df.sequence.values[0])
        self.num_channels = 20

        # TODO: replace one-hot encodings with your own encodings
        self.encodings = np.zeros((self.num_samples, self.num_channels * self.seq_length)).astype(np.float32)
        self.targets = np.zeros(self.num_samples).astype(np.float32)

        if istrain:
          for it, (seq,target) in enumerate(self.df[['sequence', 'DMS_score']].values):
              for i,aa in enumerate(seq):
                  self.encodings[it,map_a2i[aa] + (20*i)] = 1
              self.targets[it] = target

          self.encodings = self.encodings.astype(np.float32)
          self.targets = self.targets.astype(np.float32)
        else:
          for it, seq in enumerate(self.df['sequence'].values):
              for i,aa in enumerate(seq):
                  self.encodings[it,map_a2i[aa] + (20*i)] = 1

          self.encodings = self.encodings.astype(np.float32)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.tensor(self.encodings[idx]), torch.tensor(self.targets[idx])

In [None]:
train_dataset = ProteinDataset(df_train)
test_dataset = ProteinDataset(df_test, istrain=False)

# split validation set
train_dataset, val_dataset = train_test_split(train_dataset, test_size=val_ratio, random_state=seed, shuffle=True)

# TODO: revise according to your own model
train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=len(val_dataset), shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False)

In [10]:
class TransformerRegressor(nn.Module):
    def __init__(self, embed_dim=20000, num_heads=4, num_layers=3, ff_dim=10000):
        super(TransformerRegressor, self).__init__()
        self.embedding = nn.Linear(13120, embed_dim)  # Project one-hot input to embedding space
        # TODO: transformer block using TransformerEncoderLayer
        #####
        self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, ff_dim)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        #####
        self.pooling = nn.AdaptiveAvgPool1d(1)  # Global average pooling
        #self.linear = nn.Linear(embed_dim,1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.embedding(x)  # (batch_size, seq_len, embed_dim)
        # TODO: forward part of the transformer block
        #####
        x = self.encoder(x)
        #####
        #x = x.permute(0, 2, 1)
        x = self.pooling(x).squeeze(-1)  # (batch, embed_dim)
        #x = self.linear(x)
        #x = x.max(dim=1).values
        return (self.tanh(x) + 1)/2

In [14]:
class FFRegressor(nn.Module):
    def __init__(self, start_dim=13120):
        super(FFRegressor, self).__init__()
        self.linear1 = nn.Linear(start_dim, 1312)
        self.linear2 = nn.Linear(1312, 1)
        self.relu= nn.ReLU()
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear(x)
        x = self.relu(x)
        return (x)
                

In [15]:
def mse(model, data):
    preds = []
    truths = []
    for row in data:
        seq,fitness = row
        seq = seq.to(device)
        fitness = fitness.to(device)
        preds.append(model(seq))
        truths.append(ranks)
    mse_list = [(x - y)**2 for x,y in zip(truths, preds)]
    mse = sum(mse_list)/len(mse_list)
    return (mse)

In [None]:
def cor(model, data):
    preds = []
    truths = []
    for row in data:
        seq,fitness,ranks = row
        seq = seq.to(device)
        fitness = fitness.to(device)
        preds.append(model(seq))
        truths.append(fitness)
    cor = spearmanr(preds, truths)
    return (cor.statistic)

In [25]:
def train_model(model, train_dataset, val_dataset, epochs=100, batch_size=256, lr=1e-3, patience=10, device='cuda:0'):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_mse = 10000
    patience_counter = 0
    best_ckpt = None
    best_cor = -1

    for epoch in range(epochs):
        model.train()
        start_epoch = time.time()
        model.train()
        total_loss = 0
        #truths, preds = [], []

        for sequences, fitness, ranks in train_loader:
            # TODO: backpropagation
            #####
            model = model.to(device)
            sequences = sequences.to(device)
            fitness = fitness.to(device)
            ranks = ranks.to(device)
            outputs = model(ranks)
            outputs = outputs.to(device)
            loss = criterion(outputs, ranks)
             #zero's grad for backprogation
            optimizer.zero_grad()

            #backpropagates
            loss.backward()

            #steps forward
            optimizer.step()
            #####
            
            total_loss += loss.item()
        
        #avg_loss = total_loss/batch_size
        

        # Validation
        model.eval()
#         truth, preds = [], []
#         with torch.no_grad():
#             for sequences, fitness in val_loader:
#                 # TODO: model inference
#                 #####
#                 model = model.to(device)
#                 sequences = sequences.to(device)
#                 fitness = fitness.to(device)
#                 outputs = model(sequences)
#                 outputs = outputs.to(device)
#                 truth.append(fitness)
#                 preds.append(outputs)
#                 #####

        model = model.to(device)
        val_cor = cor(model, val_dataset)
        val_cor  = float(val_cor)
        end_epoch = time.time()
        #print(total_loss, train_corr.statistic, val_corr.statistic)
        print(f'Epoch [{epoch+1} / {epochs}]: Train Loss={avg_loss:.4f}, Val MSE={val_mse:.4f}, Time={end_epoch - start_epoch:.4f} sec')

        # Early stopping
        if val_cor > best_cor:
            best_cor = val_cor
            best_ckpt = deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    return model, best_ckpt

In [None]:
train_dataset = ProteinDataset(df_train)
test_dataset = ProteinDataset(df_test, istrain=False)

# split validation set
train_dataset, val_dataset = train_test_split(train_dataset, test_size=val_ratio, random_state=seed, shuffle=True)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = FFRegressor().to(device)
model, best_ckpt = train_model(model, train_dataset, val_dataset, epochs=500, batch_size=256, lr=1e-3, patience=10, device=device)
model.load_state_dict(best_ckpt)

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = FFRegressor().to(device)
model, best_ckpt = train_model(model, train_dataset, val_dataset, epochs=500, batch_size=256, lr=1e-3, patience=10, device=device)
model.load_state_dict(best_ckpt)

torch.Size([256, 13120])
Epoch [1 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0216 sec
torch.Size([256, 13120])
Epoch [2 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0205 sec
torch.Size([256, 13120])
Epoch [3 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0202 sec
torch.Size([256, 13120])
Epoch [4 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0197 sec
torch.Size([256, 13120])
Epoch [5 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0197 sec
torch.Size([256, 13120])
Epoch [6 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0196 sec
torch.Size([256, 13120])
Epoch [7 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0195 sec
torch.Size([256, 13120])
Epoch [8 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0193 sec
torch.Size([256, 13120])
Epoch [9 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0192 sec
torch.Size([256, 13120])
Epoch [10 / 500]: Train Loss=0.0000, Val MSE=0.1688, Time=0.0191 sec
torch.Size([256, 13120])
Epoch [11 / 500]: Train Loss=0.0000, Val MSE

<All keys matched successfully>

In [21]:
spearmanr(preds, truths).statistic

0.007279735479372176

In [22]:
spearmanr(train_preds, train_truths).statistic

0.060791107887777106

## ESM Embeddings

In [9]:
import esm
from Bio import SeqIO
from tqdm.auto import tqdm
import os
import numpy as np
import torch
#from autonotebook import tqdm as notebook_tqdm

def gen_emb(fasta_file, out_dir='esm_embeddings_test', device='cuda:0'):
    records = list(SeqIO.parse(fasta_file, 'fasta'))
    names = [rec.id for rec in records]
    sequences = [str(rec.seq) for rec in records]
    print(f'Number of sequences: {len(sequences)}')

    data = [(name, seq) for name, seq in zip(names, sequences)]

    # TODO: Load ESM-2 model (esm2_t33_650M_UR50D) and batch converter
    #####
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    #####
    model.to(device)
    model.eval()  # disables dropout for deterministic results

    batch_size = 2 # Reduce if you are running out of cuda memory
    num_batches = int(np.ceil(len(data) / batch_size))

    for i in tqdm(range(num_batches)):
        batch = data[i * batch_size:(i + 1) * batch_size]
        names_batch, seqs_batch = zip(*batch)
        batch_labels, batch_strs, batch_tokens = batch_converter(batch)
        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
        batch_tokens = batch_tokens.to(device)
        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            # TODO: inference
            #####
            results = model(batch_tokens, repr_layers=[33], return_contacts=True)
            #####
        # TODO: get per-residue representations
        #####
        token_representations = results["representations"][33]
        #####
        # Generate per-sequence representations via averaging
        for k, tokens_len in enumerate(batch_lens):
            seq_name = names_batch[k]
            seq_tokens = token_representations[k, :tokens_len]
            seq_mean = seq_tokens.mean(0)
            save = {'mean_representations': {33: seq_mean}}
            torch.save(save, os.path.join(out_dir, f'{seq_name}.pt'))

  from .autonotebook import tqdm as notebook_tqdm


In [65]:
train_fasta_list = []
for i,seq in enumerate(df_train["sequence"]):
    train_fasta_list.append(f">seq_{i}\n")
    train_fasta_list.append(seq + "\n")
train_fasta_string = "".join(train_fasta_list)
with open("train_fasta.fa", 'w') as file:
    file.write(train_fasta_string)


    

    

In [57]:
test_fasta_list = []
for row in df_test.itertuples():
    test_fasta_list.append(f">{row.mutant}\n")
    test_fasta_list.append(row.sequence + "\n")
test_fasta_string = "".join(test_fasta_list)
with open("test_fasta.fa", 'w') as file:
    file.write(test_fasta_string)

In [66]:
gen_emb('train_fasta.fa')

Number of sequences: 1340


100%|██████████| 670/670 [01:20<00:00,  8.37it/s]


In [18]:
gen_emb('test_fasta.fa', out_dir = "esm_embeddings_test")

Number of sequences: 11324


100%|██████████| 5662/5662 [08:59<00:00, 10.50it/s]


In [16]:
#New Dataset class for new encoding type
class ProteinESMDataset(Dataset):
    def __init__(self, sequences, seq2name, emb_dir, labels, fitness2idx):
        super().__init__()
        self.labels = [fitness2idx.get(fitness, -1) for fitness in labels]
        self.embeddings = []
        for seq in tqdm(sequences, desc='Loading esm embeddings'):
            name = seq2name[seq]
            emb_file = os.path.join(emb_dir, f'{name}.pt')
            emb = torch.load(emb_file)['mean_representations'][33]
            self.embeddings.append(emb)
        self.ranks = list(rankdata(self.labels))

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, index):
        emb = self.embeddings[index]
        label = torch.tensor(self.labels[index], dtype=torch.float32)
        rank = torch.tensor(self.ranks[index], dtype=torch.float32)
        return emb, label, rank

In [10]:
sequences = df_train['sequence']
fitness_list = df_train['DMS_score'].tolist()

seq_train, seq_val, fitness_train, fitness_val = train_test_split(sequences, fitness_list, test_size=0.2, random_state=0)

In [11]:
train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

In [12]:
emb_dir = 'esm_embeddings_test'
train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, fitness_train, fitness2idx)
val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, fitness_val, fitness2idx)

Loading esm embeddings: 100%|██████████| 1072/1072 [00:13<00:00, 78.37it/s] 
Loading esm embeddings: 100%|██████████| 268/268 [00:01<00:00, 214.72it/s]


In [17]:
def mse_esm(model, data):
    preds = []
    truths = []
    for row in data:
        esm,fitness, ranks = row
        esm = esm.to(device)
        fitness = fitness.to(device)
        ranks = ranks.to(device)
        preds.append(model(esm.unsqueeze(0)))
        truths.append(fitness)
    mse_list = [(x - y)**2 for x,y in zip(truths, preds)]
    mse = sum(mse_list)/len(mse_list)
    return (mse)

In [10]:
def cor_esm(model, data):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    preds = []
    truths = []
    for row in data:
        esm,fitness, ranks = row
        esm = esm.to(device)
        fitness = fitness.to(device)
        ranks = ranks.to(device)
        preds.append(model(esm.unsqueeze(0)))
        truths.append(fitness)
    preds = [float(x) for x in preds]
    truths = [float(x) for x in truths]
    cor = spearmanr(preds, truths)
    return (cor.statistic)

In [25]:
class MLPRegressor(nn.Module):
    def __init__(self, input_dim=1280, hidden_dim=25600):
        super(MLPRegressor, self).__init__()
        # TODO: linear layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, hidden_dim//2)
        self.fc3 = nn.Linear(hidden_dim//2, hidden_dim//4)
        self.fc4 = nn.Linear(hidden_dim//4, hidden_dim//8)
        self.fc5 = nn.Linear(hidden_dim//8, hidden_dim//16)
        self.fc6 = nn.Linear(hidden_dim//16, hidden_dim//32)
        self.fc7 = nn.Linear(hidden_dim//32, 1)

    def forward(self, x):
        # TODO: forward function
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out = self.relu(out)
        out = self.fc4(out)
        out = self.relu(out)
        out = self.fc5(out)
        out = self.relu(out)
        out = self.fc6(out)
        out = self.relu(out)
        out = self.fc7(out)
        return out

In [12]:
class RNNRegressor(nn.Module):
    def __init__(self, input_dim=1280, hidden_dim=128):
        super(RNNRegressor, self).__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, 4, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # TODO: forward function
        #h0 = torch.zeros(1, x.size(0), 12800).to(x.device)
        x,_ = self.rnn(x)
        x = self.fc(x)
        return x

In [90]:
class TransformerRegressor(nn.Module):
    def __init__(self, embed_dim=1280, num_heads=4, num_layers=4, ff_dim=128):
        super(TransformerRegressor, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, num_heads, ff_dim, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers)
        #####
        self.pooling = nn.AdaptiveAvgPool1d(32)  # Global average pooling
        #self.linear = nn.Linear(embed_dim,1)
        self.fc = nn.Linear(32, 1)

    def forward(self, x):
        x = self.encoder(x)
        x = self.pooling(x)
        x = self.fc(x)
        return (x)

In [57]:
#As written this code is broken(won't train due to dimensonality issues related to batch size)
class CNNRegressor(nn.Module):
    def __init__(self, embed_dim=1280, num_filters=128, kernel_size=2, num_layers=2):
        super(CNNRegressor, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i in range(num_layers):
            if (i > 0):
                self.conv_layers.append(nn.Conv1d(num_filters, num_filters, kernel_size))
            else:
                self.conv_layers.append(nn.Conv1d(embed_dim, num_filters, kernel_size))
        self.fc = nn.Linear(num_filters, 1)
    def forward(self, x):
        for layer in self.conv_layers:
            x = layer(x)
        x = self.fc(x)
        return(x)

In [49]:
import torch.nn.functional as F
class MultiHeadAttentionRegressor(nn.Module):
    def __init__(self, embed_dim=1280, num_heads=8, num_layers=4, ff_dim=128):
        super(MultiHeadAttentionRegressor, self).__init__()
        self.attention_layers = []
        self.attention_layers = nn.ModuleList([nn.MultiheadAttention(embed_dim, num_heads) for _ in range(num_layers)])
        self.linear_1 = nn.Linear(embed_dim, ff_dim)
        self.dropout_1 = nn.Dropout(0.1)
        self.linear_2 = nn.Linear(ff_dim, embed_dim)
        self.dropout_2 = nn.Dropout(0.1)
        self.norm_1 = nn.LayerNorm(embed_dim)
        self.norm_2 = nn.LayerNorm(embed_dim)
        
        #self.pooling = nn.AdaptiveAvgPool1d(32)
        self.fc = nn.Linear(embed_dim, 1)
    
    def forward(self, x):
        for layer in self.attention_layers:
            x2,_ = layer(x,x,x)
            x = x + self.dropout_1(x2)
            x = self.norm_1(x)
            x2 = self.linear_2(self.dropout_1(F.relu(self.linear_1(x))))
            x = x + self.dropout_2(x2)
            x = self.norm_2(x)
        #x = self.pooling(x)
        x = self.fc(x)
        
        return (x)
    

In [13]:
#Trains models using ESM embeddings
def train_model_ESM(model, train_dataset, val_dataset, epochs=100, batch_size=256, lr=1e-3, patience=10, device='cuda:0'):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    criterion = nn.MarginRankingLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_mse = 100000000000
    patience_counter = 0
    best_ckpt = None
    total_loss = 0
    best_loss = 100000000000
    best_cor = 0
    for epoch in range(epochs):
        model.train()
        start_epoch = time.time()
        total_loss = 0
        
        for sequences, fitness, ranks in train_loader:
            #print(sequences.shape, fitness.shape, ranks.shape)
            model = model.to(device)
            sequences = sequences.to(device)
            fitness = fitness.to(device)
            ranks = ranks.to(device)
            outputs = model(sequences)
            outputs = outputs.to(device)
            outputs = outputs.squeeze(1)
            y = []
            for a,b in zip(outputs, fitness):
                if (b>a):
                    y.append(1)
                else:
                    y.append(-1)
            y_tens = torch.tensor(y, dtype=torch.float32)
            y_tens = y_tens.to(device)
            loss = criterion(outputs, fitness, y_tens)
            #loss = criterion(outputs, fitness)
            #zero's grad for backprogation
            optimizer.zero_grad()

            #backpropagates
            loss.backward()

            #steps forward
            optimizer.step()

            total_loss += loss.item()
        
        

        # Validation
        model.eval()
#         truth, preds = [], []
#         with torch.no_grad():
#             for sequences, fitness in val_loader:
#                 # TODO: model inference
#                 #####
#                 model = model.to(device)
#                 sequences = sequences.to(device)
#                 fitness = fitness.to(device)
#                 outputs = model(sequences)
#                 outputs = outputs.to(device)
#                 truth.append(fitness)
#                 preds.append(outputs)
#                 #####

        model = model.to(device)
        #print(val_dataset.__getitem__(0)[0].shape)
        #val_mse = mse_esm(model, val_dataset)
        val_cor = cor_esm(model, val_dataset)
        val_cor = float(val_cor)
        #val_mse = float(val_mse)
        end_epoch = time.time()
        #print(total_loss, train_corr.statistic, val_corr.statistic)
        print(f'Epoch [{epoch+1} / {epochs}]: Train Loss={total_loss:.4f}, Val Cor={val_cor:.4f}, Time={end_epoch - start_epoch:.4f} sec')

        # Early stopping
        if val_cor > best_cor:
            best_cor = val_cor
            #print("model_updated")
            best_ckpt = deepcopy(model.state_dict())
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    return model, best_ckpt

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MultiHeadAttentionRegressor().to(device)
model, best_ckpt = train_model_ESM(model, train_dataset, val_dataset, epochs=500, batch_size=32, lr=1e-4, patience=50, device=device)
model.load_state_dict(best_ckpt)

In [16]:
sequences = df_train['sequence']
fitness_list = df_train['DMS_score'].tolist()

train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

emb_dir = 'esm_embeddings_test'
total_dataset = ProteinESMDataset(sequences, train_seq2name, emb_dir, fitness_list, fitness2idx)

Loading esm embeddings: 100%|██████████| 1240/1240 [00:02<00:00, 561.43it/s] 


In [17]:
sequences = df_train['sequence']
fitness_list = df_train['DMS_score'].tolist()

seq_train, seq_val, fitness_train, fitness_val = train_test_split(sequences, fitness_list, test_size=0.2, random_state=0)

train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

emb_dir = 'esm_embeddings_test'
train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, fitness_train, fitness2idx)
val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, fitness_val, fitness2idx)

Loading esm embeddings: 100%|██████████| 992/992 [00:00<00:00, 2391.54it/s]
Loading esm embeddings: 100%|██████████| 248/248 [00:00<00:00, 2404.26it/s]


In [46]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = MLPRegressor().to(device)
model, best_model_ckpt = train_model_ESM(model, train_dataset, val_dataset, epochs=2000, batch_size=992, lr=5e-4, patience=100, device=device)
model.load_state_dict(best_model_ckpt)

Epoch [1 / 2000]: Train Loss=0.2709, Val Cor=-0.0676, Time=0.2361 sec
Epoch [2 / 2000]: Train Loss=1.4495, Val Cor=0.0809, Time=0.2329 sec
model_updated
Epoch [3 / 2000]: Train Loss=0.3209, Val Cor=-0.0876, Time=0.2327 sec
Epoch [4 / 2000]: Train Loss=0.2355, Val Cor=-0.0855, Time=0.2323 sec
Epoch [5 / 2000]: Train Loss=0.2134, Val Cor=-0.0885, Time=0.3291 sec
Epoch [6 / 2000]: Train Loss=0.1921, Val Cor=-0.0891, Time=0.2331 sec
Epoch [7 / 2000]: Train Loss=0.2038, Val Cor=-0.0891, Time=0.2326 sec
Epoch [8 / 2000]: Train Loss=0.1941, Val Cor=-0.0884, Time=0.2327 sec
Epoch [9 / 2000]: Train Loss=0.2003, Val Cor=-0.0866, Time=0.2328 sec
Epoch [10 / 2000]: Train Loss=0.1968, Val Cor=-0.0858, Time=0.2326 sec
Epoch [11 / 2000]: Train Loss=0.1920, Val Cor=-0.0862, Time=0.2329 sec
Epoch [12 / 2000]: Train Loss=0.1957, Val Cor=-0.0850, Time=0.2326 sec
Epoch [13 / 2000]: Train Loss=0.1924, Val Cor=-0.0856, Time=0.2326 sec
Epoch [14 / 2000]: Train Loss=0.1926, Val Cor=-0.0840, Time=0.2326 sec
Ep

Epoch [112 / 2000]: Train Loss=0.1915, Val Cor=0.2835, Time=0.2322 sec
Epoch [113 / 2000]: Train Loss=0.1915, Val Cor=0.2834, Time=0.2330 sec
Epoch [114 / 2000]: Train Loss=0.1915, Val Cor=0.2828, Time=0.2327 sec
Epoch [115 / 2000]: Train Loss=0.1915, Val Cor=0.2826, Time=0.2327 sec
Epoch [116 / 2000]: Train Loss=0.1915, Val Cor=0.2832, Time=0.2326 sec
Epoch [117 / 2000]: Train Loss=0.1915, Val Cor=0.2828, Time=0.2324 sec
Epoch [118 / 2000]: Train Loss=0.1915, Val Cor=0.2831, Time=0.2325 sec
Epoch [119 / 2000]: Train Loss=0.1915, Val Cor=0.2829, Time=0.2326 sec
Epoch [120 / 2000]: Train Loss=0.1914, Val Cor=0.2825, Time=0.2326 sec
Epoch [121 / 2000]: Train Loss=0.1914, Val Cor=0.2829, Time=0.2329 sec
Epoch [122 / 2000]: Train Loss=0.1914, Val Cor=0.2820, Time=0.2327 sec
Epoch [123 / 2000]: Train Loss=0.1914, Val Cor=0.2817, Time=0.2329 sec
Epoch [124 / 2000]: Train Loss=0.1914, Val Cor=0.2815, Time=0.2326 sec
Epoch [125 / 2000]: Train Loss=0.1914, Val Cor=0.2820, Time=0.2326 sec
Epoch 

Epoch [226 / 2000]: Train Loss=0.1919, Val Cor=0.2746, Time=0.2327 sec
Epoch [227 / 2000]: Train Loss=0.1919, Val Cor=0.2750, Time=0.2330 sec
Epoch [228 / 2000]: Train Loss=0.1918, Val Cor=0.2750, Time=0.2330 sec
Epoch [229 / 2000]: Train Loss=0.1918, Val Cor=0.2752, Time=0.2330 sec
Epoch [230 / 2000]: Train Loss=0.1918, Val Cor=0.2751, Time=0.2325 sec
Epoch [231 / 2000]: Train Loss=0.1918, Val Cor=0.2758, Time=0.2327 sec
Epoch [232 / 2000]: Train Loss=0.1918, Val Cor=0.2758, Time=0.2322 sec
Epoch [233 / 2000]: Train Loss=0.1918, Val Cor=0.2761, Time=0.2329 sec
Epoch [234 / 2000]: Train Loss=0.1918, Val Cor=0.2757, Time=0.2326 sec
Epoch [235 / 2000]: Train Loss=0.1918, Val Cor=0.2757, Time=0.2325 sec
Epoch [236 / 2000]: Train Loss=0.1918, Val Cor=0.2759, Time=0.2335 sec
Epoch [237 / 2000]: Train Loss=0.1918, Val Cor=0.2769, Time=0.2338 sec
Epoch [238 / 2000]: Train Loss=0.1918, Val Cor=0.2779, Time=0.2329 sec
Epoch [239 / 2000]: Train Loss=0.1917, Val Cor=0.2785, Time=0.2327 sec
Epoch 

Epoch [341 / 2000]: Train Loss=0.1919, Val Cor=0.2817, Time=0.3290 sec
Epoch [342 / 2000]: Train Loss=0.1919, Val Cor=0.2819, Time=0.2331 sec
Epoch [343 / 2000]: Train Loss=0.1919, Val Cor=0.2821, Time=0.2326 sec
Epoch [344 / 2000]: Train Loss=0.1919, Val Cor=0.2820, Time=0.2329 sec
Epoch [345 / 2000]: Train Loss=0.1919, Val Cor=0.2819, Time=0.2328 sec
Epoch [346 / 2000]: Train Loss=0.1919, Val Cor=0.2820, Time=0.2326 sec
Epoch [347 / 2000]: Train Loss=0.1919, Val Cor=0.2820, Time=0.2325 sec
Epoch [348 / 2000]: Train Loss=0.1919, Val Cor=0.2821, Time=0.2327 sec
Epoch [349 / 2000]: Train Loss=0.1919, Val Cor=0.2821, Time=0.2327 sec
Epoch [350 / 2000]: Train Loss=0.1919, Val Cor=0.2823, Time=0.2326 sec
Epoch [351 / 2000]: Train Loss=0.1919, Val Cor=0.2822, Time=0.2325 sec
Epoch [352 / 2000]: Train Loss=0.1919, Val Cor=0.2824, Time=0.2326 sec
Epoch [353 / 2000]: Train Loss=0.1919, Val Cor=0.2824, Time=0.2329 sec
Epoch [354 / 2000]: Train Loss=0.1919, Val Cor=0.2822, Time=0.2327 sec
Epoch 

<All keys matched successfully>

In [47]:
cor_esm(model, val_dataset)

0.2909922784894928

In [54]:
sequences = df_train['sequence']
fitness_list = df_train['DMS_score'].tolist()

seq_train, seq_val, fitness_train, fitness_val = train_test_split(sequences, fitness_list, test_size=0.2, random_state=7)

train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

emb_dir = 'esm_embeddings_test'
train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, fitness_train, fitness2idx)
val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, fitness_val, fitness2idx)

Loading esm embeddings: 100%|██████████| 992/992 [00:04<00:00, 229.12it/s]
Loading esm embeddings: 100%|██████████| 248/248 [00:01<00:00, 186.58it/s]


In [55]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = RNNRegressor().to(device)
model, best_ckpt = train_model_ESM(model, train_dataset, val_dataset, epochs=3000, batch_size=32, lr=5e-4, patience=300, device=device)
model.load_state_dict(best_ckpt)

Epoch [1 / 3000]: Train Loss=6.2552, Val Cor=0.0558, Time=0.2694 sec
model_updated
Epoch [2 / 3000]: Train Loss=5.8030, Val Cor=0.1489, Time=0.1093 sec
model_updated
Epoch [3 / 3000]: Train Loss=5.8472, Val Cor=0.2020, Time=0.1083 sec
model_updated
Epoch [4 / 3000]: Train Loss=5.8279, Val Cor=0.2261, Time=0.1080 sec
model_updated
Epoch [5 / 3000]: Train Loss=5.8601, Val Cor=0.2230, Time=0.1083 sec
Epoch [6 / 3000]: Train Loss=5.8681, Val Cor=0.2686, Time=0.1081 sec
model_updated
Epoch [7 / 3000]: Train Loss=5.8730, Val Cor=0.2436, Time=0.1079 sec
Epoch [8 / 3000]: Train Loss=5.8383, Val Cor=0.2308, Time=0.1082 sec
Epoch [9 / 3000]: Train Loss=5.8579, Val Cor=0.2353, Time=0.1079 sec
Epoch [10 / 3000]: Train Loss=5.8012, Val Cor=0.2307, Time=0.1079 sec
Epoch [11 / 3000]: Train Loss=5.7931, Val Cor=0.2187, Time=0.1079 sec
Epoch [12 / 3000]: Train Loss=5.8069, Val Cor=0.2245, Time=0.1078 sec
Epoch [13 / 3000]: Train Loss=5.7862, Val Cor=0.2220, Time=0.1079 sec
Epoch [14 / 3000]: Train Loss

Epoch [109 / 3000]: Train Loss=5.0959, Val Cor=0.4371, Time=0.1080 sec
model_updated
Epoch [110 / 3000]: Train Loss=5.2392, Val Cor=0.4384, Time=0.1080 sec
model_updated
Epoch [111 / 3000]: Train Loss=5.1112, Val Cor=0.4484, Time=0.1077 sec
model_updated
Epoch [112 / 3000]: Train Loss=5.2505, Val Cor=0.4475, Time=0.1078 sec
Epoch [113 / 3000]: Train Loss=5.4193, Val Cor=0.4372, Time=0.1077 sec
Epoch [114 / 3000]: Train Loss=5.2468, Val Cor=0.4388, Time=0.1079 sec
Epoch [115 / 3000]: Train Loss=5.1568, Val Cor=0.4362, Time=0.1077 sec
Epoch [116 / 3000]: Train Loss=5.1336, Val Cor=0.4398, Time=0.1079 sec
Epoch [117 / 3000]: Train Loss=5.2261, Val Cor=0.4399, Time=0.1076 sec
Epoch [118 / 3000]: Train Loss=5.1420, Val Cor=0.4417, Time=0.1081 sec
Epoch [119 / 3000]: Train Loss=5.1320, Val Cor=0.4476, Time=0.1082 sec
Epoch [120 / 3000]: Train Loss=5.1369, Val Cor=0.4426, Time=0.1077 sec
Epoch [121 / 3000]: Train Loss=5.1625, Val Cor=0.4443, Time=0.1076 sec
Epoch [122 / 3000]: Train Loss=5.15

Epoch [221 / 3000]: Train Loss=4.9161, Val Cor=0.4597, Time=0.1080 sec
Epoch [222 / 3000]: Train Loss=4.9145, Val Cor=0.4590, Time=0.1078 sec
Epoch [223 / 3000]: Train Loss=4.9003, Val Cor=0.4600, Time=0.1079 sec
Epoch [224 / 3000]: Train Loss=4.8852, Val Cor=0.4609, Time=0.1079 sec
Epoch [225 / 3000]: Train Loss=4.8140, Val Cor=0.4601, Time=0.1080 sec
Epoch [226 / 3000]: Train Loss=4.8726, Val Cor=0.4627, Time=0.1079 sec
Epoch [227 / 3000]: Train Loss=4.9194, Val Cor=0.4617, Time=0.1080 sec
Epoch [228 / 3000]: Train Loss=4.9948, Val Cor=0.4638, Time=0.1081 sec
model_updated
Epoch [229 / 3000]: Train Loss=4.8954, Val Cor=0.4639, Time=0.1082 sec
model_updated
Epoch [230 / 3000]: Train Loss=5.0697, Val Cor=0.4641, Time=0.1077 sec
model_updated
Epoch [231 / 3000]: Train Loss=4.9290, Val Cor=0.4635, Time=0.1080 sec
Epoch [232 / 3000]: Train Loss=4.8747, Val Cor=0.4632, Time=0.1080 sec
Epoch [233 / 3000]: Train Loss=4.8378, Val Cor=0.4632, Time=0.1078 sec
Epoch [234 / 3000]: Train Loss=4.88

Epoch [335 / 3000]: Train Loss=4.7858, Val Cor=0.4654, Time=0.1082 sec
Epoch [336 / 3000]: Train Loss=4.9274, Val Cor=0.4705, Time=0.1081 sec
model_updated
Epoch [337 / 3000]: Train Loss=4.7801, Val Cor=0.4679, Time=0.1082 sec
Epoch [338 / 3000]: Train Loss=4.6855, Val Cor=0.4680, Time=0.1081 sec
Epoch [339 / 3000]: Train Loss=4.9331, Val Cor=0.4671, Time=0.1083 sec
Epoch [340 / 3000]: Train Loss=4.8754, Val Cor=0.4658, Time=0.1082 sec
Epoch [341 / 3000]: Train Loss=4.9138, Val Cor=0.4684, Time=0.1083 sec
Epoch [342 / 3000]: Train Loss=4.8502, Val Cor=0.4717, Time=0.1081 sec
model_updated
Epoch [343 / 3000]: Train Loss=4.7303, Val Cor=0.4711, Time=0.1082 sec
Epoch [344 / 3000]: Train Loss=4.7200, Val Cor=0.4691, Time=0.1081 sec
Epoch [345 / 3000]: Train Loss=4.9700, Val Cor=0.4688, Time=0.1082 sec
Epoch [346 / 3000]: Train Loss=4.7752, Val Cor=0.4672, Time=0.1080 sec
Epoch [347 / 3000]: Train Loss=4.7719, Val Cor=0.4694, Time=0.1080 sec
Epoch [348 / 3000]: Train Loss=4.9107, Val Cor=0.

Epoch [449 / 3000]: Train Loss=4.8037, Val Cor=0.4784, Time=0.1081 sec
Epoch [450 / 3000]: Train Loss=4.8290, Val Cor=0.4764, Time=0.1081 sec
Epoch [451 / 3000]: Train Loss=4.6147, Val Cor=0.4760, Time=0.1083 sec
Epoch [452 / 3000]: Train Loss=4.7000, Val Cor=0.4775, Time=0.1080 sec
Epoch [453 / 3000]: Train Loss=4.6943, Val Cor=0.4765, Time=0.1079 sec
Epoch [454 / 3000]: Train Loss=4.6690, Val Cor=0.4783, Time=0.1079 sec
Epoch [455 / 3000]: Train Loss=4.6459, Val Cor=0.4790, Time=0.1081 sec
Epoch [456 / 3000]: Train Loss=5.0259, Val Cor=0.4771, Time=0.1081 sec
Epoch [457 / 3000]: Train Loss=4.7722, Val Cor=0.4755, Time=0.1082 sec
Epoch [458 / 3000]: Train Loss=4.8014, Val Cor=0.4781, Time=0.1080 sec
Epoch [459 / 3000]: Train Loss=5.0603, Val Cor=0.4792, Time=0.1086 sec
model_updated
Epoch [460 / 3000]: Train Loss=4.7118, Val Cor=0.4796, Time=0.1088 sec
model_updated
Epoch [461 / 3000]: Train Loss=4.5261, Val Cor=0.4772, Time=0.1085 sec
Epoch [462 / 3000]: Train Loss=4.5588, Val Cor=0.

Epoch [565 / 3000]: Train Loss=4.7868, Val Cor=0.3481, Time=0.1082 sec
Epoch [566 / 3000]: Train Loss=4.7399, Val Cor=-0.4448, Time=0.1085 sec
Epoch [567 / 3000]: Train Loss=4.6720, Val Cor=-0.4148, Time=0.1083 sec
Epoch [568 / 3000]: Train Loss=4.8568, Val Cor=-0.0388, Time=0.1085 sec
Epoch [569 / 3000]: Train Loss=4.4656, Val Cor=-0.2740, Time=0.1084 sec
Epoch [570 / 3000]: Train Loss=4.6167, Val Cor=0.4825, Time=0.1084 sec
Epoch [571 / 3000]: Train Loss=4.3528, Val Cor=0.1747, Time=0.1084 sec
Epoch [572 / 3000]: Train Loss=4.5272, Val Cor=0.4814, Time=0.1085 sec
Epoch [573 / 3000]: Train Loss=4.6887, Val Cor=-0.2987, Time=0.1082 sec
Epoch [574 / 3000]: Train Loss=4.4980, Val Cor=0.1723, Time=0.1084 sec
Epoch [575 / 3000]: Train Loss=4.8083, Val Cor=-0.1475, Time=0.1082 sec
Epoch [576 / 3000]: Train Loss=4.7127, Val Cor=0.4819, Time=0.1084 sec
Epoch [577 / 3000]: Train Loss=4.6770, Val Cor=0.2439, Time=0.1081 sec
Epoch [578 / 3000]: Train Loss=4.5060, Val Cor=0.3800, Time=0.1083 sec


Epoch [681 / 3000]: Train Loss=4.3768, Val Cor=0.4341, Time=0.1078 sec
Epoch [682 / 3000]: Train Loss=4.4184, Val Cor=0.4523, Time=0.1080 sec
Epoch [683 / 3000]: Train Loss=4.4197, Val Cor=0.4876, Time=0.1080 sec
Epoch [684 / 3000]: Train Loss=4.1867, Val Cor=0.4849, Time=0.1079 sec
Epoch [685 / 3000]: Train Loss=4.3063, Val Cor=0.4827, Time=0.1080 sec
Epoch [686 / 3000]: Train Loss=4.2999, Val Cor=0.4236, Time=0.1080 sec
Epoch [687 / 3000]: Train Loss=4.4503, Val Cor=0.4349, Time=0.1077 sec
Epoch [688 / 3000]: Train Loss=4.4013, Val Cor=0.3351, Time=0.1082 sec
Epoch [689 / 3000]: Train Loss=4.4073, Val Cor=0.4829, Time=0.1079 sec
Epoch [690 / 3000]: Train Loss=4.4046, Val Cor=0.3761, Time=0.1080 sec
Epoch [691 / 3000]: Train Loss=4.4501, Val Cor=0.4912, Time=0.1079 sec
Epoch [692 / 3000]: Train Loss=4.3129, Val Cor=0.4814, Time=0.1081 sec
Epoch [693 / 3000]: Train Loss=4.4184, Val Cor=0.4921, Time=0.1078 sec
Epoch [694 / 3000]: Train Loss=4.4972, Val Cor=0.4818, Time=0.1081 sec
Epoch 

Epoch [797 / 3000]: Train Loss=4.2708, Val Cor=0.4672, Time=0.1076 sec
Epoch [798 / 3000]: Train Loss=4.1475, Val Cor=0.4722, Time=0.1077 sec
Epoch [799 / 3000]: Train Loss=4.1333, Val Cor=0.4841, Time=0.1076 sec
Epoch [800 / 3000]: Train Loss=4.1149, Val Cor=0.4102, Time=0.1077 sec
Epoch [801 / 3000]: Train Loss=4.0644, Val Cor=0.4145, Time=0.1076 sec
Epoch [802 / 3000]: Train Loss=4.4493, Val Cor=0.4680, Time=0.1077 sec
Epoch [803 / 3000]: Train Loss=4.7626, Val Cor=0.3135, Time=0.1077 sec
Epoch [804 / 3000]: Train Loss=4.6648, Val Cor=0.0768, Time=0.1077 sec
Epoch [805 / 3000]: Train Loss=4.4054, Val Cor=0.4023, Time=0.1076 sec
Epoch [806 / 3000]: Train Loss=4.3050, Val Cor=0.4205, Time=0.1080 sec
Epoch [807 / 3000]: Train Loss=4.2888, Val Cor=0.4838, Time=0.1077 sec
Epoch [808 / 3000]: Train Loss=4.3037, Val Cor=0.4171, Time=0.1078 sec
Epoch [809 / 3000]: Train Loss=4.5062, Val Cor=0.2169, Time=0.1078 sec
Epoch [810 / 3000]: Train Loss=4.5196, Val Cor=0.4559, Time=0.1078 sec
Epoch 

Epoch [913 / 3000]: Train Loss=4.4979, Val Cor=0.4633, Time=0.1076 sec
Epoch [914 / 3000]: Train Loss=4.3103, Val Cor=0.0395, Time=0.1079 sec
Epoch [915 / 3000]: Train Loss=4.3022, Val Cor=0.4724, Time=0.1076 sec
Epoch [916 / 3000]: Train Loss=4.1777, Val Cor=0.3304, Time=0.1079 sec
Epoch [917 / 3000]: Train Loss=4.0876, Val Cor=0.3153, Time=0.1076 sec
Epoch [918 / 3000]: Train Loss=4.1100, Val Cor=0.4794, Time=0.1078 sec
Epoch [919 / 3000]: Train Loss=4.0780, Val Cor=0.3977, Time=0.1077 sec
Epoch [920 / 3000]: Train Loss=4.1356, Val Cor=0.4310, Time=0.1078 sec
Epoch [921 / 3000]: Train Loss=4.1566, Val Cor=0.4730, Time=0.1077 sec
Epoch [922 / 3000]: Train Loss=4.1391, Val Cor=0.2899, Time=0.1080 sec
Epoch [923 / 3000]: Train Loss=4.1539, Val Cor=0.4751, Time=0.1083 sec
Epoch [924 / 3000]: Train Loss=4.2656, Val Cor=0.3972, Time=0.1076 sec
Epoch [925 / 3000]: Train Loss=4.4553, Val Cor=0.3827, Time=0.1078 sec
Epoch [926 / 3000]: Train Loss=4.3202, Val Cor=0.4847, Time=0.1076 sec
Epoch 

Epoch [1029 / 3000]: Train Loss=4.2761, Val Cor=0.4940, Time=0.1083 sec
Epoch [1030 / 3000]: Train Loss=4.0134, Val Cor=0.4695, Time=0.1082 sec
Epoch [1031 / 3000]: Train Loss=4.0880, Val Cor=0.4836, Time=0.1080 sec
Epoch [1032 / 3000]: Train Loss=4.0589, Val Cor=0.4901, Time=0.1079 sec
Epoch [1033 / 3000]: Train Loss=4.1889, Val Cor=0.4926, Time=0.1080 sec
Epoch [1034 / 3000]: Train Loss=4.4961, Val Cor=0.4932, Time=0.1082 sec
Epoch [1035 / 3000]: Train Loss=4.2736, Val Cor=0.4974, Time=0.1080 sec
Epoch [1036 / 3000]: Train Loss=4.0910, Val Cor=0.4342, Time=0.1082 sec
Epoch [1037 / 3000]: Train Loss=4.1168, Val Cor=0.4568, Time=0.1079 sec
Epoch [1038 / 3000]: Train Loss=4.0728, Val Cor=0.4664, Time=0.1081 sec
Epoch [1039 / 3000]: Train Loss=4.1002, Val Cor=0.3971, Time=0.1078 sec
Epoch [1040 / 3000]: Train Loss=3.9963, Val Cor=0.4434, Time=0.1082 sec
Epoch [1041 / 3000]: Train Loss=3.9052, Val Cor=0.4708, Time=0.1079 sec
Epoch [1042 / 3000]: Train Loss=3.9455, Val Cor=0.4520, Time=0.1

Epoch [1143 / 3000]: Train Loss=3.9553, Val Cor=0.4767, Time=0.1080 sec
Epoch [1144 / 3000]: Train Loss=3.8356, Val Cor=0.4806, Time=0.1079 sec
Epoch [1145 / 3000]: Train Loss=4.0640, Val Cor=0.3365, Time=0.1079 sec
Epoch [1146 / 3000]: Train Loss=3.9656, Val Cor=0.4628, Time=0.1083 sec
Epoch [1147 / 3000]: Train Loss=3.9491, Val Cor=0.4656, Time=0.1080 sec
Epoch [1148 / 3000]: Train Loss=3.8121, Val Cor=0.4329, Time=0.1082 sec
Epoch [1149 / 3000]: Train Loss=3.8165, Val Cor=0.4708, Time=0.1080 sec
Epoch [1150 / 3000]: Train Loss=3.7829, Val Cor=0.3588, Time=0.1081 sec
Epoch [1151 / 3000]: Train Loss=3.8108, Val Cor=0.4899, Time=0.1079 sec
Epoch [1152 / 3000]: Train Loss=4.1001, Val Cor=0.4692, Time=0.1081 sec
Epoch [1153 / 3000]: Train Loss=4.1134, Val Cor=0.4703, Time=0.1080 sec
Epoch [1154 / 3000]: Train Loss=4.1482, Val Cor=0.4634, Time=0.1081 sec
Epoch [1155 / 3000]: Train Loss=3.9497, Val Cor=0.4954, Time=0.1080 sec
Epoch [1156 / 3000]: Train Loss=4.0849, Val Cor=0.4409, Time=0.1

Epoch [1257 / 3000]: Train Loss=3.8961, Val Cor=0.3679, Time=0.1081 sec
Epoch [1258 / 3000]: Train Loss=3.9562, Val Cor=0.4205, Time=0.1080 sec
Epoch [1259 / 3000]: Train Loss=3.8185, Val Cor=0.0258, Time=0.1081 sec
Epoch [1260 / 3000]: Train Loss=3.9047, Val Cor=0.2577, Time=0.1081 sec
Epoch [1261 / 3000]: Train Loss=3.7618, Val Cor=0.4264, Time=0.1085 sec
Epoch [1262 / 3000]: Train Loss=3.8869, Val Cor=0.2615, Time=0.1086 sec
Epoch [1263 / 3000]: Train Loss=3.9431, Val Cor=-0.2282, Time=0.1086 sec
Epoch [1264 / 3000]: Train Loss=4.1055, Val Cor=0.3587, Time=0.1086 sec
Epoch [1265 / 3000]: Train Loss=4.1590, Val Cor=0.4845, Time=0.1083 sec
Epoch [1266 / 3000]: Train Loss=4.2086, Val Cor=0.3015, Time=0.1086 sec
Epoch [1267 / 3000]: Train Loss=3.9502, Val Cor=0.3712, Time=0.1084 sec
Epoch [1268 / 3000]: Train Loss=3.9785, Val Cor=-0.1815, Time=0.1084 sec
Epoch [1269 / 3000]: Train Loss=4.0540, Val Cor=0.4635, Time=0.1082 sec
Epoch [1270 / 3000]: Train Loss=4.0318, Val Cor=0.4382, Time=0

Epoch [1371 / 3000]: Train Loss=3.9232, Val Cor=0.4084, Time=0.1080 sec
Epoch [1372 / 3000]: Train Loss=4.0603, Val Cor=-0.1281, Time=0.1079 sec
Epoch [1373 / 3000]: Train Loss=5.1620, Val Cor=0.1285, Time=0.1083 sec
Epoch [1374 / 3000]: Train Loss=5.5684, Val Cor=0.4582, Time=0.1081 sec
Epoch [1375 / 3000]: Train Loss=5.3740, Val Cor=0.4648, Time=0.1082 sec
Epoch [1376 / 3000]: Train Loss=5.0860, Val Cor=0.4953, Time=0.1080 sec
Epoch [1377 / 3000]: Train Loss=4.3796, Val Cor=0.5052, Time=0.1079 sec
Epoch [1378 / 3000]: Train Loss=4.0194, Val Cor=0.5060, Time=0.1078 sec
Epoch [1379 / 3000]: Train Loss=4.0755, Val Cor=0.5072, Time=0.1080 sec
Epoch [1380 / 3000]: Train Loss=3.8464, Val Cor=0.4919, Time=0.1079 sec
Epoch [1381 / 3000]: Train Loss=3.6825, Val Cor=0.4927, Time=0.1079 sec
Epoch [1382 / 3000]: Train Loss=3.8696, Val Cor=0.5169, Time=0.1078 sec
model_updated
Epoch [1383 / 3000]: Train Loss=3.7531, Val Cor=0.5131, Time=0.1080 sec
Epoch [1384 / 3000]: Train Loss=3.8223, Val Cor=0

Epoch [1485 / 3000]: Train Loss=4.1460, Val Cor=0.4486, Time=0.1077 sec
Epoch [1486 / 3000]: Train Loss=3.7289, Val Cor=0.4375, Time=0.1075 sec
Epoch [1487 / 3000]: Train Loss=3.9204, Val Cor=0.5062, Time=0.1074 sec
Epoch [1488 / 3000]: Train Loss=3.6380, Val Cor=0.4817, Time=0.1076 sec
Epoch [1489 / 3000]: Train Loss=3.7110, Val Cor=0.2559, Time=0.1076 sec
Epoch [1490 / 3000]: Train Loss=3.5933, Val Cor=0.4797, Time=0.1077 sec
Epoch [1491 / 3000]: Train Loss=3.6562, Val Cor=0.5095, Time=0.1077 sec
Epoch [1492 / 3000]: Train Loss=3.4842, Val Cor=0.4878, Time=0.1075 sec
Epoch [1493 / 3000]: Train Loss=3.6056, Val Cor=0.4924, Time=0.1076 sec
Epoch [1494 / 3000]: Train Loss=3.6214, Val Cor=0.5142, Time=0.1073 sec
Epoch [1495 / 3000]: Train Loss=3.8866, Val Cor=0.0315, Time=0.1076 sec
Epoch [1496 / 3000]: Train Loss=3.9822, Val Cor=0.1056, Time=0.1075 sec
Epoch [1497 / 3000]: Train Loss=4.1572, Val Cor=0.4273, Time=0.1077 sec
Epoch [1498 / 3000]: Train Loss=3.7834, Val Cor=0.3278, Time=0.1

Epoch [1599 / 3000]: Train Loss=4.7739, Val Cor=0.5096, Time=0.1077 sec
Epoch [1600 / 3000]: Train Loss=4.3426, Val Cor=0.5060, Time=0.1073 sec
Epoch [1601 / 3000]: Train Loss=4.2025, Val Cor=0.4792, Time=0.1077 sec
Epoch [1602 / 3000]: Train Loss=4.4718, Val Cor=0.4955, Time=0.1076 sec
Epoch [1603 / 3000]: Train Loss=4.3043, Val Cor=0.4941, Time=0.1076 sec
Epoch [1604 / 3000]: Train Loss=4.2405, Val Cor=0.4972, Time=0.1076 sec
Epoch [1605 / 3000]: Train Loss=3.9931, Val Cor=0.4974, Time=0.1075 sec
Epoch [1606 / 3000]: Train Loss=4.0272, Val Cor=0.4865, Time=0.1076 sec
Epoch [1607 / 3000]: Train Loss=3.7407, Val Cor=0.4924, Time=0.1075 sec
Epoch [1608 / 3000]: Train Loss=3.8865, Val Cor=0.5010, Time=0.1076 sec
Epoch [1609 / 3000]: Train Loss=4.1951, Val Cor=0.4980, Time=0.1074 sec
Epoch [1610 / 3000]: Train Loss=3.9806, Val Cor=0.4876, Time=0.1076 sec
Epoch [1611 / 3000]: Train Loss=3.9484, Val Cor=0.4642, Time=0.1080 sec
Epoch [1612 / 3000]: Train Loss=4.0807, Val Cor=0.4782, Time=0.1

Epoch [1713 / 3000]: Train Loss=4.0005, Val Cor=-0.3289, Time=0.1077 sec
Epoch [1714 / 3000]: Train Loss=4.1360, Val Cor=0.4326, Time=0.1074 sec
Epoch [1715 / 3000]: Train Loss=3.9774, Val Cor=0.4629, Time=0.1084 sec
Epoch [1716 / 3000]: Train Loss=3.9879, Val Cor=0.4877, Time=0.1079 sec
Epoch [1717 / 3000]: Train Loss=3.6835, Val Cor=0.4957, Time=0.1080 sec
Epoch [1718 / 3000]: Train Loss=3.8397, Val Cor=0.4907, Time=0.1080 sec
Epoch [1719 / 3000]: Train Loss=4.0054, Val Cor=0.5055, Time=0.1078 sec
Epoch [1720 / 3000]: Train Loss=3.7243, Val Cor=0.4804, Time=0.1079 sec
Epoch [1721 / 3000]: Train Loss=3.6328, Val Cor=0.4913, Time=0.1077 sec
Epoch [1722 / 3000]: Train Loss=3.6010, Val Cor=0.5046, Time=0.1078 sec
Epoch [1723 / 3000]: Train Loss=6.6314, Val Cor=-0.2737, Time=0.1078 sec
Epoch [1724 / 3000]: Train Loss=5.9201, Val Cor=0.2684, Time=0.1079 sec
Epoch [1725 / 3000]: Train Loss=5.9188, Val Cor=0.3271, Time=0.1079 sec
Epoch [1726 / 3000]: Train Loss=5.8492, Val Cor=0.3131, Time=0

<All keys matched successfully>

In [56]:
cor_esm(model, val_dataset)

0.5204107282163262

In [59]:
sequences = df_train['sequence']
fitness_list = df_train['DMS_score'].tolist()

seq_train, seq_val, fitness_train, fitness_val = train_test_split(sequences, fitness_list, test_size=0.2, random_state=16)

train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

emb_dir = 'esm_embeddings_test'
train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, fitness_train, fitness2idx)
val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, fitness_val, fitness2idx)

Loading esm embeddings: 100%|██████████| 992/992 [00:00<00:00, 1343.92it/s]
Loading esm embeddings: 100%|██████████| 248/248 [00:00<00:00, 1409.90it/s]


In [91]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = TransformerRegressor().to(device)
model, best_ckpt = train_model_ESM(model, train_dataset, val_dataset, epochs=2000, batch_size=32, lr=5e-5, patience=200, device=device)
model.load_state_dict(best_ckpt)

Epoch [1 / 2000]: Train Loss=6.1966, Val Cor=0.3562, Time=0.3995 sec
Epoch [2 / 2000]: Train Loss=5.8942, Val Cor=0.3597, Time=0.3979 sec
Epoch [3 / 2000]: Train Loss=5.8698, Val Cor=0.2552, Time=0.3992 sec
Epoch [4 / 2000]: Train Loss=5.9208, Val Cor=0.3362, Time=0.3986 sec
Epoch [5 / 2000]: Train Loss=5.8688, Val Cor=0.3286, Time=0.3986 sec
Epoch [6 / 2000]: Train Loss=5.9122, Val Cor=0.3117, Time=0.3992 sec
Epoch [7 / 2000]: Train Loss=5.8723, Val Cor=0.3022, Time=0.3989 sec
Epoch [8 / 2000]: Train Loss=5.8230, Val Cor=0.3142, Time=0.3984 sec
Epoch [9 / 2000]: Train Loss=5.8465, Val Cor=0.3003, Time=0.3996 sec
Epoch [10 / 2000]: Train Loss=5.9294, Val Cor=0.2879, Time=0.3987 sec
Epoch [11 / 2000]: Train Loss=5.9595, Val Cor=0.2789, Time=0.3990 sec
Epoch [12 / 2000]: Train Loss=5.8492, Val Cor=0.2876, Time=0.3992 sec
Epoch [13 / 2000]: Train Loss=5.8379, Val Cor=0.3230, Time=0.3984 sec
Epoch [14 / 2000]: Train Loss=5.9288, Val Cor=0.2790, Time=0.3995 sec
Epoch [15 / 2000]: Train Loss

Epoch [118 / 2000]: Train Loss=5.8113, Val Cor=0.2282, Time=0.4005 sec
Epoch [119 / 2000]: Train Loss=5.7970, Val Cor=0.2210, Time=0.4009 sec
Epoch [120 / 2000]: Train Loss=5.8008, Val Cor=0.2166, Time=0.4008 sec
Epoch [121 / 2000]: Train Loss=5.7864, Val Cor=0.2510, Time=0.4010 sec
Epoch [122 / 2000]: Train Loss=5.8390, Val Cor=0.2457, Time=0.4012 sec
Epoch [123 / 2000]: Train Loss=5.8239, Val Cor=0.2247, Time=0.4008 sec
Epoch [124 / 2000]: Train Loss=5.8186, Val Cor=0.2043, Time=0.4011 sec
Epoch [125 / 2000]: Train Loss=5.8004, Val Cor=0.1832, Time=0.4019 sec
Epoch [126 / 2000]: Train Loss=5.7803, Val Cor=0.1923, Time=0.4004 sec
Epoch [127 / 2000]: Train Loss=5.7975, Val Cor=0.2423, Time=0.4017 sec
Epoch [128 / 2000]: Train Loss=5.8044, Val Cor=0.2453, Time=0.4006 sec
Epoch [129 / 2000]: Train Loss=5.8152, Val Cor=0.2594, Time=0.4006 sec
Epoch [130 / 2000]: Train Loss=5.8011, Val Cor=0.2647, Time=0.4010 sec
Epoch [131 / 2000]: Train Loss=5.8252, Val Cor=0.2847, Time=0.4012 sec
Epoch 

Epoch [234 / 2000]: Train Loss=5.7666, Val Cor=-0.0355, Time=0.4008 sec
Epoch [235 / 2000]: Train Loss=5.7906, Val Cor=0.1785, Time=0.4009 sec
Epoch [236 / 2000]: Train Loss=5.7891, Val Cor=-0.0629, Time=0.3994 sec
Epoch [237 / 2000]: Train Loss=5.8043, Val Cor=-0.0644, Time=0.4000 sec
Epoch [238 / 2000]: Train Loss=5.7791, Val Cor=0.0915, Time=0.4004 sec
Epoch [239 / 2000]: Train Loss=5.8389, Val Cor=0.1350, Time=0.4005 sec
Epoch [240 / 2000]: Train Loss=5.8121, Val Cor=-0.0448, Time=0.4017 sec
Epoch [241 / 2000]: Train Loss=5.7996, Val Cor=0.0885, Time=0.4006 sec
Epoch [242 / 2000]: Train Loss=5.8047, Val Cor=-0.0293, Time=0.4001 sec
Epoch [243 / 2000]: Train Loss=5.7735, Val Cor=0.0797, Time=0.4016 sec
Epoch [244 / 2000]: Train Loss=5.7954, Val Cor=-0.0331, Time=0.4002 sec
Epoch [245 / 2000]: Train Loss=5.7946, Val Cor=0.1655, Time=0.4008 sec
Epoch [246 / 2000]: Train Loss=5.8074, Val Cor=-0.0377, Time=0.4008 sec
Epoch [247 / 2000]: Train Loss=5.8140, Val Cor=0.0554, Time=0.5616 sec

Epoch [349 / 2000]: Train Loss=5.7969, Val Cor=0.0245, Time=0.4006 sec
Epoch [350 / 2000]: Train Loss=5.7925, Val Cor=0.0591, Time=0.4014 sec
Epoch [351 / 2000]: Train Loss=5.7767, Val Cor=0.1056, Time=0.4011 sec
Epoch [352 / 2000]: Train Loss=5.7780, Val Cor=0.0930, Time=0.3997 sec
Epoch [353 / 2000]: Train Loss=5.7785, Val Cor=0.0802, Time=0.4005 sec
Epoch [354 / 2000]: Train Loss=5.7942, Val Cor=0.1358, Time=0.4004 sec
Epoch [355 / 2000]: Train Loss=5.7999, Val Cor=-0.1536, Time=0.4001 sec
Epoch [356 / 2000]: Train Loss=5.7953, Val Cor=0.0131, Time=0.4007 sec
Epoch [357 / 2000]: Train Loss=5.8145, Val Cor=-0.0205, Time=0.3999 sec
Epoch [358 / 2000]: Train Loss=5.7795, Val Cor=0.0865, Time=0.5635 sec
Epoch [359 / 2000]: Train Loss=5.7824, Val Cor=0.0611, Time=0.4026 sec
Epoch [360 / 2000]: Train Loss=5.8049, Val Cor=0.0290, Time=0.4011 sec
Epoch [361 / 2000]: Train Loss=5.7943, Val Cor=0.0455, Time=0.4011 sec
Epoch [362 / 2000]: Train Loss=5.7809, Val Cor=-0.0755, Time=0.4013 sec
Epo

<All keys matched successfully>

In [92]:
cor_esm(model, val_dataset)

0.3672901054568896

## Ensemble Methods

In [112]:
X_train,X_val,Y_train,Y_val = train_test_split(X, Y, test_size=0.2, random_state=50)

In [105]:
def evaluate_GB(model, X_test, Y_test):
    preds = model.predict(X_test)
    corr = spearmanr(preds, Y_test)
    return (corr.statistic)


In [113]:
print(X_train.shape, X_val.shape,Y_train.shape,Y_val.shape)

(992, 3) (248, 3) (992,) (248,)


In [40]:
X_val = np.zeros(shape=(val_dataset.__len__(), 1280))
Y_val = np.zeros(shape=(val_dataset.__len__(),))
for i in range(0, val_dataset.__len__()):
    X_row = val_dataset.__getitem__(i)[0].detach().cpu().numpy()#.flatten()
    Y_row = val_dataset.__getitem__(i)[1].detach().cpu().numpy()
    X_val[i] = X_row
    Y_val[i] = Y_row
print(X_val.shape)
print(Y_val.shape)
print(X_val)
print(Y_val)
    
    

(228, 1280)
(228,)
[[ 0.02760313 -0.07387351  0.02632519 ... -0.17728937  0.00693432
   0.14927322]
 [ 0.02774767 -0.07886593  0.02485685 ... -0.17778333  0.00544644
   0.15013547]
 [ 0.02778614 -0.07698066  0.02444576 ... -0.17596288  0.00500626
   0.14988568]
 ...
 [ 0.02909854 -0.0790172   0.02478098 ... -0.17759982  0.00528359
   0.15115647]
 [ 0.0280756  -0.07748885  0.02418508 ... -0.17688388  0.00527793
   0.15222098]
 [ 0.02687751 -0.07627581  0.02395615 ... -0.176063    0.00379012
   0.14865206]]
[0.55010003 0.1777     0.61879998 0.0908     0.30160001 0.1382
 0.0171     0.077      0.1856     0.1525     0.26300001 0.25729999
 0.33700001 0.185      0.0439     0.1235     0.59740001 0.0214
 0.36559999 0.36219999 0.63550001 0.36700001 0.0163     0.3096
 0.1179     0.0391     0.45480001 0.6462     0.94150001 0.0333
 0.1497     0.0317     0.2112     0.1019     0.1373     0.0999
 0.84140003 0.045      0.0232     0.33419999 0.0585     0.25389999
 0.0579     0.13240001 0.3617     0.0676

In [41]:
X_train = np.zeros(shape=(train_dataset.__len__(), 1280))
Y_train = np.zeros(shape=(train_dataset.__len__(),))
for i in range(0, train_dataset.__len__()):
    X_row = train_dataset.__getitem__(i)[0].detach().cpu().numpy()#.flatten()
    Y_row = train_dataset.__getitem__(i)[1].detach().cpu().numpy()
    X_train[i] = X_row
    Y_train[i] = Y_row
print(X_train.shape)
print(Y_train.shape)
print(X_train)
print(Y_train)

(912, 1280)
(912,)
[[ 0.02758145 -0.07752989  0.0261986  ... -0.17473149  0.00351624
   0.15196908]
 [ 0.02670845 -0.07564262  0.02456556 ... -0.17363183  0.00384603
   0.1469208 ]
 [ 0.02794668 -0.08093263  0.02550585 ... -0.17754811  0.00492499
   0.14913122]
 ...
 [ 0.0253938  -0.07835097  0.02320869 ... -0.17715487  0.00485733
   0.14893161]
 [ 0.02810415 -0.07836012  0.02577844 ... -0.1771356   0.00528115
   0.14885642]
 [ 0.02643437 -0.0775612   0.02527846 ... -0.17677799  0.0042864
   0.15230393]]
[0.2078     0.14470001 0.1725     0.50830001 0.59219998 0.52969998
 0.0533     0.66839999 0.026      0.3899     0.0312     0.0386
 0.4084     0.034      0.13850001 0.32370001 0.3784     0.0268
 0.28049999 0.0234     0.25920001 0.21529999 0.83170003 0.13160001
 0.44080001 0.28850001 0.13150001 0.1215     0.2374     0.0426
 0.1168     0.2931     0.0883     0.24789999 0.1787     0.1032
 0.0427     0.0527     0.0211     0.0219     0.35569999 0.1842
 0.57770002 0.0332     0.0284     0.50629

In [126]:
learning_rate_range = [0.01,0.02,0.03,0.04,0.05,0.06,0.07]
n_estimators_range = [5,10,20,50,100,200]
depth_range = [1,2,3,4,5]
min_samples_split_range = [2,5]
min_samples_leaf_range = [1,2]
best_corr = -1
counter = 0
for lr in learning_rate_range:
    for n_estimator in n_estimators_range:
        for depth in depth_range:
            for min_samples_split in min_samples_split_range:
                for min_samples_leaf in min_samples_leaf_range:
                    model = GradientBoostingRegressor(learning_rate=lr,n_estimators=n_estimator,
                                                       max_depth=depth,min_samples_split=min_samples_split, 
                                                       min_samples_leaf=min_samples_leaf)
                    model.fit(X_val, Y_val)
                    counter = counter + 1
                    val_corr = evaluate_GB(model, X_train, Y_train)
                    if val_corr > best_corr:
                        best_corr = val_corr
                        print(f"""Current Best Hyperparameters: lr: {lr}, trees: {n_estimator}, depth: {depth}, 
                        split: {min_samples_split}, leaf: {min_samples_leaf} | Current Best Val Corr: {best_corr}""")
                    else:
                        print(counter)
                        
                    

Current Best Hyperparameters: lr: 0.01, trees: 5, depth: 1, 
                        split: 2, leaf: 1 | Current Best Val Corr: 0.5152983227525537
2
3
4
Current Best Hyperparameters: lr: 0.01, trees: 5, depth: 2, 
                        split: 2, leaf: 1 | Current Best Val Corr: 0.735909032134385
6
7
8
Current Best Hyperparameters: lr: 0.01, trees: 5, depth: 3, 
                        split: 2, leaf: 1 | Current Best Val Corr: 0.7898072055309913
Current Best Hyperparameters: lr: 0.01, trees: 5, depth: 3, 
                        split: 2, leaf: 2 | Current Best Val Corr: 0.7898505575786545
11
12
Current Best Hyperparameters: lr: 0.01, trees: 5, depth: 4, 
                        split: 2, leaf: 1 | Current Best Val Corr: 0.7977865033980587
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
Current Best Hyperparameters: lr: 0.01, trees: 10, depth: 4, 
                        split: 2, leaf: 1 | Current Best Val Corr: 0.8011653311837318
34
Current Best Hyperparameters: lr: 0.01, 

In [109]:
X = np.zeros(shape=(total_dataset.__len__(), 3))
Y = np.zeros(shape=(total_dataset.__len__(),))
for i in range(0, train_dataset.__len__()):
    X_row = [preds_MLP[i], preds_RNN[i], preds_transformer[i]]
    Y_row = total_dataset.__getitem__(i)[1].detach().cpu().numpy()
    X[i] = X_row
    Y[i] = Y_row
print(X)
print(Y)

[[0.12143923 0.07700172 0.18719986]
 [0.12123349 0.13049431 0.18720004]
 [0.12169383 0.06313442 0.18719965]
 ...
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 [0.         0.         0.        ]]
[0.273      0.28569999 0.21529999 ... 0.         0.         0.        ]


In [110]:
def min_max_scale_by_column(arr):
    mins = arr.min(axis=0)
    maxs = arr.max(axis=0)
    return (arr - mins) / (maxs - mins)

In [111]:
X = min_max_scale_by_column(X)
print(X)

[[0.98754399 0.30408035 0.99999363]
 [0.98587097 0.37663063 0.99999459]
 [0.98961441 0.28527258 0.99999252]
 ...
 [0.         0.1996454  0.        ]
 [0.         0.1996454  0.        ]
 [0.         0.1996454  0.        ]]


In [127]:
model_ensemble = GradientBoostingRegressor(learning_rate=0.01,n_estimators=10,
                                        max_depth=4,min_samples_split=5, 
                                        min_samples_leaf=1)
model_ensemble.fit(X_train, Y_train)

In [128]:
model_ensemble_full = GradientBoostingRegressor(learning_rate=0.01,n_estimators=10,
                                        max_depth=4,min_samples_split=5, 
                                        min_samples_leaf=1)
model_ensemble_full.fit(X, Y)

In [17]:
def bootstrap_dataset(df_train, emb_dir='esm_embeddings_test', seed=0, test_split=0.1):
    sequences = df_train['sequence']
    fitness_list = df_train['DMS_score'].tolist()

    seq_train, seq_val, fitness_train, fitness_val = train_test_split(sequences, fitness_list, test_size=test_split, random_state=seed)

    train_seq2name = train_seq2name = {seq: f'seq_{i}' for i, seq in enumerate(sequences)}
    fitness2idx = {fitness: fitness for idx, fitness in enumerate(fitness_list)}

    emb_dir = emb_dir
    train_dataset = ProteinESMDataset(seq_train, train_seq2name, emb_dir, fitness_train, fitness2idx)
    val_dataset = ProteinESMDataset(seq_val, train_seq2name, emb_dir, fitness_val, fitness2idx)
    
    return (train_dataset, val_dataset)

In [18]:
train_sets = []
val_sets = []
for i in range(5):
    train_set, val_set = bootstrap_dataset(df_train, seed=i,test_split=0.1)
    train_sets.append(train_set)
    val_sets.append(val_set)

Loading esm embeddings: 100%|██████████| 1206/1206 [00:09<00:00, 127.14it/s]
Loading esm embeddings: 100%|██████████| 134/134 [00:00<00:00, 140.21it/s]
Loading esm embeddings: 100%|██████████| 1206/1206 [00:03<00:00, 393.51it/s]
Loading esm embeddings: 100%|██████████| 134/134 [00:00<00:00, 1607.58it/s]
Loading esm embeddings: 100%|██████████| 1206/1206 [00:02<00:00, 438.95it/s] 
Loading esm embeddings: 100%|██████████| 134/134 [00:00<00:00, 1655.41it/s]
Loading esm embeddings: 100%|██████████| 1206/1206 [00:00<00:00, 1652.40it/s]
Loading esm embeddings: 100%|██████████| 134/134 [00:00<00:00, 1626.61it/s]
Loading esm embeddings: 100%|██████████| 1206/1206 [00:00<00:00, 1551.18it/s]
Loading esm embeddings: 100%|██████████| 134/134 [00:00<00:00, 1751.64it/s]


In [19]:
def train_RNN_ensemble(input_dim, hidden_dim, epochs, batch, lr, patience, train_data, val_data, seed):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = RNNRegressor().to(device)
    model, best_ckpt = train_model_ESM(model, train_data, val_data, epochs=epochs, batch_size=batch, lr=lr, patience=200, device=device)
    model.load_state_dict(best_ckpt)
    return (model)

In [20]:
RNN_ensemble = []
for i in range(5):
    print(f"Training RNN model {i}:")
    print("-"*127)
    RNN = train_RNN_ensemble(1280, 128, 2000, 32, 5e-4, 200, train_sets[i], val_sets[i], i)
    RNN_ensemble.append(RNN)

    

Training RNN model 0:
-------------------------------------------------------------------------------------------------------------------------------
Epoch [1 / 2000]: Train Loss=8.9274, Val Cor=0.0718, Time=2.0089 sec
Epoch [2 / 2000]: Train Loss=8.1782, Val Cor=0.2308, Time=0.2780 sec
Epoch [3 / 2000]: Train Loss=8.1082, Val Cor=0.2825, Time=0.2763 sec
Epoch [4 / 2000]: Train Loss=8.1835, Val Cor=0.2898, Time=0.2755 sec
Epoch [5 / 2000]: Train Loss=8.1175, Val Cor=0.3270, Time=0.2756 sec
Epoch [6 / 2000]: Train Loss=8.1644, Val Cor=0.3076, Time=0.2757 sec
Epoch [7 / 2000]: Train Loss=8.1114, Val Cor=0.3103, Time=0.2757 sec
Epoch [8 / 2000]: Train Loss=8.1985, Val Cor=0.3352, Time=0.2757 sec
Epoch [9 / 2000]: Train Loss=8.0887, Val Cor=0.3511, Time=0.2749 sec
Epoch [10 / 2000]: Train Loss=8.1345, Val Cor=0.3569, Time=0.2742 sec
Epoch [11 / 2000]: Train Loss=8.1127, Val Cor=0.3576, Time=0.2828 sec
Epoch [12 / 2000]: Train Loss=8.0897, Val Cor=0.3546, Time=0.2741 sec
Epoch [13 / 2000]: 

Epoch [116 / 2000]: Train Loss=6.4691, Val Cor=0.4958, Time=0.2732 sec
Epoch [117 / 2000]: Train Loss=6.5635, Val Cor=0.4980, Time=0.2732 sec
Epoch [118 / 2000]: Train Loss=6.5396, Val Cor=0.4996, Time=0.2729 sec
Epoch [119 / 2000]: Train Loss=6.4711, Val Cor=0.4978, Time=0.2729 sec
Epoch [120 / 2000]: Train Loss=6.6828, Val Cor=0.5000, Time=0.2729 sec
Epoch [121 / 2000]: Train Loss=6.3643, Val Cor=0.5022, Time=0.2730 sec
Epoch [122 / 2000]: Train Loss=6.7714, Val Cor=0.4978, Time=0.2733 sec
Epoch [123 / 2000]: Train Loss=6.4213, Val Cor=0.4997, Time=0.2732 sec
Epoch [124 / 2000]: Train Loss=6.6804, Val Cor=0.4973, Time=0.2732 sec
Epoch [125 / 2000]: Train Loss=6.3936, Val Cor=0.4957, Time=0.2731 sec
Epoch [126 / 2000]: Train Loss=6.4974, Val Cor=0.4975, Time=0.2732 sec
Epoch [127 / 2000]: Train Loss=6.9871, Val Cor=0.4979, Time=0.2735 sec
Epoch [128 / 2000]: Train Loss=6.5239, Val Cor=0.4952, Time=0.2731 sec
Epoch [129 / 2000]: Train Loss=6.3465, Val Cor=0.4977, Time=0.2799 sec
Epoch 

Epoch [232 / 2000]: Train Loss=6.1534, Val Cor=0.5253, Time=0.2942 sec
Epoch [233 / 2000]: Train Loss=6.0050, Val Cor=0.5248, Time=0.2887 sec
Epoch [234 / 2000]: Train Loss=6.3738, Val Cor=0.5242, Time=0.2928 sec
Epoch [235 / 2000]: Train Loss=6.2917, Val Cor=0.5210, Time=0.2781 sec
Epoch [236 / 2000]: Train Loss=5.9928, Val Cor=0.5214, Time=0.2757 sec
Epoch [237 / 2000]: Train Loss=6.0266, Val Cor=0.5236, Time=0.2749 sec
Epoch [238 / 2000]: Train Loss=5.9887, Val Cor=0.5252, Time=0.2756 sec
Epoch [239 / 2000]: Train Loss=6.1785, Val Cor=0.5228, Time=0.2771 sec
Epoch [240 / 2000]: Train Loss=6.1606, Val Cor=0.5225, Time=0.2761 sec
Epoch [241 / 2000]: Train Loss=5.9410, Val Cor=0.5223, Time=0.2762 sec
Epoch [242 / 2000]: Train Loss=6.4262, Val Cor=0.5225, Time=0.2908 sec
Epoch [243 / 2000]: Train Loss=6.0335, Val Cor=0.5231, Time=0.3141 sec
Epoch [244 / 2000]: Train Loss=5.7443, Val Cor=0.5231, Time=0.3102 sec
Epoch [245 / 2000]: Train Loss=6.1447, Val Cor=0.5252, Time=0.3276 sec
Epoch 

Epoch [348 / 2000]: Train Loss=5.7245, Val Cor=0.5387, Time=0.2722 sec
Epoch [349 / 2000]: Train Loss=5.6283, Val Cor=0.5310, Time=0.2720 sec
Epoch [350 / 2000]: Train Loss=5.5754, Val Cor=0.5371, Time=0.2721 sec
Epoch [351 / 2000]: Train Loss=5.4789, Val Cor=0.5144, Time=0.2722 sec
Epoch [352 / 2000]: Train Loss=5.8913, Val Cor=0.4919, Time=0.2721 sec
Epoch [353 / 2000]: Train Loss=5.8560, Val Cor=0.5350, Time=0.2722 sec
Epoch [354 / 2000]: Train Loss=5.6651, Val Cor=0.5389, Time=0.2721 sec
Epoch [355 / 2000]: Train Loss=5.8417, Val Cor=0.5383, Time=0.2719 sec
Epoch [356 / 2000]: Train Loss=6.1205, Val Cor=-0.4375, Time=0.2721 sec
Epoch [357 / 2000]: Train Loss=6.0605, Val Cor=-0.2893, Time=0.2775 sec
Epoch [358 / 2000]: Train Loss=6.0399, Val Cor=0.5209, Time=0.2723 sec
Epoch [359 / 2000]: Train Loss=5.9397, Val Cor=0.4620, Time=0.2721 sec
Epoch [360 / 2000]: Train Loss=5.8899, Val Cor=0.5327, Time=0.2720 sec
Epoch [361 / 2000]: Train Loss=5.9443, Val Cor=0.4214, Time=0.2721 sec
Epoc

Epoch [464 / 2000]: Train Loss=6.0581, Val Cor=0.5698, Time=0.2731 sec
Epoch [465 / 2000]: Train Loss=5.6277, Val Cor=0.5279, Time=0.2727 sec
Epoch [466 / 2000]: Train Loss=5.7057, Val Cor=0.5495, Time=0.2723 sec
Epoch [467 / 2000]: Train Loss=5.6405, Val Cor=0.5710, Time=0.2732 sec
Epoch [468 / 2000]: Train Loss=5.9102, Val Cor=0.5652, Time=0.2731 sec
Epoch [469 / 2000]: Train Loss=5.7314, Val Cor=0.4948, Time=0.2748 sec
Epoch [470 / 2000]: Train Loss=5.8253, Val Cor=0.5663, Time=0.2728 sec
Epoch [471 / 2000]: Train Loss=5.4509, Val Cor=0.5647, Time=0.2728 sec
Epoch [472 / 2000]: Train Loss=5.6447, Val Cor=0.5681, Time=0.2723 sec
Epoch [473 / 2000]: Train Loss=5.5187, Val Cor=0.5699, Time=0.2728 sec
Epoch [474 / 2000]: Train Loss=5.5656, Val Cor=0.5705, Time=0.2731 sec
Epoch [475 / 2000]: Train Loss=5.4752, Val Cor=0.5640, Time=0.2730 sec
Epoch [476 / 2000]: Train Loss=5.6437, Val Cor=0.5701, Time=0.2740 sec
Epoch [477 / 2000]: Train Loss=5.5893, Val Cor=0.5746, Time=0.2728 sec
Epoch 

Epoch [580 / 2000]: Train Loss=5.9220, Val Cor=0.5932, Time=0.2729 sec
Epoch [581 / 2000]: Train Loss=5.7634, Val Cor=0.5791, Time=0.2733 sec
Epoch [582 / 2000]: Train Loss=5.4715, Val Cor=0.5252, Time=0.2742 sec
Epoch [583 / 2000]: Train Loss=5.7112, Val Cor=0.5984, Time=0.2735 sec
Epoch [584 / 2000]: Train Loss=5.4730, Val Cor=0.5840, Time=0.2728 sec
Epoch [585 / 2000]: Train Loss=5.6531, Val Cor=0.5797, Time=0.2729 sec
Epoch [586 / 2000]: Train Loss=5.8191, Val Cor=0.4914, Time=0.2735 sec
Epoch [587 / 2000]: Train Loss=5.7553, Val Cor=0.5824, Time=0.2735 sec
Epoch [588 / 2000]: Train Loss=5.5726, Val Cor=0.5569, Time=0.2728 sec
Epoch [589 / 2000]: Train Loss=5.3669, Val Cor=0.5822, Time=0.2728 sec
Epoch [590 / 2000]: Train Loss=5.3623, Val Cor=0.4003, Time=0.2728 sec
Epoch [591 / 2000]: Train Loss=5.3792, Val Cor=0.5944, Time=0.2729 sec
Epoch [592 / 2000]: Train Loss=5.3581, Val Cor=0.5876, Time=0.2728 sec
Epoch [593 / 2000]: Train Loss=5.5470, Val Cor=0.5918, Time=0.1846 sec
Epoch 

Epoch [696 / 2000]: Train Loss=5.2730, Val Cor=0.5954, Time=0.1621 sec
Epoch [697 / 2000]: Train Loss=5.4321, Val Cor=0.6053, Time=0.1619 sec
Epoch [698 / 2000]: Train Loss=5.6353, Val Cor=-0.0956, Time=0.1624 sec
Epoch [699 / 2000]: Train Loss=5.3007, Val Cor=0.5753, Time=0.1622 sec
Epoch [700 / 2000]: Train Loss=5.3404, Val Cor=0.6045, Time=0.1624 sec
Epoch [701 / 2000]: Train Loss=5.3853, Val Cor=0.5518, Time=0.1641 sec
Epoch [702 / 2000]: Train Loss=5.1609, Val Cor=-0.3310, Time=0.1607 sec
Epoch [703 / 2000]: Train Loss=5.3862, Val Cor=0.5902, Time=0.1604 sec
Epoch [704 / 2000]: Train Loss=5.2871, Val Cor=0.6035, Time=0.1604 sec
Epoch [705 / 2000]: Train Loss=5.2148, Val Cor=0.6128, Time=0.1607 sec
Epoch [706 / 2000]: Train Loss=5.6235, Val Cor=-0.0752, Time=0.1607 sec
Epoch [707 / 2000]: Train Loss=5.3559, Val Cor=0.4833, Time=0.1609 sec
Epoch [708 / 2000]: Train Loss=5.1633, Val Cor=0.6111, Time=0.1610 sec
Epoch [709 / 2000]: Train Loss=5.6717, Val Cor=-0.5113, Time=0.1609 sec
Ep

Epoch [812 / 2000]: Train Loss=5.0610, Val Cor=0.6068, Time=0.1634 sec
Epoch [813 / 2000]: Train Loss=5.1628, Val Cor=0.6109, Time=0.1627 sec
Epoch [814 / 2000]: Train Loss=5.4173, Val Cor=0.6162, Time=0.1619 sec
Epoch [815 / 2000]: Train Loss=5.1616, Val Cor=0.5648, Time=0.1616 sec
Epoch [816 / 2000]: Train Loss=5.1132, Val Cor=0.6243, Time=0.1603 sec
Epoch [817 / 2000]: Train Loss=5.1561, Val Cor=0.5081, Time=0.1604 sec
Epoch [818 / 2000]: Train Loss=5.1424, Val Cor=0.6084, Time=0.1613 sec
Epoch [819 / 2000]: Train Loss=5.5523, Val Cor=0.2911, Time=0.1609 sec
Epoch [820 / 2000]: Train Loss=5.5581, Val Cor=0.5604, Time=0.1614 sec
Epoch [821 / 2000]: Train Loss=5.2009, Val Cor=0.6154, Time=0.1623 sec
Epoch [822 / 2000]: Train Loss=5.2619, Val Cor=-0.5753, Time=0.1630 sec
Epoch [823 / 2000]: Train Loss=5.2576, Val Cor=0.6044, Time=0.1625 sec
Epoch [824 / 2000]: Train Loss=5.3447, Val Cor=0.5981, Time=0.1631 sec
Epoch [825 / 2000]: Train Loss=5.1148, Val Cor=0.5952, Time=0.1621 sec
Epoch

Epoch [928 / 2000]: Train Loss=5.0382, Val Cor=0.3411, Time=0.1636 sec
Epoch [929 / 2000]: Train Loss=5.0983, Val Cor=0.1742, Time=0.1668 sec
Epoch [930 / 2000]: Train Loss=5.1253, Val Cor=0.5989, Time=0.1629 sec
Epoch [931 / 2000]: Train Loss=4.9660, Val Cor=0.6114, Time=0.1615 sec
Epoch [932 / 2000]: Train Loss=4.9530, Val Cor=0.5612, Time=0.1610 sec
Epoch [933 / 2000]: Train Loss=4.9817, Val Cor=0.3339, Time=0.1614 sec
Epoch [934 / 2000]: Train Loss=4.9339, Val Cor=0.5920, Time=0.1619 sec
Epoch [935 / 2000]: Train Loss=4.9598, Val Cor=0.6156, Time=0.1616 sec
Epoch [936 / 2000]: Train Loss=5.3957, Val Cor=0.5266, Time=0.1622 sec
Epoch [937 / 2000]: Train Loss=5.2556, Val Cor=0.3198, Time=0.1622 sec
Epoch [938 / 2000]: Train Loss=5.1539, Val Cor=0.5837, Time=0.1633 sec
Epoch [939 / 2000]: Train Loss=5.1682, Val Cor=-0.5992, Time=0.1630 sec
Epoch [940 / 2000]: Train Loss=5.3414, Val Cor=0.5191, Time=0.1635 sec
Epoch [941 / 2000]: Train Loss=5.0079, Val Cor=0.6171, Time=0.1634 sec
Epoch

Epoch [1044 / 2000]: Train Loss=4.9907, Val Cor=0.6130, Time=0.1604 sec
Epoch [1045 / 2000]: Train Loss=5.3093, Val Cor=0.6314, Time=0.1612 sec
Epoch [1046 / 2000]: Train Loss=5.2666, Val Cor=0.4809, Time=0.1610 sec
Epoch [1047 / 2000]: Train Loss=5.5401, Val Cor=0.4371, Time=0.1622 sec
Epoch [1048 / 2000]: Train Loss=5.1890, Val Cor=-0.0991, Time=0.1624 sec
Epoch [1049 / 2000]: Train Loss=4.9460, Val Cor=0.6021, Time=0.1617 sec
Epoch [1050 / 2000]: Train Loss=4.9418, Val Cor=0.6143, Time=0.1625 sec
Epoch [1051 / 2000]: Train Loss=5.2387, Val Cor=0.3299, Time=0.1620 sec
Epoch [1052 / 2000]: Train Loss=5.1695, Val Cor=0.4671, Time=0.1604 sec
Epoch [1053 / 2000]: Train Loss=5.0233, Val Cor=0.6270, Time=0.1599 sec
Epoch [1054 / 2000]: Train Loss=5.1613, Val Cor=0.6070, Time=0.1600 sec
Epoch [1055 / 2000]: Train Loss=4.9628, Val Cor=0.6018, Time=0.1604 sec
Epoch [1056 / 2000]: Train Loss=4.8029, Val Cor=0.6243, Time=0.1610 sec
Epoch [1057 / 2000]: Train Loss=5.2589, Val Cor=0.4783, Time=0.

Epoch [1158 / 2000]: Train Loss=5.1732, Val Cor=0.6342, Time=0.1605 sec
Epoch [1159 / 2000]: Train Loss=5.2995, Val Cor=0.5831, Time=0.1610 sec
Epoch [1160 / 2000]: Train Loss=4.9855, Val Cor=0.6105, Time=0.1615 sec
Epoch [1161 / 2000]: Train Loss=5.0302, Val Cor=0.5976, Time=0.1614 sec
Epoch [1162 / 2000]: Train Loss=4.7416, Val Cor=0.6009, Time=0.1609 sec
Epoch [1163 / 2000]: Train Loss=4.8583, Val Cor=0.5446, Time=0.1626 sec
Epoch [1164 / 2000]: Train Loss=5.0729, Val Cor=-0.0273, Time=0.1632 sec
Epoch [1165 / 2000]: Train Loss=4.8650, Val Cor=0.5854, Time=0.1627 sec
Epoch [1166 / 2000]: Train Loss=4.9829, Val Cor=0.6111, Time=0.1635 sec
Epoch [1167 / 2000]: Train Loss=5.0980, Val Cor=0.6049, Time=0.1629 sec
Epoch [1168 / 2000]: Train Loss=5.1950, Val Cor=0.5478, Time=0.1613 sec
Epoch [1169 / 2000]: Train Loss=4.8679, Val Cor=0.6149, Time=0.1608 sec
Epoch [1170 / 2000]: Train Loss=4.7730, Val Cor=0.6026, Time=0.1622 sec
Epoch [1171 / 2000]: Train Loss=4.9045, Val Cor=0.6099, Time=0.

Epoch [1272 / 2000]: Train Loss=4.7045, Val Cor=0.1083, Time=0.1628 sec
Epoch [1273 / 2000]: Train Loss=4.8523, Val Cor=0.4446, Time=0.1625 sec
Epoch [1274 / 2000]: Train Loss=4.7079, Val Cor=0.5009, Time=0.1614 sec
Epoch [1275 / 2000]: Train Loss=4.9958, Val Cor=-0.1476, Time=0.1625 sec
Epoch [1276 / 2000]: Train Loss=4.8420, Val Cor=0.5355, Time=0.1638 sec
Epoch [1277 / 2000]: Train Loss=4.8255, Val Cor=0.4441, Time=0.1630 sec
Epoch [1278 / 2000]: Train Loss=4.8636, Val Cor=0.6167, Time=0.1633 sec
Epoch [1279 / 2000]: Train Loss=4.8338, Val Cor=0.6102, Time=0.1641 sec
Epoch [1280 / 2000]: Train Loss=4.6768, Val Cor=0.5949, Time=0.1642 sec
Epoch [1281 / 2000]: Train Loss=4.7421, Val Cor=0.6024, Time=0.1634 sec
Epoch [1282 / 2000]: Train Loss=4.9353, Val Cor=0.5137, Time=0.1616 sec
Epoch [1283 / 2000]: Train Loss=4.9641, Val Cor=0.3733, Time=0.1611 sec
Epoch [1284 / 2000]: Train Loss=5.1851, Val Cor=0.5043, Time=0.1616 sec
Epoch [1285 / 2000]: Train Loss=4.8512, Val Cor=0.3725, Time=0.

Epoch [1386 / 2000]: Train Loss=4.8738, Val Cor=0.5170, Time=0.1624 sec
Epoch [1387 / 2000]: Train Loss=4.9794, Val Cor=0.4897, Time=0.1622 sec
Epoch [1388 / 2000]: Train Loss=5.0940, Val Cor=0.4600, Time=0.1619 sec
Epoch [1389 / 2000]: Train Loss=4.8236, Val Cor=0.6228, Time=0.1639 sec
Epoch [1390 / 2000]: Train Loss=4.6317, Val Cor=0.5997, Time=0.1637 sec
Epoch [1391 / 2000]: Train Loss=4.8012, Val Cor=0.5850, Time=0.1643 sec
Epoch [1392 / 2000]: Train Loss=4.8192, Val Cor=0.5880, Time=0.1641 sec
Epoch [1393 / 2000]: Train Loss=4.8001, Val Cor=0.5478, Time=0.1616 sec
Epoch [1394 / 2000]: Train Loss=4.7614, Val Cor=0.5204, Time=0.1619 sec
Epoch [1395 / 2000]: Train Loss=4.8269, Val Cor=0.5110, Time=0.1608 sec
Epoch [1396 / 2000]: Train Loss=5.0305, Val Cor=0.5686, Time=0.1623 sec
Epoch [1397 / 2000]: Train Loss=4.9147, Val Cor=0.6015, Time=0.1624 sec
Epoch [1398 / 2000]: Train Loss=4.8678, Val Cor=0.5317, Time=0.1625 sec
Epoch [1399 / 2000]: Train Loss=4.9399, Val Cor=0.2519, Time=0.1

Epoch [71 / 2000]: Train Loss=7.1235, Val Cor=0.4391, Time=0.1665 sec
Epoch [72 / 2000]: Train Loss=6.8554, Val Cor=0.4412, Time=0.1653 sec
Epoch [73 / 2000]: Train Loss=6.8699, Val Cor=0.4431, Time=0.1649 sec
Epoch [74 / 2000]: Train Loss=7.0659, Val Cor=0.4400, Time=0.1649 sec
Epoch [75 / 2000]: Train Loss=6.8439, Val Cor=0.4408, Time=0.1668 sec
Epoch [76 / 2000]: Train Loss=6.7585, Val Cor=0.4439, Time=0.1644 sec
Epoch [77 / 2000]: Train Loss=6.9735, Val Cor=0.4406, Time=0.1644 sec
Epoch [78 / 2000]: Train Loss=6.5643, Val Cor=0.4419, Time=0.1641 sec
Epoch [79 / 2000]: Train Loss=6.6042, Val Cor=0.4426, Time=0.1628 sec
Epoch [80 / 2000]: Train Loss=6.6844, Val Cor=0.4437, Time=0.1627 sec
Epoch [81 / 2000]: Train Loss=6.7795, Val Cor=0.4442, Time=0.1630 sec
Epoch [82 / 2000]: Train Loss=6.8850, Val Cor=0.4427, Time=0.1627 sec
Epoch [83 / 2000]: Train Loss=6.7357, Val Cor=0.4442, Time=0.1633 sec
Epoch [84 / 2000]: Train Loss=6.6070, Val Cor=0.4493, Time=0.1627 sec
Epoch [85 / 2000]: T

Epoch [187 / 2000]: Train Loss=6.5151, Val Cor=0.4576, Time=0.1649 sec
Epoch [188 / 2000]: Train Loss=6.1392, Val Cor=0.4637, Time=0.1639 sec
Epoch [189 / 2000]: Train Loss=6.1590, Val Cor=0.4547, Time=0.1650 sec
Epoch [190 / 2000]: Train Loss=6.0259, Val Cor=0.4640, Time=0.1640 sec
Epoch [191 / 2000]: Train Loss=6.1177, Val Cor=0.4639, Time=0.1645 sec
Epoch [192 / 2000]: Train Loss=6.2145, Val Cor=0.4645, Time=0.1624 sec
Epoch [193 / 2000]: Train Loss=6.2261, Val Cor=0.4668, Time=0.1623 sec
Epoch [194 / 2000]: Train Loss=6.4018, Val Cor=0.4572, Time=0.1621 sec
Epoch [195 / 2000]: Train Loss=6.3192, Val Cor=0.4629, Time=0.1621 sec
Epoch [196 / 2000]: Train Loss=6.3127, Val Cor=0.4559, Time=0.1630 sec
Epoch [197 / 2000]: Train Loss=6.2515, Val Cor=0.4651, Time=0.1633 sec
Epoch [198 / 2000]: Train Loss=5.9873, Val Cor=0.4661, Time=0.1630 sec
Epoch [199 / 2000]: Train Loss=5.9085, Val Cor=0.4642, Time=0.1654 sec
Epoch [200 / 2000]: Train Loss=6.4956, Val Cor=0.4665, Time=0.1645 sec
Epoch 

Epoch [303 / 2000]: Train Loss=5.6560, Val Cor=0.4769, Time=0.1652 sec
Epoch [304 / 2000]: Train Loss=5.8491, Val Cor=0.4754, Time=0.1659 sec
Epoch [305 / 2000]: Train Loss=5.7211, Val Cor=0.4866, Time=0.1666 sec
Epoch [306 / 2000]: Train Loss=5.7884, Val Cor=0.4897, Time=0.1659 sec
Epoch [307 / 2000]: Train Loss=5.8858, Val Cor=0.4884, Time=0.1655 sec
Epoch [308 / 2000]: Train Loss=6.0762, Val Cor=0.4681, Time=0.1642 sec
Epoch [309 / 2000]: Train Loss=5.9563, Val Cor=0.4886, Time=0.1643 sec
Epoch [310 / 2000]: Train Loss=6.0161, Val Cor=0.4826, Time=0.1644 sec
Epoch [311 / 2000]: Train Loss=6.0772, Val Cor=0.4820, Time=0.1637 sec
Epoch [312 / 2000]: Train Loss=5.8357, Val Cor=0.4815, Time=0.1628 sec
Epoch [313 / 2000]: Train Loss=5.9921, Val Cor=0.4787, Time=0.1650 sec
Epoch [314 / 2000]: Train Loss=5.9341, Val Cor=0.4781, Time=0.1642 sec
Epoch [315 / 2000]: Train Loss=5.8606, Val Cor=0.4757, Time=0.1648 sec
Epoch [316 / 2000]: Train Loss=5.9114, Val Cor=0.4857, Time=0.1643 sec
Epoch 

Epoch [420 / 2000]: Train Loss=5.4712, Val Cor=0.5072, Time=0.1647 sec
Epoch [421 / 2000]: Train Loss=5.4877, Val Cor=0.4952, Time=0.1651 sec
Epoch [422 / 2000]: Train Loss=5.6338, Val Cor=0.5001, Time=0.1655 sec
Epoch [423 / 2000]: Train Loss=5.7738, Val Cor=0.4991, Time=0.1662 sec
Epoch [424 / 2000]: Train Loss=5.7017, Val Cor=0.4482, Time=0.1647 sec
Epoch [425 / 2000]: Train Loss=5.5899, Val Cor=0.4965, Time=0.1633 sec
Epoch [426 / 2000]: Train Loss=5.6125, Val Cor=0.4379, Time=0.1629 sec
Epoch [427 / 2000]: Train Loss=5.6694, Val Cor=0.4658, Time=0.1617 sec
Epoch [428 / 2000]: Train Loss=5.8348, Val Cor=0.4986, Time=0.1618 sec
Epoch [429 / 2000]: Train Loss=5.6210, Val Cor=0.4993, Time=0.1622 sec
Epoch [430 / 2000]: Train Loss=5.6520, Val Cor=0.5001, Time=0.1627 sec
Epoch [431 / 2000]: Train Loss=5.7840, Val Cor=0.4996, Time=0.1622 sec
Epoch [432 / 2000]: Train Loss=5.5785, Val Cor=0.4830, Time=0.1639 sec
Epoch [433 / 2000]: Train Loss=5.7777, Val Cor=0.4993, Time=0.1633 sec
Epoch 

Epoch [536 / 2000]: Train Loss=5.5660, Val Cor=0.4952, Time=0.1638 sec
Epoch [537 / 2000]: Train Loss=5.6035, Val Cor=0.4865, Time=0.1646 sec
Epoch [538 / 2000]: Train Loss=5.4529, Val Cor=0.5209, Time=0.1643 sec
Epoch [539 / 2000]: Train Loss=5.5606, Val Cor=0.5227, Time=0.1635 sec
Epoch [540 / 2000]: Train Loss=5.4617, Val Cor=0.1786, Time=0.1648 sec
Epoch [541 / 2000]: Train Loss=5.6502, Val Cor=0.5080, Time=0.1640 sec
Epoch [542 / 2000]: Train Loss=5.4579, Val Cor=0.5232, Time=0.1628 sec
Epoch [543 / 2000]: Train Loss=5.3721, Val Cor=0.5303, Time=0.1618 sec
Epoch [544 / 2000]: Train Loss=5.8544, Val Cor=-0.3626, Time=0.1625 sec
Epoch [545 / 2000]: Train Loss=5.9551, Val Cor=0.0697, Time=0.1629 sec
Epoch [546 / 2000]: Train Loss=5.5350, Val Cor=0.5176, Time=0.1623 sec
Epoch [547 / 2000]: Train Loss=5.4737, Val Cor=0.5202, Time=0.1625 sec
Epoch [548 / 2000]: Train Loss=5.3378, Val Cor=0.5229, Time=0.1625 sec
Epoch [549 / 2000]: Train Loss=5.4778, Val Cor=0.4993, Time=0.1625 sec
Epoch

Epoch [652 / 2000]: Train Loss=5.2209, Val Cor=0.5461, Time=0.1620 sec
Epoch [653 / 2000]: Train Loss=5.5374, Val Cor=0.5334, Time=0.1613 sec
Epoch [654 / 2000]: Train Loss=5.2880, Val Cor=0.5250, Time=0.1614 sec
Epoch [655 / 2000]: Train Loss=5.3391, Val Cor=0.5306, Time=0.1613 sec
Epoch [656 / 2000]: Train Loss=5.3649, Val Cor=0.5325, Time=0.1627 sec
Epoch [657 / 2000]: Train Loss=5.3637, Val Cor=0.4823, Time=0.1631 sec
Epoch [658 / 2000]: Train Loss=5.3759, Val Cor=0.4964, Time=0.1631 sec
Epoch [659 / 2000]: Train Loss=5.4372, Val Cor=0.4996, Time=0.1624 sec
Epoch [660 / 2000]: Train Loss=6.3661, Val Cor=-0.2619, Time=0.1635 sec
Epoch [661 / 2000]: Train Loss=5.9639, Val Cor=-0.3075, Time=0.1631 sec
Epoch [662 / 2000]: Train Loss=5.8943, Val Cor=-0.2714, Time=0.1626 sec
Epoch [663 / 2000]: Train Loss=5.6855, Val Cor=-0.1498, Time=0.1616 sec
Epoch [664 / 2000]: Train Loss=5.2563, Val Cor=0.5351, Time=0.1613 sec
Epoch [665 / 2000]: Train Loss=5.3772, Val Cor=0.5169, Time=0.1610 sec
Ep

Epoch [768 / 2000]: Train Loss=5.3186, Val Cor=0.5360, Time=0.1619 sec
Epoch [769 / 2000]: Train Loss=5.2175, Val Cor=0.5028, Time=0.1628 sec
Epoch [770 / 2000]: Train Loss=5.0903, Val Cor=0.5227, Time=0.1636 sec
Epoch [771 / 2000]: Train Loss=5.1435, Val Cor=-0.0946, Time=0.1632 sec
Epoch [772 / 2000]: Train Loss=5.1105, Val Cor=0.5331, Time=0.1645 sec
Epoch [773 / 2000]: Train Loss=5.9414, Val Cor=0.5312, Time=0.1636 sec
Epoch [774 / 2000]: Train Loss=5.3199, Val Cor=0.2252, Time=0.1629 sec
Epoch [775 / 2000]: Train Loss=5.4533, Val Cor=0.4733, Time=0.1615 sec
Epoch [776 / 2000]: Train Loss=5.3951, Val Cor=0.5097, Time=0.1612 sec
Epoch [777 / 2000]: Train Loss=5.3017, Val Cor=0.5073, Time=0.1613 sec
Epoch [778 / 2000]: Train Loss=5.0330, Val Cor=0.5218, Time=0.1614 sec
Epoch [779 / 2000]: Train Loss=5.0545, Val Cor=0.5520, Time=0.1627 sec
Epoch [780 / 2000]: Train Loss=5.1838, Val Cor=0.5367, Time=0.1630 sec
Epoch [781 / 2000]: Train Loss=5.3480, Val Cor=0.5415, Time=0.1629 sec
Epoch

Epoch [884 / 2000]: Train Loss=5.4069, Val Cor=0.4891, Time=0.1627 sec
Epoch [885 / 2000]: Train Loss=5.0764, Val Cor=0.4862, Time=0.1633 sec
Epoch [886 / 2000]: Train Loss=5.3121, Val Cor=0.4902, Time=0.1632 sec
Epoch [887 / 2000]: Train Loss=5.0764, Val Cor=-0.0327, Time=0.1631 sec
Epoch [888 / 2000]: Train Loss=5.1911, Val Cor=0.5345, Time=0.1634 sec
Epoch [889 / 2000]: Train Loss=5.1552, Val Cor=0.5224, Time=0.1627 sec
Epoch [890 / 2000]: Train Loss=5.0657, Val Cor=0.5096, Time=0.1635 sec
Epoch [891 / 2000]: Train Loss=5.2817, Val Cor=-0.2485, Time=0.1612 sec
Epoch [892 / 2000]: Train Loss=5.3998, Val Cor=0.0148, Time=0.1611 sec
Epoch [893 / 2000]: Train Loss=5.0876, Val Cor=0.0305, Time=0.1616 sec
Epoch [894 / 2000]: Train Loss=5.5061, Val Cor=0.4626, Time=0.1614 sec
Epoch [895 / 2000]: Train Loss=5.2369, Val Cor=0.5256, Time=0.1614 sec
Epoch [896 / 2000]: Train Loss=4.9470, Val Cor=0.5011, Time=0.1627 sec
Epoch [897 / 2000]: Train Loss=4.9442, Val Cor=0.5351, Time=0.1623 sec
Epoc

Epoch [1000 / 2000]: Train Loss=4.9730, Val Cor=-0.0696, Time=0.1630 sec
Epoch [1001 / 2000]: Train Loss=5.0457, Val Cor=0.5015, Time=0.1620 sec
Epoch [1002 / 2000]: Train Loss=4.9414, Val Cor=0.5325, Time=0.1612 sec
Epoch [1003 / 2000]: Train Loss=4.8409, Val Cor=0.2719, Time=0.1609 sec
Epoch [1004 / 2000]: Train Loss=4.8835, Val Cor=0.5155, Time=0.1618 sec
Epoch [1005 / 2000]: Train Loss=4.9049, Val Cor=0.5375, Time=0.1617 sec
Epoch [1006 / 2000]: Train Loss=4.9264, Val Cor=0.5101, Time=0.1622 sec
Epoch [1007 / 2000]: Train Loss=4.9308, Val Cor=0.0178, Time=0.1621 sec
Epoch [1008 / 2000]: Train Loss=5.3074, Val Cor=0.4913, Time=0.1630 sec
Epoch [1009 / 2000]: Train Loss=5.1311, Val Cor=0.1194, Time=0.1628 sec
Epoch [1010 / 2000]: Train Loss=5.0556, Val Cor=0.4116, Time=0.1636 sec
Epoch [1011 / 2000]: Train Loss=5.0194, Val Cor=-0.4207, Time=0.1629 sec
Epoch [1012 / 2000]: Train Loss=5.0646, Val Cor=-0.3838, Time=0.1620 sec
Epoch [1013 / 2000]: Train Loss=5.2926, Val Cor=0.5478, Time=

Epoch [1114 / 2000]: Train Loss=5.0802, Val Cor=0.5001, Time=0.1605 sec
Epoch [1115 / 2000]: Train Loss=4.7271, Val Cor=-0.2932, Time=0.1615 sec
Epoch [1116 / 2000]: Train Loss=5.0001, Val Cor=-0.0620, Time=0.1607 sec
Epoch [1117 / 2000]: Train Loss=4.9063, Val Cor=0.4696, Time=0.1590 sec
Epoch [1118 / 2000]: Train Loss=4.9499, Val Cor=0.4159, Time=0.1590 sec
Epoch [1119 / 2000]: Train Loss=4.8882, Val Cor=0.3623, Time=0.1595 sec
Epoch [1120 / 2000]: Train Loss=5.1369, Val Cor=0.4611, Time=0.1597 sec
Epoch [1121 / 2000]: Train Loss=4.9943, Val Cor=0.5221, Time=0.1595 sec
Epoch [1122 / 2000]: Train Loss=4.9407, Val Cor=-0.5243, Time=0.1602 sec
Epoch [1123 / 2000]: Train Loss=5.1183, Val Cor=0.4516, Time=0.1619 sec
Epoch [1124 / 2000]: Train Loss=5.1254, Val Cor=-0.1447, Time=0.1616 sec
Epoch [1125 / 2000]: Train Loss=5.1144, Val Cor=-0.3611, Time=0.1624 sec
Epoch [1126 / 2000]: Train Loss=4.8667, Val Cor=-0.3308, Time=0.1626 sec
Epoch [1127 / 2000]: Train Loss=4.9819, Val Cor=0.4901, Ti

Epoch [92 / 2000]: Train Loss=7.0186, Val Cor=0.5140, Time=0.1599 sec
Epoch [93 / 2000]: Train Loss=6.5740, Val Cor=0.5123, Time=0.1596 sec
Epoch [94 / 2000]: Train Loss=6.5069, Val Cor=0.5089, Time=0.1599 sec
Epoch [95 / 2000]: Train Loss=6.6062, Val Cor=0.5100, Time=0.1601 sec
Epoch [96 / 2000]: Train Loss=6.5302, Val Cor=0.5135, Time=0.1612 sec
Epoch [97 / 2000]: Train Loss=6.5800, Val Cor=0.5143, Time=0.1614 sec
Epoch [98 / 2000]: Train Loss=6.5965, Val Cor=0.5158, Time=0.1612 sec
Epoch [99 / 2000]: Train Loss=6.6938, Val Cor=0.5145, Time=0.1613 sec
Epoch [100 / 2000]: Train Loss=6.9718, Val Cor=0.5199, Time=0.1616 sec
Epoch [101 / 2000]: Train Loss=6.8609, Val Cor=0.5200, Time=0.1609 sec
Epoch [102 / 2000]: Train Loss=6.7514, Val Cor=0.5195, Time=0.1606 sec
Epoch [103 / 2000]: Train Loss=6.6281, Val Cor=0.5186, Time=0.1614 sec
Epoch [104 / 2000]: Train Loss=6.3284, Val Cor=0.5170, Time=0.1613 sec
Epoch [105 / 2000]: Train Loss=6.5635, Val Cor=0.5187, Time=0.1595 sec
Epoch [106 / 2

Epoch [208 / 2000]: Train Loss=6.0712, Val Cor=0.5555, Time=0.1620 sec
Epoch [209 / 2000]: Train Loss=5.9840, Val Cor=0.5568, Time=0.1615 sec
Epoch [210 / 2000]: Train Loss=6.1687, Val Cor=0.5549, Time=0.1601 sec
Epoch [211 / 2000]: Train Loss=5.9788, Val Cor=0.5524, Time=0.1600 sec
Epoch [212 / 2000]: Train Loss=6.1257, Val Cor=0.5547, Time=0.1594 sec
Epoch [213 / 2000]: Train Loss=5.8718, Val Cor=0.5594, Time=0.1600 sec
Epoch [214 / 2000]: Train Loss=6.1720, Val Cor=0.5533, Time=0.1597 sec
Epoch [215 / 2000]: Train Loss=6.1826, Val Cor=0.5606, Time=0.1614 sec
Epoch [216 / 2000]: Train Loss=6.0743, Val Cor=0.5584, Time=0.1612 sec
Epoch [217 / 2000]: Train Loss=5.8798, Val Cor=0.5548, Time=0.1612 sec
Epoch [218 / 2000]: Train Loss=6.0289, Val Cor=0.5640, Time=0.1620 sec
Epoch [219 / 2000]: Train Loss=5.9295, Val Cor=0.5653, Time=0.1614 sec
Epoch [220 / 2000]: Train Loss=6.2159, Val Cor=0.5686, Time=0.1617 sec
Epoch [221 / 2000]: Train Loss=6.0935, Val Cor=0.5670, Time=0.1616 sec
Epoch 

Epoch [324 / 2000]: Train Loss=6.1003, Val Cor=0.5828, Time=0.1603 sec
Epoch [325 / 2000]: Train Loss=5.9467, Val Cor=0.6020, Time=0.1602 sec
Epoch [326 / 2000]: Train Loss=5.8408, Val Cor=0.4985, Time=0.1603 sec
Epoch [327 / 2000]: Train Loss=5.7018, Val Cor=0.5717, Time=0.1610 sec
Epoch [328 / 2000]: Train Loss=5.9591, Val Cor=0.5295, Time=0.1618 sec
Epoch [329 / 2000]: Train Loss=5.7784, Val Cor=0.5600, Time=0.1618 sec
Epoch [330 / 2000]: Train Loss=5.9748, Val Cor=0.6037, Time=0.1621 sec
Epoch [331 / 2000]: Train Loss=6.0943, Val Cor=0.5122, Time=0.1616 sec
Epoch [332 / 2000]: Train Loss=5.7324, Val Cor=0.5765, Time=0.1625 sec
Epoch [333 / 2000]: Train Loss=5.8288, Val Cor=0.5855, Time=0.1616 sec
Epoch [334 / 2000]: Train Loss=5.9823, Val Cor=0.5615, Time=0.1611 sec
Epoch [335 / 2000]: Train Loss=5.7203, Val Cor=0.5919, Time=0.1604 sec
Epoch [336 / 2000]: Train Loss=6.1824, Val Cor=0.6070, Time=0.1611 sec
Epoch [337 / 2000]: Train Loss=5.8808, Val Cor=0.5590, Time=0.1599 sec
Epoch 

Epoch [440 / 2000]: Train Loss=5.6475, Val Cor=0.6063, Time=0.1602 sec
Epoch [441 / 2000]: Train Loss=5.5360, Val Cor=0.5444, Time=0.1603 sec
Epoch [442 / 2000]: Train Loss=5.5808, Val Cor=0.4508, Time=0.1604 sec
Epoch [443 / 2000]: Train Loss=5.6803, Val Cor=0.6198, Time=0.1609 sec
Epoch [444 / 2000]: Train Loss=5.6535, Val Cor=0.6225, Time=0.1612 sec
Epoch [445 / 2000]: Train Loss=5.6364, Val Cor=0.5809, Time=0.1609 sec
Epoch [446 / 2000]: Train Loss=5.6488, Val Cor=0.6116, Time=0.1615 sec
Epoch [447 / 2000]: Train Loss=5.6269, Val Cor=0.5966, Time=0.1615 sec
Epoch [448 / 2000]: Train Loss=5.8953, Val Cor=0.5796, Time=0.1611 sec
Epoch [449 / 2000]: Train Loss=5.5530, Val Cor=0.6232, Time=0.1598 sec
Epoch [450 / 2000]: Train Loss=5.5018, Val Cor=0.6087, Time=0.1598 sec
Epoch [451 / 2000]: Train Loss=5.5279, Val Cor=0.5951, Time=0.1599 sec
Epoch [452 / 2000]: Train Loss=5.3743, Val Cor=0.6195, Time=0.1597 sec
Epoch [453 / 2000]: Train Loss=5.7930, Val Cor=0.6049, Time=0.1599 sec
Epoch 

Epoch [556 / 2000]: Train Loss=5.7771, Val Cor=0.6473, Time=0.1614 sec
Epoch [557 / 2000]: Train Loss=5.2019, Val Cor=0.6372, Time=0.1610 sec
Epoch [558 / 2000]: Train Loss=5.7203, Val Cor=0.6203, Time=0.1600 sec
Epoch [559 / 2000]: Train Loss=5.5572, Val Cor=0.5873, Time=0.1593 sec
Epoch [560 / 2000]: Train Loss=5.3028, Val Cor=0.6204, Time=0.1600 sec
Epoch [561 / 2000]: Train Loss=5.3327, Val Cor=0.6177, Time=0.1600 sec
Epoch [562 / 2000]: Train Loss=5.7053, Val Cor=0.6380, Time=0.1602 sec
Epoch [563 / 2000]: Train Loss=6.2564, Val Cor=0.6310, Time=0.1614 sec
Epoch [564 / 2000]: Train Loss=5.7038, Val Cor=0.6024, Time=0.1613 sec
Epoch [565 / 2000]: Train Loss=5.5903, Val Cor=0.6522, Time=0.1613 sec
Epoch [566 / 2000]: Train Loss=5.9242, Val Cor=0.6294, Time=0.1611 sec
Epoch [567 / 2000]: Train Loss=5.6999, Val Cor=0.6177, Time=0.1621 sec
Epoch [568 / 2000]: Train Loss=5.5602, Val Cor=0.6368, Time=0.1621 sec
Epoch [569 / 2000]: Train Loss=5.4658, Val Cor=0.6354, Time=0.1612 sec
Epoch 

Epoch [672 / 2000]: Train Loss=5.3888, Val Cor=0.6695, Time=0.1612 sec
Epoch [673 / 2000]: Train Loss=5.4622, Val Cor=0.6537, Time=0.1609 sec
Epoch [674 / 2000]: Train Loss=5.3212, Val Cor=0.6514, Time=0.1620 sec
Epoch [675 / 2000]: Train Loss=5.2848, Val Cor=0.6588, Time=0.1614 sec
Epoch [676 / 2000]: Train Loss=5.3918, Val Cor=0.6546, Time=0.1600 sec
Epoch [677 / 2000]: Train Loss=5.1506, Val Cor=0.6679, Time=0.1585 sec
Epoch [678 / 2000]: Train Loss=5.3517, Val Cor=-0.4407, Time=0.1587 sec
Epoch [679 / 2000]: Train Loss=5.3109, Val Cor=0.6551, Time=0.1590 sec
Epoch [680 / 2000]: Train Loss=5.4340, Val Cor=0.6135, Time=0.1587 sec
Epoch [681 / 2000]: Train Loss=5.5340, Val Cor=0.6587, Time=0.1589 sec
Epoch [682 / 2000]: Train Loss=5.3310, Val Cor=0.6508, Time=0.1596 sec
Epoch [683 / 2000]: Train Loss=5.2231, Val Cor=0.6633, Time=0.1603 sec
Epoch [684 / 2000]: Train Loss=5.2352, Val Cor=0.6495, Time=0.1598 sec
Epoch [685 / 2000]: Train Loss=5.3155, Val Cor=0.6588, Time=0.1606 sec
Epoch

Epoch [788 / 2000]: Train Loss=5.3837, Val Cor=0.5724, Time=0.1623 sec
Epoch [789 / 2000]: Train Loss=5.2295, Val Cor=0.5450, Time=0.1619 sec
Epoch [790 / 2000]: Train Loss=5.1020, Val Cor=0.6382, Time=0.1627 sec
Epoch [791 / 2000]: Train Loss=5.2016, Val Cor=0.6463, Time=0.1621 sec
Epoch [792 / 2000]: Train Loss=5.2425, Val Cor=0.4153, Time=0.1611 sec
Epoch [793 / 2000]: Train Loss=5.3504, Val Cor=0.5995, Time=0.1600 sec
Epoch [794 / 2000]: Train Loss=5.1193, Val Cor=0.5798, Time=0.1605 sec
Epoch [795 / 2000]: Train Loss=5.0764, Val Cor=0.6660, Time=0.1604 sec
Epoch [796 / 2000]: Train Loss=5.2398, Val Cor=0.5272, Time=0.1609 sec
Epoch [797 / 2000]: Train Loss=5.4735, Val Cor=0.3756, Time=0.1618 sec
Epoch [798 / 2000]: Train Loss=5.4357, Val Cor=0.6543, Time=0.1621 sec
Epoch [799 / 2000]: Train Loss=5.1674, Val Cor=0.6599, Time=0.1615 sec
Epoch [800 / 2000]: Train Loss=5.1777, Val Cor=0.3913, Time=0.1619 sec
Epoch [801 / 2000]: Train Loss=5.3183, Val Cor=0.5526, Time=0.1618 sec
Epoch 

Epoch [904 / 2000]: Train Loss=5.5479, Val Cor=0.3429, Time=0.1620 sec
Epoch [905 / 2000]: Train Loss=5.0793, Val Cor=0.6567, Time=0.1623 sec
Epoch [906 / 2000]: Train Loss=5.3324, Val Cor=0.6474, Time=0.1619 sec
Epoch [907 / 2000]: Train Loss=5.2687, Val Cor=0.5426, Time=0.1610 sec
Epoch [908 / 2000]: Train Loss=5.1406, Val Cor=0.4577, Time=0.1600 sec
Epoch [909 / 2000]: Train Loss=5.4136, Val Cor=0.6056, Time=0.1603 sec
Epoch [910 / 2000]: Train Loss=5.3522, Val Cor=-0.6493, Time=0.1606 sec
Epoch [911 / 2000]: Train Loss=5.3290, Val Cor=-0.6253, Time=0.1600 sec
Epoch [912 / 2000]: Train Loss=5.7051, Val Cor=-0.2538, Time=0.1603 sec
Epoch [913 / 2000]: Train Loss=5.3470, Val Cor=-0.6261, Time=0.1615 sec
Epoch [914 / 2000]: Train Loss=5.4756, Val Cor=-0.4962, Time=0.1613 sec
Epoch [915 / 2000]: Train Loss=5.2842, Val Cor=-0.3732, Time=0.1597 sec
Epoch [916 / 2000]: Train Loss=5.2239, Val Cor=-0.5886, Time=0.1598 sec
Epoch [917 / 2000]: Train Loss=5.2259, Val Cor=-0.5255, Time=0.1600 se

Epoch [1020 / 2000]: Train Loss=5.1011, Val Cor=-0.6372, Time=0.1602 sec
Epoch [1021 / 2000]: Train Loss=5.0407, Val Cor=0.6457, Time=0.1597 sec
Epoch [1022 / 2000]: Train Loss=4.8983, Val Cor=0.2139, Time=0.1600 sec
Epoch [1023 / 2000]: Train Loss=5.0146, Val Cor=0.5423, Time=0.1601 sec
Epoch [1024 / 2000]: Train Loss=5.1755, Val Cor=0.6279, Time=0.1607 sec
Epoch [1025 / 2000]: Train Loss=4.8753, Val Cor=0.6567, Time=0.1611 sec
Epoch [1026 / 2000]: Train Loss=4.8993, Val Cor=0.3961, Time=0.1614 sec
Epoch [1027 / 2000]: Train Loss=4.8729, Val Cor=0.6588, Time=0.1615 sec
Epoch [1028 / 2000]: Train Loss=5.2768, Val Cor=0.6869, Time=0.1612 sec
Epoch [1029 / 2000]: Train Loss=5.0994, Val Cor=0.5439, Time=0.1606 sec
Epoch [1030 / 2000]: Train Loss=5.0248, Val Cor=0.6421, Time=0.1617 sec
Epoch [1031 / 2000]: Train Loss=5.2261, Val Cor=0.6636, Time=0.1613 sec
Epoch [1032 / 2000]: Train Loss=5.2129, Val Cor=0.6547, Time=0.1614 sec
Epoch [1033 / 2000]: Train Loss=4.9320, Val Cor=0.6692, Time=0.

Epoch [1135 / 2000]: Train Loss=5.4395, Val Cor=-0.4911, Time=0.1615 sec
Epoch [1136 / 2000]: Train Loss=5.2076, Val Cor=0.4232, Time=0.1595 sec
Epoch [1137 / 2000]: Train Loss=5.2166, Val Cor=-0.2640, Time=0.1588 sec
Epoch [1138 / 2000]: Train Loss=5.4763, Val Cor=0.4957, Time=0.1585 sec
Epoch [1139 / 2000]: Train Loss=5.2584, Val Cor=-0.4649, Time=0.1590 sec
Epoch [1140 / 2000]: Train Loss=5.1217, Val Cor=0.6566, Time=0.1588 sec
Epoch [1141 / 2000]: Train Loss=5.1942, Val Cor=-0.1012, Time=0.1587 sec
Epoch [1142 / 2000]: Train Loss=5.1695, Val Cor=-0.2328, Time=0.1590 sec
Epoch [1143 / 2000]: Train Loss=5.2442, Val Cor=-0.6324, Time=0.1595 sec
Epoch [1144 / 2000]: Train Loss=5.2654, Val Cor=-0.2630, Time=0.1605 sec
Epoch [1145 / 2000]: Train Loss=5.1021, Val Cor=-0.1290, Time=0.1607 sec
Epoch [1146 / 2000]: Train Loss=5.0034, Val Cor=0.6747, Time=0.1603 sec
Epoch [1147 / 2000]: Train Loss=5.0011, Val Cor=0.5835, Time=0.1600 sec
Epoch [1148 / 2000]: Train Loss=5.0870, Val Cor=0.6169, 

Epoch [1249 / 2000]: Train Loss=4.7369, Val Cor=-0.4596, Time=0.1622 sec
Epoch [1250 / 2000]: Train Loss=4.7004, Val Cor=-0.6460, Time=0.1620 sec
Epoch [1251 / 2000]: Train Loss=4.8170, Val Cor=0.5172, Time=0.1618 sec
Epoch [1252 / 2000]: Train Loss=4.8133, Val Cor=0.2983, Time=0.1603 sec
Epoch [1253 / 2000]: Train Loss=4.8837, Val Cor=0.5987, Time=0.1603 sec
Epoch [1254 / 2000]: Train Loss=4.8657, Val Cor=-0.1300, Time=0.1605 sec
Epoch [1255 / 2000]: Train Loss=4.9228, Val Cor=0.3148, Time=0.1602 sec
Epoch [1256 / 2000]: Train Loss=4.9344, Val Cor=-0.2422, Time=0.1602 sec
Epoch [1257 / 2000]: Train Loss=5.0334, Val Cor=0.5954, Time=0.1611 sec
Epoch [1258 / 2000]: Train Loss=5.1635, Val Cor=0.6219, Time=0.1612 sec
Epoch [1259 / 2000]: Train Loss=5.0179, Val Cor=0.6483, Time=0.1617 sec
Epoch [1260 / 2000]: Train Loss=4.9135, Val Cor=0.6417, Time=0.1617 sec
Epoch [1261 / 2000]: Train Loss=4.9938, Val Cor=0.4303, Time=0.1624 sec
Epoch [1262 / 2000]: Train Loss=4.7628, Val Cor=0.6558, Time

Epoch [7 / 2000]: Train Loss=8.1228, Val Cor=0.4787, Time=0.1598 sec
Epoch [8 / 2000]: Train Loss=8.2847, Val Cor=0.4699, Time=0.1601 sec
Epoch [9 / 2000]: Train Loss=8.1392, Val Cor=0.5294, Time=0.1614 sec
Epoch [10 / 2000]: Train Loss=8.1827, Val Cor=0.5142, Time=0.1609 sec
Epoch [11 / 2000]: Train Loss=8.1206, Val Cor=0.5244, Time=0.1612 sec
Epoch [12 / 2000]: Train Loss=8.0906, Val Cor=0.4865, Time=0.1616 sec
Epoch [13 / 2000]: Train Loss=8.1210, Val Cor=0.4799, Time=0.1614 sec
Epoch [14 / 2000]: Train Loss=8.1051, Val Cor=0.5165, Time=0.1596 sec
Epoch [15 / 2000]: Train Loss=8.0706, Val Cor=0.5283, Time=0.1596 sec
Epoch [16 / 2000]: Train Loss=8.0574, Val Cor=0.5529, Time=0.1595 sec
Epoch [17 / 2000]: Train Loss=7.9711, Val Cor=0.5529, Time=0.1594 sec
Epoch [18 / 2000]: Train Loss=8.0985, Val Cor=0.5632, Time=0.1593 sec
Epoch [19 / 2000]: Train Loss=7.9425, Val Cor=0.5821, Time=0.1594 sec
Epoch [20 / 2000]: Train Loss=7.9089, Val Cor=0.5926, Time=0.1596 sec
Epoch [21 / 2000]: Trai

Epoch [125 / 2000]: Train Loss=6.0732, Val Cor=0.7111, Time=0.1599 sec
Epoch [126 / 2000]: Train Loss=6.6421, Val Cor=0.7061, Time=0.1602 sec
Epoch [127 / 2000]: Train Loss=6.6090, Val Cor=0.7130, Time=0.1600 sec
Epoch [128 / 2000]: Train Loss=6.3339, Val Cor=0.7068, Time=0.1595 sec
Epoch [129 / 2000]: Train Loss=6.5487, Val Cor=0.7125, Time=0.1611 sec
Epoch [130 / 2000]: Train Loss=6.6166, Val Cor=0.7087, Time=0.1609 sec
Epoch [131 / 2000]: Train Loss=6.3580, Val Cor=0.7081, Time=0.1605 sec
Epoch [132 / 2000]: Train Loss=6.1867, Val Cor=0.7077, Time=0.1608 sec
Epoch [133 / 2000]: Train Loss=6.3794, Val Cor=0.7089, Time=0.1614 sec
Epoch [134 / 2000]: Train Loss=6.4862, Val Cor=0.7078, Time=0.1609 sec
Epoch [135 / 2000]: Train Loss=6.4457, Val Cor=0.7090, Time=0.1597 sec
Epoch [136 / 2000]: Train Loss=6.4888, Val Cor=0.7071, Time=0.1594 sec
Epoch [137 / 2000]: Train Loss=6.4094, Val Cor=0.7078, Time=0.1600 sec
Epoch [138 / 2000]: Train Loss=6.4208, Val Cor=0.7082, Time=0.1591 sec
Epoch 

Epoch [241 / 2000]: Train Loss=6.2065, Val Cor=0.6901, Time=0.1607 sec
Epoch [242 / 2000]: Train Loss=5.9939, Val Cor=0.6873, Time=0.1614 sec
Epoch [243 / 2000]: Train Loss=6.2595, Val Cor=0.6977, Time=0.1612 sec
Epoch [244 / 2000]: Train Loss=5.8508, Val Cor=0.6967, Time=0.1614 sec
Epoch [245 / 2000]: Train Loss=6.0944, Val Cor=0.6990, Time=0.1613 sec
Epoch [246 / 2000]: Train Loss=6.1656, Val Cor=0.6984, Time=0.1607 sec
Epoch [247 / 2000]: Train Loss=6.0771, Val Cor=0.6976, Time=0.1600 sec
Epoch [248 / 2000]: Train Loss=6.0408, Val Cor=0.6999, Time=0.1596 sec
Epoch [249 / 2000]: Train Loss=5.9455, Val Cor=0.6994, Time=0.1598 sec
Epoch [250 / 2000]: Train Loss=5.8389, Val Cor=0.6935, Time=0.1601 sec
Epoch [251 / 2000]: Train Loss=5.8933, Val Cor=0.6897, Time=0.1605 sec
Epoch [252 / 2000]: Train Loss=5.7616, Val Cor=0.7008, Time=0.1607 sec
Epoch [253 / 2000]: Train Loss=6.4310, Val Cor=0.7007, Time=0.1615 sec
Epoch [254 / 2000]: Train Loss=5.9531, Val Cor=0.6984, Time=0.1614 sec
Epoch 

Epoch [33 / 2000]: Train Loss=7.3274, Val Cor=0.5490, Time=0.1622 sec
Epoch [34 / 2000]: Train Loss=7.5116, Val Cor=0.5444, Time=0.1633 sec
Epoch [35 / 2000]: Train Loss=7.2703, Val Cor=0.5488, Time=0.1633 sec
Epoch [36 / 2000]: Train Loss=7.3167, Val Cor=0.5572, Time=0.1604 sec
Epoch [37 / 2000]: Train Loss=7.3002, Val Cor=0.5530, Time=0.1603 sec
Epoch [38 / 2000]: Train Loss=7.1094, Val Cor=0.5610, Time=0.1605 sec
Epoch [39 / 2000]: Train Loss=7.1161, Val Cor=0.5607, Time=0.1607 sec
Epoch [40 / 2000]: Train Loss=7.1419, Val Cor=0.5588, Time=0.1605 sec
Epoch [41 / 2000]: Train Loss=7.4177, Val Cor=0.5630, Time=0.1598 sec
Epoch [42 / 2000]: Train Loss=7.1985, Val Cor=0.5660, Time=0.1596 sec
Epoch [43 / 2000]: Train Loss=7.4003, Val Cor=0.5571, Time=0.1599 sec
Epoch [44 / 2000]: Train Loss=7.4538, Val Cor=0.5669, Time=0.1603 sec
Epoch [45 / 2000]: Train Loss=7.0226, Val Cor=0.5609, Time=0.1600 sec
Epoch [46 / 2000]: Train Loss=7.0878, Val Cor=0.5629, Time=0.1611 sec
Epoch [47 / 2000]: T

Epoch [151 / 2000]: Train Loss=6.7057, Val Cor=0.6375, Time=0.1609 sec
Epoch [152 / 2000]: Train Loss=6.6791, Val Cor=0.6374, Time=0.1612 sec
Epoch [153 / 2000]: Train Loss=6.7499, Val Cor=0.6388, Time=0.1611 sec
Epoch [154 / 2000]: Train Loss=6.6450, Val Cor=0.6343, Time=0.1610 sec
Epoch [155 / 2000]: Train Loss=6.4180, Val Cor=0.6368, Time=0.1596 sec
Epoch [156 / 2000]: Train Loss=6.7158, Val Cor=0.6363, Time=0.1593 sec
Epoch [157 / 2000]: Train Loss=6.6535, Val Cor=0.6384, Time=0.1599 sec
Epoch [158 / 2000]: Train Loss=6.5848, Val Cor=0.6379, Time=0.1594 sec
Epoch [159 / 2000]: Train Loss=6.4353, Val Cor=0.6412, Time=0.1601 sec
Epoch [160 / 2000]: Train Loss=6.3872, Val Cor=0.6389, Time=0.1597 sec
Epoch [161 / 2000]: Train Loss=6.7373, Val Cor=0.6401, Time=0.1611 sec
Epoch [162 / 2000]: Train Loss=6.4014, Val Cor=0.6417, Time=0.1612 sec
Epoch [163 / 2000]: Train Loss=6.1430, Val Cor=0.6426, Time=0.1618 sec
Epoch [164 / 2000]: Train Loss=6.1742, Val Cor=0.6390, Time=0.1613 sec
Epoch 

Epoch [267 / 2000]: Train Loss=6.0956, Val Cor=0.6723, Time=0.1605 sec
Epoch [268 / 2000]: Train Loss=6.3954, Val Cor=0.6636, Time=0.1607 sec
Epoch [269 / 2000]: Train Loss=6.0178, Val Cor=0.6678, Time=0.1618 sec
Epoch [270 / 2000]: Train Loss=6.1109, Val Cor=0.6690, Time=0.1618 sec
Epoch [271 / 2000]: Train Loss=6.0216, Val Cor=0.6660, Time=0.1603 sec
Epoch [272 / 2000]: Train Loss=5.9140, Val Cor=0.6655, Time=0.1598 sec
Epoch [273 / 2000]: Train Loss=5.8027, Val Cor=0.6724, Time=0.1610 sec
Epoch [274 / 2000]: Train Loss=5.8056, Val Cor=0.6735, Time=0.1612 sec
Epoch [275 / 2000]: Train Loss=5.6803, Val Cor=0.6682, Time=0.1621 sec
Epoch [276 / 2000]: Train Loss=6.5529, Val Cor=0.5574, Time=0.1611 sec
Epoch [277 / 2000]: Train Loss=6.1949, Val Cor=0.6601, Time=0.1629 sec
Epoch [278 / 2000]: Train Loss=5.9610, Val Cor=0.6693, Time=0.1614 sec
Epoch [279 / 2000]: Train Loss=5.9701, Val Cor=0.6696, Time=0.1607 sec
Epoch [280 / 2000]: Train Loss=5.9264, Val Cor=0.6727, Time=0.1601 sec
Epoch 

Epoch [383 / 2000]: Train Loss=5.9589, Val Cor=0.5689, Time=0.1602 sec
Epoch [384 / 2000]: Train Loss=5.9242, Val Cor=0.6193, Time=0.1604 sec
Epoch [385 / 2000]: Train Loss=6.0324, Val Cor=0.6693, Time=0.1607 sec
Epoch [386 / 2000]: Train Loss=5.9801, Val Cor=0.5867, Time=0.1603 sec
Epoch [387 / 2000]: Train Loss=5.9821, Val Cor=0.6689, Time=0.1615 sec
Epoch [388 / 2000]: Train Loss=6.0762, Val Cor=0.6685, Time=0.1608 sec
Epoch [389 / 2000]: Train Loss=5.9512, Val Cor=0.6733, Time=0.1611 sec
Epoch [390 / 2000]: Train Loss=6.1270, Val Cor=0.6752, Time=0.1612 sec
Epoch [391 / 2000]: Train Loss=5.8947, Val Cor=0.6802, Time=0.1611 sec
Epoch [392 / 2000]: Train Loss=5.8058, Val Cor=0.6654, Time=0.1597 sec
Epoch [393 / 2000]: Train Loss=5.5484, Val Cor=0.6605, Time=0.1598 sec
Epoch [394 / 2000]: Train Loss=6.0811, Val Cor=0.6720, Time=0.1595 sec
Epoch [395 / 2000]: Train Loss=5.9382, Val Cor=0.6752, Time=0.1599 sec
Epoch [396 / 2000]: Train Loss=5.7853, Val Cor=0.6635, Time=0.1595 sec
Epoch 

Epoch [499 / 2000]: Train Loss=5.7748, Val Cor=0.4250, Time=0.1602 sec
Epoch [500 / 2000]: Train Loss=5.7304, Val Cor=0.6139, Time=0.1612 sec
Epoch [501 / 2000]: Train Loss=5.6474, Val Cor=0.6737, Time=0.1611 sec
Epoch [502 / 2000]: Train Loss=5.6117, Val Cor=0.6770, Time=0.1613 sec
Epoch [503 / 2000]: Train Loss=5.8042, Val Cor=0.6123, Time=0.1610 sec
Epoch [504 / 2000]: Train Loss=5.7504, Val Cor=0.6729, Time=0.1610 sec
Epoch [505 / 2000]: Train Loss=5.5027, Val Cor=0.6747, Time=0.1603 sec
Epoch [506 / 2000]: Train Loss=5.4645, Val Cor=0.6212, Time=0.1597 sec
Epoch [507 / 2000]: Train Loss=5.6866, Val Cor=0.6715, Time=0.1601 sec
Epoch [508 / 2000]: Train Loss=5.5661, Val Cor=0.6695, Time=0.1611 sec
Epoch [509 / 2000]: Train Loss=5.6135, Val Cor=-0.0718, Time=0.1612 sec
Epoch [510 / 2000]: Train Loss=5.6400, Val Cor=0.4727, Time=0.1616 sec
Epoch [511 / 2000]: Train Loss=5.8023, Val Cor=0.6705, Time=0.1618 sec
Epoch [512 / 2000]: Train Loss=5.7384, Val Cor=0.6743, Time=0.1615 sec
Epoch

In [96]:
for i in range(5):
    torch.save(RNN_ensemble[i].state_dict(), f'Models/Revised_RNN_ensemble_5_model_{i}_state_dict.pth')

In [21]:
query1 = pd.read_csv("query1.csv")
query1.head()

Unnamed: 0,mutant,DMS_score,sequence
0,S8P,0.805727,MVNEARGNPSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,S20E,0.777952,MVNEARGNSSLNPCLEGSASEGSESSKDSSRCSTPGLDPERHERLR...
2,S20P,0.473713,MVNEARGNSSLNPCLEGSASPGSESSKDSSRCSTPGLDPERHERLR...
3,D27G,0.428925,MVNEARGNSSLNPCLEGSASSGSESSKGSSRCSTPGLDPERHERLR...
4,S28W,0.8621,MVNEARGNSSLNPCLEGSASSGSESSKDWSRCSTPGLDPERHERLR...


In [22]:
query1_mutants = list(query1["mutant"])
print(query1_mutants)

['S8P', 'S20E', 'S20P', 'D27G', 'S28W', 'S29V', 'S29T', 'R30L', 'R40K', 'R40C', 'R49A', 'R51G', 'D56S', 'K57M', 'S60N', 'S60P', 'R67G', 'S77V', 'S77G', 'D103K', 'K104Q', 'S107P', 'C129S', 'C130Q', 'R131M', 'R131V', 'Q132T', 'H142Q', 'K143V', 'K145P', 'Q146S', 'K150P', 'H180E', 'S183K', 'S183Y', 'D190W', 'C192M', 'H200V', 'H200M', 'S205T', 'H212A', 'H212R', 'K214H', 'S218G', 'S218Y', 'Q259P', 'H262F', 'S263Q', 'R294I', 'Q306V', 'R324G', 'R324I', 'M326S', 'M337L', 'S351R', 'H353K', 'R356G', 'R357M', 'Q378T', 'Q378V', 'S391G', 'S393Y', 'S393Q', 'K400Q', 'K410A', 'S411R', 'K413P', 'K418M', 'M419T', 'M419E', 'S426R', 'C451F', 'K467Q', 'R475L', 'Q476R', 'S483M', 'S483R', 'Q484W', 'S502V', 'Q508F', 'K509S', 'Q526M', 'Q570G', 'Q570A', 'D575W', 'R593Y', 'R593D', 'R605C', 'R605I', 'R605L', 'H612L', 'H612T', 'D628H', 'C630V', 'Q633Y', 'Q633D', 'R644S', 'Q647A', 'R650I', 'R650C']


In [23]:
df_test_mutants = list(df_test["mutant"])
print(df_test_mutants)

['V1D', 'V1Y', 'V1C', 'V1A', 'V1E', 'V1W', 'V1T', 'V1R', 'V1Q', 'V1S', 'V1N', 'V1M', 'V1L', 'V1K', 'V1I', 'V1H', 'V1G', 'V1F', 'V1P', 'N2R', 'N2Y', 'N2W', 'N2V', 'N2T', 'N2S', 'N2Q', 'N2I', 'N2M', 'N2A', 'N2P', 'N2D', 'N2E', 'N2F', 'N2C', 'N2H', 'N2K', 'N2L', 'N2G', 'E3K', 'E3C', 'E3D', 'E3F', 'E3G', 'E3H', 'E3I', 'E3L', 'E3A', 'E3N', 'E3M', 'E3W', 'E3V', 'E3T', 'E3Y', 'E3R', 'E3Q', 'E3P', 'E3S', 'A4M', 'A4F', 'A4G', 'A4I', 'A4K', 'A4L', 'A4N', 'A4S', 'A4Q', 'A4R', 'A4E', 'A4T', 'A4V', 'A4W', 'A4Y', 'A4P', 'A4D', 'A4H', 'A4C', 'R5W', 'R5A', 'R5C', 'R5D', 'R5E', 'R5F', 'R5G', 'R5H', 'R5I', 'R5Y', 'R5L', 'R5M', 'R5N', 'R5P', 'R5Q', 'R5S', 'R5T', 'R5V', 'R5K', 'G6V', 'G6T', 'G6S', 'G6R', 'G6Q', 'G6P', 'G6N', 'G6H', 'G6L', 'G6K', 'G6I', 'G6F', 'G6E', 'G6D', 'G6C', 'G6M', 'G6A', 'G6Y', 'G6W', 'N7A', 'N7W', 'N7V', 'N7T', 'N7S', 'N7R', 'N7Q', 'N7P', 'N7M', 'N7Y', 'N7K', 'N7C', 'N7D', 'N7E', 'N7L', 'N7G', 'N7H', 'N7I', 'N7F', 'S8D', 'S8E', 'S8F', 'S8G', 'S8H', 'S8I', 'S8K', 'S8L', 'S8N', 'S8P'

In [24]:
df_test_no_query1 = list(set(df_test_mutants) - set(query1_mutants))
print(len(df_test_no_query1))

11224


In [25]:
df_test_active_learn = df_test[df_test['mutant'].isin(df_test_no_query1)]
df_test_active_learn

Unnamed: 0,mutant,sequence
0,V1D,MDNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
1,V1Y,MYNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
2,V1C,MCNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
3,V1A,MANEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
4,V1E,MENEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
...,...,...
11319,P655S,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11320,P655T,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11321,P655V,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...
11322,P655A,MVNEARGNSSLNPCLEGSASSGSESSKDSSRCSTPGLDPERHERLR...


In [26]:
df_test_active_learn[df_test_active_learn["mutant"] == "S77V"]

Unnamed: 0,mutant,sequence


In [27]:
for i in range(5):
    print(cor_esm(RNN_ensemble[i],val_sets[i]))

0.6451460057437225
0.560067133348446
0.6916470280178059
0.7136465882304719
0.6842313955690316


## Active Learning

In [28]:
def RNN_predict(model, test_dataset):
    model.eval()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        preds = []
        for row in test_dataset:
            mut,esm = row
            esm = esm.to(device)
            #mut = mut.to(device)
            preds.append(model(esm.unsqueeze(0)))
    preds = [float(x) for x in preds]
    return(preds)

In [29]:
#class for creating a dataset to predict with 
class ProteinTestESMDataset(Dataset):
    def __init__(self, mutant_list, emb_dir):
        super().__init__()
        self.mutants = mutant_list
        self.embeddings = []
        for mut in tqdm(mutant_list, desc='Loading esm embeddings'):
            name = mut
            emb_file = os.path.join(emb_dir, f'{name}.pt')
            emb = torch.load(emb_file)['mean_representations'][33]
            self.embeddings.append(emb)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, index):
        emb = self.embeddings[index]
        mut = self.mutants[index]
        return mut, emb

In [30]:
mutant_list = list(df_test_active_learn["mutant"])
test_active_learn = ProteinTestESMDataset(mutant_list, "esm_embeddings_test")

Loading esm embeddings: 100%|██████████| 11224/11224 [00:34<00:00, 327.10it/s]


In [31]:
ensemble_test_predicts = []
for i in range(5):
    RNN_preds = RNN_predict(RNN_ensemble[i], test_active_learn)
    ensemble_test_predicts.append(RNN_preds)

In [32]:
def calc_mean_and_sd(pred_list):
    mean = []
    sd = []
    for i in range(len(pred_list[0])):
        calc_list = []
        for j in range(len(pred_list)):
            calc_list.append(pred_list[j][i])
        mean.append(np.mean(calc_list))
        sd.append(np.std(calc_list))
    return (mean,sd)

In [33]:
means_RNN_ensemble, sds_RNN_ensemble = calc_mean_and_sd(ensemble_test_predicts)

In [34]:
active_learn_df = pd.DataFrame({"mutant": mutant_list, "RNN_model_1": ensemble_test_predicts[0], 
                                "RNN_model_2": ensemble_test_predicts[1], "RNN_model_3": ensemble_test_predicts[2],
                                "RNN_model_4": ensemble_test_predicts[3], "RNN_model_5": ensemble_test_predicts[4],
                                "Mean_DMS" : means_RNN_ensemble , "SD_DMS" : sds_RNN_ensemble})
active_learn_df

Unnamed: 0,mutant,RNN_model_1,RNN_model_2,RNN_model_3,RNN_model_4,RNN_model_5,Mean_DMS,SD_DMS
0,V1D,0.182941,0.191655,0.351250,0.126397,0.794630,0.329374,0.244356
1,V1Y,0.192862,0.237899,0.527609,0.109419,0.748183,0.363194,0.238479
2,V1C,0.172023,0.155978,0.370862,0.103844,0.769396,0.314421,0.244913
3,V1A,0.206846,0.417425,0.627325,0.179057,0.915782,0.469287,0.275729
4,V1E,0.187640,0.332994,0.480051,0.164003,0.889675,0.410872,0.264907
...,...,...,...,...,...,...,...,...
11219,P655S,0.155549,0.085672,0.147302,0.143114,0.745580,0.255443,0.246312
11220,P655T,0.172705,0.165946,0.235787,0.132652,0.804672,0.302352,0.253365
11221,P655V,0.162603,0.214961,0.343043,0.145074,0.802657,0.333667,0.244524
11222,P655A,0.170669,0.222135,0.392505,0.166648,0.869379,0.364267,0.265582


In [35]:
active_learn_df.to_csv("query2_active_learning.csv")

In [36]:
def active_learn(active_learn_df, beta):
    alpha_list = []
    for i in range(len(active_learn_df)):
        alpha = active_learn_df["Mean_DMS"][i] + (beta**(1/2) * active_learn_df["SD_DMS"][i])
        alpha_list.append(alpha)
    active_learn_df["Alpha_Score"] = alpha_list
    return (active_learn_df)

In [37]:
active_learn_df = active_learn(active_learn_df,4)
active_learn_df
active_learn_df.to_csv("query2_active_learning.csv")

In [127]:
active_learn_df.sort_values("Alpha_Score", ascending=False).head(100)

Unnamed: 0,mutant,RNN_model_1,RNN_model_2,RNN_model_3,RNN_model_4,RNN_model_5,Mean_DMS,SD_DMS,Alpha_Score
10055,R593A,0.470846,0.914777,0.215206,0.232565,0.825227,0.531724,0.291979,1.115683
2342,R131G,0.462670,0.913696,0.224873,0.246490,0.832928,0.536131,0.288726,1.113583
2373,R133P,0.475836,0.913461,0.222046,0.237987,0.822225,0.534311,0.288231,1.110774
88,R5N,0.467331,0.916028,0.221987,0.241014,0.821470,0.533566,0.288521,1.110608
4114,M245A,0.465566,0.914376,0.228136,0.224811,0.818960,0.530369,0.289709,1.109787
...,...,...,...,...,...,...,...,...,...
7671,K467A,0.471034,0.886453,0.205352,0.236715,0.810032,0.521917,0.282857,1.087630
7375,C451E,0.469920,0.888264,0.214475,0.239881,0.812152,0.524938,0.281130,1.087198
5066,R324M,0.353368,0.886423,0.190478,0.221431,0.812049,0.492750,0.297099,1.086948
2367,R133N,0.467440,0.892058,0.217041,0.238809,0.808047,0.524679,0.281035,1.086749


In [128]:
active_learn_df.sort_values("Mean_DMS", ascending=False).head(1000)

Unnamed: 0,mutant,RNN_model_1,RNN_model_2,RNN_model_3,RNN_model_4,RNN_model_5,Mean_DMS,SD_DMS,Alpha_Score
2342,R131G,0.462670,0.913696,0.224873,0.246490,0.832928,0.536131,0.288726,1.113583
8621,R518A,0.477493,0.906072,0.222484,0.244554,0.827644,0.535649,0.285890,1.107429
2373,R133P,0.475836,0.913461,0.222046,0.237987,0.822225,0.534311,0.288231,1.110774
88,R5N,0.467331,0.916028,0.221987,0.241014,0.821470,0.533566,0.288521,1.110608
3357,K198Q,0.477034,0.904346,0.221316,0.244529,0.818463,0.533138,0.283860,1.100857
...,...,...,...,...,...,...,...,...,...
466,S25K,0.448754,0.863414,0.195269,0.224780,0.780522,0.502548,0.276392,1.055332
10624,D624A,0.464275,0.825124,0.200790,0.238139,0.784319,0.502529,0.263015,1.028559
9887,D584N,0.467553,0.832310,0.195092,0.240766,0.776574,0.502459,0.263862,1.030184
4739,K287A,0.461335,0.835180,0.194085,0.237968,0.783714,0.502456,0.267039,1.036534


In [131]:
active_learn_df.sort_values("SD_DMS", ascending=False).head(100)

Unnamed: 0,mutant,RNN_model_1,RNN_model_2,RNN_model_3,RNN_model_4,RNN_model_5,Mean_DMS,SD_DMS,Alpha_Score
634,P34V,0.122858,0.834310,0.150564,0.123836,0.780246,0.402363,0.331203,1.064769
625,P34I,0.126094,0.834952,0.151882,0.117739,0.777460,0.401625,0.331030,1.063685
636,P34L,0.139693,0.848730,0.176091,0.144827,0.786397,0.419148,0.326140,1.071428
626,P34T,0.131012,0.824373,0.152367,0.133915,0.766578,0.401649,0.322160,1.045969
620,P34D,0.112605,0.790250,0.136292,0.111777,0.756968,0.381578,0.320386,1.022350
...,...,...,...,...,...,...,...,...,...
1854,D103P,0.252813,0.833864,0.165538,0.211090,0.786249,0.449911,0.295735,1.041380
10364,Y610K,0.199439,0.755065,0.113587,0.167564,0.765987,0.400329,0.295398,0.991125
7230,R444I,0.304534,0.871377,0.200796,0.220550,0.806391,0.480730,0.295216,1.071161
7086,L436R,0.247674,0.811360,0.154619,0.181429,0.775458,0.434108,0.295146,1.024401


In [132]:
query_mutants = list(active_learn_df.sort_values("Alpha_Score", ascending=False).head(100)["mutant"])
query_mutants_str = "\n".join(query_mutants)
with open("query2_run2_mutant.txt", 'w') as file:
    file.write(query_mutants_str)

In [23]:
test_mutant_list = list(df_test["mutant"])
test = ProteinTestESMDataset(test_mutant_list, "esm_embeddings_test")

Loading esm embeddings: 100%|██████████| 11324/11324 [00:08<00:00, 1367.30it/s]


In [24]:
ensemble_test_predicts_full = []
for i in range(5):
    RNN_preds = RNN_predict(RNN_ensemble[i], test)
    ensemble_test_predicts_full.append(RNN_preds)

In [25]:
best_RNN_preds = RNN_predict(RNN_ensemble[0], test)
best_RNN_ensemble_pred_df = pd.DataFrame({'mutant': test_mutant_list, 'DMS_score_predicted': best_RNN_preds})
best_RNN_ensemble_pred_df

Unnamed: 0,mutant,DMS_score_predicted
0,V1D,0.113374
1,V1Y,0.113965
2,V1C,0.114098
3,V1A,0.122715
4,V1E,0.117648
...,...,...
11319,P655S,0.114440
11320,P655T,0.114683
11321,P655V,0.116050
11322,P655A,0.119691


In [26]:
best_RNN_ensemble_pred_df.to_csv("RNN_red_from_ensemble_predictions.csv",index=False)

In [82]:
len(RNN_preds)

11324

In [98]:
DMS_Score_test, _ = calc_mean_and_sd(ensemble_test_predicts_full)

In [84]:
print(len(DMS_Score_test), len(test_mutant_list))

11324 11324


In [99]:
RNN_ensemble_pred_df = pd.DataFrame({'mutant': test_mutant_list, 'DMS_score_predicted': DMS_Score_test})
RNN_ensemble_pred_df

Unnamed: 0,mutant,DMS_score_predicted
0,V1D,0.168579
1,V1Y,0.186914
2,V1C,0.169218
3,V1A,0.232668
4,V1E,0.202389
...,...,...
11319,P655S,0.150979
11320,P655T,0.163953
11321,P655V,0.187040
11322,P655A,0.208302


In [100]:
RNN_ensemble_pred_df.to_csv("RNN_5_query2_upd_ensemble_predictions.csv",index=False)

In [104]:
top_ten_mutants = RNN_ensemble_pred_df.sort_values('DMS_score_predicted', ascending=False).head(10)
print(RNN_ensemble_pred_df.sort_values('DMS_score_predicted', ascending=False).head(10))
top_ten_list = list(top_ten_mutants["mutant"])
print(top_ten_list)
top_ten_string = "\n".join(top_ten_list)
with open("top10query2_2.txt", 'w') as file:
    file.write(top_ten_string)

      mutant  DMS_score_predicted
556     R30F             0.332197
553     R30Y             0.309540
540     S29F             0.309469
597     S32L             0.308339
601     S32F             0.307381
550     S29Y             0.304780
10142  R593A             0.303232
3390   K198L             0.303015
560     R30L             0.300149
539     S29L             0.299632
['R30F', 'R30Y', 'S29F', 'S32L', 'S32F', 'S29Y', 'R593A', 'K198L', 'R30L', 'S29L']


In [49]:
mutant_list = list(df_test["mutant"])

In [50]:
emb_list = []
for mutant in tqdm(mutant_list, desc='Loading esm embeddings'):
    name = mutant
    emb_file = os.path.join("esm_embeddings_test", f'{name}.pt')
    emb_test = torch.load(emb_file)['mean_representations'][33]
    emb_list.append(emb_test)

Loading esm embeddings: 100%|██████████| 11324/11324 [00:25<00:00, 437.78it/s]


In [52]:
test = ProteinTestESMDataset(mutant_list, "esm_embeddings_test")
test.__getitem__(0)[1]

Loading esm embeddings: 100%|██████████| 11324/11324 [00:13<00:00, 829.59it/s]


tensor([ 0.0269, -0.0786,  0.0247,  ..., -0.1781,  0.0061,  0.1503],
       device='cuda:0')