In [9]:
from transformers import T5Tokenizer, T5EncoderModel
import re

import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
import scipy
from scipy import stats
import torch.nn.functional as F
from torch.cuda.amp import autocast
torch.cuda.empty_cache()

import warnings
warnings.filterwarnings("ignore")

HIDDEN_UNITS_POS_CONTACT = 5
#device = torch.device("cpu")
device = torch.device("cuda:0") #if torch.cuda.is_available() else "cpu")

In [10]:
class ProstT5_mut(nn.Module):

    def __init__(self):
        super().__init__() 
        self.prostt5 = T5EncoderModel.from_pretrained("Rostlab/ProstT5") 
        self.classifier = nn.Linear(1024, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1024)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1024)))
        

    def forward(self, token_ids1, token_ids2, pos): 
      #  with torch.no_grad():
        outputs1 = self.prostt5.forward(token_ids1).last_hidden_state  
        outputs2 = self.prostt5.forward(token_ids2).last_hidden_state
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits

In [11]:
class ProteinDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.tokenizer = T5Tokenizer.from_pretrained('Rostlab/ProstT5', do_lower_case=False)

    def __getitem__(self, idx):
        wild_seq = [self.df.iloc[idx]['wild_type'], self.df.iloc[idx]['structure']]
        wild_seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in wild_seq]
        wild_seq = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in wild_seq]
        prostt5_batch_tokens1 = self.tokenizer.batch_encode_plus(wild_seq, add_special_tokens=True, padding="longest", return_tensors='pt')
        
        mut_seq = [self.df.iloc[idx]['mutated'], self.df.iloc[idx]['structure']]
        mut_seq = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in mut_seq]
        mut_seq = [ "<AA2fold>" + " " + s if s.isupper() else "<fold2AA>" + " " + s for s in mut_seq]
        prostt5_batch_tokens2 = self.tokenizer.batch_encode_plus(mut_seq, add_special_tokens=True, padding="longest", return_tensors='pt')
        
        pos = self.df.iloc[idx]['pos']
        return prostt5_batch_tokens1, prostt5_batch_tokens2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)

In [12]:
def train(epoch):
    scaler = torch.cuda.amp.GradScaler()
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()

    for idx, batch in enumerate(training_loader):
        
        input_ids1, input_ids2, pos, labels = batch 
        input_ids1 = input_ids1['input_ids'].to(device)[0] 
        input_ids2 = input_ids2['input_ids'].to(device)[0] 
        labels = labels.to(device)
        pos = pos.to(device)
        
        with autocast():
            
            logits = model(token_ids1 = input_ids1, token_ids2 = input_ids2, pos = pos).to(device) 
            loss = torch.nn.functional.mse_loss(logits, labels)

        tr_loss += loss.item()
        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)

        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=0.1)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scheduler.step()
        scaler.update()
        
    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")
lr = 1e-5

In [13]:
lr = 1e-5
EPOCHS = 3

models = ['ProstT5_mut']

full_df = pd.read_csv('training_set.csv',sep=',')

preds = {n:[] for n in models} 
true = [None] 

for model_name in models:
    model_class = globals()[model_name]
    print(f'Training model {model_name}')
    train_df = full_df
    train_ds = ProteinDataset(train_df)
        
    model = model_class()    
    model.to(device) 
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    training_loader = DataLoader(train_ds, batch_size=1, num_workers = 2, shuffle = True)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
        
    for epoch in range(EPOCHS):
        print(epoch)
        train(epoch)
         
    model.to('cpu') 
        
    torch.save(model.state_dict(), 'weights/' + model_name)
    
    del model

Training model ProstT5_mut


Some weights of the model checkpoint at Rostlab/ProstT5 were not used when initializing T5EncoderModel: ['decoder.block.17.layer.0.SelfAttention.k.weight', 'decoder.block.22.layer.2.DenseReluDense.wo.weight', 'decoder.block.4.layer.0.layer_norm.weight', 'decoder.block.8.layer.2.layer_norm.weight', 'decoder.block.7.layer.2.DenseReluDense.wi.weight', 'decoder.block.14.layer.0.SelfAttention.o.weight', 'decoder.block.15.layer.1.EncDecAttention.q.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.16.layer.1.EncDecAttention.v.weight', 'decoder.block.22.layer.1.EncDecAttention.q.weight', 'decoder.block.18.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.2.DenseReluDense.wo.weight', 'decoder.block.13.layer.0.SelfAttention.q.weight', 'decoder.block.9.layer.1.EncDecAttention.k.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.10.layer.2.DenseReluDense.wo.weight', 'decoder.block.11.layer.1.EncDecAttention.v.weight', 'decoder.block.12.layer.1.En

Some weights of T5EncoderModel were not initialized from the model checkpoint at Rostlab/ProstT5 and are newly initialized: ['encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


0
Training loss epoch: 4.744966924772095
1
Training loss epoch: 4.599168683252638
2
Training loss epoch: 4.5276364176634765
