# SEMA 3D

SEMA-3D is fine-tuned SaProt model aimed to predict epitope resiudes based on antigen protein sequence and structure features

The MIT License (MIT)
Copyright (c) 2016 AYLIEN
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

In [7]:
import os, sys
# set cuda params
# 'TORCH_HOME'directory will be used to save origenal esm-1v weights

sys.path.append("../../")
# sys.path.append(os.path.join(os.path.dirname(sys.path[0]),'bar'))
os.environ['TORCH_HOME'] = "../torch_hub"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [8]:
import scipy
import sklearn
import esm
import pickle

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn

from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import EsmTokenizer, EsmForMaskedLM

from biotite.structure.residues import get_residues

from saprot_utils.foldseek_util import get_struc_seq

In [31]:
onetothree = {
    'A': 'ALA',  # Alanine
    'R': 'ARG',  # Arginine
    'N': 'ASN',  # Asparagine
    'D': 'ASP',  # Aspartic Acid
    'C': 'CYS',  # Cysteine
    'Q': 'GLN',  # Glutamine
    'E': 'GLU',  # Glutamic Acid
    'G': 'GLY',  # Glycine
    'H': 'HIS',  # Histidine
    'I': 'ILE',  # Isoleucine
    'L': 'LEU',  # Leucine
    'K': 'LYS',  # Lysine
    'M': 'MET',  # Methionine
    'F': 'PHE',  # Phenylalanine
    'P': 'PRO',  # Proline
    'S': 'SER',  # Serine
    'T': 'THR',  # Threonine
    'W': 'TRP',  # Tryptophan
    'Y': 'TYR',  # Tyrosine
    'V': 'VAL'   # Valine
}

def load_data(path, filtr_length = None):
    '''
    path- path to table with columns: pdb_id_chain, resi_pos, resi_aa, contact_number
    '''
    df= pd.read_csv(path)

    df = df.groupby('pdb_id_chain').agg({
        'resi_pos': list,
        'resi_aa': list,
        'contact_number': list,
        'contact_number_binary': list,
        }).reset_index()

    if filtr_length:
            df = df[df['resi_aa'].str.len() >= filtr_length]

    to_residue_name = lambda x: [onetothree[i] if i in onetothree else None for i in x]
    df.loc[:, 'resi_name'] = df['resi_aa'].apply(to_residue_name)

    return df


def get_features_from_pdb(pdb_path, chain):
    '''
    load needed features from PDB file
    
        Parameters:
            pdb_path (Path): path to pdb-file
        Returns:
            dict (dict): dictionary, where keys are properties of the protein's tertiary structure
    '''
    
    try:
        parsed_seqs = get_struc_seq("../../saprot_utils/bin/foldseek", pdb_path)[chain.upper()]
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.upper())
    except:
        parsed_seqs = get_struc_seq("../../saprot_utils/bin/foldseek", pdb_path)[chain.lower()]
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.lower())
    
    seq, foldseek_seq, combined_seq = parsed_seqs    
        
    resi_index = get_residues(structure)[0]
    resi_aa    = get_residues(structure)[1]
    resi_keys  = []
    for resi_index_,resi_aa_ in zip(get_residues(structure)[0],get_residues(structure)[1]):
        key = (str(resi_aa_),int(resi_index_))
        resi_keys.append(key) 
    
    assert len(resi_keys) == len(seq), 'foldseek and esm.inverse_folding.util.load_structure returns different seqs'

    return {"seq":seq,
            "foldseek_seq":foldseek_seq,
            "combined_seq":combined_seq,
            "residues": resi_keys
            }


def get_contact_number_from_struct_data(df_row, struct_data):

    # assert type(df_row['resi_pos'][0]) == type(struct_data["residues"][0][1]), \
    #     f'different residue position types in pdb and loaded df:\npdb- {type(struct_data["residues"][0][1])}\ndf- {type(df_row["resi_pos"][0])}'

    cn = [-100]* len(struct_data['seq'])

    key_map = {} 
    for i, key in enumerate(zip(df_row['resi_name'], df_row['resi_pos'])):
        # key_map[(key[0],int(key[1]))] = i
        key_map[(key[0],key[1])] = i

    check_keys = {}
    for i,key in enumerate(struct_data["residues"]):
        try:
            key = (key[0],int(key[1]))
            # key = (key[0],key[1])
        except:
            key = (key[0],str(key[1]))

        check_keys[(key[0],key[1])] = i

        if key in key_map:
            cn[i]= df_row['cn'][key_map[key]]

    assert len([i for i in cn if i != -100]) > 0, \
            f'all contact numbers are equal to -100:'+ '\n' + \
            f'{key_map}\n{check_keys}\n{struct_data["residues"]}\n{cn}'

    return cn


