In [45]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F
from torch import nn
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [None]:
!pip install fair-esm

In [58]:
single = pd.read_csv("/kaggle/input/hackaton/train_dataset.csv", header=None)
multiple = pd.read_csv("/kaggle/input/hackaton/train_dataset_multiple.csv", header=None)

In [63]:
data = pd.concat([single, multiple])

In [64]:
data.columns = ["wt", "mut", "score", "pos"]

In [65]:
import esm

In [66]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [69]:
SEED = 42
TRAIN_SIZE = 20000

In [70]:
data = data.sample(frac=1.0, random_state=SEED)

In [71]:
train_df = data.iloc[:TRAIN_SIZE, :]
valid_df = data.iloc[TRAIN_SIZE:, :]

In [72]:
train_df.index = np.arange(0, len(train_df))
valid_df.index = np.arange(0, len(valid_df))

In [73]:
class Protseq(Dataset):
    def __init__(self, df):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()

        
    def __getitem__(self, idx):
        _, _, wt = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "wt"]))])
        _, _, mut = self.esm1v_batch_converter([('' , ''.join(self.df.loc[idx, "mut"]))])
        pos = self.df.loc[idx, "pos"]
        target = torch.FloatTensor([self.df.loc[idx, "score"]])
        return wt, mut, pos, target

    
    def __len__(self):
        return len(self.df)

In [74]:
BATCH_SIZE = 180

In [75]:
train_ds = Protseq(train_df)
valid_ds = Protseq(valid_df)

In [76]:
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False)


In [92]:
HIDDEN_UNITS_POS_CONTACT = 5
class ESM_sum_seqembed(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)
    
    
    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False
            

    def forward(self, wt_ids, mut_ids):
        outputs1 = self.esm2.forward(wt_ids, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(mut_ids, repr_layers=[33])[
            'representations'][33]
        outputs1_mean = outputs1.mean(1)
        outputs2_mean = outputs2.mean(1)
        add = outputs1_mean + outputs2_mean
        fc1_outputs = F.relu(self.fc1(add))
        logits = self.fc2(fc1_outputs)
        return logits

In [90]:
def train_2(train_dataloader, model, epochs):
    model.train()
    for _ in range(epochs):
        tr_loss = 0
        for batch in tqdm(train_dataloader):
            wt_ids, mut_ids, _, labels = batch
            wt_ids = wt_ids.squeeze(1).to(device)
            mut_ids = mut_ids.squeeze(1).to(device)
            labels = labels.to(device)
            logits = model(wt_ids, mut_ids)
            loss = torch.nn.functional.mse_loss(logits, labels)
            tr_loss += loss.item()
        
            torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = tr_loss / len(train_dataloader)
        print(f"Training loss epoch: {epoch_loss}")

In [86]:
lr = 1e-5
EPOCHS = 3

In [93]:
model = ESM_sum_seqembed().to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [94]:
train_2(train_dataloader, model, EPOCHS)

  0%|          | 0/112 [00:02<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 160.00 MiB (GPU 0; 14.76 GiB total capacity; 13.69 GiB already allocated; 17.75 MiB free; 14.03 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [43]:
y_pred, y_true = valid(model, valid_dataloader)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 14.76 GiB total capacity; 13.73 GiB already allocated; 3.75 MiB free; 14.04 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF