**TwoFold_DL - inference**

In [1]:
! git clone  https://github.com/ORNL/TwoFold_DL

fatal: destination path 'TwoFold_DL' already exists and is not an empty directory.


In [1]:
import sys
sys.path.append('TwoFold_DL/')

In [2]:
import logging
logging.disable(logging.INFO)
logging.disable(logging.WARNING)

In [5]:
! pip install --upgrade pip



In [None]:
! conda install -q -c conda-forge -y Rust
! pip install -q datasets
! pip install -q transformers==4.18.0
! pip install -q huggingface_hub
! pip install -q rdkit
! pip install -q biopython

In [11]:
import torch
from torch.utils.data import Dataset
from huggingface_hub import hf_hub_download

In [12]:
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoConfig, Trainer
from transformers import EvalPrediction
from transformers import TrainingArguments

from tokenizers import Regex
from tokenizers import pre_tokenizers
from tokenizers import normalizers
from tokenizers.normalizers import Replace

from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.pre_tokenizers import Digits
from tokenizers.pre_tokenizers import Sequence
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.pre_tokenizers import Split

from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
 
from sklearn.metrics import precision_recall_curve, roc_curve
import pandas as pd
import numpy as np
import json
import re
import tqdm
import os
import rdkit

2023-09-26 16:54:47.911545: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-26 16:54:48.236835: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [13]:
from contact_pred.models import StructurePrediction, ProteinLigandConfigStructure
from contact_pred.structure import IPAConfig
from contact_pred.data_utils import StructurePredictionPipeline

In [14]:
assert torch.cuda.is_available()

In [16]:
from transformers import BertConfig

seq_model_name = 'Rostlab/prot_bert_bfd'
seq_config = AutoConfig.from_pretrained(seq_model_name)
seq_tokenizer = AutoTokenizer.from_pretrained(seq_model_name)
normalizer = normalizers.Sequence([Replace(Regex('[UZOB]'),'X'),Replace(Regex('\s'),'')])
pre_tokenizer = pre_tokenizers.Split(Regex(''),behavior='isolated')
seq_tokenizer = AutoTokenizer.from_pretrained(seq_model_name, do_lower_case=False)
seq_tokenizer.backend_tokenizer.normalizer = normalizer
seq_tokenizer.backend_tokenizer.pre_tokenizer = pre_tokenizer

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [17]:
smiles_model_name = 'mossaic-candle/regex-gb-2021'
smiles_tokenizer =  AutoTokenizer.from_pretrained(smiles_model_name)
smiles_config = AutoConfig.from_pretrained(smiles_model_name)
#smiles_tokenizer.backend_tokenizer.pre_tokenizer = Sequence([WhitespaceSplit(),Split(Regex(r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""), behavior='isolated')])

Downloading:   0%|          | 0.00/354 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/565 [00:00<?, ?B/s]

In [20]:
checkpoint_path = hf_hub_download('djh992/TwoFold_DL_GB2022', 'pytorch_model.bin')
config_json = hf_hub_download('djh992/TwoFold_DL_GB2022', 'config.json')
config = ProteinLigandConfigStructure(json.load(open(config_json, 'r')))
config.seq_config = seq_config.to_dict()
config.smiles_config = smiles_config.to_dict()
config.seq_vocab = seq_tokenizer.get_vocab()
seq_ipa_config = IPAConfig(bert_config=seq_config.to_dict(),
                           num_ipa_heads=seq_config.num_attention_heads)
smiles_ipa_config = IPAConfig(bert_config=smiles_config.to_dict(),
                            num_ipa_heads=smiles_config.num_attention_heads)
config.seq_ipa_config = seq_ipa_config.to_dict()
config.smiles_ipa_config = smiles_ipa_config.to_dict()
model = StructurePrediction(config=config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint,strict=True)

del checkpoint

pipeline = StructurePredictionPipeline(
           model,
           seq_tokenizer=seq_tokenizer,
           smiles_tokenizer=smiles_tokenizer,
           device=0,
           batch_size=1)

### Inference

In [22]:
class ProteinLigandDataset(Dataset):
    def __init__(self, dataset, smiles_name='smiles', seq_name='seq'):
        self.dataset = dataset
        self.seq_name = seq_name
        self.smiles_name = smiles_name

    def __getitem__(self, idx):
        try:
            item = self.dataset[idx]
        except:
            item = self.dataset.iloc[idx]
        
        try:
            # make canonical
            smiles_canonical = str(Chem.MolToSmiles(Chem.MolFromSmiles(item[self.smiles_name])))
        except:
            smiles_canonical = str(item[self.smiles_name])
        
        result = {'ligand': smiles_canonical, 
#        result = {'ligand': '', 
                  'protein': item[self.seq_name]}
                
        return result

    def __len__(self):
        return len(self.dataset)

In [23]:
from rdkit import Chem
smi_4mds = Chem.MolToSmiles(Chem.MolFromMolFile('TwoFold_DL/examples/4mds_23H_ligand.sdf'))

