In [18]:
import re
from rdkit import Chem

In [19]:
def add_space(pro):
    return pro.replace('',' ').strip()

In [20]:
def rm_map_number(smiles):
    t = re.sub(':\d*', '', smiles)
    return t
def canonicalize(smiles):
    try:
        smiles = rm_map_number(smiles)
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        else:
            return Chem.MolToSmiles(mol)
    except:
        return None

In [21]:
def smi_tokenizer(smi):
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    try:
        assert re.sub('\s+', '', smi) == ''.join(tokens)
    except:
        return ''

    return ' '.join(tokens)

In [22]:
# smiles_raw = "CN1CCN(CC1)CCCN2C3=CC=CC=C3SC4=C2C=C(C=C4)Cl"
smiles_5EH = "CN(C)CC/C=C/1\c2ccccc2COc3c1cccc3"
smiles_D7V = "CN(C)CC/C=C\\1/c2ccccc2COc3c1cccc3"
smiles_OLC = "CCCCCCCC\\C=C/CCCCCCCC(=O)OC[C@@H](CO)O"
smiles_PO4 = "[O-]P(=O)([O-])[O-]"
smiles_537 = "c1ccc2c(c1)-c3c4c(cccc4[nH]n3)C2=O"
smiles_2GM = "C[C@@]1(C(=O)N2[C@H](C(=O)N3CCC[C@H]3[C@@]2(O1)O)Cc4ccccc4)NC(=O)[C@@H]5C[C@@H]6c7cccc8c7c(c[nH]8)C[C@H]6N(C5)C"
smiles_IHI = "c1ccc(c(c1)Nc2c3c(nc(n2)C#N)n(cn3)C4CCCC4)OCCCn5ccnc5"
smiles_raw = smiles_IHI
smiles_can = canonicalize(smiles_raw)
smiles_bpe = smi_tokenizer(smiles_can)
smiles_can

'N#Cc1nc(Nc2ccccc2OCCCn2ccnc2)c2ncn(C3CCCC3)c2n1'

