In [None]:
import os
import pandas as pd
import numpy as np
import h5py
import torch
from transformers import AutoTokenizer
from scipy.stats import spearmanr
from sklearn.metrics import roc_auc_score

import sys
sys.path.insert(0, "../model")
from config import config # please first download the dataset and fill in the config.py file with the path where you downloaded the dataset
from model import ESMwrap

### load model and dataset

In [2]:
# Set device
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Load model
esm2_select = 'model_35M'
model_select = 'seqdance' # or 'esmdance'
dance_model = ESMwrap(esm2_select, model_select).to(device)

# Load the SeqDance model from huggingface
dance_model = dance_model.from_pretrained("ChaoHou/ESMDance")
dance_model = dance_model.to(device)
dance_model.eval()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ESMwrap(
  (esm2): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 480, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 480, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=480, out_features=480, bias=True)
              (key): Linear(in_features=480, out_features=480, bias=True)
              (value): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=480, out_features=480, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((480,), eps=1e-05, elementwise_affin

In [None]:
# Load dataset, use test set
df = pd.read_csv(config['file_path']['train_df_path'])
df = df[df['label'] == 'test']

# Load HDF5 dataset
h5py_read = h5py.File(config['file_path']['h5py_path'], 'r')
max_len = 1024

In [3]:
# Load dataset
df = pd.read_csv("/nfs/user/Users/ch3849/ProDance/data_new/seq/name_seq_all_mdCATH_ATLAS_GPCRmd_PED_IDRome_proteinflow.csv")
df = df[df['label'] == 'test']

# Load HDF5 dataset
h5py_read = h5py.File("/ssd/Users/ch3849/prodance/feature_all_mdCATH_ATLAS_GPCRmd_PED_IDRome_proteinflow.h5", 'r')
max_len = 1024

### get pariwise dynamic properties with 240 SeqDance attention maps
using 1a0aA00 from mdCATH as an example

In [12]:
# get the sequence and pair feature of the protein
pro = '1a0aA00'
seq = df[df['name'] == pro]['modify_seq'].values[0]
pair_f = h5py_read[f'{pro}_pair_feature'][:]

In [14]:
# tokenize the sequence, for longer sequences, we need to truncate them to the max length
raw_input = tokenizer(seq, return_tensors="pt", max_length=max_len, truncation=True)
length = raw_input['input_ids'].shape[1]
pair_f = pair_f[:length, :length]

In [16]:
# get the attention map
with torch.no_grad():
    output = dance_model(raw_input.to(device), return_attention_map=True)
atten = output['attention_map'][0].cpu().numpy()

In [19]:
print(length, atten.shape, pair_f.shape)

65 (65, 65, 240) (65, 65, 10)


### compare pariwise dynamic properties with 240 SeqDance attention maps

In [20]:
# only analyze residues with distance > 2
row_indices, col_indices = np.where(np.abs(np.arange(length)[:, None] - np.arange(length)) > 2)

In [26]:
co_move = pair_f[:,:,9]
inter = (pair_f[:,:,:9]**3).sum(-1) # the interaction frequency is pow(x, 1/3) in the file, as we use this value to train the model

In [30]:
atten_scores = {'co_move_topL_ratio': [], 'co_move_posi_spearman': [], 'co_move_neg_spearman': [], 'inter_topL_ratio': [], 'inter_auroc': []}

#### for movement corelation

In [None]:
f = co_move[row_indices, col_indices]
mask = f != -1 # -1 is padding
f_flat = f[mask]

for k in range(atten.shape[-1]):
    att = atten[:, :, k][row_indices, col_indices]
    att_flat = att[mask]

    # Compute fold change
    top_L_indices = np.argsort(-att_flat)[:length]  # Negative sign for descending sort
    top_L_mean = np.mean(np.abs(f_flat[top_L_indices]))
    other_indices = np.setdiff1d(np.arange(f_flat.shape[0]), top_L_indices)
    other_mean = np.mean(np.abs(f_flat[other_indices]))

    atten_scores['co_move_topL_ratio'].append(top_L_mean / (other_mean + 1e-8))

    # Compute Spearman correlations
    atten_scores['co_move_posi_spearman'].append(spearmanr(att_flat[f_flat > 0], f_flat[f_flat > 0])[0])
    atten_scores['co_move_neg_spearman'].append(spearmanr(att_flat[f_flat < 0], f_flat[f_flat < 0])[0])

#### for summed interaction frequency

In [44]:
f = inter[row_indices, col_indices]
# for fold change
mask = f != -9 # -9 is summed padding of 9 interactions
f_flat = f[mask]

# for AUROC
mask_auroc = mask = (f == 0) | (f >= 0.01) # for binary classification, positive are interaction frequency > 0.01
f_auroc = f[mask_auroc]
f_binary = f_auroc > 0

for k in range(atten.shape[-1]):
    att = atten[:, :, k][row_indices, col_indices]
    att_flat = att[mask]

    # Compute fold change
    top_L_indices = np.argsort(-att_flat)[:length]  # Negative sign for descending sort
    top_L_mean = np.mean(np.abs(f_flat[top_L_indices]))
    other_indices = np.setdiff1d(np.arange(f_flat.shape[0]), top_L_indices)
    other_mean = np.mean(np.abs(f_flat[other_indices]))

    atten_scores['inter_topL_ratio'].append(top_L_mean / (other_mean + 1e-8))

    # Compute AUROC
    att_auroc = att[mask_auroc]
    atten_scores['inter_auroc'].append(roc_auc_score(f_binary, att_auroc))

In [47]:
df_atten_scores = pd.DataFrame(atten_scores)
df_atten_scores.describe()

Unnamed: 0,co_move_topL_ratio,co_move_posi_spearman,co_move_neg_spearman,inter_topL_ratio,inter_auroc
count,240.0,240.0,240.0,240.0,240.0
mean,1.44488,0.231459,-0.009377,1.091253,0.513889
std,0.595649,0.217605,0.077245,0.653905,0.042282
min,0.59757,-0.279531,-0.212586,0.07681,0.340415
25%,1.023115,0.057276,-0.053536,0.647623,0.489946
50%,1.20096,0.213669,-0.012227,0.941406,0.51642
75%,1.70629,0.414706,0.042398,1.422613,0.536317
max,3.10303,0.715159,0.240129,3.477425,0.645985