def process_data(df, pdb_folder_path, for_classification = False):
    
    if for_classification:
        colname = 'contact_number_binary'
    else:
        colname = 'contact_number'

    df.loc[:, 'cn'] = df.loc[:, colname]

    processed_data = []
    bad_prots = []

    for row_id in range(df.shape[0]):
    # for row_id in df.index:
        # print(row_id)
        df_row = df.iloc[row_id, :]
        pdb_id = df_row["pdb_id_chain"]
        chain = pdb_id.split('_')[1]
        pdb_path = os.path.join(pdb_folder_path, pdb_id+ '.pdb')

        if np.unique(df_row['cn']).shape[0] == 1:
            bad_prots.append(pdb_id)
            continue

        if not os.path.exists(pdb_path):
            bad_prots.append(pdb_id)
            continue
        
        struct_data = get_features_from_pdb(pdb_path, chain)

        struct_data["cn"] = get_contact_number_from_struct_data(df_row, struct_data)

        processed_data.append(struct_data)

    print('Number of proteins with invalid contact_number: ', len(bad_prots))
    
    processed_data = pd.DataFrame(processed_data)[['combined_seq', 'cn']]
    return processed_data


class SaProtDataset(Dataset):
    """
    A class to represent a sutable data set for model. 
    
    convert original pandas data frame to model set,
    where 'token ids' is ESM-1v embedings corresponed to protein sequence (max length 1024 AA)
    and 'lables' is a contact number values
    Attributes:
        df (pandas.DataFrame): dataframe with two columns: 
                0 -- preotein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
        esm1v_batch_converter (function):
                    ESM function callable to convert an unprocessed (labels + strings) batch to a
                    processed (labels + tensor) batch.
        label_type (str):
                type of model: regression or binary

    """
    def __init__(self, df, label_type ='regression'):
        """
        Construct all the necessary attributes to the PDB_Database object.
        
        Parameters:
            df (pandas.DataFrame): dataframe with two columns: 
                0 -- protein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
            label_type (str):
                type of model: regression or binary
        """
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.tokenizer = EsmTokenizer.from_pretrained(model_path)

        self.df = df
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}
        
        # item['token_ids'] = self.tokenizer(self.df.iloc[idx, 0][:2044], padding = False, return_tensors="pt")['input_ids']
        # item['labels'] = torch.unsqueeze(torch.tensor(self.df.iloc[idx, 1][:1022]),0)
        item['token_ids'] = self.tokenizer(self.df.iloc[idx, 0], padding = False, return_tensors="pt")['input_ids']
        item['labels'] = torch.unsqueeze(torch.tensor(self.df.iloc[idx, 1]),0)

        return item

    def __len__(self):
        return len(self.df)
    

def process_pdb_for_inference(pdb_path, chain = None):
    
    parsed_seqs = get_struc_seq("../../saprot_utils/bin/foldseek", pdb_path)
    
    if chain:
        try:
            parsed_seqs = parsed_seqs[chain.upper()]
        except:
            parsed_seqs = parsed_seqs[chain.lower()]
    
        seq, foldseek_seq, combined_seq = parsed_seqs    
    
    else:
        chains_seqs = [parsed_seqs[chain][0] for chain in parsed_seqs.keys()]
        combined_seq = ''.join(chains_seqs)
        
    return combined_seq



## Evaluate

In [None]:
## load test data
pdb_path = '/home/shevtsov/SEMAi/data/new_train'
test_path = '/home/shevtsov/SEMAi/data/test_set_prev.csv'
# test_path = '/home/shevtsov/SEMAi/data/test_set_discotope.csv'

test_data = load_data(test_path)
processed_test_data = process_data(test_data, pdb_path, for_classification = True)
test_ds = SaProtDataset(processed_test_data)

print(f'Data preparation was completed\nNumber of enteries:test= {test_ds.__len__()}')