In [23]:
protein_0 = "MFGLKRNAVIGLNLYCGGAGLGAGSGGATRPGGRLLATEKEASARREIGGGEAGAVIGGSAGASPPSTLTPDSRRVARPPPIGAEVPDVTATPARLLFFAPTRRAAPLEEMEAPAADAIMSPEEELDGYEPEPLGKRPAVLPLLELVGESGNNTSTDGSLPSTPPPAEEEEDELYRQSLEIISRYLREQATGAKDTKPMGRSGATSRKALETLRRVGDGVQRNHETAFQGMLRKLDIKNEDDVKSLSRVMIHVFSDGVTNWGRIVTLISFGAFVAKHLKTINQESCIEPLAESITDVLVRTKRDWLVKQRGWDGFVEFFHVEDLEGGIRNVLLAFAGVAGVGAGLAYLIR"
protein_1 = "MSFLGFGGGQPQLSSQQKIQAAEAELDLVTDMFNKLVNNCYKKCINTSYSEGELNKNESSCLDRCVAKYFETNVQVGENMQKMGQSFNAAGKF"
protein_2 = "MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTPDHHCQSPGVAELSQRCGWSPAEELNYTVPGLGPAGEAFLGQCRRYEVDWNQSALSCVDPLASLATNRSHLPLGPCQDGWVYDTPGSSIVTEFNLVCADSWKLDLFQSCLNAGFLFGSLGVGYFADRFGRKLCLLGTVLVNAVSGVLMAFSPNYMSMLLFRLLQGLVSKGNWMAGYTLITEFVGSGSRRTVAIMYQMAFTVGLVALTGLAYALPHWRWLQLAVSLPTFLFLLYYWCVPESPRWLLSQKRNTEAIKIMDHIAQKNGKLPPADLKMLSLEEDVTEKLSPSFADLFRTPRLRKRTFILMYLWFTDSVLYQGLILHMGATSGNLYLDFLYSALVEIPGAFIALITIDRVGRIYPMAMSNLLAGAACLVMIFISPDLHWLNIIIMCVGRMGITIAIQMICLVNAELYPTFVRNLGVMVCSSLCDIGGIITPFIVFRLREVWQALPLILFAVLGLLAAGVTLLLPETKGVALPETMKDAENLGRKAKPKENTIYLKVQTSEPSGT"
protein_3M0W = "ACPLEKALDVMVSTFHKYSGKEGDKFKLNKSELKELLTRELPSFLGKRTDEAAFQKLMSNLDSNRDNEVDFQEYCVFLSCIAMMCNEFFEGFPDKQPRKK"
protein_3RZE = "TTMASPQLMPLVVVLSTICLVTVGLNLLVLYAVRSERKLHTVGNLYIVSLSVADLIVGAVVMPMNILYLLMSKWSLGRPLCLFWLSMDYVASTASIFSVFILCIDRYRSVQQPLRYLKYRTKTRASATILGAWFLSFLWVIPILGWNHFMQQTSVRREDKCETDFYDVTWFKVMTAIINFYLPTLLMLWFYAKIYKAVRQHCNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYLHMNRERKAAKQLGFIMAAFILCWIPYFIFFMVIAFCKNCCNEHLHMFTIWLGYINSTLNPLIYPLCNENFKKTFKRILHIRSGENLYFQ"
protein_4IAQ = "GGTCSAKDYIYQDSISLPWKVLLVMLLALITLATTLSNAFVIATVYRTRKLHTPANYLIASLAVTDLLVSILVMPISTMYTVTGRWTLGQVVCDFWLSSDITCCTASIWHLCVIALDRYWAITDAVEYSAKRTPKRAAVMIALVWVFSISISLPPFFWRQAKAEEEVSECVVNTDHILYTVYSTVGAFYFPTLLLIALYGRIYVEARSRIADLEDNWETLNDNLKVIEKADNAAQVKDALTKMRAAALDAQKATPPKLEDKSPDSPEMKDFRHGFDILVGQIDDALKLANEGKVKEAQAAAEQLKTTRNAYIQKYLLMAARERKATKTLGIILGAFIVCWLPFFIISLVMPICKDACWFHLAIFDFFTWLGYLNSLINPIIYTMSNEDFKQAFHKLIRFKCTS"
protein_1UKI = "MSRSKRDNNFYSVEIGDSTFTVLKRYQNLKPIGSGAQGIVCAAYDAILERNVAIKKLSRPFQNQTHAKRAYRELVLMKCVNHKNIIGLLNVFTPQKSLEEFQDVYIVMELMDANLCQVIQMELDHERMSYLLYQMLCGIKHLHSAGIIHRDLKPSNIVVKSDCTLKILDFGLARTAGTSFMMTPYVVTRYYRAPEVILGMGYKENVDIWSVGCIMGEMIKGGVLFPGTDHIDQWNKVIEQLGTPCPEFMKKLQPTVRTYVENRPKYAGYSFEKLFPDVLFPADSEHNKLKASQARDLLSKMLVIDASKRISVDEALQHPYINVWYDPSEAEAPPPKIPDKQLDEREHTIEEWKELIYKEVMDLHHHHHH"
protein_1U9W = "GRAPDSVDYRKKGYVTPVKNQGQCGSCWAFSSVGALEGQLKKKTGKLLNLSPQNLVDCVSENDGCGGGYMTNAFQYVQKNRGIDSEDAYPYVGQEESCMYNPTGKAAKCRGYREIPEGNEKALKRAVARVGPVSVAIDASLTSFQFYSKGVYYDESCNSDNLNHAVLAVGYGIQKGNKHWIIKNSWGENWGNKGYILMARNKNNACGIANLASFPKM"

In [24]:
protein = add_space(protein_1U9W)
protein

'G R A P D S V D Y R K K G Y V T P V K N Q G Q C G S C W A F S S V G A L E G Q L K K K T G K L L N L S P Q N L V D C V S E N D G C G G G Y M T N A F Q Y V Q K N R G I D S E D A Y P Y V G Q E E S C M Y N P T G K A A K C R G Y R E I P E G N E K A L K R A V A R V G P V S V A I D A S L T S F Q F Y S K G V Y Y D E S C N S D N L N H A V L A V G Y G I Q K G N K H W I I K N S W G E N W G N K G Y I L M A R N K N N A C G I A N L A S F P K M'

In [25]:
# len(protein_3RZE)

In [26]:
from fairseq.models.roberta import RobertaModel
import numpy as np
roberta = RobertaModel.from_pretrained(
    f'/protein/users/v-qizhipei/checkpoints/roberta_char_bsz256_nopretrain_separate_wd0.1_dp0.1_layer12_hongda_mlm_regression_cross_attn_mod_pad',
    checkpoint_file=f'checkpoint154.pt',
    data_name_or_path='/protein/users/v-qizhipei/data-bin/BindingDB_hongda_char_for_pretrain',
    arch = 'roberta_dti_mlm_regress_case_study'
)

roberta.cuda()
roberta.eval()

