# SEMA-3D

**SEMA-3D** is a fine-tuned ESM-IF1 model aimed to predict epitope resiudes based on therity structures

The MIT License (MIT)

Copyright (c) 2016 AYLIEN

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

### Finetuning ESM-IF1 for epitope prediction tasks

In [1]:
import os
# set cuda params
# 'TORCH_HOME'directory will be used to save origenal esm-1v weights
os.environ['TORCH_HOME'] = "../torch_hub"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [2]:
import copy
import math
import json
import scipy
import pickle

import pandas as pd
import numpy as np
from pathlib import Path


import esm
from esm.data import BatchConverter
from esm.inverse_folding.util import CoordBatchConverter

import torch
from torch.utils.data import Dataset
from torch import nn
from tqdm import tqdm

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction
import scipy

from biotite.structure.residues import get_residues

import sklearn
from sklearn.metrics import r2_score, mean_squared_error, auc, PrecisionRecallDisplay, precision_recall_curve
# from scikitplot import plot_precision_recall

from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def esmStructDataset(pdb_path):
    '''
    Convert PDB-file into dataset format
    
        Parameters:
            pdb_path (Path): path to pdb-file
        Returns:
            dict (dict): dictionary, where keys are properties of the protein's tertiary structure
    '''
    entity = pdb_path.name.split(".pdb")[0]
    pdb_id,chain = entity.split("_")
    if not os.path.exists(pdb_path):
        print("missing "+pdb_path)
        return
    try:
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.upper()) #chain.upper())
    except:
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.lower()) #chain.upper())
    resi_index = get_residues(structure)[0]
    resi_aa    = get_residues(structure)[1]
    resi_keys     = []
    cn =[]
    binary = []
    for resi_index_,resi_aa_ in zip(get_residues(structure)[0],get_residues(structure)[1]):
        key = (str(resi_aa_),resi_index_)
        cn.append(None)
        binary.append(None)
        resi_keys.append(key)    
    coords, seq = esm.inverse_folding.util.extract_coords_from_structure(structure)#structure)
    return {"pdb_id":pdb_id,"seq":seq,"chain":chain,"coords":coords,
            "cn":cn, #contact_number
            "binary":binary,"residues":resi_keys}

In [52]:
def create_Dataset(path='../data/pdb_structures.pkl'):
    '''
    Create dataset of protein's tertiary structure or load it form pickle-file
        Parameters:
            path (Path): path to pikle object with dataset of protein's tertiary structure
        Returens:
            esm_structs (dict): dataset of protein's tertiary structure
    '''
    path_exist = True #False #os.path.exists(path)
    if not path_exist:
        # path_pdbs = Path("/mnt/nfs_protein/shevtsov/SEMA/dataset/3D/").glob("*.pdb")
        path_pdbs = Path("../data/structs_antigen_fab/").glob("*.pdb")
        esm_structs = {}
        for pdb_path in list(path_pdbs):
            esm_structs[pdb_path.name.split(".pdb")[0]] = esmStructDataset(pdb_path)
        pickle.dump(esm_structs,  open(path,'wb'))
    else:
        esm_structs = pickle.load(open(path,'rb'))
    return esm_structs

In [33]:
class epitopes_Dataset(Dataset):
    def __init__(self, epitope_data):
        self.epitope_data    = epitope_data
        _, alphabet          = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
        self.batch_converter = CoordBatchConverter(alphabet)

    def __getitem__(self, idx):
        item = {}
        coords = self.epitope_data[idx]["coords"]
        seq    = self.epitope_data[idx]["seq"]     
        batch = [(coords, None, seq)]
        coords, confidence, strs, tokens, padding_mask = self.batch_converter(batch)
        item['seq'] = seq
        item['coords'] = coords
        item['confidence']= confidence
        item['tokens'] = tokens
        item['padding_mask'] = padding_mask
        # discrete
        # cn_log_e = np.log(10**np.array(self.epitope_data[idx]["cn"]))
        # item['labels_cn'] =  torch.unsqueeze(torch.LongTensor(self.epitope_data[idx]["cn"]), 0).to(torch.float32)

        item['labels_cn'] =  torch.unsqueeze(torch.tensor(self.epitope_data[idx]["cn"]), 0).to(torch.float32)
        item['labels_binary'] =  torch.tensor([0]) #self.epitope_data[idx]["binary"] #torch.unsqueeze(torch.LongTensor(self.epitope_data[idx]["binary"]), 0)

        return item

    def __len__(self):
        return len(self.epitope_data)


