In [27]:
from pathlib import Path
import sys, os
import torch
from torch.utils.data import DataLoader
import torch.optim as optim

import pandas as pd
import json
from tqdm import tqdm
import esm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from dictionary import AutoEncoder
from trainer import StandardTrainer
from training import train_run
from dataset import ProteinDataset, collate_batch
from config import my_config

In [None]:
# Load ESM-2 model
if my_config["plm_name"] == 'esm2_t33_650M_UR50D':
    norm = torch.load('/share/vault/Users/ch3849/esm_sae/model/normalize_vector/ESM2_2k.pth')
elif my_config["plm_name"] == 'esm1b_t33_650M_UR50S':
    norm = torch.load('/share/vault/Users/ch3849/esm_sae/model/normalize_vector/ESM1b_2k.pth')

norm_mean = norm['mean'].requires_grad_(False).to(my_config['device'])
norm_std = norm['std'].requires_grad_(False).to(my_config['device'])

stage = my_config['stage']
df = pd.read_csv(my_config[f'df_path_{stage}'])
df = df[df['split'] == 'train'].reset_index(drop=True)

dataset = ProteinDataset(df=df, df_name_col=my_config[f'df_name_col_{stage}'], embed_logit_path=my_config[f'embed_logit_path_{stage}'], stage=stage)
loader = DataLoader(dataset, collate_fn=collate_batch, batch_size=my_config['batch_size'], drop_last=True, num_workers=my_config['dataloader_num_workers'], shuffle = True)

  norm = torch.load('/share/vault/Users/ch3849/esm_sae/model/normalize_vector/ESM1b_2k.pth')


In [70]:
sae_model = 250219
chk = 110000
chk_path = f'/share/vault/Users/ch3849/esm_sae/model/{sae_model}/checkpoints/step_{chk}.pt'
sae = AutoEncoder.from_pretrained(chk_path)
sae.eval()  # Disables dropout for deterministic results
sae = sae.to(my_config['device'])

# Set scaling factor for each feature
f_scaling = torch.nn.Parameter(torch.ones(my_config['dict_size'], device=my_config['device']))  # scaling factor

# Include only f_scaling and sae.decode parameters in the optimizer
optimizer = optim.Adam(
    [{'params': sae.decoder.weight, 'lr': 1e-3},  # Learning rate for decoder parameters
     {'params': [f_scaling], 'lr': 1e-2}]  # Separate learning rate for f_scaling
)

  state_dict = t.load(path)


In [None]:
# Training loop
for batch_data in loader:
    act = batch_data['repr']
    act = (act.to(my_config["device"]) - norm_mean) / norm_std  # Normalize the input
    
    f = sae.encode(act)  # Encode the input
    act_hat = sae.decode(f * f_scaling)  # Apply scaling and decode
    
    # Compute the L2 loss
    l2_loss = torch.linalg.norm(act - act_hat, dim=-1).mean()
    
    # Zero gradients from the previous step
    optimizer.zero_grad()
    
    # Backpropagate the loss
    l2_loss.backward()
    
    # Update the parameters
    optimizer.step()
    
    # Ensure scaling factor remains positive
    f_scaling.data = torch.clamp(f_scaling.data, min=1)
    
    # Optional: print the loss for monitoring
    print(f"L2 Loss: {l2_loss.item()}")

In [72]:
for i in range(20):
    f = sae.encode(act)  # Encode the input
    act_hat = sae.decode(f * f_scaling)  # Apply scaling and decode
    
    # Compute the L2 loss
    l2_loss = torch.linalg.norm(act - act_hat, dim=-1).mean()
    
    # Zero gradients from the previous step
    optimizer.zero_grad()
    
    # Backpropagate the loss
    l2_loss.backward()
    
    # Update the parameters
    optimizer.step()
    
    # Ensure scaling factor remains positive
    f_scaling.data = torch.clamp(f_scaling.data, min=1)
    
    # Optional: print the loss for monitoring
    print(f"L2 Loss: {l2_loss.item()}")

L2 Loss: 17.9985294342041
L2 Loss: 17.613338470458984
L2 Loss: 17.256746292114258
L2 Loss: 16.92742919921875
L2 Loss: 16.623279571533203
L2 Loss: 16.341808319091797
L2 Loss: 16.08048439025879
L2 Loss: 15.836944580078125
L2 Loss: 15.609139442443848
L2 Loss: 15.395423889160156
L2 Loss: 15.194488525390625
L2 Loss: 15.00529956817627
L2 Loss: 14.827001571655273
L2 Loss: 14.658856391906738
L2 Loss: 14.500155448913574
L2 Loss: 14.350207328796387
L2 Loss: 14.20832633972168
L2 Loss: 14.073843002319336
L2 Loss: 13.94612979888916
L2 Loss: 13.824626922607422


In [74]:
f_scaling[f_scaling != 1]

tensor([1.1766, 1.1498, 1.2081,  ..., 1.0134, 1.1991, 1.1991], device='cuda:3',
       grad_fn=<IndexBackward0>)

In [78]:
act.var(dim=1)

tensor([0.6619, 0.6247, 0.5845,  ..., 0.5619, 0.5372, 0.5271], device='cuda:3')

In [60]:
sae_model = 250219
chk = 110000
chk_path = f'/share/vault/Users/ch3849/esm_sae/model/{sae_model}/checkpoints/step_{chk}.pt'
sae = AutoEncoder.from_pretrained(chk_path)
sae.eval()  # disables dropout for deterministic results
sae = sae.to(my_config['device'])

f_scaling = torch.nn.Parameter(torch.ones(my_config['dict_size'], device=my_config['device']))  # scaling factor for each feature
optimizer = optim.Adam([f_scaling], lr=0.1)

for batch_data in loader:
    act = batch_data['repr']
    act = (act.to(my_config["device"]) - norm_mean) / norm_std  # Normalize the input
    
    f = sae.encode(act)  # Encode the input
    act_hat = sae.decode(f * f_scaling)  # Apply scaling and decode
    
    # Compute the L2 loss
    l2_loss = torch.linalg.norm(act - act_hat, dim=-1).mean()
    
    # Zero gradients from the previous step
    optimizer.zero_grad()
    
    # Backpropagate the loss
    l2_loss.backward()
    
    # Update the f_scaling parameter
    optimizer.step()
    f_scaling.data = torch.clamp(f_scaling.data, min=1)    

    # Optional: print the loss for monitoring
    print(f"L2 Loss: {l2_loss.item()}")

  state_dict = t.load(path)


In [67]:
f_scaling[f_scaling > 1].shape[0]

6578