In [61]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from torch import nn
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import gc
from scipy import stats

In [2]:
# !pip install fair-esm

In [64]:
single = pd.read_csv("single_train.csv", header=None)
#multiple = pd.read_csv("train_dataset_multiple.csv", header=None)

In [36]:
def process_change(item):
    length = item.count('_')
    condition = length == 0
    return condition
    

In [39]:
#subset_multiple = multiple[multiple[3].apply(lambda x: process_change(x))]
#data = pd.concat([single, subset_multiple])

In [65]:
single.columns = ["wt", "mut", "score", "pos"]

In [66]:
import esm

In [67]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [74]:
SEED = 42
TRAIN_SIZE = 7000

In [75]:
data = single.sample(frac=1.0, random_state=SEED)

In [76]:
data

Unnamed: 0,wt,mut,score,pos
6021,GSSKTQYEYDTKEEAQKAYEKFKKQGIPVTITQKNGKWFVQVE,GSSKTQYEEDTKEEAQKAYEKFKKQGIPVTITQKNGKWFVQVE,0.41,6
2127,TTIKVNGQEYTVPLSPEQAAKAAKKRWPDYEVQIHGNTVKVTR,TTIKVNGQEYTVPLSPEQAAKAAKKRWPDYEVQIHGSTVKVTR,0.02,37
8473,RKWEEIAERLREEFNINPEEAREAVEKAGGNEEEARRIVKKRL,RKWEEIAERLREEFNINPEEAREAVEKAGGNEEEARRIVKKVL,0.29,42
6191,TIDEIIKALEQAVKDNKPIQVGNYTVTSADEAEKLAKKLKKEY,TIDEIIKALEQAVKDNKPIQVGNYTVTSADEAEKLAKKLKKIY,0.15,42
5382,GSSTTRYRFTDEEEARRAAKEWARRGYQVHVTQNGTYWEVEVR,GSSTTRYRFTDEEEARRAAKEWARRGYQVHVTQNGTYWEVEGR,0.14,39
...,...,...,...,...
5734,GSSKTQYEYDTKEEAQKAYEKFKKQGIPVTITQKNGKWFVQVE,GSSKTQYEYDTKEEAQKAYEKFKSQGIPVTITQKNGKWFVQVE,0.28,21
5191,GSSTTRYRFTDEEEARRAAKEWARRGYQVHVTQNGTYWEVEVR,GSSTTRYRFTDEEEARSAAKEWARRGYQVHVTQNGTYWEVEVR,0.03,14
5390,GSSTTRYRFTDEEEARRAAKEWARRGYQVHVTQNGTYWEVEVR,GSSTTRYRFTDEEEARRAAKEWARRGYQVHVTQNGTYWEVESR,0.04,39
860,SKDEAQREAERAIRSGNKEEARRILEEAGYSPEQAERIIRKLG,SKDEAQREAERALRSGNKEEARRILEEAGYSPEQAERIIRKLG,0.23,13


In [77]:
train_df = data.iloc[:TRAIN_SIZE, :]
valid_df = data.iloc[TRAIN_SIZE:, :]

In [78]:
train_df.index = np.arange(0, len(train_df))
valid_df.index = np.arange(0, len(valid_df))

In [79]:
class Protseq(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

        
    def __getitem__(self, idx):
        _, _, wt = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "wt"]))])
        _, _, mut = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "mut"]))])
        pos = self.df.loc[idx, "pos"]
        target = torch.FloatTensor([self.df.loc[idx, "score"]])
        return wt, mut, target

    
    def __len__(self):
        return len(self.df)

In [80]:
BATCH_SIZE = 180

In [86]:
train_ds = Protseq(train_df)
valid_ds = Protseq(valid_df)

In [87]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)

