In [1]:
import os

os.environ['http_proxy'] = 'http://10.176.58.101:7890'
os.environ['https_proxy'] = 'http://10.176.58.101:7890'

import sys
sys.path.append('/remote-home1/jxwang/project/lorsa/src')

import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.components import Attention
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
torch.set_grad_enabled(False)

import copy

from tqdm import tqdm

import pysvelte
import numpy as np

from model.attention import LowRankSparseAttention
from config import LoRSAConfig, LoRSATrainConfig

In [2]:
# sah config
cfg = LoRSAConfig(
    # self attention head config
    d_qk_head = 16,
    d_ov_head = 1,
    n_qk_heads = 96,
    n_ov_heads = 1536,
    rotary_scale = 1,
    rotary_dim = 16,
    top_k = 64,
    device = "cuda",
    virtual_kv_num=0,
    
    # orig attention head config
    model_name = 'EleutherAI/pythia-160m',
    layer = 5,
    max_length = 256,
    prepend_bos = True,
)

dataset_path = '/remote-home1/share/research/mechinterp/gpt2-dictionary/data/openwebtext'# load dataset
dataset = load_from_disk(dataset_path)

Loading dataset from disk:   0%|          | 0/80 [00:00<?, ?it/s]

In [18]:
hf_model = GPTNeoXForCausalLM.from_pretrained(cfg.model_name)
tokenizer = GPTNeoXTokenizerFast.from_pretrained(cfg.model_name)
model = HookedTransformer.from_pretrained(cfg.model_name, hf_model=hf_model, tokenizer=tokenizer, device=cfg.device)
model.eval()
for param in model.parameters():
    param.requires_grad = False

# get original attention block
orig_attn = model.blocks[cfg.layer].attn

Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


In [19]:
batch_size = 4
data = DataLoader(dataset['text'], batch_size=batch_size)
data_iter = iter(data)

In [20]:
from lm_saes import LanguageModelSAERunnerConfig, SAEConfig
from lm_saes.sae import SparseAutoEncoder

hook_point_in = 'A'
exp_factor = 32
layer = cfg.layer

ckpt_path = f"/remote-home1/share/research/mechinterp/pythia-160m-LX{hook_point_in}-{exp_factor}x-topk"
ckpt_path = os.path.join(ckpt_path, f"Pythia-160m-L{layer}{hook_point_in}-{exp_factor}x-lr0.001")
sae_config = SAEConfig.from_pretrained(ckpt_path)
sae_config.sae_pretrained_name_or_path = ckpt_path

sae = SparseAutoEncoder.from_config(cfg=sae_config)

In [21]:
cfg.update_from_model_config(model.cfg)

# initialize sah
sah = LowRankSparseAttention(cfg).to(cfg.device)

sah_path = '/remote-home1/jxwang/project/lorsa/result/pythia-160m-sah/L5A-d16&1-n96&1536-k64.pth'

state_dict = torch.load(sah_path, map_location=cfg.device)

sah.load_state_dict(state_dict)

  state_dict = torch.load(sah_path, map_location=cfg.device)


<All keys matched successfully>

In [12]:
hook_in_name = f'blocks.{cfg.layer}.ln1.hook_normalized'
hook_out_name = f'blocks.{cfg.layer}.hook_attn_out'

names_filter = [hook_in_name, hook_out_name]

# Run the forward pass
batch = next(data_iter)
tokens = model.to_tokens(batch, prepend_bos=cfg.prepend_bos).to(cfg.device)
tokens = tokens[:, :cfg.max_length]

filter_mask = torch.logical_and(tokens.ne(model.tokenizer.eos_token_id), tokens.ne(model.tokenizer.pad_token_id))
filter_mask = torch.logical_and(filter_mask, tokens.ne(model.tokenizer.bos_token_id))

_, cache = model.run_with_cache(tokens, names_filter=names_filter)
hook_in = cache[hook_in_name]
hook_out = cache[hook_out_name]

In [13]:
_, top_k_indices = sah.forward_top_k(hook_in)
feature = sae.encode(hook_out.to(torch.bfloat16))

top_k_indices = top_k_indices[filter_mask]
feature = feature[filter_mask]

In [1]:
ov_vector = sah.W_O.squeeze(1)
index = 0
feature_vector = ov_vector[index].unsqueeze(0)

sae_vector = sae.decoder.weight.permute(1, 0)

cos = nn.CosineSimilarity(dim=1)
res = cos(feature_vector, sae_vector)
print(res.shape)

top_cos, top_cos_indices  = torch.topk(res, 50)
print(top_cos)
print(top_cos_indices)
# print(top_cos_indices % cfg.n_qk_heads)

NameError: name 'sah' is not defined

In [30]:
ov_vector = ov_vector.to(torch.bfloat16)

ov_vector_norm = ov_vector.norm(dim=-1, keepdim=True)
sae_vector_norm = sae_vector.norm(dim=-1, keepdim=True)

res = torch.mm(ov_vector, sae_vector.T) / (ov_vector_norm * sae_vector_norm.T)

In [36]:
print(res.shape)
max_cos = torch.max(res, dim=1).values
sorted_indices = torch.argsort(max_cos, descending=True)
print(max_cos[sorted_indices[1536 // 2]])

torch.Size([1536, 24576])
tensor(0.4043, device='cuda:0', dtype=torch.bfloat16)


In [74]:
ov_vector = sah.W_O.squeeze(1)
print(ov_vector.shape)
index = 842
feature_vector = ov_vector[index:index + 1]
print(feature_vector.shape)

cos = nn.CosineSimilarity(dim=1)
# res = cos(ov_vector, feature_vector)
res = cos(ov_vector, feature_vector)
print(res.shape)

top_cos, top_cos_indices  = torch.topk(res, 50)
print(top_cos)
print(top_cos_indices)
print(top_cos_indices % cfg.n_qk_heads)

torch.Size([1536, 768])
torch.Size([1, 768])
torch.Size([1536])
tensor([1.0000, 0.7865, 0.7597, 0.7306, 0.7189, 0.7085, 0.6958, 0.6836, 0.6653,
        0.6457, 0.6089, 0.6024, 0.5841, 0.5768, 0.5688, 0.5308, 0.5238, 0.5158,
        0.4976, 0.4940, 0.4918, 0.4895, 0.4830, 0.4703, 0.4667, 0.4646, 0.4535,
        0.4494, 0.4470, 0.4445, 0.4392, 0.4254, 0.4180, 0.4109, 0.4092, 0.4028,
        0.3957, 0.3955, 0.3887, 0.3764, 0.3746, 0.3688, 0.3681, 0.3618, 0.3583,
        0.3481, 0.3452, 0.3450, 0.3417, 0.3401], device='cuda:0')
tensor([ 842,  510,    1, 1125,  143, 1494,  172,  969,  337, 1252,  936,  825,
        1270,  903, 1129,  942,  343,  144,  382, 1290,  900,  822,  121, 1478,
        1001,  811,  586,  701,  385,  876,  637,  850,  461,  871,  313, 1479,
         432,   93,  800,  722,  142,  580,   15,  570,  138,  203,  310,  369,
        1250,  755], device='cuda:0')
tensor([74, 30,  1, 69, 47, 54, 76,  9, 49,  4, 72, 57, 22, 39, 73, 78, 55, 48,
        94, 42, 36, 54, 25, 38, 