In [7]:
import esm
import pandas as pd
import numpy as np
import os, re
import pickle
from sklearn.metrics import roc_auc_score
import torch
import sys

file_path = "../model"
sys.path.append(file_path)
from dictionary import AutoEncoder

### load models

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval()
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

chk_path = '/path/to/MotifAE_step_80000.pt' # please download this file from zenodo
motifae = AutoEncoder.from_pretrained(chk_path)
motifae.eval()
motifae = motifae.to(device)

  state_dict = t.load(path)


### get ESM embeddings for sequences with ELM motif

In [10]:
seq_label = pd.read_csv('../data/elm_instances_seq_label_iupred_max1022.csv', index_col=0)
# for proteins longer than 1022, a region of length 1022 around the ELM motif is extracted, which is recorded in start, end, Sequence columns
# Label column: 1 for motif instance, 0 for other regions
# iupred columns: disorder scores predicted by iupred

In [None]:
for i in seq_label.index:
    data = [(seq_label.loc[i, 'id'], seq_label.loc[i, 'Sequence'])]
    batch_labels, batch_strs, batch_input = batch_converter(data)
    batch_input = batch_input.to(device)

    with torch.no_grad():
        outputs = esm_model(batch_input, repr_layers=[33])
        
    repr = outputs['representations'][33].cpu()

    save_path = f"../data/embed/{seq_label.loc[i, 'id']}.npz"
        
    np.savez_compressed(save_path, repr=repr[0,1:-1].numpy())

### compare MotifAE features with ELM motifs

In [None]:
elm_motifae_info = []
for elm, group in seq_label.groupby('ELMIdentifier'):
    positive = []
    negative = []
    for i in group.index:
        label = group.loc[i, 'Label']
        id = group.loc[i, 'id']
        repr = np.load(f'../data/embed/{id}.npz')['repr']
        with torch.no_grad():
            _, f = motifae(torch.tensor(repr).to(device), output_features=True)
        f = f.cpu().numpy()

        positive += [f[i] for i in range(len(label)) if label[i] == '1']
        negative += [f[i] for i in range(len(label)) if label[i] == '0']

    positive = np.array(positive)
    negative = np.array(negative)

    positive_mean = positive.mean(axis=0)
    negative_mean = negative.mean(axis=0)

    positive_negative_label = np.concatenate([np.ones(len(positive)), np.zeros(len(negative))])

    for f_index in np.where(positive_mean > negative_mean)[0]:
        f_value_index = np.concatenate([positive[:,f_index], negative[:,f_index]])

        auroc = roc_auc_score(positive_negative_label, f_value_index)

        elm_motifae_info.append((elm, f_index, positive_mean[f_index], negative_mean[f_index], auroc))