In [4]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from torch import nn
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import esm

In [27]:
train = pd.read_csv("./train_dataset.csv", header=None)
train.columns = ["wt", "mut", "score", "pos"]

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [29]:
SEED = 42
TRAIN_SIZE = 7000

In [30]:
data = train.sample(frac=1.0, random_state=SEED)

In [31]:
train_df = data.iloc[:TRAIN_SIZE, :]


In [32]:
train_df.index = np.arange(0, len(train_df))

In [33]:
class Protseq(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

        
    def __getitem__(self, idx):
        _, _, wt = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "wt"]))])
        _, _, mut = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "mut"]))])
        pos = self.df.loc[idx, "pos"]
        target = torch.FloatTensor([self.df.loc[idx, "score"]])
        return wt, mut, pos, target

    
    def __len__(self):
        return len(self.df)

In [34]:
BATCH_SIZE = 1

In [35]:
train_ds = Protseq(train_df)

In [36]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

In [37]:
HIDDEN_UNITS_POS_CONTACT = 5
class ESM_concat_mut(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280*2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)
    
    
    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False
            

    def forward(self, wt_ids, mut_ids, pos):
        outputs1 = self.esm2.forward(wt_ids, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(mut_ids, repr_layers=[33])[
            'representations'][33]
        wt_pos = outputs1[:, pos, :].squeeze(1)
        mut_pos = outputs2[:, pos, :].squeeze(1)
        pos_concat = torch.cat((wt_pos, mut_pos), 1)
        fc1_outputs = F.relu(self.fc1(pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits

In [38]:
lr = 1e-5
EPOCHS = 3

In [39]:
model = ESM_concat_mut().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [40]:
model

ESM_concat_mut(
  (esm2): ESM2(
    (embed_tokens): Embedding(33, 1280, padding_idx=1)
    (layers): ModuleList(
      (0): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
          (rot_emb): RotaryEmbedding()
        )
        (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=1280, out_features=5120, bias=True)
        (fc2): Linear(in_features=5120, out_features=1280, bias=True)
        (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      )
      (1): TransformerLayer(
        (self_attn): MultiheadAttention(
          (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
   

In [43]:
def retrieve_embeddings(train_dataloader, model):
    for batch in tqdm(train_dataloader):
        wt_ids, mut_ids, pos, labels = batch
        wt_ids.squeeze_(1).to(device)
        mut_ids.squeeze_(1).to(device)
        labels = labels.to(device)
        logits = model(wt_ids, mut_ids, pos)
    print(logits)

In [44]:
retrieve_embeddings(train_dataloader, model)

  1%|â–Ž                                     | 48/7000 [03:11<7:42:27,  3.99s/it]


KeyboardInterrupt: 