In [1]:
import esm
import sys, os
import pandas as pd
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.stats import spearmanr

file_path = "../model"
sys.path.append(file_path)
from dictionary import AutoEncoder

### load models and data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval()
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

chk_path = '/path/to/MotifAE_step_80000.pt' # please download this file from zenodo: https://zenodo.org/records/17488191
motifae = AutoEncoder.from_pretrained(chk_path)
motifae.eval()
motifae = motifae.to(device)

  state_dict = t.load(path)


In [3]:
embed = pickle.load(open('../data/embedding_412pro.pkl', 'rb'))
label = pickle.load(open('../data/mutation_label.pkl', 'rb'))

embed = {key:value.to(device) for key, value in embed.items()}
label = {key:value.to(device) for key, value in label.items()}

pro = pd.read_csv('../data/412pro_info.csv')
pro_train = pro[pro['split'] == 'train'].reset_index(drop=True)
pro_test = pro[pro['split'] == 'test'].reset_index(drop=True)

### gate model

In [4]:
lr = 1e-3
n_epoch = 4 # in our manuscript, we trained for 60 epochs
lambda_sparse = 1

In [None]:
class BinaryGate(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(embed_dim))

    def forward(self):
        probs = torch.sigmoid(self.logits)
        # STE: forward uses hard 0/1, backward uses gradient through probs
        hard = (probs > 0.5).float()
        gates = hard + probs - probs.detach()
        return gates, probs

In [6]:
# ---------------------
# Differentiable "soft rank" -> used to compute a differentiable Spearman proxy
# Implementation: soft_rank_i = 1 + sum_j sigmoid((x_j - x_i)/tau)
# tau controls smoothness (smaller => closer to true rank)
# ---------------------
def soft_rank(x: torch.Tensor, tau: float = 1e-2):
    x_i = x.unsqueeze(0)
    x_j = x.unsqueeze(1)
    # pairwise differences x_j - x_i -> shape (N, N)
    pairwise = (x_j - x_i) / tau
    # sigmoid sums: for each i, count how many x_j > x_i
    s = torch.sigmoid(pairwise)
    ranks = 1.0 + s.sum(dim=1)
    return ranks

def pearson_corr(x: torch.Tensor, y: torch.Tensor, eps=1e-8):
    x_mean = x.mean()
    y_mean = y.mean()
    xm = x - x_mean
    ym = y - y_mean
    r_num = (xm * ym).sum()
    r_den = torch.sqrt((xm * xm).sum() * (ym * ym).sum() + eps)
    return r_num / (r_den + eps)

def get_input_label(i, split='train'):
    if split == 'test':
        pro, seq = pro_test.loc[i, ['WT_name', 'aa_seq']]
    else:
        pro, seq = pro_train.loc[i, ['WT_name', 'aa_seq']]
    batch_labels, batch_strs, batch_tokens = batch_converter([(pro, seq)])
    batch_tokens = batch_tokens[0,1:-1].to(device)

    embed_pro = embed[pro][0,1:-1].to(device)
    label_pro = label[f'{pro}_{split}'].to(device)

    return pro, batch_tokens, embed_pro, label_pro

def get_llr(gated_recon, batch_tokens, label_pro):
    logits = esm_model.lm_head(gated_recon)
    log_probs = torch.log_softmax(logits, dim=1)
    
    llr = log_probs[:,4:24] - log_probs[range(len(batch_tokens)), batch_tokens].unsqueeze(1)
    llr = llr.reshape(-1)

    mask = ~torch.isnan(label_pro)

    return llr[mask], label_pro[mask]

In [7]:
for p in motifae.parameters():
    p.requires_grad = False
for p in esm_model.parameters():
    p.requires_grad = False

gate = BinaryGate(40960).to(device)

optimizer = optim.Adam([gate.logits], lr=lr)
optimizer.zero_grad()

In [8]:
running_spearman = 0.0
eval_spearman = 0.0
test_spearman = 0.0

for epoch in range(n_epoch):
    gate.train()
    for i in pro_train.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='train')

        with torch.no_grad():
            sae_hidden = motifae.encode(embed_pro)
        gates, gate_probs = gate()                 
        gated_sae_hidden = sae_hidden * gates
        gated_recon = motifae.decode(gated_sae_hidden)

        llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)

        spearman_approx = pearson_corr(soft_rank(llr), soft_rank(label_pro))

        sparsity_loss = lambda_sparse * gate_probs.mean()
        loss = (-spearman_approx + sparsity_loss) / len(pro_train)

        loss.backward()
        running_spearman += spearman_approx

    optimizer.step()
    optimizer.zero_grad()

    gate.eval()
    gates, gate_probs = gate()
    for i in pro_train.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='eval')

        with torch.no_grad():
            sae_hidden = motifae.encode(embed_pro)
            gated_sae_hidden = sae_hidden * gates
            gated_recon = motifae.decode(gated_sae_hidden)

            llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)

        eval_spearman += spearmanr(llr.cpu().numpy(), label_pro.cpu().numpy()).correlation

    for i in pro_test.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='test')

        with torch.no_grad():
            sae_hidden = motifae.encode(embed_pro)
            gated_sae_hidden = sae_hidden * gates
            gated_recon = motifae.decode(gated_sae_hidden)

            llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)
        test_spearman += spearmanr(llr.cpu().numpy(), label_pro.cpu().numpy()).correlation

    print(f"epoch {epoch}; open gates={gates.sum().item()}; sparsity_loss={sparsity_loss.item():.4f}; running_spearman={running_spearman / len(pro_train):.4f}; eval_spearman={eval_spearman / len(pro_train):.4f}; test_spearman={test_spearman / len(pro_test):.4f}")

    running_spearman = 0.0
    eval_spearman = 0.0
    test_spearman = 0.0

epoch 0; open gates=1983.0; sparsity_loss=0.5000; running_spearman=-0.2093; eval_spearman=0.5543; test_spearman=0.5527
epoch 1; open gates=1981.0; sparsity_loss=0.4998; running_spearman=0.5568; eval_spearman=0.5543; test_spearman=0.5526
epoch 2; open gates=1916.0; sparsity_loss=0.4995; running_spearman=0.5568; eval_spearman=0.5797; test_spearman=0.5687
epoch 3; open gates=1856.0; sparsity_loss=0.4993; running_spearman=0.5816; eval_spearman=0.5847; test_spearman=0.5729


In [None]:
torch.save(gates.detach().cpu(), '../data/selected_gates.pt')