In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import sys
sys.path.append('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Lorsa/src')

import torch
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
from transformer_lens.components import Attention
from datasets import load_from_disk
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
torch.set_grad_enabled(False)

import copy

from tqdm import tqdm

import numpy as np
import einops

from models.lorsa import LowRankSparseAttention
from config import LorsaTrainConfig, LorsaConfig

In [3]:
lorsa = LowRankSparseAttention.from_pretrained(
    '/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Lorsa/result/pythia-160m-lorsa-L5A',
    device='cuda'
)

In [4]:
hf_model = GPTNeoXForCausalLM.from_pretrained('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/models/pythia-160m')
tokenizer = GPTNeoXTokenizerFast.from_pretrained('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/models/pythia-160m')
model = HookedTransformer.from_pretrained(
    'EleutherAI/pythia-160m', 
    hf_model=hf_model, 
    tokenizer=tokenizer, 
    device='cuda',
    hf_config=hf_model.config
)
model.eval()
for param in model.parameters():
    param.requires_grad = False

# get original attention block
orig_attn = model.blocks[lorsa.cfg.layer].attn

# load dataset
dataset_path = '/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/data/SlimPajama-3B'
dataset = load_from_disk(dataset_path)

Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


Loading dataset from disk:   0%|          | 0/57 [00:00<?, ?it/s]

In [5]:
total_analyzing_tokens=100000000
from analysis.top_activating_dfa import sample_max_activating_sequences

ignore_tokens = {
    model.tokenizer.bos_token_id,
    model.tokenizer.eos_token_id,
    model.tokenizer.pad_token_id,
}

lorsa.fold_W_O_into_W_V()

