In [1]:
from pathlib import Path
import sys, os
import torch
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import json
from tqdm import tqdm
import esm
from scipy.special import softmax


from dictionary import AutoEncoder
from config import my_config
from dataset import ProteinDataset, collate_batch

In [2]:
# Load ESM-2 model
# esm_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

norm_b = esm_model.emb_layer_norm_after.bias.clone().detach().requires_grad_(False).to(my_config['device'])
norm_w = esm_model.emb_layer_norm_after.weight.clone().detach().requires_grad_(False).to(my_config['device'])

In [24]:
def get_sae(sae_model, chk):
    chk_path = f'/share/vault/Users/ch3849/esm_sae/model/{sae_model}/checkpoints/step_{chk}.pt'
    sae = AutoEncoder.from_pretrained(chk_path)
    sae.eval()  # disables dropout for deterministic results
    return sae.to(my_config['device'])

In [25]:
df = pd.read_csv('/share/vault/Users/ch3849/esm_sae/sequence/eval_test_seq_max1022_addmask_perplexity.csv', nrows=100)
# df = df[df['split'] == 'eval'].reset_index(drop=True)
stage = 'representative'

dataset = ProteinDataset(df=df, df_name_col=my_config[f'df_name_col_{stage}'], embed_logit_path=my_config[f'embed_logit_path_{stage}'], stage=stage)
loader = DataLoader(dataset, collate_fn=collate_batch, batch_size=60, drop_last=False, num_workers=10, shuffle = False)

In [26]:
for batch_data in loader:
    act = batch_data['repr'].to(my_config["device"])
    break

In [39]:
ratio = {250416: [], 250417: []}
for sae_model in [250416, 250417]:
    sae = get_sae(sae_model, 80000)

    with torch.no_grad():
        f = sae.encode(act).cpu().numpy()
        
    l1 = abs(f).sum(axis=1)
    for k in range(1,11):
        f_diff = abs(f[k:] - f[:-k]).sum(axis=1)
        ratio[sae_model].append((f_diff / l1[k:]).mean())

ValueError: operands could not be broadcast together with shapes (18952,40960) (0,40960) 

In [38]:
pd.DataFrame(ratio).to_csv('/share/vault/Users/ch3849/esm_sae/fig/sae_basic/SAE_feature_diff_ratio.csv', index=False)

In [None]:
import torch

# Function to shuffle, move non-zero values to the top, and sort
def process_tensor(tensor):
    # Shuffle the tensor
    flattened = tensor.flatten()
    shuffled_indices = torch.randperm(flattened.numel())
    shuffled_flattened = flattened[shuffled_indices]
    shuffled_tensor = shuffled_flattened.view(tensor.size())
    
    # Create tensors for storing non-sorted and sorted results
    rows, cols = shuffled_tensor.size()
    non_sorted_tensor = torch.zeros_like(shuffled_tensor)
    sorted_tensor = torch.zeros_like(shuffled_tensor)
    
    for col in range(cols):
        # Extract non-zero values from the shuffled tensor
        nonzeros = shuffled_tensor[:, col][shuffled_tensor[:, col] != 0]
        
        # Non-sorted version: Place non-zero values at the top
        non_sorted_tensor[:len(nonzeros), col] = nonzeros
        
        # Sorted version: Sort non-zero values and place at the top
        sorted_tensor[:len(nonzeros), col] = torch.sort(nonzeros)[0]
    
    return shuffled_tensor, non_sorted_tensor, sorted_tensor

In [20]:
nearby_loss = (torch.cat([f[1:-2].unsqueeze(0), f[2:-1].unsqueeze(0), f[3:].unsqueeze(0)], dim=0) - f[0:-3].unsqueeze(0)).norm(p=1, dim=-1)
smooth_loss = (softmin(nearby_loss)*nearby_loss).sum(dim=0).mean()

In [21]:
smooth_loss

tensor(39.2049, device='cuda:1')

In [None]:
        nearby_loss = (torch.cat([f[1:-2].unsqueeze(0), f[2:-1].unsqueeze(0), f[3:].unsqueeze(0)], dim=0) - f[0:-3].unsqueeze(0)).norm(p=1, dim=-1).mean(dim=-1)
        smooth_loss = self.softmin(nearby_loss)@nearby_loss

In [7]:
nearby_loss = torch.tensor([(f_splice - f[:-3]).norm(p=1, dim=-1).mean(dim=-1) for f_splice in [f[1:-2], f[2:-1], f[3:]]])


In [11]:
softmin = torch.nn.Softmin(dim=0)
# smooth_loss = softmin(nearby_loss)@nearby_loss

In [11]:
smooth_loss

tensor(41.1605)

: 

In [8]:
nearby_loss

tensor([41.1040, 45.6899, 47.5280])

In [9]:
(torch.cat([f[:-3].unsqueeze(0), f[1:-2].unsqueeze(0), f[2:-1].unsqueeze(0)], dim=0) - f[3:].unsqueeze(0)).norm(p=1, dim=-1).mean(dim=-1)

