In [1]:
import os
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer

import sys
sys.path.insert(0, "../model")
from model import ESMwrap

### load the model

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
esm2_select = 'model_35M'
model_select = 'esmdance' # 'seqdance' or 'esmdance'
dance_model = ESMwrap(esm2_select, model_select)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Method 1: load from Hugging face

In [5]:
dance_model = dance_model.from_pretrained("ChaoHou/ESMDance")

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### Method 2: load from checkpoint in zenodo: https://zenodo.org/records/15047777

In [None]:
checkpoint = torch.load('update_***.pt')
dance_model.load_state_dict(checkpoint, strict=False)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(f'/nfs/user/Users/ch3849/ProDance/model/{model_date}/checkpoints/update_{chk}.pt')


ESMwrap(
  (esm2): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affine=True

In [6]:
dance_model = dance_model.to(device)
dance_model.eval()

ESMwrap(
  (esm2): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affin

### prepare dataset
please download Tsuboyama2023_Dataset2_Dataset3_20230416.csv from https://zenodo.org/records/7992926

In [None]:
df = pd.read_csv('Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv')

# data filtering: 1. have ddG value, 2. not insertion or deletion as the sequence lengths are different
df = df[(df['ddG_ML']!='-') & (~df['mut_type'].str.contains('ins|del'))]
df['ddG_ML'] = df['ddG_ML'].astype(float)

# get the mean ddG value for each aa_seq (some aa_seq have multiple experiments)
ddg_mean = df[['aa_seq','ddG_ML']].groupby('aa_seq').mean()
df = pd.merge(df.sort_values('mut_type', ascending=False).drop_duplicates('aa_seq')[['aa_seq','mut_type','WT_name','WT_cluster']], ddg_mean, on='aa_seq', how='left')

df.index = df['WT_name'] + '$' + df['mut_type']

  df = pd.read_csv('Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv')


In [None]:
df['WT_name'].value_counts()

WT_name
3DKM.pdb              5587
2MXD.pdb              5131
2KGT.pdb              4730
5VNT.pdb              4480
2LGW.pdb              4327
                      ... 
EEHEE_rd4_0308.pdb     599
EHEE_rd4_0864.pdb      571
6YSE.pdb               565
1WR4.pdb               501
2M9E.pdb               489
Name: count, Length: 412, dtype: int64

### get the zero-shot scores, using r10_572_TrROS_Hall as an example

In [7]:
pro = 'r10_572_TrROS_Hall.pdb'
df_pro = df[df['WT_name'] == pro]

In [8]:
# get the prediction for the wildtype sequence
wt_seq = df.loc[f'{pro}$wt','aa_seq']
wt_input = tokenizer(wt_seq, return_tensors="pt").to(device)
with torch.no_grad():
    wt_output = dance_model(wt_input)

In [9]:
# all mutations
pro_muts = df[df['WT_name']==pro]['mut_type']

# the epsilon is used to avoid division by zero
epsilon = 1e-2

for mt in pro_muts:
    # get the prediction for the mutant sequence
    mt_seq = df.loc[f'{pro}${mt}','aa_seq']
    mt_input = tokenizer(mt_seq, return_tensors="pt").to(device)
    with torch.no_grad():
        mt_output = dance_model(mt_input)

    # calculate the relative difference between the mutant and wildtype predictions
    for feature in wt_output:
        df_pro.loc[f'{pro}${mt}',feature] = (abs(mt_output[feature] - wt_output[feature]) / (abs(wt_output[feature]) + epsilon)).mean().item()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_pro.loc[f'{pro}${mt}',feature] = (abs(mt_output[feature] - wt_output[feature]) / (abs(wt_output[feature]) + epsilon)).mean().item()
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_pro.loc[f'{pro}${mt}',feature] = (abs(mt_output[feature] - wt_output[feature]) / (abs(wt_output[feature]) + epsilon)).mean().item()
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/

### combine zero-shot score of different dynamic properties

In [10]:
def quantile_normalization(features):
    """
    Applies quantile normalization to a set of feature vectors (each column represents a protein).
    Ensures all features have the same distribution across proteins.
    """
    features = np.array(features)
    
    # Rank transformation
    ranks = np.argsort(np.argsort(features, axis=0), axis=0)
    
    # Compute mean per rank across all features
    sorted_features = np.sort(features, axis=0)
    rank_means = np.mean(sorted_features, axis=1)
    
    # Map original values to rank means
    normalized_matrix = np.zeros_like(features, dtype=np.float64)
    for i in range(features.shape[1]):
        normalized_matrix[:,i] = rank_means[ranks[:,i]]
    
    return normalized_matrix

def get_normalized_values(df, features):
    """
    Normalizes the values of the features to have the same distribution across proteins.
    """
    normalized_value = quantile_normalization(df[features].values)
    normalized_df = pd.DataFrame(normalized_value, columns=[i + '_norm' for i in features])
    normalized_df.index = df.index
    df = pd.merge(df, normalized_df, left_index=True, right_index=True)

    return df

In [11]:
features = df_pro.columns[5:]
df_pro = get_normalized_values(df_pro, features)

# Convert to NumPy array
raw_features = df_pro[features].values
normalized_features = df_pro[[i + '_norm' for i in features]].values

# Compute row-wise weights (value / sum of row)
raw_weights = (raw_features + 1e-8) / (raw_features + 1e-8).sum(axis=1, keepdims=True)
norm_weights = (normalized_features + 1e-8) / (normalized_features + 1e-8).sum(axis=1, keepdims=True)

# Compute feature combinations
df_pro['raw_mean'] = raw_features.mean(axis=1)
df_pro['raw_max'] = raw_features.max(axis=1)
df_pro['raw_weighted_mean'] = np.sum(raw_features * raw_weights, axis=1)
df_pro['raw_geometric_mean'] = np.exp(np.mean(np.log(raw_features + 1e-8), axis=1))

df_pro['norm_mean'] = normalized_features.mean(axis=1)
df_pro['norm_max'] = normalized_features.max(axis=1)
df_pro['norm_weighted_mean'] = np.sum(normalized_features * norm_weights, axis=1)
df_pro['norm_geometric_mean'] = np.exp(np.mean(np.log(normalized_features + 1e-8), axis=1))

### spearman correlation between ddG_ML and relative changes
the correlations are negative, you can convert them to positive  
the combined score is not the best in our experiment, users can just use sasa_mean and psi, both show robust performance in our experiments

In [12]:
df_pro.iloc[:, 4:].corr(method='spearman')

Unnamed: 0,ddG_ML,sasa_mean,sasa_std,rmsf_nor,ss,chi,phi,psi,nma_res1,nma_res2,...,nma_pair2_norm,nma_pair3_norm,raw_mean,raw_max,raw_weighted_mean,raw_geometric_mean,norm_mean,norm_max,norm_weighted_mean,norm_geometric_mean
ddG_ML,1.0,-0.545935,-0.527522,-0.334406,-0.561495,-0.479828,-0.626823,-0.630722,-0.550547,-0.507444,...,-0.4814,-0.499153,-0.512962,-0.435832,-0.434384,-0.430254,-0.503443,-0.240062,-0.427017,-0.519448
sasa_mean,-0.545935,1.0,0.898843,0.688297,0.807039,0.677418,0.807609,0.826116,0.80085,0.821362,...,0.843753,0.845907,0.848826,0.717359,0.740537,0.791976,0.850895,0.457631,0.723522,0.874523
sasa_std,-0.527522,0.898843,1.0,0.669837,0.726851,0.651659,0.783677,0.775964,0.750917,0.775979,...,0.796804,0.802649,0.807158,0.690883,0.700391,0.758926,0.818001,0.466424,0.70824,0.83594
rmsf_nor,-0.334406,0.688297,0.669837,1.0,0.724649,0.444559,0.656517,0.691617,0.71842,0.755537,...,0.794997,0.816126,0.749766,0.629223,0.668052,0.776262,0.793148,0.43342,0.686315,0.809376
ss,-0.561495,0.807039,0.726851,0.724649,1.0,0.6386,0.885085,0.937814,0.82671,0.813902,...,0.860983,0.8515,0.843233,0.685282,0.727184,0.78566,0.842149,0.424117,0.700416,0.872508
chi,-0.479828,0.677418,0.651659,0.444559,0.6386,1.0,0.709802,0.680486,0.634752,0.656006,...,0.646876,0.646122,0.682761,0.565378,0.600427,0.559693,0.64517,0.388598,0.546304,0.665054
phi,-0.626823,0.807609,0.783677,0.656517,0.885085,0.709802,1.0,0.96212,0.81773,0.797757,...,0.83825,0.846012,0.819146,0.673859,0.704537,0.760329,0.824042,0.399748,0.677062,0.856737
psi,-0.630722,0.826116,0.775964,0.691617,0.937814,0.680486,0.96212,1.0,0.83433,0.82211,...,0.854651,0.855621,0.83943,0.698703,0.727635,0.774305,0.840062,0.402576,0.691631,0.872538
nma_res1,-0.550547,0.80085,0.750917,0.71842,0.82671,0.634752,0.81773,0.83433,1.0,0.877616,...,0.872898,0.871225,0.841238,0.711106,0.750707,0.789041,0.851284,0.425154,0.714146,0.877838
nma_res2,-0.507444,0.821362,0.775979,0.755537,0.813902,0.656006,0.797757,0.82211,0.877616,1.0,...,0.903621,0.903043,0.865478,0.74885,0.786635,0.809976,0.870196,0.467669,0.737261,0.896241


### Zero-shot using ESM2

In [13]:
import esm
# Load the ESM-2 model
esm_35, alphabet = esm.pretrained.esm2_t12_35M_UR50D()
esm_35 = esm_35.to(device)
esm_35.eval()
# Create a batch converter
batch_converter = alphabet.get_batch_converter()

In [14]:
# use wildtype sequence to get the logits, then get LLR, this perform similar to masked LLR
def get_wt_logits(esm_model, batch_converter, seq):
    batch_labels, batch_strs, batch_tokens = batch_converter([('name', seq)])
    with torch.no_grad():
        results = esm_model(batch_tokens.to(device))
    logits = torch.log_softmax(results["logits"],dim=-1)[0,:,:].cpu().numpy()[1:-1,:]
    return logits

# mask each residue, get the logits of masked residue, then get LLR
def get_mask_llr(esm_model, batch_converter, alphabet, seq):
    batch_labels, batch_strs, batch_tokens = batch_converter([('name', seq)])
    mask_logits = []
    for i in range(len(seq)):
        batch_tokens_masked = batch_tokens.clone()
        batch_tokens_masked[0, i+1] = alphabet.mask_idx
        with torch.no_grad():
            results = esm_model(batch_tokens_masked.to(device))
        mask_logits.append(results["logits"][:,i+1,:])
    
    mask_logits = torch.cat(mask_logits, dim=0)
    mask_logits = torch.log_softmax(mask_logits,dim=-1).cpu().numpy()

    return mask_logits

def logits_to_llr(logits, alphabet, seq):
    logits = pd.DataFrame(logits,columns=alphabet.all_toks,index=list(seq)).T
    wt_norm=np.diag(logits.loc[logits.columns])
    llr = logits-wt_norm
    
    llr.columns = [seq[i] + str(i+1) for i in range(len(seq))]
    
    llr = pd.DataFrame(llr.iloc[4:24].T.stack(), columns=['LLR']).reset_index()
    llr['mutant'] = llr['level_0'].str.replace('_','') + llr['level_1']
    return llr[['mutant','LLR']].set_index('mutant')

In [15]:
# get the zero-shot LLR using both wildtype and masked logits
wt_logits = get_wt_logits(esm_35, batch_converter, wt_seq)
wt_llr = logits_to_llr(wt_logits, alphabet, wt_seq)

mask_logits = get_mask_llr(esm_35, batch_converter, alphabet, wt_seq)
mask_llr = logits_to_llr(mask_logits, alphabet, wt_seq)

In [16]:
for mt in pro_muts:
    if mt != 'wt':
        # merge the LLR values if there are multiple mutations
        df_pro.loc[f'{pro}${mt}', 'wt_llr'] = wt_llr.loc[mt.split(':'), 'LLR'].sum()
        df_pro.loc[f'{pro}${mt}', 'mask_llr'] = mask_llr.loc[mt.split(':'), 'LLR'].sum()
    else:
        df_pro.loc[f'{pro}${mt}', 'wt_llr'] = 0
        df_pro.loc[f'{pro}${mt}', 'mask_llr'] = 0

In [17]:
df_pro.iloc[:,4:].corr(method='spearman')

Unnamed: 0,ddG_ML,sasa_mean,sasa_std,rmsf_nor,ss,chi,phi,psi,nma_res1,nma_res2,...,raw_mean,raw_max,raw_weighted_mean,raw_geometric_mean,norm_mean,norm_max,norm_weighted_mean,norm_geometric_mean,wt_llr,mask_llr
ddG_ML,1.0,-0.545935,-0.527522,-0.334406,-0.561495,-0.479828,-0.626823,-0.630722,-0.550547,-0.507444,...,-0.512962,-0.435832,-0.434384,-0.430254,-0.503443,-0.240062,-0.427017,-0.519448,0.165468,0.182909
sasa_mean,-0.545935,1.0,0.898843,0.688297,0.807039,0.677418,0.807609,0.826116,0.80085,0.821362,...,0.848826,0.717359,0.740537,0.791976,0.850895,0.457631,0.723522,0.874523,-0.305231,-0.315631
sasa_std,-0.527522,0.898843,1.0,0.669837,0.726851,0.651659,0.783677,0.775964,0.750917,0.775979,...,0.807158,0.690883,0.700391,0.758926,0.818001,0.466424,0.70824,0.83594,-0.318644,-0.323774
rmsf_nor,-0.334406,0.688297,0.669837,1.0,0.724649,0.444559,0.656517,0.691617,0.71842,0.755537,...,0.749766,0.629223,0.668052,0.776262,0.793148,0.43342,0.686315,0.809376,-0.416891,-0.416156
ss,-0.561495,0.807039,0.726851,0.724649,1.0,0.6386,0.885085,0.937814,0.82671,0.813902,...,0.843233,0.685282,0.727184,0.78566,0.842149,0.424117,0.700416,0.872508,-0.319991,-0.333279
chi,-0.479828,0.677418,0.651659,0.444559,0.6386,1.0,0.709802,0.680486,0.634752,0.656006,...,0.682761,0.565378,0.600427,0.559693,0.64517,0.388598,0.546304,0.665054,-0.228657,-0.242369
phi,-0.626823,0.807609,0.783677,0.656517,0.885085,0.709802,1.0,0.96212,0.81773,0.797757,...,0.819146,0.673859,0.704537,0.760329,0.824042,0.399748,0.677062,0.856737,-0.351082,-0.368664
psi,-0.630722,0.826116,0.775964,0.691617,0.937814,0.680486,0.96212,1.0,0.83433,0.82211,...,0.83943,0.698703,0.727635,0.774305,0.840062,0.402576,0.691631,0.872538,-0.342307,-0.35747
nma_res1,-0.550547,0.80085,0.750917,0.71842,0.82671,0.634752,0.81773,0.83433,1.0,0.877616,...,0.841238,0.711106,0.750707,0.789041,0.851284,0.425154,0.714146,0.877838,-0.476278,-0.490526
nma_res2,-0.507444,0.821362,0.775979,0.755537,0.813902,0.656006,0.797757,0.82211,0.877616,1.0,...,0.865478,0.74885,0.786635,0.809976,0.870196,0.467669,0.737261,0.896241,-0.420231,-0.425638
