In [1]:
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, EsmForMaskedLM
from tokenizers import Tokenizer
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from plm_compare_esm import *
from protein_data import *

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model_name = "facebook/esm2_t33_650M_UR50D"
model, tokenizer = initialize_esm2(model_name)

Using cpu device


In [13]:
amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
aa_list = [x for x in amino_acids]

In [6]:
# this was the probable sequence output for A102E OTC_HUMAN 
# according to the log_probs matrix from ESM2

sequence = 'MLFNLRILLNNAAFRNGHNFMVRNFRCGQPLQNKVQLKGRDLLTLKNFTGEEIKYMLWLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRLSTETGFELLGGHPCFLTTQDIHLGVNESLTDTARVLSSMADAVLARVYKQSDLDTLAKEASIPIINGLSDLYHPIQILADYLTLQEHYSSLKGLTLSWIGDGNNILHSIMMSAAKFGMHLQAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAHGGNVLITDTWISMGQEEEKKKRLQAFQGYQVTMKTAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTIMAVMVSLLTDYSPQLQKPKF'
esm_sequence = "MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRLSTETGMTLLGGHPIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQGYQVTMKLAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTIMAVMVSLLTDYSPQLQKPKF"

In [8]:
len(esm_sequence)

354

In [11]:
otc_h_A102E = dict()

lp, rlp, llr = collect_log_prob_esm2(esm_sequence, model, tokenizer)

otc_h_A102E['seq'] = sequence
otc_h_A102E['esm_seq'] = esm_sequence
otc_h_A102E['log_probs'] = lp
otc_h_A102E['ref_log_probs'] = rlp
otc_h_A102E['llr_matrix'] = llr

In [12]:
lp

tensor([[ -8.7734,  -9.0326, -10.1269,  ...,  -8.1146, -10.7045,  -9.2512],
        [ -2.0480,  -5.6513,  -6.4064,  ...,  -4.2742,  -7.0381,  -5.4738],
        [ -3.7519,  -4.6660,  -6.1298,  ...,  -4.3647,  -5.6725,  -4.8743],
        ...,
        [ -3.6057,  -5.6934,  -4.7780,  ...,  -3.5461,  -5.2440,  -3.9756],
        [ -4.2826,  -6.0343,  -5.5282,  ...,  -4.2942,  -5.8081,  -4.7914],
        [ -2.7027,  -5.1378,  -4.0889,  ...,  -2.7996,  -4.9509,  -3.3338]])

In [20]:
 
A = otc_h_A102E['log_probs']
eA = torch.exp(A)

n = len(eA)
ESM_mut2 = ""

count = 0
for i in range(n):
    valueA, indexA = torch.max(eA[i], dim=0)
    A_aa = aa_list[indexA]
    B_aa = esm_sequence[i]
    ESM_mut2 = ESM_mut2+A_aa
    if A_aa == B_aa:
        # print(i+1, "agree")
        count += 1
    else:
        print(i)
print(esm_sequence)
print(ESM_mut2)
print(f"They agree at {count} positions and disagree at {n-count} positions")

94
96
107
287
333
MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRLSTETGMTLLGGHPIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQGYQVTMKLAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTIMAVMVSLLTDYSPQLQKPKF
MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRVSFETGMTLLGGHAIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQGYQVTSKLAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTAMAVMVSLLTDYSPQLQKPKF
They agree at 349 positions and disagree at 5 positions


In [16]:
otc_h_A102E_2 = dict()

lp, rlp, llr = collect_log_prob_esm2(ESM_mut2, model, tokenizer)

#otc_h_A102E_2['seq'] = sequence
otc_h_A102E_2['esm_seq'] = ESM_mut2
otc_h_A102E_2['log_probs'] = lp
otc_h_A102E_2['ref_log_probs'] = rlp
otc_h_A102E_2['llr_matrix'] = llr

In [17]:
lp

tensor([[ -8.5529,  -8.7959,  -9.9032,  ...,  -7.8598, -10.4185,  -9.0471],
        [ -2.1592,  -5.7038,  -6.3150,  ...,  -4.2630,  -6.9779,  -5.4364],
        [ -3.8028,  -4.7254,  -6.0593,  ...,  -4.3939,  -5.6651,  -4.8482],
        ...,
        [ -3.6211,  -5.7774,  -4.8919,  ...,  -3.4805,  -5.2012,  -4.0108],
        [ -4.3227,  -6.1680,  -5.6949,  ...,  -4.3402,  -5.7744,  -4.8518],
        [ -2.5893,  -5.1565,  -4.0957,  ...,  -2.7746,  -4.8806,  -3.3529]])

In [19]:
 
A = otc_h_A102E_2['log_probs']
eA = torch.exp(A)

n = len(eA)
ESM_mut3 = ""

count = 0
for i in range(n):
    valueA, indexA = torch.max(eA[i], dim=0)
    A_aa = aa_list[indexA]
    B_aa = ESM_mut2[i]
    ESM_mut3 = ESM_mut3+A_aa
    if A_aa == B_aa:
        # print(i+1, "agree")
        count += 1
    else:
        print(i)
print(esm_sequence)
print(ESM_mut2)
print(ESM_mut3)
print(f"They agree at {count} positions and disagree at {n-count} positions")

102
353
MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRLSTETGMTLLGGHPIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQGYQVTMKLAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTIMAVMVSLLTDYSPQLQKPKF
MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRVSFETGMTLLGGHAIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQGYQVTSKLAKVAASDWTFLHCLPRKPEEVDDEVFYSPRSLVFPEAENRKWTAMAVMVSLLTDYSPQLQKPKF
MLFNLRILLNNAAFRNGHNFMVRNFRTGQPLQNKVQLKGRDLLTLKDFTGEEIKYMLDLSADLKFRIKQKGEYLPLLQGKSLGMIFEKRSTRTRVSFETGMTQLGGHAIFLTTQDIQLGVNESLTDTARVLSSMLDAVMARVYKQSDLDTLAKEASIPIINGLSDLYHPLQILADYLTLQEHYGSLKGLTLSWIGDGNNILHSWMMSAAKFGMNLRAATPKGYEPDASVTKLAEQYAKENGTKLLLTNDPLEAAKGANVLITDTWISMGQEEEKKKRLQAFQ