In [1]:
import os
import sys
home_dir = "../../"
module_path = os.path.abspath(os.path.join(home_dir))
if module_path not in sys.path:
    sys.path.append(module_path)

In [42]:
import torch
from tape import ProteinBertModel, TAPETokenizer
cache_dir=home_dir+"models/tape_rao_1/cache/protbert"
model = ProteinBertModel.from_pretrained('bert-base', cache_dir=cache_dir) # force_download=True, 
tokenizer = TAPETokenizer(vocab='iupac')  # iupac is the vocab for TAPE models, use unirep for the UniRep model

In [5]:
tokenizer.vocab

OrderedDict([('<pad>', 0),
             ('<mask>', 1),
             ('<cls>', 2),
             ('<sep>', 3),
             ('<unk>', 4),
             ('A', 5),
             ('B', 6),
             ('C', 7),
             ('D', 8),
             ('E', 9),
             ('F', 10),
             ('G', 11),
             ('H', 12),
             ('I', 13),
             ('K', 14),
             ('L', 15),
             ('M', 16),
             ('N', 17),
             ('O', 18),
             ('P', 19),
             ('Q', 20),
             ('R', 21),
             ('S', 22),
             ('T', 23),
             ('U', 24),
             ('V', 25),
             ('W', 26),
             ('X', 27),
             ('Y', 28),
             ('Z', 29)])

In [47]:
print(tokenizer.encode("MVNST"))
print(tokenizer.encode(list("MVNST")))

[ 2 16 25 17 22 23  3]
[ 2 16 25 17 22 23  3]


In [43]:
import numpy as np
def get_embed(seq):
    with torch.no_grad():
        token_ids = torch.tensor(np.array([tokenizer.encode(seq)]))
        output = model(token_ids)
        embedding = output[0][0].detach().numpy()
        embedding = embedding[1:-1]
        embedding = embedding.mean(0)
        print(embedding.shape)
    return embedding

In [44]:
seq = 'MVNSTHRGMHTSLHLWNRSSYRLHSNASESLGKGYSDGGCYEQLFVSPEVFVTLGVISLLENILVIVAIAKNKNLHSPMYFFICSLAVADMLVSVSNGSETIVITLLNSTDTDAQSFTVNIDNVIDSVICSSLLASICSLLSIAVDRYFTIFYALQYHNIMTVKRVGIIISCIWAACTVSGILFIIYSDSSAVIICLITMFFTMLALMASLYVHMFLMARLHIKRIAVLPGTGAIRQGANMKGAITLTILIGVFVVCWAPFFLHLIFYISCPQNPYCVCFMSHFNLYLILIMCNSIIDPLIYALRSQELRKTFKEIICCYPLGGLCDLSSRY'
wt_seq = seq
wt_embedding = get_embed(wt_seq)
wt_embedding

(768,)


array([ 2.80448467e-01, -7.94637680e-01,  3.08121204e-01, -2.54442871e-01,
        1.18473239e-01, -2.66037285e-01,  5.21305025e-01,  1.70480236e-02,
        5.78735955e-02,  6.44447803e-02,  2.22400382e-01, -2.70326465e-01,
        1.00110078e+00,  6.87336206e-01,  4.54183191e-01,  4.20318842e-01,
        1.38342154e+00,  9.47494134e-02,  4.89116758e-01, -3.16154689e-01,
       -2.18339741e-01, -7.65409768e-01, -2.43496373e-01,  1.10215224e-01,
       -5.38218558e-01,  5.33561707e-01,  8.64228845e-01, -1.29203424e-01,
        1.11910380e-01,  6.29362985e-02,  5.22455215e-01, -6.30238593e-01,
        7.82978255e-03, -9.06239688e-01,  3.99620503e-01, -5.73510379e-02,
       -2.59183615e-01, -3.39372344e-02, -5.46194255e-01, -5.96532896e-02,
        1.40064612e-01, -4.19979841e-01,  2.34356552e-01,  2.00310811e-01,
       -3.22492272e-01,  2.85786569e-01, -1.68238044e-01, -2.78418154e-01,
        1.14061058e-01,  5.48120677e-01,  7.92972445e-01,  5.32257594e-02,
       -8.15155208e-01, -

In [45]:
one_indexed_mut_pos, wt_aa, mt_aa = 271, "C", "Y"

mt_seq = list(seq)
mt_seq[one_indexed_mut_pos-1] = mt_aa
mt_seq = "".join(mt_seq)
print(wt_seq[one_indexed_mut_pos-1], wt_aa,  mt_seq[one_indexed_mut_pos-1], mt_aa)

mt_embedding = get_embed(mt_seq)
mt_embedding

C C Y Y
(768,)


array([ 2.85825074e-01, -8.01807463e-01,  3.04910451e-01, -2.45637417e-01,
        1.13982409e-01, -2.68112689e-01,  5.14541268e-01,  1.10080075e-02,
        7.08159804e-02,  3.83006707e-02,  2.16815904e-01, -2.62443244e-01,
        9.95166063e-01,  6.89275742e-01,  4.39479977e-01,  4.00036633e-01,
        1.36523712e+00,  9.95224267e-02,  4.74765420e-01, -2.99265057e-01,
       -1.96711868e-01, -7.41861761e-01, -2.50972003e-01,  9.57724750e-02,
       -5.48717022e-01,  5.36947489e-01,  8.49112630e-01, -1.32139623e-01,
        1.13996528e-01,  7.39057288e-02,  5.10660768e-01, -6.58378720e-01,
        1.01198927e-02, -8.81629407e-01,  4.03865963e-01, -4.97584045e-02,
       -2.61686683e-01, -3.30366194e-02, -5.41803837e-01, -4.78569828e-02,
        1.42444864e-01, -4.08200771e-01,  2.36097500e-01,  1.95851907e-01,
       -3.45793903e-01,  2.67188281e-01, -1.60611168e-01, -2.89834708e-01,
        1.16268791e-01,  5.39804697e-01,  7.70718098e-01,  6.14510551e-02,
       -8.32267463e-01, -

In [46]:
np.linalg.norm(mt_embedding - wt_embedding)

0.3075448