In [50]:
# Andrii's Mpro1-199
df = pd.DataFrame({'seq': [
     #'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTT' # Mpro1-199
#    'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ' #Full Mpro
    'SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDTVYCPRHVICTAEDMLNPNYEDLLIRKSNHSFLVQAGNVQLRVIGHSMQNCLLRLKVDTSNPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNHTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGKFYGPFVDRQTAQAAGTDTTITLNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCAALKELLQNGMNGRTILGSTILEDEFTPFDVVRQCSGASGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDTVYCPRHVICTAEDMLNPNYEDLLIRKSNHSFLVQAGNVQLRVIGHSMQNCLLRLKVDTSNPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNHTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGKFYGPFVDRQTAQAAGTDTTITLNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCAALKELLQNGMNGRTILGSTILEDEFTPFDVVRQCSGA' # 4mds' # 4mds
    #'VNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVTFQ' # Mpro200-306
], 
                   #'smiles': ['']})
                   'smiles': smi_4mds})

In [51]:
df['seq'].str.len()

0    606
Name: seq, dtype: int64

In [52]:
dataset = ProteinLigandDataset(df)

In [53]:
pipeline.model.enable_cross = True

In [54]:
output = list(pipeline(dataset))
pred = output[0]

  aatypes = torch.tensor(self.input_ids_to_aatype[input_ids_1], device=input_ids_1.device)#, requires_grad=False)


In [55]:
from contact_pred.residue_constants import restype_name_to_atom14_names, restype_1to3
def write_pdb_no_ref(f, seq, feat):
    k = 0
    resid = 1
    i = 1
    for s in seq:
        res = restype_1to3[s]
        for idx, atom in enumerate(restype_name_to_atom14_names[res]):
            if atom != '':
                xyz = feat[0,k+1,idx]
                write_pdb_line(f,'ATOM', str(i), atom, res, 'A', str(resid), *xyz, 1.0, 1.0, atom[0])
                i+=1
        k+=1
        resid+=1

def write_Calpha_no_ref(f, seq, feat):
    k = 0
    resid = 1
    i = 1
    for s in seq:
        res = restype_1to3[s]
        xyz = feat[0,k+1]
        write_pdb_line(f,'ATOM', str(i), 'CA', res, 'A', str(resid), *xyz, 1.0, 1.0, 'C')
        i+=1
        k+=1
        resid+=1

In [56]:
def write_pdb_line(f,*j):
    j = list(j)
    j[0] = j[0].ljust(6)#atom#6s
    j[1] = j[1].rjust(5)#aomnum#5d
    j[2] = j[2].center(4)#atomname$#4s
    j[3] = j[3].ljust(3)#resname#1s
    j[4] = j[4].rjust(1) #Astring
    j[5] = j[5].rjust(4) #resnum
    j[6] = str('%8.3f' % (float(j[6]))).rjust(8) #x
    j[7] = str('%8.3f' % (float(j[7]))).rjust(8)#y
    j[8] = str('%8.3f' % (float(j[8]))).rjust(8) #z\
    j[9] =str('%6.2f'%(float(j[9]))).rjust(6)#occ
    j[10]=str('%6.2f'%(float(j[10]))).ljust(6)#temp
    j[11]=j[11].rjust(12)#elname
    f.write("%s%s %s %s %s%s    %s%s%s%s%s%s\n"% (j[0],j[1],j[2],j[3],j[4],j[5],j[6],j[7],j[8],j[9],j[10],j[11]))
                                                  
#with open(f'TwoFold_DL/examples/pred_Mpro1-199_ligand_7s3s.pdb','w') as f:
with open(f'TwoFold_DL/examples/pred_Mpro_monomer_ligand_4mds.pdb','w') as f:
    feat = pred['receptor_xyz']
    write_pdb_no_ref(f, df['seq'][0], feat)
    #feat = pred['receptor_frames_xyz']
    #write_Calpha_no_ref(f, df['seq'][0], feat)

In [67]:
# update molecule coordinates using prediction
from rdkit.Geometry import Point3D
from rdkit import Chem
from rdkit.Chem import AllChem
from utils.token_coords import get_token_coords

smi, ligand_xyz_ref, tokens, atom_map = get_token_coords(mol) 
mol = Chem.MolFromSmiles(smi_4mds)
AllChem.EmbedMolecule(mol)
conf = mol.GetConformer()
for i, xyz in enumerate(pred['ligand_frames_xyz'].squeeze(0)[1:-1]):
    idx = atom_map[i]

    if idx is not None:
        conf.SetAtomPosition(idx,Point3D(*xyz.astype(np.double)))

with Chem.SDWriter('TwoFold_DL/examples/ligand_pred_4mds_dimer.sdf') as w:
    w.write(mol)

[17:12:24] Molecule does not have explicit Hs. Consider calling AddHs()
