In [14]:
import esm
import sys, os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from scipy.stats import spearmanr

cmap = plt.get_cmap("tab20c")
sns.set(font_scale=1)
pd.set_option('display.max_columns', 50)
sns.set_style("white")

file_path = "/home/ch3849/SAE_mut/code/model_relu"
sys.path.append(file_path)
from dictionary import AutoEncoder

### load models and data

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm_model.eval()  # disables dropout for deterministic results
esm_model = esm_model.to(device)
batch_converter = alphabet.get_batch_converter()

model = 250417
chk = 80000
chk_path = f'/share/vault/Users/ch3849/esm_sae/model/{model}/checkpoints/step_{chk}.pt'
sae = AutoEncoder.from_pretrained(chk_path)
sae.eval()  # disables dropout for deterministic results
sae = sae.to(device)

  state_dict = t.load(path)


In [23]:
os.chdir('/nfs/user/Users/ch3849/esm_sae/gate_mut')
embed = pickle.load(open('embedding_412wt.pkl', 'rb'))
label = pickle.load(open('training_label.pkl', 'rb'))

embed = {key:value.to(device) for key, value in embed.items()}
label = {key:value.to(device) for key, value in label.items()}

df = pd.read_csv('412wt_info.csv')
df_train = df[df['split'] == 'train'].reset_index(drop=True)
df_test = df[df['split'] == 'test'].reset_index(drop=True)

### gate model

In [80]:
lr = 1e-3
n_epoch = 54
lambda_sparse = 1

In [81]:
class BinaryGate(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(embed_dim))

    def forward(self):
        probs = torch.sigmoid(self.logits)
        # STE: forward uses hard 0/1, backward uses gradient through probs
        hard = (probs > 0.5).float()
        gates = hard + probs - probs.detach()
        return gates, probs  # return both for logging if desired

In [82]:
# ---------------------
# Differentiable "soft rank" -> used to compute a differentiable Spearman proxy
# Implementation: soft_rank_i = 1 + sum_j sigmoid((x_j - x_i)/tau)
# tau controls smoothness (smaller => closer to true rank)
# ---------------------
def soft_rank(x: torch.Tensor, tau: float = 1e-2):
    # x: (N,) or (N, )
    x_i = x.unsqueeze(0)   # (1, N)
    x_j = x.unsqueeze(1)   # (N, 1)
    # pairwise differences x_j - x_i -> shape (N, N)
    pairwise = (x_j - x_i) / tau
    # sigmoid sums: for each i, count how many x_j > x_i
    s = torch.sigmoid(pairwise)
    ranks = 1.0 + s.sum(dim=1)  # (N,)
    return ranks

def pearson_corr(x: torch.Tensor, y: torch.Tensor, eps=1e-8):
    x_mean = x.mean()
    y_mean = y.mean()
    xm = x - x_mean
    ym = y - y_mean
    r_num = (xm * ym).sum()
    r_den = torch.sqrt((xm * xm).sum() * (ym * ym).sum() + eps)
    return r_num / (r_den + eps)

def get_input_label(i, split='train'):
    if split == 'test':
        pro, seq = df_test.loc[i, ['WT_name', 'aa_seq']]
    else:
        pro, seq = df_train.loc[i, ['WT_name', 'aa_seq']]
    batch_labels, batch_strs, batch_tokens = batch_converter([(pro, seq)])
    batch_tokens = batch_tokens[0,1:-1].to(device)

    embed_pro = embed[pro][0,1:-1].to(device)
    label_pro = label[f'{pro}_{split}'].to(device)

    return pro, batch_tokens, embed_pro, label_pro

def get_llr(gated_recon, batch_tokens, label_pro):
    logits = esm_model.lm_head(gated_recon)
    log_probs = torch.log_softmax(logits, dim=1)
    
    llr = log_probs[:,4:24] - log_probs[range(len(batch_tokens)), batch_tokens].unsqueeze(1)
    llr = llr.reshape(-1)

    mask = ~torch.isnan(label_pro)

    return llr[mask], label_pro[mask]

