In [1]:
import os
import torch
import esm
from esm import Alphabet, FastaBatchedDataset, ProteinBertModel, pretrained, MSATransformer
import pathlib
import pandas as pd
import numpy as np

In [2]:
device = 'cuda:3'

In [3]:
model_location='/home2/kangboming/kangboming/workspace/PIC_revise/model/ESM/esm2_t33_650M_UR50D.pt'

In [4]:
### 加载ESM模型
model, alphabet = esm.pretrained.load_model_and_alphabet(model_location)
batch_converter = alphabet.get_batch_converter()

In [5]:
### 获得概率值：GPU版
def get_logits(seq, model,alphabet,batch_converter,device,format=None):
    AAorder=['R','K','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']
    data = [("_", seq), ("_", seq), ("_", seq), ("_", seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)  # Move batch_tokens to GPU
    model = model.to(device)  # Move model to GPU
    with torch.no_grad():
        logits = torch.softmax(model(batch_tokens, repr_layers=[33], return_contacts=False)["logits"], dim=-1).cpu().numpy()
        # return logits
    if format == 'pandas':
        WTlogits = pd.DataFrame(logits[0][1:-1, :], columns=alphabet.all_toks, index=list(seq)).T.iloc[4:24].loc[AAorder]
        WTlogits.columns = [j.split('.')[0] + '_' + str(i + 1) for i, j in enumerate(WTlogits.columns)]
        return WTlogits
    if format == 'array':
        WTlogits = pd.DataFrame(logits[0][1:-1, :], columns=alphabet.all_toks, index=list(seq)).T.iloc[4:24].loc[AAorder]
        WTlogits_array = WTlogits.values
        return WTlogits_array
    else:
        return logits[0][1:-1, :]

In [6]:
def get_logits_batch(seq_lst, model,alphabet,batch_converter,device,format=None):
    AAorder=['R','K','H','E','D','N','Q','T','S','C','G','A','V','L','I','M','P','Y','F','W']
    data = [("_", seq),]
    print(data)
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(device)  # Move batch_tokens to GPU
    model = model.to(device)  # Move model to GPU
    with torch.no_grad():
        logits = torch.softmax(model(batch_tokens, repr_layers=[33], return_contacts=False)["logits"], dim=-1).cpu().numpy()
    if format == 'pandas':
        WTlogits = pd.DataFrame(logits[0][1:-1, :], columns=alphabet.all_toks, index=list(seq)).T.iloc[4:24].loc[AAorder]
        WTlogits.columns = [j.split('.')[0] + '_' + str(i + 1) for i, j in enumerate(WTlogits.columns)]
        return WTlogits
    if format == 'array':
        WTlogits = pd.DataFrame(logits[0][1:-1, :], columns=alphabet.all_toks, index=list(seq)).T.iloc[4:24].loc[AAorder]
        WTlogits_array = WTlogits.values
        return WTlogits_array
    else:
        return logits[0][1:-1, :]

In [7]:
def calculate_entropy(prob_matrix):

    # Ensure probabilities are non-zero to avoid log(0)
    prob_matrix = np.clip(prob_matrix, 1e-10, 1.0)
    
    # Calculate entropy for each position
    entropy = -np.sum(prob_matrix * np.log2(prob_matrix), axis=-1)
    return entropy

In [8]:
np.log2(0.05)

-4.321928094887363

In [8]:
test_seq = 'ELKMDQALLLIHNELLWTNLTVYWKSECCYHCLFQVLVNVPQSPKAGKPSAAAASVSTQHGSILQLNDTLEEKEVCRLEYRFGEFGNYSLLVKNIHEIACDLAVNEDPVDSNLPVSIAFLIGLAVIIVISFLRLLLSLDDFNNWISKAIPPRLRSVDTFRGIALILMVFVNYGGGKYWYFKHASWNGLTVADLVFPWFVFIMGSSIFLSMTSILQRGCSKFRLLGKIAWRSFLLICIGIIIVNPNYCLGPLSWDKVRIPGVLQRLGVTYFVVAVLELLFAKPVPEHCASERSCLSLRDITSSWPQWLLILVLEGLWLGLTFLLPVPGCPTGYLGPGGIGDFGKYPNCTGGAAGYIDRLLLGDDHLYQHPSSAVLYHTEVAYDPEGILGTINSIVMAFLGVQAGKILLYYKARTKDILIRFTAWCCILGLISVALTKVSENEGFIPVNKNLWSLSYVTTLSSFAFFILLVLYPVVDVKGLWTGTPFFYPGMNSILVYVGHEVFENYFPFQWKLKDNQSHKEHLTQNIVATALWVLIAYILYRKKIFWKI'

In [9]:
len(test_seq)

548

In [10]:
output = get_logits(test_seq, model, alphabet, batch_converter, device, format='array')
output

array([[1.2250018e-02, 3.9223229e-05, 1.1716114e-02, ..., 4.7875488e-05,
        9.8384253e-04, 1.0958149e-06],
       [1.0675518e-02, 2.8816390e-05, 8.7031573e-01, ..., 3.8456769e-06,
        9.8915803e-01, 1.5217583e-06],
       [1.3588141e-02, 2.0262183e-05, 5.3340937e-03, ..., 8.1742326e-07,
        1.3227669e-05, 5.3689786e-07],
       ...,
       [4.3586232e-03, 5.9041125e-05, 2.3313689e-03, ..., 1.2329894e-05,
        2.7305873e-06, 1.3298820e-06],
       [2.7868692e-03, 9.1796427e-04, 1.3474926e-03, ..., 3.6789401e-04,
        1.8143755e-06, 3.5933281e-05],
       [6.7175296e-04, 4.3809534e-05, 6.1142072e-04, ..., 9.8527050e-01,
        1.1422233e-06, 8.0947632e-07]], dtype=float32)

In [11]:
output.shape

(20, 548)

In [12]:
ent = calculate_entropy(output.T)
ent

array([1.99262369e+00, 5.56528941e-02, 1.03980267e+00, 1.35550451e+00,
       2.36687362e-02, 4.44860548e-01, 5.74263260e-02, 5.70762157e-01,
       3.81424874e-02, 1.59127212e+00, 3.93836558e-01, 1.85096884e+00,
       1.20384119e-01, 2.45319033e+00, 9.01113868e-01, 2.15786552e+00,
       2.98788428e+00, 3.06583762e+00, 5.17395735e-01, 2.11477801e-01,
       5.59312224e-01, 6.47761881e-01, 5.55434108e-01, 1.55031955e+00,
       4.00108516e-01, 1.81208730e-01, 1.07898355e+00, 3.87053847e+00,
       5.22520160e-03, 5.80141246e-01, 2.43029189e+00, 1.51541121e-02,
       1.63787174e+00, 3.96122158e-01, 5.67766368e-01, 3.36733651e+00,
       7.32173383e-01, 2.30593276e+00, 2.09550047e+00, 3.84388894e-01,
       6.31459415e-01, 2.80002308e+00, 1.71389949e+00, 2.01333857e+00,
       2.89870524e+00, 2.00612497e+00, 1.68606305e+00, 2.40589714e+00,
       2.31904578e+00, 1.89286888e+00, 1.90348792e+00, 9.91261601e-01,
       2.47729540e+00, 1.80909872e+00, 2.52624965e+00, 6.26598775e-01,
      

In [13]:
weight = (1 + 10/(ent+1)) / (1 + 10/5.3)
weight

array([1.5039355, 3.6278365, 2.0446346, 1.8170254, 3.7303634, 2.7439046,
       3.622333 , 2.5517375, 3.6831846, 1.6832206, 2.8316696, 1.5614493,
       3.4382486, 1.3495507, 2.1685224, 1.4433653, 1.2150494, 1.198395 ,
       2.6292984, 3.2057664, 2.5679312, 2.4486825, 2.5734699, 1.7046868,
       2.8205366, 3.2790387, 2.0126293, 1.0576309, 3.7924514, 2.5386474,
       1.3562471, 3.7587466, 1.6596048, 2.827601 , 2.5559514, 1.1395781,
       2.346235 , 1.3942343, 1.4654658, 2.8486302, 2.4696896, 1.2579924,
       1.6228164, 1.4959781, 1.2349187, 1.4987367, 1.6360445, 1.3634801,
       1.3900945, 1.5438504, 1.5394711, 2.0860322, 1.3425968, 1.5795596,
       1.3287668, 2.4760344, 1.7040498, 3.412456 , 1.8142215, 2.095312 ,
       1.3088373, 1.3747687, 1.7385228, 3.2595224, 1.419225 , 2.3496897,
       1.1699681, 1.1782144, 1.5498102, 1.4121457, 1.4959573, 1.704027 ,
       1.3426752, 2.2748432, 2.1303089, 3.7897305, 2.3446133, 2.918709 ,
       1.618211 , 1.9592689, 2.2584598, 3.7112415, 

In [15]:
seq_lst = list(test_seq)

In [16]:
df1 = pd.DataFrame({'sequence':seq_lst, 'entropy':ent, 'weight':weight})
df1

Unnamed: 0,sequence,entropy,weight
0,E,1.992624,1.503935
1,L,0.055653,3.627836
2,K,1.039803,2.044635
3,M,1.355505,1.817025
4,D,0.023669,3.730363
...,...,...,...
543,I,0.430244,2.768406
544,F,0.013231,3.765222
545,W,0.118205,3.444275
546,K,0.103014,3.486939


In [18]:
df1.to_csv('weight_of_8vkj_A.csv', index=False)