sample_results = sample_max_activating_sequences(
    lorsa=lorsa, 
    dataset=dataset.select(range(total_analyzing_tokens // lorsa.cfg.n_ctx)), 
    model=model,
    ignore_tokens=ignore_tokens,
    batch_size=32,
    get_topn_activating_samples=32,
)

print(sample_results)




  0%|          | 0/12207 [00:00<?, ?it/s]

  0%|          | 0/21135 [00:00<?, ?it/s]

{'elt': tensor([[ 9.2140,  9.0521,  8.5536,  ...,  7.2810,  7.2633,  7.2200],
        [10.3189, 10.2590, 10.0027,  ...,  8.8949,  8.8847,  8.8749],
        [17.6253, 17.4367, 17.2300,  ..., 15.5040, 15.4961, 15.4922],
        ...,
        [29.0293, 25.0514, 23.5866,  ..., 17.1216, 17.1016, 17.0721],
        [14.7810, 14.7504, 14.7476,  ..., 11.9625, 11.8930, 11.8667],
        [27.7498, 27.7353, 27.7239,  ..., 25.5348, 25.5112, 25.4889]],
       device='cuda:0'), 'context_idx': tensor([[246132, 187948, 149742,  ..., 357431, 338937, 255002],
        [176031,  44468, 239690,  ..., 181643, 153645,  33107],
        [ 17379, 317247,  31091,  ...,  11164,  28462, 133725],
        ...,
        [208799, 352219,  55843,  ..., 189515, 153426, 248468],
        [243482, 339795, 110198,  ..., 292465,  86638, 199283],
        [118015, 128137,  56246,  ...,  43318, 374343, 210155]],
       device='cuda:0', dtype=torch.int32), 'dfa_of_max_activating_samples': tensor([[[ 6.3942e-03, -1.6435e-03, -1.6765

In [5]:
sample_results = torch.load('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Lorsa/result/pythia-160m-lorsa-L5A/sample_results.pt', weights_only=True, map_location='cuda')
from lm_saes import SparseAutoEncoder

lxa_sae = SparseAutoEncoder.from_pretrained(
    '/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Language-Model-SAEs/Pythia160M-LXA-8x-topk/Pythia160M-L5A-8x-lr0.0001'
)

lxain_sae = SparseAutoEncoder.from_pretrained(
    '/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Language-Model-SAEs/Pythia160M-LXAin-8x-topk/Pythia160M-L5Ain-8x-lr0.0001'
)

In [6]:
sample_results

{'elt': tensor([[ 9.2140,  9.0521,  8.5536,  ...,  7.2810,  7.2633,  7.2200],
         [10.3189, 10.2590, 10.0027,  ...,  8.8949,  8.8847,  8.8749],
         [17.6253, 17.4367, 17.2300,  ..., 15.5040, 15.4961, 15.4922],
         ...,
         [29.0293, 25.0514, 23.5866,  ..., 17.1216, 17.1016, 17.0721],
         [14.7810, 14.7504, 14.7476,  ..., 11.9625, 11.8930, 11.8667],
         [27.7498, 27.7353, 27.7239,  ..., 25.5348, 25.5112, 25.4889]],
        device='cuda:0'),
 'context_idx': tensor([[246132, 187948, 149742,  ..., 357431, 338937, 255002],
         [176031,  44468, 239690,  ..., 181643, 153645,  33107],
         [ 17379, 317247,  31091,  ...,  11164,  28462, 133725],
         ...,
         [208799, 352219,  55843,  ..., 189515, 153426, 248468],
         [243482, 339795, 110198,  ..., 292465,  86638, 199283],
         [118015, 128137,  56246,  ...,  43318, 374343, 210155]],
        device='cuda:0', dtype=torch.int32),
 'dfa_of_max_activating_samples': tensor([[[ 6.3942e-03, -1.6

In [36]:
from analysis.correlation_w_sae import correlation_analyses_between_saes_and_lorsas
from datasets import load_from_disk

dataset = load_from_disk('/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/data/SlimPajama-3B')
res = correlation_analyses_between_saes_and_lorsas(
    lorsa=lorsa,
    lxa_sae=lxa_sae,
    lxain_sae=lxain_sae,
    model=model,
    dataset=dataset,
    lorsa_sample_results=sample_results,
    top_n_sae_features=8
)

res

Loading dataset from disk:   0%|          | 0/57 [00:00<?, ?it/s]

sample_based_correlation:   0%|          | 0/768 [00:00<?, ?it/s]

qk_weight_based_correlation:   0%|          | 0/48 [00:00<?, ?it/s]

{'wo_lxa_most_correlated_sae_features': tensor([[2936, 2934, 5563,  ...,  310, 1686, 3251],
         [1903, 2422, 6140,  ...,  600, 5826, 2833],
         [ 554, 5585, 5418,  ...,  420, 3399,   33],
         ...,
         [5887, 4890, 3913,  ..., 1006, 2088, 2146],
         [ 771, 4598,   96,  ..., 1309,  173, 2481],
         [2787, 1944, 2814,  ...,  897, 1560, 6081]], device='cuda:0'),
 'wo_lxa_most_correlated_sae_feature_cos_sims': tensor([[0.4155, 0.2623, 0.2465,  ..., 0.2123, 0.2017, 0.1957],
         [0.7328, 0.4375, 0.3428,  ..., 0.3056, 0.2832, 0.2700],
         [0.4864, 0.4671, 0.3913,  ..., 0.3587, 0.3445, 0.3314],
         ...,
         [0.8986, 0.2501, 0.2269,  ..., 0.1983, 0.1947, 0.1923],
         [0.4987, 0.4078, 0.3289,  ..., 0.2287, 0.2194, 0.2076],
         [0.9413, 0.3482, 0.2799,  ..., 0.2118, 0.2007, 0.1936]],
        device='cuda:0'),
 'wo_lxa_most_correlated_feature_dfas': tensor([[ 0.4399,  0.1940,  0.1076,  ...,  0.1086,  0.1895,  0.0260],
         [ 0.4477,  0.

In [38]:
torch.save(res, '/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/zfhe/zf_projects/Lorsa/result/pythia-160m-lorsa-L5A/sample_results.pt')