In [32]:
class SaProtForTokenClassification(nn.Module):

    def __init__(self, num_labels = 2):
        super().__init__()
        self.num_labels = num_labels
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.encoder = EsmForMaskedLM.from_pretrained(model_path)

        self.classifier = nn.Linear(446, self.num_labels)

    def forward(self, token_ids, labels):
        # print('SaProtForTokenClassification token_ids', token_ids.shape)
        outputs = self.encoder(input_ids = token_ids)['logits']
        # print('SaProtForTokenClassification encoder outputs', outputs.shape)
        outputs = outputs[:,1:-1,:]
        logits = self.classifier(outputs)
        # print('logits', logits.shape)
        # print('logits', logits)
        return SequenceClassifierOutput(logits=logits)

def model_init():
    return SaProtForTokenClassification().cuda()

In [33]:
def predict(model, ds, masked = True):
    indexes = []
    preds=[]
    labels = []
    index = 0
    with torch.no_grad():
        for it in tqdm(ds):

            it_preds = model.forward(it['token_ids'].cuda(), labels = None)[0][0][:,1].cpu().numpy()
            it_labels = it['labels'].squeeze(0).numpy()
            # print(it_labels.shape)
            for ind, label in enumerate(it_labels):

                if masked:
                    if label >0:
                        labels.append(1) 
                        preds.append(it_preds[ind])
                        indexes.append(index)
                    elif label == 0:
                        labels.append(0) 
                        preds.append(it_preds[ind])
                        indexes.append(index)
                else:
                    if label >0:
                        labels.append(1) 
                        preds.append(it_preds[ind])
                        indexes.append(index)
                    else:
                        labels.append(0) 
                        preds.append(it_preds[ind])
                        indexes.append(index)
                
                index += 1
            
    return preds, labels, indexes

In [None]:
from sklearn.metrics import roc_curve,roc_auc_score
from tqdm import tqdm


# path = '/home/shevtsov/SEMAi/models/sema_saprot_old_0.pth'

ensembl_res = []
for seed in [0,1,2,3,4]:
    path = f'../models/sema_3d_{seed}.pth'
    model=model_init()
    model.load_state_dict(torch.load(path))
    model.eval()
    model.cuda()

    preds, labels, indexes = predict(model, test_ds, masked = False)

    ensembl_res.append(preds)

mean_preds = np.mean(np.stack(ensembl_res), axis = 0)

In [None]:
mean_preds = np.mean(np.stack(ensembl_res), axis = 0)

In [None]:
import numpy as np
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve

print('ROC AUC score:', roc_auc_score(labels, mean_preds))

### thresholds based on ROC AUC

In [None]:
fpr, tpr, thresholds = roc_curve(labels, mean_preds, pos_label=1)

# 1st variant
# optimal_threshold_ind = np.argmin(np.sqrt((fpr)**2 + (1-tpr)**2))
# 2nd variant
optimal_threshold_ind = np.argmax(tpr-fpr)
print(optimal_threshold_ind, thresholds[optimal_threshold_ind])


### thresholds based on precision_recall_curve

In [None]:
p, r, thresholds = precision_recall_curve(labels, mean_preds, pos_label=1)

optimal_threshold_ind = np.argmin(np.sqrt((1-p)**2 + (1-r)**2))
print(optimal_threshold_ind, thresholds[optimal_threshold_ind])

In [28]:
## or use simple example
pdb_path= '../data/test_pdb.pdb'
chain = 'A' # use nan for multichain prediction
combined_seq = process_pdb_for_inference(pdb_path, chain)
ds = SaProtDataset(pd.DataFrame({'combined_seq': combined_seq,
                                   'cn': [0]}))

tokenizer_config.json: 100%|██████████| 40.0/40.0 [00:00<00:00, 5.82kB/s]
vocab.txt: 100%|██████████| 1.35k/1.35k [00:00<00:00, 223kB/s]
special_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 63.9kB/s]


## Inference

In [None]:
seed = 0
path = f'../models/sema_3d_{seed}.pth'
model=model_init()
model.load_state_dict(torch.load(path))
model.eval()
model.cuda()

preds = model.forward(ds.__getitem__(0)['token_ids'].cuda(), labels = None)[0][0][:,1].cpu().numpy()