In [83]:
for p in sae.parameters():
    p.requires_grad = False
for p in esm_model.parameters():
    p.requires_grad = False

gate = BinaryGate(40960).to(device)

optimizer = optim.Adam([gate.logits], lr=lr)
optimizer.zero_grad()

In [84]:
running_spearman = 0.0
eval_spearman = 0.0
test_spearman = 0.0

for epoch in range(n_epoch):
    gate.train()
    for i in df_train.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='train')

        with torch.no_grad():
            sae_hidden = sae.encode(embed_pro)
        gates, gate_probs = gate()                 
        gated_sae_hidden = sae_hidden * gates
        gated_recon = sae.decode(gated_sae_hidden)

        llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)

        spearman_approx = pearson_corr(soft_rank(llr), soft_rank(label_pro))

        sparsity_loss = lambda_sparse * gate_probs.mean()
        loss = (-spearman_approx + sparsity_loss) / len(df_train)

        loss.backward()
        running_spearman += spearman_approx

    optimizer.step()
    optimizer.zero_grad()

    gate.eval()
    gates, gate_probs = gate()
    for i in df_train.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='eval')

        with torch.no_grad():
            sae_hidden = sae.encode(embed_pro)
            gated_sae_hidden = sae_hidden * gates
            gated_recon = sae.decode(gated_sae_hidden)

            llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)

        eval_spearman += spearmanr(llr.cpu().numpy(), label_pro.cpu().numpy()).correlation

    for i in df_test.index:
        pro, batch_tokens, embed_pro, label_pro = get_input_label(i, split='test')

        with torch.no_grad():
            sae_hidden = sae.encode(embed_pro)
            gated_sae_hidden = sae_hidden * gates
            gated_recon = sae.decode(gated_sae_hidden)

            llr, label_pro = get_llr(gated_recon, batch_tokens, label_pro)
        test_spearman += spearmanr(llr.cpu().numpy(), label_pro.cpu().numpy()).correlation

    print(f"epoch {epoch}; open gates={gates.sum().item()}; sparsity_loss={sparsity_loss.item():.4f}; running_spearman={running_spearman / len(df_train):.4f}; eval_spearman={eval_spearman / len(df_train):.4f}; test_spearman={test_spearman / len(df_test):.4f}")
    if test_spearman / len(df_test) > 0.57 and gates.sum().item() < 1410:
        break

    running_spearman = 0.0
    eval_spearman = 0.0
    test_spearman = 0.0

epoch 0; open gates=2004.0; sparsity_loss=0.5000; running_spearman=-0.2092; eval_spearman=0.5522; test_spearman=0.5528
epoch 1; open gates=2001.0; sparsity_loss=0.4998; running_spearman=0.5616; eval_spearman=0.5522; test_spearman=0.5528
epoch 2; open gates=1927.0; sparsity_loss=0.4995; running_spearman=0.5616; eval_spearman=0.5754; test_spearman=0.5683
epoch 3; open gates=1856.0; sparsity_loss=0.4993; running_spearman=0.5841; eval_spearman=0.5793; test_spearman=0.5702
epoch 4; open gates=1809.0; sparsity_loss=0.4991; running_spearman=0.5879; eval_spearman=0.5822; test_spearman=0.5745
epoch 5; open gates=1774.0; sparsity_loss=0.4989; running_spearman=0.5905; eval_spearman=0.5855; test_spearman=0.5790
epoch 6; open gates=1746.0; sparsity_loss=0.4986; running_spearman=0.5937; eval_spearman=0.5862; test_spearman=0.5734
epoch 7; open gates=1712.0; sparsity_loss=0.4984; running_spearman=0.5946; eval_spearman=0.5874; test_spearman=0.5773
epoch 8; open gates=1679.0; sparsity_loss=0.4982; runni

In [94]:
torch.save(gates.detach().cpu(), f'1404gates.pt')