# SEMA-3D

**SEMA-3D** is a fine-tuned ESM-IF1 model aimed to predict epitope resiudes based on therity structures

The MIT License (MIT)

Copyright (c) 2016 AYLIEN

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

### Finetuning ESM-IF1 for epitope prediction tasks

In [1]:
import os
# set cuda params
# 'TORCH_HOME'directory will be used to save origenal esm-1v weights
os.environ['TORCH_HOME'] = "../torch_hub"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [3]:
import scipy
import sklearn
import esm
import pickle

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn
import math

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction
from transformers import EsmTokenizer, EsmForMaskedLM

from esm.pretrained import load_model_and_alphabet_hub

from sklearn.metrics import r2_score, mean_squared_error

from saprot_utils.foldseek_util import get_struc_seq
from pathlib import Path

In [4]:
def esmStructDataset(pdb_path, chain):
    '''
    Convert PDB-file into dataset format
    
        Parameters:
            pdb_path (Path): path to pdb-file
        Returns:
            dict (dict): dictionary, where keys are properties of the protein's tertiary structure
    '''
    # parsed_seqs = get_struc_seq("bin/foldseek", pdb_path, ["A"])["A"]
    parsed_seqs = get_struc_seq("saprot_utils/bin/foldseek", pdb_path)[chain]
    seq, foldseek_seq, combined_seq = parsed_seqs

    return {"seq":seq,
            "foldseek_seq":foldseek_seq,
            "combined_seq":combined_seq}

In [5]:
def create_Dataset(path='../data/pdb_structures.pkl'):
    '''
    Create dataset of protein's tertiary structure or load it form pickle-file
        Parameters:
            path (Path): path to pikle object with dataset of protein's tertiary structure
        Returens:
            esm_structs (dict): dataset of protein's tertiary structure
    '''
    esm_structs = pickle.load(open(path,'rb'))    
    path_pdbs = Path("../data/structs_antigen_fab/").glob("*.pdb")

    # # esm_structs = {}
    for pdb_path in list(path_pdbs):
        tmp_dict = esmStructDataset(pdb_path, esm_structs[pdb_path.name.split(".pdb")[0]]['chain'])
        esm_structs[pdb_path.name.split(".pdb")[0]]['combined_seq'] = tmp_dict['combined_seq']
        esm_structs[pdb_path.name.split(".pdb")[0]]['foldseek_seq'] = tmp_dict['foldseek_seq']
    return esm_structs    

In [6]:
def create_Dataset(path='../data/pdb_structures.pkl'):
    '''
    Create dataset of protein's tertiary structure or load it form pickle-file
        Parameters:
            path (Path): path to pikle object with dataset of protein's tertiary structure
        Returens:
            esm_structs (dict): dataset of protein's tertiary structure
    '''
    esm_structs = pickle.load(open(path,'rb'))    
    path_pdbs = Path("../data/structs_antigen_fab/").glob("*.pdb")

    # # esm_structs = {}
    for pdb_path in list(path_pdbs):
        tmp_dict = esmStructDataset(pdb_path, esm_structs[pdb_path.name.split(".pdb")[0]]['chain'])
        esm_structs[pdb_path.name.split(".pdb")[0]]['combined_seq'] = tmp_dict['combined_seq']
        esm_structs[pdb_path.name.split(".pdb")[0]]['foldseek_seq'] = tmp_dict['foldseek_seq']
    return esm_structs    

In [7]:
class PDB_Dataset(Dataset):
    """
    A class to represent a sutable data set for model. 
    
    convert original pandas data frame to model set,
    where 'token ids' is ESM-1v embedings corresponed to protein sequence (max length 1024 AA)
    and 'lables' is a contact number values
    Attributes:
        df (pandas.DataFrame): dataframe with two columns: 
                0 -- preotein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
        esm1v_batch_converter (function):
                    ESM function callable to convert an unprocessed (labels + strings) batch to a
                    processed (labels + tensor) batch.
        label_type (str):
                type of model: regression or binary

    """
    def __init__(self, df, label_type ='regression'):
        """
        Construct all the necessary attributes to the PDB_Database object.
        
        Parameters:
            df (pandas.DataFrame): dataframe with two columns: 
                0 -- protein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
            label_type (str):
                type of model: regression or binary
        """
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.tokenizer = EsmTokenizer.from_pretrained(model_path)

        self.df = df
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}
        
        # item['token_ids'] = self.tokenizer.tokenize(self.df.iloc[idx, 0][:1022])
        item['token_ids'] = self.tokenizer(self.df.iloc[idx, 0][:2044], padding = False, return_tensors="pt")['input_ids']
        item['labels'] = torch.unsqueeze(torch.LongTensor(self.df.iloc[idx, 1][:1022]),0)

        # print('PDB_Dataset pdb_id', self.df.iloc[idx, 2][:2048])
        # print('PDB_Dataset seq', len(self.df.iloc[idx, 0][:2048]), self.df.iloc[idx, 0][:2048])
        # print('PDB_Dataset token_ids', item['token_ids'].shape)
        # print('PDB_Dataset labels', item['labels'].shape)
        return item

    def __len__(self):
        return len(self.df)