RobertaHubInterface(
  (model): RobertaDTIMLMRegressCaseStudy(
    (encoder_0): DTIRobertaEncoder(
      (sentence_encoder): TransformerSentenceEncoder(
        (dropout_module): FairseqDropout()
        (embed_tokens): Embedding(2353, 768, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(514, 768, padding_idx=1)
        (layers): ModuleList(
          (0): TransformerSentenceEncoderLayer(
            (dropout_module): FairseqDropout()
            (activation_dropout_module): FairseqDropout()
            (self_attn): MultiheadAttention(
              (dropout_module): FairseqDropout()
              (k_proj): Linear(in_features=768, out_features=768, bias=True)
              (v_proj): Linear(in_features=768, out_features=768, bias=True)
              (q_proj): Linear(in_features=768, out_features=768, bias=True)
              (out_proj): Linear(in_features=768, out_features=768, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05,

In [27]:
tokens_0, tokens_1 = roberta.myencode_separate(smiles_bpe, protein)
predictions, cls_0_attn_1, cls_1_attn_0 = roberta.myextract_features_separate_case_study(tokens_0, tokens_1)
predictions, cls_0_attn_1, cls_1_attn_0

(tensor([[5.6512]], device='cuda:0', grad_fn=<AddmmBackward0>),
 tensor([[[0.0077, 0.0002, 0.0018, 0.0017, 0.0068, 0.0005, 0.0031, 0.0003,
           0.0004, 0.0049, 0.0107, 0.0010, 0.0017, 0.0006, 0.0019, 0.0008,
           0.0010, 0.0089, 0.0002, 0.0018, 0.0011, 0.0062, 0.0006, 0.0005,
           0.0188, 0.0009, 0.0023, 0.0200, 0.0153, 0.0016, 0.0012, 0.0008,
           0.0020, 0.0009, 0.0003, 0.0010, 0.0007, 0.0012, 0.0015, 0.0036,
           0.0010, 0.0014, 0.0007, 0.0041, 0.0004, 0.0026, 0.0092, 0.0012,
           0.0004, 0.0031, 0.0003, 0.0040, 0.0090, 0.0014, 0.0098, 0.0003,
           0.0005, 0.0008, 0.0037, 0.0007, 0.0014, 0.0017, 0.0012, 0.0004,
           0.0011, 0.0109, 0.0005, 0.0018, 0.0007, 0.0034, 0.0020, 0.0006,
           0.0081, 0.0047, 0.0080, 0.0028, 0.0010, 0.0004, 0.0045, 0.0024,
           0.0017, 0.0061, 0.0003, 0.0007, 0.0005, 0.0074, 0.0010, 0.0006,
           0.0010, 0.0428, 0.0374, 0.0014, 0.0008, 0.0003, 0.0047, 0.0008,
           0.0006, 0.0010, 0.0043, 0

In [28]:
cls_0_attn_1 = cls_0_attn_1.squeeze()
cls_0_attn_1.size()

torch.Size([219])

In [29]:
len(cls_0_attn_1)

219

In [30]:
cls_0_attn_1

tensor([0.0077, 0.0002, 0.0018, 0.0017, 0.0068, 0.0005, 0.0031, 0.0003, 0.0004,
        0.0049, 0.0107, 0.0010, 0.0017, 0.0006, 0.0019, 0.0008, 0.0010, 0.0089,
        0.0002, 0.0018, 0.0011, 0.0062, 0.0006, 0.0005, 0.0188, 0.0009, 0.0023,
        0.0200, 0.0153, 0.0016, 0.0012, 0.0008, 0.0020, 0.0009, 0.0003, 0.0010,
        0.0007, 0.0012, 0.0015, 0.0036, 0.0010, 0.0014, 0.0007, 0.0041, 0.0004,
        0.0026, 0.0092, 0.0012, 0.0004, 0.0031, 0.0003, 0.0040, 0.0090, 0.0014,
        0.0098, 0.0003, 0.0005, 0.0008, 0.0037, 0.0007, 0.0014, 0.0017, 0.0012,
        0.0004, 0.0011, 0.0109, 0.0005, 0.0018, 0.0007, 0.0034, 0.0020, 0.0006,
        0.0081, 0.0047, 0.0080, 0.0028, 0.0010, 0.0004, 0.0045, 0.0024, 0.0017,
        0.0061, 0.0003, 0.0007, 0.0005, 0.0074, 0.0010, 0.0006, 0.0010, 0.0428,
        0.0374, 0.0014, 0.0008, 0.0003, 0.0047, 0.0008, 0.0006, 0.0010, 0.0043,
        0.0032, 0.0150, 0.0027, 0.0210, 0.0003, 0.0002, 0.0082, 0.0003, 0.0007,
        0.0008, 0.0056, 0.0033, 0.0005, 

In [31]:
print(smiles_can)
len(smiles_can)

N#Cc1nc(Nc2ccccc2OCCCn2ccnc2)c2ncn(C3CCCC3)c2n1


47

In [32]:
len(smiles_bpe.split())

47

In [33]:
tmp = cls_0_attn_1[1: len(cls_0_attn_1) - 1]
# tmp, tmp.size()

In [34]:
import torch
sorted, indices = torch.sort(tmp, descending=True)
indices = indices + 1
# sorted, indices
indices

tensor([215, 155, 171,  89,  90, 132, 102,  27, 113,  24, 190, 145,  28, 100,
        159, 129, 174,  65, 157,  10, 149, 201,  54, 206,  46,  52,  17, 189,
        105,  72,  74, 200, 125,  85,   4, 116, 165,  21,  81, 203, 109, 177,
        161, 153,   9,  94,  73,  78, 210, 186,  98,  43,  51, 112,  58,  39,
        185, 158,  69, 110,  99,   6,  49, 216, 122, 114, 213, 137,  75, 101,
        204,  45, 195, 192,  79,  26, 193, 183, 152,  32, 180,  70, 181,  14,
        141,  67,   2,  19, 120, 134,  80,  12,   3,  61, 119, 163,  29, 169,
        124, 151,  38, 196, 121, 191,  41,  91,  60, 187, 136, 150,  53, 144,
        179, 147,  47,  37,  30,  62, 164,  20,  64, 138,  35,  86, 166, 184,
         11,  76, 202,  88,  97,  40,  16, 143, 126, 140, 154, 212,  25, 148,
         33, 108,  92,  57,  95,  31, 175,  15, 127, 142,  59, 117,  83, 170,
        168, 107,  68,  42,  36,  87,  22, 178, 214,  13,  71, 205, 123, 160,
         96, 199, 188, 217,   5,  66, 176,  23, 111, 172, 156, 2

In [None]:
'''
protein_3M0W
A: 5 9 43 85 88
B: 5 9 13 43 84 85 87 88 92
C: 5 9 13 85 88
D: 5 9 13 43 85 88
E: 9 13 44 85 88
F: 5 9 13 44 85 88
G: 5 9 85 88
H: 5 9 13 85 88 89
I: 5 9 13 43 44 46 84 85 88
J: 5 9 85 88 92
'''

Molecule

In [21]:
def del_tensor_element(arr,index):
    arr1 = arr[0:index]
    arr2 = arr[index+1:]
    return torch.cat((arr1,arr2),dim=0)

In [22]:
cls_1_attn_0 = cls_1_attn_0.squeeze()
cls_1_attn_0.size()

torch.Size([39])

In [23]:
del_indices = []
tokens = smiles_bpe.split()
del_indices.append(0)
for i in range(37):
    if not tokens[i].isalpha():
        del_indices.append(i + 1)
del_indices.append(len(cls_1_attn_0) - 1)
del_indices

[0, 3, 7, 12, 14, 20, 23, 27, 29, 32, 33, 34, 37, 38]

In [24]:
for i in reversed(del_indices):
    cls_1_attn_0 = del_tensor_element(cls_1_attn_0, i)
cls_1_attn_0

tensor([0.0198, 0.0370, 0.0169, 0.0194, 0.0491, 0.0355, 0.0379, 0.0446, 0.0865,
        0.0256, 0.0194, 0.0202, 0.0211, 0.0203, 0.0204, 0.0610, 0.0235, 0.0175,
        0.0167, 0.0189, 0.0304, 0.0164, 0.0145, 0.0219, 0.0265],
       device='cuda:0', grad_fn=<CatBackward0>)

In [25]:
import torch
sorted, indices = torch.sort(cls_1_attn_0, descending=True)
sorted, indices

(tensor([0.0865, 0.0610, 0.0491, 0.0446, 0.0379, 0.0370, 0.0355, 0.0304, 0.0265,
         0.0256, 0.0235, 0.0219, 0.0211, 0.0204, 0.0203, 0.0202, 0.0198, 0.0194,
         0.0194, 0.0189, 0.0175, 0.0169, 0.0167, 0.0164, 0.0145],
        device='cuda:0', grad_fn=<SortBackward0>),
 tensor([ 8, 15,  4,  7,  6,  1,  5, 20, 24,  9, 16, 23, 12, 14, 13, 11,  0, 10,
          3, 19, 17,  2, 18, 21, 22], device='cuda:0'))

Bad pipe message: %s [b'\xf6\x85\xf2\x87\x89\x9eo\xd0\x07sL2\x02\xaf\x0ba*\xac _o\xc0\x9f\n\x11BI\x99\x0b\x92\x8cc\xf4\xdf\x84\xbb\x8fC\xfa\xf9Y\r(\x84|\x84\xdc5\x05X\xda\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 ', b'']
Bad pipe message: %s [b"\xf5o\x96\xe4_&\x89\xc0\xe2\x88\xc5\xe6\xdc@@\xc9U+\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q