def prepareEsmDataset(dataset, structs_data):
    dataset_esm = []
    bad_prots = []
    for k in dataset:
        entity = dataset[k]
        pdb_id = entity["pdb_id_chain"]
        assert pdb_id in structs_data
        struct_data = structs_data[pdb_id]
        
        if len(np.unique(entity['contact_number'])) == 1:
            bad_prots.append(pdb_id)
            continue

        # struct_data["cn"] = [i for i in entity["contact_number"] if i != -100]
        struct_data["cn"] = [-100]* len(struct_data['residues'])
        
        key_map = {(key[0],key[1]):i for i, key in enumerate(zip(entity['resi_name'],
                                                                 entity['resi_pos']))
                  } 
    
        for i,key in enumerate(struct_data["residues"]):
            key = (key[0],key[1])
            if key not in key_map:
                struct_data["cn"][i] = -100
                continue
            struct_data["cn"][i]= entity['contact_number'][key_map[key]]
            # struct_data["binary"][i]= entity['contact_number_binary'][key_map[key]]

        assert len([i for i in struct_data["cn"] if i != -100]) > 0, print(pdb_id + '\n', key_map,'\n', struct_data["residues"], '\n', struct_data["cn"])
        
        if len(struct_data["seq"])>1500:
            print("Skip long  ", pdb_id, len(struct_data["seq"]))
            continue
        dataset_esm.append(struct_data)

    print('Number of proteins with invalid contact_number: ', {len(bad_prots)})
    
    return epitopes_Dataset(dataset_esm)

In [34]:
class ESM1vForTokenClassification(nn.Module):
    def __init__(self, num_labels = 2):
        super().__init__()
        self.num_labels = num_labels    
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
        self.classifier = nn.Linear(512, self.num_labels)

    def forward(self, coords, padding_mask, confidence, tokens, labels_cn, labels_binary, seq):

        prev_output_tokens = tokens[:, :-1]
        target = tokens[:, 1:]
        target_padding_mask = (target == self.esm1v_alphabet.padding_idx)
        feat, x = self.esm1v.forward(coords, padding_mask, confidence, prev_output_tokens, features_only = True)
        f = feat[0,:,:]
        tt = torch.transpose(feat,1,2)
        logits = self.classifier(tt)
        # print('forward logits', logits.shape)
        return SequenceClassifierOutput(logits=logits)     

In [35]:
def compute_metrics_quntative(p: EvalPrediction):
    preds = p.predictions[:,:,1]
    label_ids= p.label_ids[0]
    # print('preds', preds)
    # print('label_ids', label_ids)
    batch_size, seq_len = preds.shape
    out_labels, out_preds = [], []
    for i in range(batch_size):
        for j in range(seq_len):
            if label_ids[i, j] >= 0:
                out_labels.append(label_ids[i][j])
                out_preds.append(preds[i][j])
    # discret
    out_labels_regr = [math.log(t+1) for t in out_labels]
    # out_labels_regr = out_labels
    return {
        "r2": r2_score(out_labels_regr,  out_preds),
        "mse": mean_squared_error(out_labels_regr,  out_preds)
    }

In [36]:
def model_init():
    return ESM1vForTokenClassification().cuda()