In [8]:
class SaProtForTokenClassification(nn.Module):

    def __init__(self, num_labels = 2):
        super().__init__()
        self.num_labels = num_labels
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.encoder = EsmForMaskedLM.from_pretrained(model_path)

        self.classifier = nn.Linear(446, self.num_labels)

    def forward(self, token_ids, labels):
        # print('SaProtForTokenClassification token_ids', token_ids.shape)
        outputs = self.encoder(input_ids = token_ids)['logits']
        # print('SaProtForTokenClassification encoder outputs', outputs.shape)
        outputs = outputs[:,1:-1,:]
        logits = self.classifier(outputs)
        # print('logits', logits.shape)
        # print('logits', logits)
        return SequenceClassifierOutput(logits=logits)

In [9]:
def compute_metrics_regr(p: EvalPrediction):
    
    preds = p.predictions[:,:,1]

    batch_size, seq_len = preds.shape    
    out_labels, out_preds = [], []

    for i in range(batch_size):
        for j in range(seq_len):
            if p.label_ids[i, j] > -1:
                out_labels.append(p.label_ids[i][j])
                out_preds.append(preds[i][j])
                
    out_labels_regr = [math.log(t+1) for t in out_labels]

    
    return {
        "pearson_r": scipy.stats.pearsonr(out_labels_regr, out_preds)[0],
        "mse": mean_squared_error(out_labels_regr, out_preds),
        "r2_score": r2_score(out_labels_regr, out_preds)
    }


In [10]:
def model_init():
    return SaProtForTokenClassification().cuda()