In [88]:
HIDDEN_UNITS_POS_CONTACT = 5
class ESM_sum_seqembed(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)
    
    
    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False
            

    def forward(self, wt_ids, mut_ids):
        outputs1 = self.esm2.forward(wt_ids, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(mut_ids, repr_layers=[33])[
            'representations'][33]
        outputs1_mean = outputs1.mean(1)
        outputs2_mean = outputs2.mean(1)
        add = outputs1_mean + outputs2_mean
        fc1_outputs = F.relu(self.fc1(add))
        logits = self.fc2(fc1_outputs)
        return logits

In [84]:
def train_eval(train_dataloader, valid_dataloader, model, epochs):
    for _ in range(epochs):
        model.train()
        tr_loss = 0
        for batch in tqdm(train_dataloader):
            wt_ids, mut_ids, labels = batch
            wt_ids = wt_ids.squeeze(1).to(device)
            mut_ids = mut_ids.squeeze(1).to(device)
            labels = labels.to(device)
            logits = model(wt_ids, mut_ids)
            loss = torch.nn.functional.mse_loss(logits, labels)
            tr_loss += loss.item()
        
            torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tr_loss = tr_loss / len(train_dataloader)
        print(f"Training loss epoch: {tr_loss}")
        
        model.eval()
        valid_loss = 0
        with torch.no_grad():
            for batch in tqdm(valid_dataloader):
                wt_ids, mut_ids, labels = batch
                wt_ids = wt_ids.squeeze(1).to(device)
                mut_ids = mut_ids.squeeze(1).to(device)
                labels = labels.to(device)
                logits = model(wt_ids, mut_ids)
                loss = torch.nn.functional.mse_loss(logits, labels)
                valid_loss += loss.item()

        valid_loss = valid_loss / len(valid_dataloader)
        print(f"Valid loss epoch: {valid_loss}")
    return model

In [89]:
lr = 1e-5
EPOCHS = 3

In [90]:
model = ESM_sum_seqembed()
model = model.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [91]:
model._freeze_esm2_layers()

In [92]:
model = train_eval(train_dataloader, valid_dataloader, model, EPOCHS)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:12<00:00,  1.87s/it]


Training loss epoch: 0.07396647811700137


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:15<00:00,  1.71s/it]


Valid loss epoch: 0.06479279614157146


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:15<00:00,  1.93s/it]


Training loss epoch: 0.06639366262616256


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:15<00:00,  1.75s/it]


Valid loss epoch: 0.060778735412491694


100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [01:18<00:00,  2.01s/it]


Training loss epoch: 0.06367301606597045


100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:16<00:00,  1.80s/it]

Valid loss epoch: 0.0594661393099361





In [93]:
def collect_pred(model, valid_dataloader):
    model.eval()
    y_pred = []
    y_true = []
    for batch in tqdm(valid_dataloader):
        wt_ids, mut_ids, labels = batch
        wt_ids = wt_ids.squeeze(1).to(device)
        mut_ids = mut_ids.squeeze(1).to(device)
        labels = labels.to(device)
        labels = labels.detach().cpu().numpy()
        logits = model(wt_ids, mut_ids)
        logits = logits.detach().cpu().numpy()
        y_pred.append(logits)
        y_true.append(labels)
    y_pred = np.concatenate(y_pred).reshape(-1, )
    y_true = np.concatenate(y_true).reshape(-1, )
    return y_pred, y_true

In [94]:
y_pred, y_true = collect_pred(model, valid_dataloader)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:15<00:00,  1.73s/it]


In [95]:
print(stats.spearmanr(y_pred, y_true))
print(stats.pearsonr(y_pred, y_true))

SignificanceResult(statistic=0.5363510912165107, pvalue=2.30541164644201e-115)
PearsonRResult(statistic=0.45927509673088246, pvalue=3.777181433178789e-81)


#### Test data evaluation

In [96]:
single_test = pd.read_csv('./single_test.csv', header=None)
multiple_test = pd.read_csv('./multiple_test.csv', header=None)

In [97]:
#test = pd.concat([single_test, multiple_test])
single_test.columns = ["wt", "mut", "score", "pos"]

In [98]:
single_test.index = np.arange(0, len(single_test))

In [99]:
test_ds = Protseq(single_test)

In [100]:
test_dataloader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

In [101]:
y_pred, y_true = collect_pred(model, test_dataloader)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.60s/it]


In [102]:
print(stats.pearsonr(y_pred, y_true))
print(stats.spearmanr(y_pred, y_true))

PearsonRResult(statistic=0.46352433648499825, pvalue=1.4129390081740676e-27)
SignificanceResult(statistic=0.472110825370124, pvalue=1.1167362818397197e-28)
