## Setup

In [1]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import einops
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f'{device}')

cuda:0


In [4]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = 'gemma-2b-res-jb',
    sae_id = 'blocks.6.hook_resid_post',
    device = device
)

In [5]:
model = HookedTransformer.from_pretrained('gemma-2b', device=device)
model.to(device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer
Moving model to device:  cuda:0


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-17): 18 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

In [6]:
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import to_numpy

In [7]:
activation_store = []

In [8]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,
    tokenizer = model.tokenizer,
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [9]:
token_dataset['tokens'].shape

torch.Size([15196, 1024])

In [10]:
torch.cuda.empty_cache()

In [11]:
import tqdm

In [12]:
batch_size = 16
sae.eval()
activation_buffer = torch.zeros((len(token_dataset['tokens']), sae.cfg.context_size, sae.cfg.d_in), device='cpu')

with torch.no_grad():
    for i in tqdm.trange(0, len(token_dataset['tokens']), batch_size):
        batch_tokens = token_dataset['tokens'][:batch_size].to(device)
        _, cache = model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=['blocks.6.hook_resid_post'])

        # feature_acts = sae.encode(cache['blocks.6.hook_resid_post'])
        # sae_out = sae.decode(feature_acts)

        activation_buffer[i:i+batch_size] = cache['blocks.6.hook_resid_post'].cpu()

        del cache
        torch.cuda.empty_cache()

        # print(f'Feature acts shape: {feature_acts.shape}')
        # print(f'SAE out shape: {sae_out.shape}')

100%|█████████▉| 949/950 [1:24:00<00:05,  5.31s/it]


RuntimeError: The expanded size of the tensor (12) must match the existing size (16) at non-singleton dimension 0.  Target sizes: [12, 1024, 2048].  Tensor sizes: [16, 1024, 2048]

In [None]:
path: str = './activation_buffer_pile_gemma2b.pt'
torch.save(activation_buffer, path)

In [13]:
activation_buffer.shape

torch.Size([15196, 1024, 2048])