# Imports and parameters

In [5]:
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch
import re
import gc

from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import trange, tqdm_notebook
from transformers import  BertForTokenClassification, AutoTokenizer
from torch.nn import BCELoss
from torch.optim import AdamW
from torch import nn

In [6]:
batch_size = 64
test_size = 0.05
learning_rate = 2e-5
epochs = 20
validation_size=100
model_name ='Rostlab/prot_bert'

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForTokenClassification: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the

# Loading and preparing Data Set

In [8]:
df = pd.read_csv("./data/binding_affinity_hla_cleaned.csv", index_col=0)
df.dropna(inplace=True)

In [9]:
hla_group_count = len(df['HLA_group_idx'].unique())
hla_gene_count = len(df['HLA_gene'].unique())

In [10]:
df['HLA_group_one_hot_encode'] = pd.get_dummies(df['HLA_group_idx']).values.tolist()
df['HLA_gene_one_hot_encode'] =  pd.get_dummies(df['HLA_gene']).values.tolist()

In [11]:
df.head()

Unnamed: 0,MHC_sequence,MHC_type,peptide_sequence,label,HLA_gene,HLA_allele,HLA_allele_group,HLA_allele_id,HLA_group_idx,HLA_group_one_hot_encode,HLA_gene_one_hot_encode
0,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B*27:05,ERLKEVQKR,1,B,27:05,27,5,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0]"
1,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B*27:05,KPRKTAEVAGKTL,1,B,27:05,27,5,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0]"
2,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B*27:05,KEARRIIKK,1,B,27:05,27,5,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0]"
3,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B*27:05,EEKITEAKEL,0,B,27:05,27,5,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0]"
4,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,HLA-B*27:05,SLPSSRAARVPG,0,B,27:05,27,5,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0]"


In [12]:
def clean_for_tokenizer(s: str):
    return " ".join(list(re.sub(r"[UZOB]", "X", s)))

def extract_for_tokenizer(row):
    return [clean_for_tokenizer(row["MHC_sequence"]), clean_for_tokenizer(row["peptide_sequence"])]

In [13]:
extracted_for_tokenizer_data = df.apply(axis=1, func=extract_for_tokenizer)

In [14]:
extracted_for_tokenizer_data = extracted_for_tokenizer_data.tolist()

In [15]:
train_data, test_data, train_label, test_label, train_group_encode, test_group_encode, train_gene_encode, test_gene_encode = train_test_split(extracted_for_tokenizer_data, df['label'].tolist(), df['HLA_group_one_hot_encode'].tolist(), df['HLA_gene_one_hot_encode'].tolist(), test_size=test_size, random_state=42)

In [32]:
del df
gc.collect()
tokenized_train_data = tokenizer.batch_encode_plus(train_data, add_special_tokens=True, padding=True, return_tensors='pt')
del train_data
gc.collect()
train_label_tensors = torch.tensor(train_label, dtype=torch.float)
del train_label
gc.collect()
train_group_encode_tensors = torch.tensor(train_group_encode)
del train_group_encode
gc.collect()
train_gene_encode_tensors = torch.tensor(train_gene_encode)
del train_gene_encode
gc.collect()

0

In [16]:
class MHCPeptideDataSet(Dataset):
    def __init__(self, data, labels, allele_group_one_hot, hla_gene_one_hot):
        self.data = data
        self.data['allele_group_encoding'] = allele_group_one_hot
        self.data['hla_gene_encoding'] = hla_gene_one_hot
        self.labels = labels
        self.size = self.labels.shape[0]

    def __len__(self):
        return self.size

    def __getitem__(self, index):
        if index < 0 or index >= self.__len__():
            raise IndexError(f"Index {index} is out of bounds for dataset with length {self.__len__()}")
        item = {key: value[index] for (key, value) in self.data.items()}
        item['labels'] = self.labels[index]
        return item

In [34]:
dataset = MHCPeptideDataSet(tokenized_train_data, train_label_tensors, allele_group_one_hot=train_group_encode_tensors, hla_gene_one_hot=train_gene_encode_tensors)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

# Creating Model