tensor([49.5101, 47.5359, 42.4042], device='cuda:1')

In [38]:
f[3:-3].unsqueeze(0).shape

torch.Size([1, 4438, 40960])

In [63]:
neaby_f = torch.cat([f[:-6].unsqueeze(0), f[1:-5].unsqueeze(0), f[2:-4].unsqueeze(0), f[4:-2].unsqueeze(0), f[5:-1].unsqueeze(0), f[6:].unsqueeze(0)], dim=0)
nearby_loss = (neaby_f - f[3:-3].unsqueeze(0)).norm(p=1, dim=-1).mean(dim=-1)
softmin = torch.nn.Softmin(dim=-1)
softmin(nearby_loss)@nearby_loss

tensor(42.4375, device='cuda:1')

In [6]:
smooth_loss_ratio = []
smooth_loss_shuffle_ratio = []
smooth_loss_nonsorted_ratio = []
smooth_loss_sorted_ratio = []


for batch_data in loader:
    with torch.no_grad():
        act = batch_data['repr']
        act = (act.to(my_config["device"]) - norm_b) / norm_w
        f = sae.encode(act)
        l1_loss = f.norm(p=1, dim=-1)

        smooth_loss_1 = ((f[2:]/2 + f[:-2]/2 - f[1:-1])**2).sum(dim=-1) # MSE
        smooth_loss_2 = (f[2:]/2 + f[:-2]/2 - f[1:-1]).norm(p=2, dim=-1) # L2, RMSE
        smooth_loss_3 = (f[2:]/2 + f[:-2]/2 - f[1:-1]).norm(p=1, dim=-1) # L1

        shuffled_f, non_sorted_f, sorted_f = process_tensor(f)


        smooth_loss_shuffle = (shuffled_f[2:]/2 + shuffled_f[:-2]/2 - shuffled_f[1:-1]).norm(p=1, dim=-1) # L1
        smooth_loss_nonsorted = (non_sorted_f[2:]/2 + non_sorted_f[:-2]/2 - non_sorted_f[1:-1]).norm(p=1, dim=-1) # L1
        smooth_loss_sorted = (sorted_f[2:]/2 + sorted_f[:-2]/2 - sorted_f[1:-1]).norm(p=1, dim=-1) # L1

        smooth_loss_ratio.append((smooth_loss_3 / l1_loss[1:-1]).mean().item())
        smooth_loss_shuffle_ratio.append((smooth_loss_shuffle / l1_loss[1:-1]).mean().item())
        smooth_loss_nonsorted_ratio.append((smooth_loss_nonsorted / l1_loss[1:-1]).mean().item())
        smooth_loss_sorted_ratio.append((smooth_loss_sorted / l1_loss[1:-1]).mean().item())

    for i in [smooth_loss_3, smooth_loss_shuffle, smooth_loss_nonsorted, smooth_loss_sorted]:
        print(round((i / l1_loss[1:-1]).mean().item(),3), end=', ')

1.601, 2.07, 1.056, 0.688, 1.597, 2.06, 0.764, 0.495, 1.623, 2.083, 0.936, 0.551, 1.606, 2.077, 0.807, 0.402, 1.61, 2.081, 0.96, 0.583, 1.641, 2.073, 1.02, 0.758, 1.638, 2.072, 1.085, 0.69, 1.6, 2.103, 0.916, 0.615, 1.627, 2.086, 0.993, 0.647, 1.594, 2.062, 0.884, 0.547, 1.584, 2.091, 1.137, 0.72, 1.539, 2.08, 0.973, 0.58, 1.569, 2.086, 1.098, 0.669, 1.579, 2.071, 1.125, 0.77, 1.622, 2.066, 0.879, 0.588, 1.595, 2.08, 0.867, 0.495, 1.617, 2.084, 1.026, 0.714, 1.601, 2.075, 0.775, 0.46, 1.58, 2.07, 0.932, 0.61, 1.613, 2.082, 0.728, 0.503, 1.603, 2.085, 1.056, 0.723, 1.608, 2.078, 0.891, 0.613, 1.594, 2.089, 1.072, 0.698, 1.641, 2.09, 0.851, 0.575, 1.557, 2.101, 0.895, 0.548, 1.603, 2.079, 0.983, 0.632, 1.597, 2.076, 0.932, 0.587, 1.589, 2.07, 0.966, 0.659, 1.57, 2.086, 0.871, 0.465, 1.627, 2.072, 1.155, 0.738, 1.6, 2.074, 0.746, 0.493, 1.613, 2.077, 0.896, 0.517, 1.595, 2.097, 0.939, 0.727, 1.616, 2.09, 0.931, 0.478, 1.609, 2.067, 0.938, 0.526, 1.588, 2.084, 1.124, 0.657, 1.637, 2.065, 0

KeyboardInterrupt: 