In [5]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from torch import nn
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import esm

In [31]:
train = pd.read_csv("./single_test.csv", header=None)
train.columns = ["wt", "mut", "score", "pos"]

In [32]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [33]:
SEED = 42
TRAIN_SIZE = 7000

In [34]:
train_df = train

In [35]:
train_df.index = np.arange(0, len(train_df))

In [36]:
class Protseq(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

        
    def __getitem__(self, idx):
        _, _, wt = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "wt"]))])
        _, _, mut = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "mut"]))])
        pos = self.df.loc[idx, "pos"]
        target = torch.FloatTensor([self.df.loc[idx, "score"]])
        return wt, mut, pos, target

    
    def __len__(self):
        return len(self.df)

In [37]:
BATCH_SIZE = 1

In [38]:
train_ds = Protseq(train_df)

In [39]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [40]:
HIDDEN_UNITS_POS_CONTACT = 5
class ESM_concat_mut(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False
            

    def forward(self, wt_ids, mut_ids, pos):
        outputs1 = self.esm2.forward(wt_ids, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(mut_ids, repr_layers=[33])[
            'representations'][33]
        wt_pos = outputs1[:, pos, :].squeeze(1)
        mut_pos = outputs2[:, pos, :].squeeze(1)
        pos_concat = torch.cat((wt_pos, mut_pos), 1)
     
        return pos_concat

In [41]:
lr = 1e-5
EPOCHS = 3

In [42]:
model = ESM_concat_mut().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [43]:
model

ESM_concat_mut(
  (esm2): ESM2(
    (embed_tokens): Embedding(33, 1280, padding_idx=1)
    (layers): ModuleList(
      (0-32): 33 x TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
    )
    (contact_head): ContactPredictionHead(
      (regression): Linear(in_features=660, out_features=1, bias=True)
      (activation): 

In [44]:
def retrieve_embeddings(train_dataloader, model):
    with open("./embedings_tensors.txt", 'a+') as fw:
        for batch in tqdm(train_dataloader):
            wt_ids, mut_ids, pos, labels = batch
            wt_ids = wt_ids.squeeze_(1).to(device)
            mut_ids = mut_ids.squeeze_(1).to(device)
            labels = labels.to(device)
            pos = pos.to(device)
            logits = model(wt_ids, mut_ids, pos)
            logits = logits.detach().cpu().numpy()
            np.savetxt(fw, logits, fmt='%d', delimiter=',')

In [45]:
retrieve_embeddings(train_dataloader, model)

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 492/492 [00:30<00:00, 16.03it/s]
