# SEMA-1D

## Requirements 

In [None]:
import scipy
import sklearn
import math

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction

import esm
from esm.pretrained import load_model_and_alphabet_hub

from sklearn.metrics import r2_score, mean_squared_error

from tqdm import tqdm

## Set model

In [None]:
class ESM1vForTokenClassification(nn.Module):

    def __init__(self, num_labels = 2, pretrained_no = 1):
        super().__init__()
        self.num_labels = num_labels
        self.model_name = "esm1v_t33_650M_UR90S_" + str(pretrained_no) 
        
        self.esm1v, self.esm1v_alphabet = load_model_and_alphabet_hub(self.model_name)        
        self.classifier = nn.Linear(1280, self.num_labels)

    def forward(self, token_ids):
                
        outputs = self.esm1v.forward(token_ids, repr_layers=[33])['representations'][33]
        outputs = outputs[:,1:-1,:]
        logits = self.classifier(outputs)

        return SequenceClassifierOutput(logits=logits)

In [None]:
class PDB_bin_Dataset(Dataset):
    def __init__(self, df, label_type ='regression'):
        self.df = df
        _, esm1v_alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
        self.esm1v_batch_converter = esm1v_alphabet.get_batch_converter()
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}
        _, _, esm1b_batch_tokens = self.esm1v_batch_converter([('' , ''.join(self.df.iloc[idx,0])[:1022])])
        item['token_ids'] = esm1b_batch_tokens
        item['labels'] = torch.unsqueeze(torch.LongTensor(self.df.iloc[idx, 1][:1022]),0)

        return item

    def __len__(self):
        return len(self.df)

In [None]:
model = ESM1vForTokenClassification(pretrained_no = 1).cuda()

## Prediction

In [None]:
test_set = pd.read_csv('data/sema_1d_test_set.csv')
test_ds = PDB_bin_Dataset(test_set[['fullseq', 'cn']])

In [None]:
res=[]
for ens_idx in range(5):

    model.load_state_dict(torch.load(f'sema_1d_epitopes_{str(ens_idx)}.pth'))
    model.eval()
    model.cuda()

with torch.no_grad():
    preds=[]
    for it in tqdm(test_ds):
        preds.append(model.forward(it['token_ids'].cuda())[0][0][:,1].cpu().numpy())
res.append(preds)

In [None]:
res = np.mean(np.array(plists), axis = 0).tolist()
test_set['pred'] = [p for p,cn in res]