In [1]:
from transformers import EsmForMaskedLM, EsmTokenizer, EsmConfig

In [22]:
tokenizer = EsmTokenizer.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_1')
model = EsmForMaskedLM.from_pretrained(f'facebook/esm1v_t33_650M_UR90S_1')

Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm1v_t33_650M_UR90S_1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
data_path = '../data/ProteinNPT_data/fitness/substitutions_singles'

In [21]:
import pandas as pd
import os
import re
import torch
def get_pos(row):
    pos = []
    for mut in row['mutant'].split(':'):
        result = int(re.findall(r'\d+', mut)[0])
        pos.append(result)
    if len(pos)<=1:return pos[0]
    else:
        return pos
def get_wt(seq, mut):
    # mut的输入为A2D, or A2D:B3C
    pos = []
    chars = []
    
    for mutation in mut.split(':'):
        original_char = mutation[0]  # 获取原始字符
        position = int(re.findall(r'\d+', mutation)[0])  # 获取位置
        pos.append(position)
        chars.append(original_char)  # 保存原始字符
    
    seq = list(seq)
    for i, p in enumerate(pos):
        seq[p - 1] = chars[i]  # 替换为原始字符
    return ''.join(seq)

def compute_score(model, seq, mask, wt, pos, tokenizer):
    '''
    compute mutational proxy using masked marginal probability
    :param seq:mutant seq
    :param mask:attention mask for input seq
    :param wt: wild type sequence
    :param pos:mutant position
    :return:
        score: mutational proxy score
        logits: output logits for masked sequence
    '''
    seq = torch.tensor(seq).to('cuda') 
    mask = torch.tensor(mask).to('cuda') 
    wt = torch.tensor(wt).to('cuda') 
    device = seq.device
    model.eval()
    mask_seq = seq.clone()
    m_id = tokenizer.mask_token_id

    batch_size = int(seq.shape[0])
    for i in range(batch_size):
        mut_pos = pos[i]
        mask_seq[i, mut_pos] = m_id

    out = model(mask_seq, mask, output_hidden_states=True)
    logits = out.logits
    log_probs = torch.log_softmax(logits, dim=-1)
    scores = torch.zeros(batch_size)
    scores = scores.to(device)

    for i in range(batch_size):

        mut_pos = pos[i]
        score_i = log_probs[i]
        wt_i = wt[i]
        seq_i = seq[i]
        scores[i] = torch.sum(score_i[mut_pos, seq_i[mut_pos]])-torch.sum(score_i[mut_pos, wt_i[mut_pos]])

    return scores, logits

In [20]:
# file_list = [file for  file in  os.listdir(data_path) if file.endswith('.csv')]
file_list = ['SRC_HUMAN_Chakraborty_2023_binding-DAS_25uM.csv',
             'VKOR1_HUMAN_Chiasson_2020_abundance.csv']
for csv_file in file_list:
    df = pd.read_csv(os.path.join(data_path,csv_file))
    df['mut_pos'] = df.apply(get_pos,axis = 1)
    wt_seq = get_wt(df['mutated_sequence'][0],df['mutant'][0])
    df['target_seq'] = wt_seq
    df['PID'] = df.index
    df = df[df['mut_pos']<1023]
    seq,mask = tokenizer(list(df['mutated_sequence']),
                         padding='max_length',
                         truncation=True,
                         max_length=1024).values()
    wt_seq,wt_mask = tokenizer(list(df['target_seq']),
                         padding='max_length',
                         truncation=True,
                         max_length=1024).values()
    scores, logits = compute_score(model,seq,mask,wt_seq,list(df['mut_pos']),tokenizer)


['SRC_HUMAN_Chakraborty_2023_binding-DAS_25uM.csv',
 'VKOR1_HUMAN_Chiasson_2020_abundance.csv',
 'RASK_HUMAN_Weng_2022_binding-DARPin_K55.csv',
 'CAS9_STRP1_Spencer_2017_positive.csv',
 'Q53Z42_HUMAN_McShan_2019_binding-TAPBPR.csv',
 'CASP3_HUMAN_Roychowdhury_2020.csv',
 'PSAE_SYNP2_Tsuboyama_2023_1PSE.csv',
 'OTU7A_HUMAN_Tsuboyama_2023_2L2D.csv',
 'TPOR_HUMAN_Bridgford_2020.csv',
 'FKBP3_HUMAN_Tsuboyama_2023_2KFV.csv',
 'GFP_AEQVI_Sarkisyan_2016.csv',
 'A4GRB6_PSEAI_Chen_2020.csv',
 'SPIKE_SARS2_Starr_2020_binding.csv',
 'RPC1_LAMBD_Li_2019_low-expression.csv',
 'UBR5_HUMAN_Tsuboyama_2023_1I2T.csv',
 'YNZC_BACSU_Tsuboyama_2023_2JVD.csv',
 'DYR_ECOLI_Thompson_2019.csv',
 'A4D664_9INFA_Soh_2019.csv',
 'NPC1_HUMAN_Erwood_2022_RPE1.csv',
 'RPC1_BP434_Tsuboyama_2023_1R69.csv',
 'HXK4_HUMAN_Gersing_2023_abundance.csv',
 'CP2C9_HUMAN_Amorosi_2021_abundance.csv',
 'HEM3_HUMAN_Loggerenberg_2023.csv',
 'BCHB_CHLTE_Tsuboyama_2023_2KRU.csv',
 'RASH_HUMAN_Bandaru_2017.csv',
 'HSP82_YEAST_Flynn_201

In [17]:
import math

def kl_divergence(p_dist, q_dist):
    kl_div = 0.0
    for symbol, p_prob in p_dist.items():
        q_prob = q_dist.get(symbol, 1e-10)  # 避免零概率
        kl_div += p_prob * math.log(p_prob / q_prob)
    return kl_div

kl_div1_to_2 = kl_divergence(data0, data1)
kl_div2_to_1 = kl_divergence(data1, data0)

kl_div1_to_2, kl_div2_to_1

AttributeError: 'list' object has no attribute 'items'

In [19]:
data1

array([-1.41241189, -1.48117833,  1.65891102, -0.60048029,  0.00744575,
        1.64104913, -0.17643328,  1.0361886 , -1.46832801, -0.55502682,
       -0.37343161, -1.17080446, -1.04137638,  0.37712358, -0.7282591 ,
        0.06380822,  0.37070707,  0.27831338,  0.39584432,  0.29614668,
        0.79065623,  0.29672463,  1.19139607,  0.33636635,  0.38340338,
       -0.52261109,  0.16975789, -0.69095027, -0.4883984 ,  0.74652992,
       -0.91683084, -0.437161  ,  2.10188804,  1.30295225,  0.10895416,
        0.45308116,  0.57935278, -0.58376389,  0.80617219,  0.05931934,
        3.20935726, -1.14372039, -0.53096137, -0.17874124,  0.17777646,
       -0.90496896,  0.01482596,  0.70671803,  0.24761787,  0.55001282,
        0.14662923,  0.79083877, -0.96689153, -1.57286868,  1.00104892,
        0.37765632,  0.56225167,  1.21401816,  0.50429521,  0.47379543,
        0.17178551, -0.02210353, -1.44456033, -0.69126186,  0.76961219,
       -1.37834517,  0.40260747, -0.44555963,  0.99388245,  2.09