In [37]:
class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()
        
    def forward(self, inputs, target, mask):

        diff2 = (torch.flatten(inputs[:,:,1]) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
        result = torch.sum(diff2) / torch.sum(mask)
        if torch.sum(mask)==0:
            return torch.sum(diff2)
        else:
            return result

    
class MaskedRegressTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        # labels_cn = inputs.pop("labels_cn")
        # labels_bin   = inputs.pop("labels_binary")
        labels_cn = inputs["labels_cn"]
        labels_bin   = inputs["labels_binary"]
        
        outputs = model(**inputs)
        logits = outputs.logits
        masks = ~torch.eq(labels_cn, -100).cuda()
        loss_fn = MaskedMSELoss()
        loss = loss_fn(logits, labels_cn, masks)    
        return (loss, outputs) if return_outputs else loss
        
def collator_fn(x):
    if len(x)==1:
        return x[0]
    print('x:', x)
    return x


## Data

In [38]:
esm_structs = create_Dataset(path='../data/pdb_structures.pkl')

In [39]:
print('Number of structures:', len(esm_structs.keys()))

In [57]:
import json
path = '/mnt/nfs_protein/ivanisenko/SEMA/dataset/SEMA_DS.json'
def load_datasets(path, filtr_length):
    with open(path, 'r') as file:
        dataset = json.load(file)
        dataset = pd.DataFrame(dataset)

    dataset['pdb_id_chain'] = dataset['pdb_path'].str.split('/', expand = True)[2].str.split('.', expand = True)[0]
    
    if filtr_length:
        dataset = dataset[dataset['seq'].str.len() >= filtr_length]

    to_residue_name = lambda x: [i['residue_name'] if i != None else i for i in x]
    to_residue_pos = lambda x: [i['residue_number'] if i != None else i for i in x]
    dataset.loc[:, 'resi_name'] = dataset['residues'].apply(to_residue_name)
    dataset.loc[:, 'resi_pos'] = dataset['residues'].apply(to_residue_pos)

    train_set = dataset[dataset['ds_type'] == 'train']
    # test_set = dataset[dataset['ds_type'] == 'test_new']
    test_set = dataset[dataset['ds_type'] == 'test_original']

    # train_set = train_set.iloc[0:10, :]
    # test_set = test_set.iloc[0:10, :]

    train_set = train_set.reset_index().to_dict(orient='index')
    test_set = test_set.reset_index().to_dict(orient='index')

    return train_set , test_set

train_set , test_set = load_datasets(path, filtr_length = 0)

train_ds = prepareEsmDataset(train_set, esm_structs)
test_ds = prepareEsmDataset(test_set, esm_structs)

Skip long   5I5K_A 1632
Number of proteins with invalid contact_number:  {21}




Number of proteins with invalid contact_number:  {0}


In [53]:
# esm_structs = create_Dataset(path='../data/pdb_structures_old.pkl')

# train_set = pd.read_csv('../data/train_set.csv')

# train_set = train_set.groupby('pdb_id_chain').agg({'resi_pos': list,
#                                  'resi_aa': list,
#                                  'resi_name': list,
#                                  'contact_number': list,
#                                  'contact_number_binary': list})\
#                  .reset_index()\
#                  .to_dict(orient='index')

# train_ds = prepareEsmDataset(train_set, esm_structs)
# ## the first run will take about 5-10 minutes, because esm weights should be downloaded
# # 

# test_set = pd.read_csv('../data/test_set.csv')
# test_set = test_set.groupby('pdb_id_chain').agg({'resi_pos': list,
#                                  'resi_aa': list,
#                                  'resi_name': list,
#                                  'contact_number': list,
#                                  'contact_number_binary': list})\
#                  .reset_index()\
#                  .to_dict(orient='index')

# # test_set = {k: test_set[k] for k in range(10)}
# test_ds = prepareEsmDataset(test_set, esm_structs)

1AFV_B_MK
 {('PRO', '1'): 0, ('ILE', '2'): 1, ('VAL', '3'): 2, ('GLN', '4'): 3, ('ASN', '5'): 4, ('LEU', '6'): 5, ('GLN', '7'): 6, ('GLY', '8'): 7, ('GLN', '9'): 8, ('MET', '10'): 9, ('VAL', '11'): 10, ('HIS', '12'): 11, ('GLN', '13'): 12, ('ALA', '14'): 13, ('ILE', '15'): 14, ('SER', '16'): 15, ('PRO', '17'): 16, ('ARG', '18'): 17, ('THR', '19'): 18, ('LEU', '20'): 19, ('ASN', '21'): 20, ('ALA', '22'): 21, ('TRP', '23'): 22, ('VAL', '24'): 23, ('LYS', '25'): 24, ('VAL', '26'): 25, ('VAL', '27'): 26, ('GLU', '28'): 27, ('GLU', '29'): 28, ('LYS', '30'): 29, ('ALA', '31'): 30, ('PHE', '32'): 31, ('SER', '33'): 32, ('PRO', '34'): 33, ('GLU', '35'): 34, ('VAL', '36'): 35, ('ILE', '37'): 36, ('PRO', '38'): 37, ('MET', '39'): 38, ('PHE', '40'): 39, ('SER', '41'): 40, ('ALA', '42'): 41, ('LEU', '43'): 42, ('SER', '44'): 43, ('GLU', '45'): 44, ('GLY', '46'): 45, ('ALA', '47'): 46, ('THR', '48'): 47, ('PRO', '49'): 48, ('GLN', '50'): 49, ('ASP', '51'): 50, ('LEU', '52'): 51, ('ASN', '53'): 52, 

AssertionError: None

In [41]:
esm_structs

{}

## Model training and test 

In [None]:
training_args = TrainingArguments(
    output_dir='./results_fold' ,          # output directory
    num_train_epochs=2,          # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
    warmup_steps=0,                # number of warmup steps for learning rate scheduler
    learning_rate=1e-04,             # learning rate
    weight_decay=0.0,                # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    save_strategy = "no",
    do_train=True,                   # Perform training
    do_eval=False,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=1,  # total number of steps before back propagation
    fp16=False,                       # Use mixed precision
    run_name="PDB_regr",      # experiment name
    seed=42,                         # Seed for experiment reproducibility
    load_best_model_at_end=False,
    metric_for_best_model="eval_r2",
    greater_is_better=True,

)

In [None]:
#create direactory to weights storage
if not os.path.exists("../models/"):
    os.makedirs("../models/")

In [None]:
# torch.cuda.empty_cache()
trainer = MaskedRegressTrainer(
    model=model_init(),                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset = train_ds,    # training dataset
    eval_dataset  = test_ds,    # evaluation dataset
    data_collator = collator_fn,
    compute_metrics = compute_metrics_quntative
)
trainer.train()

#save weights
torch.save(trainer.model.state_dict(), "../models/sema_3d_newdata_uncut_log10_0.pth")

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss,R2,Mse
1,0.1283,0.143908,-0.16561,0.091718
2,0.108,0.138815,-0.052535,0.082821


# Test on multimer dataset

In [13]:
from biopandas.pdb import PandasPdb
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import os
from IPython.display import clear_output

aa_3_to_1 = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
     'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
     'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
     'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}

In [22]:
from tqdm import tqdm 
import numpy as np
import subprocess

# def esmStructDataset(pdb_path):
#     '''
#     Convert PDB-file into dataset format
    
#         Parameters:
#             pdb_path (Path): path to pdb-file
#         Returns:
#             dict (dict): dictionary, where keys are properties of the protein's tertiary structure
#     '''
#     entity = pdb_path.split(".pdb")[0]
#     pdb_id,chain = entity.split("_")
#     if not os.path.exists(pdb_path):
#         print("missing "+pdb_path)
#         return
#     try:
#         structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.upper()) #chain.upper())
#     except:
#         structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.lower()) #chain.upper())
#     resi_index = get_residues(structure)[0]
#     resi_aa    = get_residues(structure)[1]
#     resi_keys     = []
#     cn =[]
#     binary = []
#     for resi_index_,resi_aa_ in zip(get_residues(structure)[0],get_residues(structure)[1]):
#         key = (str(resi_aa_),resi_index_)
#         cn.append(None)
#         binary.append(None)
#         resi_keys.append(key)    
#     coords, seq = esm.inverse_folding.util.extract_coords_from_structure(structure)#structure)
#     return {"pdb_id":pdb_id,"seq":seq,"chain":chain,"coords":coords,
#             "cn":cn, #contact_number
#             "binary":binary,"residues":resi_keys}

def change_residue_number_multimer(atom_df):
    cur_chain_id = atom_df.loc[0, 'chain_id']
    for row_index in atom_df.index:
        if atom_df.loc[row_index, 'chain_id'] != cur_chain_id:
            # вычесть у следующих номеров первый номер, номера в новой субъединице начинались с нуля, 
            # добавить последний индекс предыдущей субъединицы,
            # добавить 1 к новой цепи
            add =  atom_df.loc[row_index-1, 'residue_number'] - atom_df.loc[row_index , 'residue_number'].item() + 1
            atom_df.loc[row_index: , 'residue_number'] += add
            
            cur_chain_id = atom_df.loc[row_index, 'chain_id']
    return atom_df


def prepare_multimer_test_1D_model(path_to_pdbs):
    
    # list pdb folder path and create result objects
    path_pdbs = Path(path_to_pdbs).glob("*.pdb")
    one_chain_prots = []
    test_df = pd.DataFrame(columns = ['seq', 'concat_number'])

    test_list_dict = []
    for pdb_path in tqdm(list(path_pdbs)):
        # read pdb
        pdb_df =  PandasPdb().read_pdb(str(pdb_path))
        pdb_df.to_pdb('tmp_raw_pdb.pdb')
        # fix insertions in residue number (tool convert int+ letter (ex 150A) value to int (ex 151))
        _ = os.system("pdb_delinsertion tmp_raw_pdb.pdb > tmp_raw_pdb_fix_ins.pdb")
        
        # read pdb, change all chains to M and save to tmp_M.pdb
        pdb_df =  PandasPdb().read_pdb('tmp_raw_pdb_fix_ins.pdb')
        atom_df =  pdb_df.df['ATOM'].copy()
        atom_df.loc[:, 'chain_id'] = 'M'
        pdb_df.df['ATOM'] = atom_df
        pdb_df.to_pdb('tmp_M.pdb')
        # change residue_number in tmp_M.pdb
        _ = os.system("pdb_reres -1 tmp_M.pdb > tmpReind_M.pdb")

        struct_dict = esmStructDataset('tmpReind_M.pdb')
        
        pdb_df =  PandasPdb().read_pdb(str('tmpReind_M.pdb'))
        atom_df =  pdb_df.df['ATOM'].copy()
        # create one letter aa columns
        atom_df['residue_letter'] =  atom_df['residue_name'].map(aa_3_to_1)
        # change mask values from -1 to -100
        atom_df.loc[atom_df['b_factor'] == -1, 'b_factor'] = -100
        # make nonmask values - log10
        atom_df.loc[atom_df['b_factor'] >= 0, 'b_factor'] = atom_df.loc[atom_df['b_factor'] >= 0, 'b_factor'].apply(lambda x: np.log10(x + 1))
        
        struct_dict['cn'] = atom_df[atom_df.atom_name == 'CA']['b_factor'].tolist()

        assert len(struct_dict['cn']) == len(struct_dict['coords']), print(str(pdb_path) + '\n', f'cn: {len(struct_dict["cn"])}{struct_dict["cn"]},\ncoords: {len(struct_dict["coords"])}, {struct_dict["coords"]}')
        
        test_list_dict.append(struct_dict)
        clear_output()

    print('PDBs with only one chain:', len(one_chain_prots))
    # return epitopes_Dataset(test_list_dict)
    return test_list_dict

In [23]:
test_df = prepare_multimer_test_1D_model('/mnt/nfs_protein/ivanisenko/SEMA/dataset/multimers/')
multi_test_ds = epitopes_Dataset(test_df)

100%|██████████| 773/773 [09:46<00:00,  1.32it/s]

PDBs with only one chain: 0





In [62]:
trainer.evaluate(multi_test_ds)

{'eval_loss': 0.12447740882635117,
 'eval_r2': 0.02568819944184264,
 'eval_mse': 0.0691983899943198,
 'eval_runtime': 169.9554,
 'eval_samples_per_second': 4.548,
 'eval_steps_per_second': 4.548,
 'epoch': 2.0}