In [17]:
bert_model.to('cuda')
class BERTClassification(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = bert_model.bert
        self.head = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(int(self.bert.config.hidden_size + hla_group_count + hla_gene_count), 512),
            nn.ReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        self.freeze()
        # Unfreeze the last layer of the BERT model

    def freeze(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_layer(self, layer: int):
        for param in self.bert.encoder.layer[layer].parameters():
            param.requires_grad = True


    def get_special_params(self):
        return list(self.head.parameters()) + list(self.bert.parameters())

    def forward(self, allele_group_encoding, hla_gene_encoding, *args, **kwargs):
        outputs = self.bert(*args, **kwargs)
        return self.head(torch.cat((outputs[0][:, 0, :], allele_group_encoding, hla_gene_encoding), dim=1))

model = BERTClassification().to('cuda')

## WarmUp round

In [18]:
optimizer = AdamW(model.get_special_params(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=1, threshold=1e-2)
criterion = BCELoss()

In [19]:
tmp = torch.load('./output/model_state')
model.load_state_dict(tmp['model_state_dict'])
optimizer.load_state_dict(tmp['optimizer_state_dict'])
scheduler.load_state_dict(tmp['scheduler_state_dict'])

In [38]:
model.unfreeze_layer(-1)

In [21]:
def set_lr(lr:float):
    for g in optimizer.param_groups:
        g['lr'] = lr

In [23]:
set_lr(5e-6)

In [24]:
for epoch in trange(epochs):
    running_loss = 0.0
    model.train()
    for batch in  tqdm_notebook(dataloader):
        batch_db = {key: value.to('cuda') for key, value in batch.items() if key != 'labels'}
        batch_labels = batch['labels'].to('cuda')
        optimizer.zero_grad()

        # Forward
        with torch.set_grad_enabled(True):
            output = model(**batch_db)
            loss = criterion(output.view(-1), batch_labels)
            loss.backward()
            optimizer.step()

        # Statistics
        running_loss += loss.item() * batch_labels.size(0)
    epoch_loss = running_loss / len(dataloader)
    print('epoch: {} Loss: {:.4f}'.format(epoch, epoch_loss))

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/19612 [00:00<?, ?it/s]

epoch: 0 Loss: 12.2929


  0%|          | 0/19612 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [25]:
torch.save(
    {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    },
    './output/model_state'
)

# Test

In [25]:
gc.collect()
tokenized_test_data = tokenizer.batch_encode_plus(test_data, add_special_tokens=True, padding=True, return_tensors='pt')
del test_data
gc.collect()
test_label_tensors = torch.tensor(test_label, dtype=torch.float)
del test_label
gc.collect()
test_group_encode_tensors = torch.tensor(test_group_encode)
del test_group_encode
gc.collect()
test_gene_encode_tensors = torch.tensor(test_gene_encode)
del test_gene_encode
gc.collect()

0

In [26]:
model.eval()
test_dataset = MHCPeptideDataSet(tokenized_test_data, test_label_tensors, allele_group_one_hot=test_group_encode_tensors, hla_gene_one_hot=test_gene_encode_tensors)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [27]:
model_outputs = []
actual_labels = []

In [28]:
model.eval()
for batch in  tqdm_notebook(test_dataloader):
    batch_db = {key: value.to('cuda') for key, value in batch.items() if key != 'labels'}
    batch_labels = batch['labels'].to('cuda')

    # Forward
    with torch.no_grad():
        test_output = model(**batch_db)
        model_outputs += [i[0] for i in test_output.tolist()]
        actual_labels += batch_labels.tolist()

  0%|          | 0/1401 [00:00<?, ?it/s]

In [None]:
fine_tuned_bert = model.bert


In [29]:
df_test_model = pd.DataFrame(data= {
    "model_output": model_outputs,
    "labels": actual_labels,
})

In [30]:
df_test_model

Unnamed: 0,model_output,labels
0,0.106708,1.0
1,0.008202,0.0
2,0.281923,0.0
3,0.048741,0.0
4,0.014424,0.0
...,...,...
89649,0.003637,0.0
89650,0.028282,0.0
89651,0.045913,0.0
89652,0.977217,1.0


In [31]:
df_test_model.to_csv("./output/checkpoint_model_bert.csv")