In [11]:
class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, inputs, target, mask):   
        # print('loss inputs',inputs.shape, inputs)
        # print('target',target.shape, target)
        # print('mask',mask.shape)
        diff2 = (torch.flatten(inputs[:,:,1]) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
        result = torch.sum(diff2) / torch.sum(mask)
        if torch.sum(mask)==0:
            return torch.sum(diff2)
        else:
            #print('loss:', result)
            return result

    
class MaskedRegressTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        # print('trainer inputs',inputs['token_ids'].shape, inputs['token_ids'])
        # print('trainer labels',inputs['labels'].shape, inputs['labels'])
        # # inputs = inputs[0]
        # # labels = inputs.pop("labels")
        labels = inputs["labels"]
        labels = labels.squeeze().detach().cpu().numpy().tolist()
        labels = [math.log(t+1) if t!=-100 else -100 for t in labels]
        labels = torch.unsqueeze(torch.FloatTensor(labels), 0).cuda()
        masks = ~torch.eq(labels, -100).cuda()
        
        #masks = inputs.pop("masks")
        outputs = model(**inputs)
        logits = outputs.logits

        loss_fn = MaskedMSELoss()
        loss = loss_fn(logits, labels, masks)
        
        return (loss, outputs) if return_outputs else loss        
        
def collator_fn(x):
    if len(x)==1:
        return x[0]
    # print('x:', x)
    return x


## Data

In [13]:
esm_structs = create_Dataset(path='../data/pdb_structures_old.pkl')

In [14]:
print('Number of structures:', len(esm_structs.keys()))

Number of structures: 15719


In [15]:
train_set = pd.read_csv('../data/train_set.csv')
train_set = train_set.groupby('pdb_id_chain').agg({'resi_pos': list,
                                 'resi_aa': list,
                                 'resi_name': list,
                                 'contact_number': list,
                                 'contact_number_binary': list})\
                 .reset_index()

print('Init number of rows:', train_set.shape[0])

train_set['combined_seq'] = None
to_drop = []
for idx in train_set.index:
    pdb = train_set.loc[idx, 'pdb_id_chain']
    train_set.loc[idx, 'combined_seq'] = esm_structs[pdb]['combined_seq']

    s = ''.join(train_set.loc[idx, 'resi_aa'])
    n_s = train_set.loc[idx, 'combined_seq']
    n_ss = n_s[::2]
    if len(s) != len(n_ss):
        to_drop.append(pdb)

print('number of pdbs to drop:', len(to_drop))
print('number of rows to drop:', train_set[train_set.pdb_id_chain.isin(to_drop)].shape[0])
train_set = train_set[~train_set.pdb_id_chain.isin(to_drop)]
print('End number of rows:', train_set.shape[0])


train_ds = PDB_Dataset(train_set[['combined_seq', 'contact_number', 'pdb_id_chain']], 
                      label_type ='regression')

Init number of rows: 783
number of pdbs to drop: 92
number of rows to drop: 92
End number of rows: 691


In [15]:
# pdb = '7KBT_A_FE'
# # train_set[train_set.pdb_id_chain == '5LDN_A_LH'].combined_seq.str.len()
# labels = train_set[train_set.pdb_id_chain == pdb].contact_number.item()
# s = ''.join(train_set[train_set.pdb_id_chain == pdb].resi_aa.item())
# n_s = train_set[train_set.pdb_id_chain == pdb].combined_seq.item()
# n_ss = n_s[::2]
# print(s)
# print(n_ss)
# print(len(s), len(n_ss), len(labels))

# # # train_set

In [16]:
test_set = pd.read_csv('../data/test_set.csv')
test_set = test_set.groupby('pdb_id_chain').agg({'resi_pos': list,
                                 'resi_aa': list,
                                 'resi_name': list,
                                 'contact_number': list,
                                 'contact_number_binary': list})\
                 .reset_index()

print('Init number of rows:', test_set.shape[0])

test_set['combined_seq'] = None
to_drop = []
for idx in test_set.index:
    pdb = test_set.loc[idx, 'pdb_id_chain']
    test_set.loc[idx, 'combined_seq'] = esm_structs[pdb]['combined_seq']

    s = ''.join(test_set.loc[idx, 'resi_aa'])
    n_s = test_set.loc[idx, 'combined_seq']
    n_ss = n_s[::2]
    if len(s) != len(n_ss):
        to_drop.append(pdb)

print('number of pdbs to drop:', len(to_drop))
print('number of rows to drop:', test_set[test_set.pdb_id_chain.isin(to_drop)].shape[0])
test_set = test_set[~test_set.pdb_id_chain.isin(to_drop)]
print('End number of rows:', test_set.shape[0])

test_ds = PDB_Dataset(test_set[['combined_seq', 'contact_number', 'pdb_id_chain']], 
                      label_type ='regression')

Init number of rows: 101
number of pdbs to drop: 9
number of rows to drop: 9
End number of rows: 92


## Model training and test 

In [None]:
training_args = TrainingArguments(
    output_dir='./results_fold' ,          # output directory
    num_train_epochs=2,          # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
    warmup_steps=20,                # number of warmup steps for learning rate scheduler
    learning_rate=1e-05,             # learning rate
    weight_decay=0.0,                # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    save_strategy = "no",
    do_train=True,                   # Perform training
    do_eval=False,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=1,  # total number of steps before back propagation
    fp16=False,                       # Use mixed precision
    run_name="PDB_regr",      # experiment name
    seed=42,                         # Seed for experiment reproducibility
    load_best_model_at_end=False,
    metric_for_best_model="eval_r2",
    greater_is_better=True,

)

In [18]:
#create direactory to weights storage
if not os.path.exists("../models/"):
    os.makedirs("../models/")

In [19]:
torch.cuda.empty_cache()

In [21]:
# torch.cuda.empty_cache()
trainer = MaskedRegressTrainer(
    model=model_init(),                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset = train_ds,    # training dataset
    eval_dataset  = test_ds,    # evaluation dataset
    data_collator = collator_fn,
    compute_metrics = compute_metrics_regr
)
trainer.train()

#save weights
torch.save(trainer.model.state_dict(), "../models/sema_saprot_old_0.pth")

Some weights of EsmForMaskedLM were not initialized from the model checkpoint at westlake-repl/SaProt_650M_PDB and are newly initialized: ['esm.contact_head.regression.bias', 'esm.embeddings.position_embeddings.weight', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  item['labels'] = torch.unsqueeze(torch.LongTensor(self.df.iloc[idx, 1][:1022]),0)


Epoch,Training Loss,Validation Loss,Pearson R,Mse,R2 Score
1,0.2058,0.185335,0.217532,0.172641,-6.27433
2,0.1763,0.098038,0.227606,0.083671,-2.525531
3,0.1678,0.115015,0.24079,0.101489,-3.276303


  item['labels'] = torch.unsqueeze(torch.LongTensor(self.df.iloc[idx, 1][:1022]),0)
  item['labels'] = torch.unsqueeze(torch.LongTensor(self.df.iloc[idx, 1][:1022]),0)
