In [28]:
%load_ext autoreload
%autoreload 2

import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def load_model_and_sae(
    model_name: str, sae_release: str, sae_id: str, device: str
) -> tuple:
    model = HookedTransformer.from_pretrained(model_name, device=device)
    sae, _, _ = SAE.from_pretrained(release=sae_release, sae_id=sae_id, device=device)
    sae.W_dec.norm(dim=-1).mean()
    sae.fold_W_dec_norm()
    return model, sae

In [3]:
model_gemma, sae_gemma = load_model_and_sae(
    "gemma-2-2b",
    "gemma-scope-2b-pt-res-canonical",
    "layer_18/width_16k/canonical",
    "mps",
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
model_gpt2, sae_gpt2 = load_model_and_sae(
    "gpt2-small",
    "gpt2-small-res-jb-feature-splitting",
    "blocks.8.hook_resid_pre_768",
    "mps",
)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
activation_store_gpt2 = ActivationsStore.from_sae(
    model=model_gpt2,
    sae=sae_gpt2,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=512,
    n_batches_in_buffer=16,
    device="mps",
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [6]:
# activation_store_gemma = ActivationsStore.from_sae(
#     model=model_gemma,
#     sae=sae_gemma,
#     streaming=True,
#     store_batch_size_prompts=8,
#     train_batch_size_tokens=512,
#     n_batches_in_buffer=16,
#     device="mps",
# )

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
activations_batch_old = activation_store_gpt2.next_batch()
activations_batch_old.shape

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([512, 1, 768])

In [25]:
sae_gpt2.encode(activations_batch_old).shape

torch.Size([512, 1, 768])

In [9]:
batch_size = activation_store_gpt2.train_batch_size_tokens
batch_size

512

In [14]:
batch_tokens = activation_store_gpt2.get_batch_tokens(batch_size)
# batch_tokens = batch_tokens[:, 1:, ...]
batch_tokens

tensor([[50256,   257,  1256,  ...,   465,    11,   326],
        [50256,   447,   247,  ...,  1406,  5543,   356],
        [50256,  1410,   284,  ...,    72,  4893,   262],
        ...,
        [50256,   262,  8425,  ...,   850,    12, 35448],
        [50256,  4249,  7310,  ..., 16686,   284,   257],
        [50256,  2060, 15815,  ...,   517,  3665,  4899]], device='mps:0')

In [23]:
batch_tokens.shape

torch.Size([512, 128])

In [16]:
activations = activation_store_gpt2.get_activations(batch_tokens)
activations.shape

torch.Size([512, 128, 1, 768])

In [17]:
activations_wout_bos = activations[:, 1:, ...]
activations_wout_bos.shape

torch.Size([512, 127, 1, 768])

In [22]:
# flattened_activations = activations_wout_bos.view(-1, activation_store_gpt2.d_in)
flattened_activations = activations_wout_bos.reshape(-1, 1, activation_store_gpt2.d_in)
flattened_activations.shape

torch.Size([65024, 1, 768])

In [26]:
def get_batch_without_first_token(activations_store):
    """
    Get a batch of activations from the ActivationsStore, removing the first token of every prompt.

    Args:
    activations_store (ActivationsStore): An instance of the ActivationsStore class.

    Returns:
    torch.Tensor: A tensor of shape [train_batch_size, 1, d_in] containing activations,
                  with the first token of each prompt removed.
    """
    # Get a batch of tokens
    batch_tokens = activations_store.get_batch_tokens()

    # Get activations for these tokens
    with torch.no_grad():
        activations = activations_store.get_activations(batch_tokens)

    # Remove the first token's activation from each prompt
    activations = activations[:, 1:, ...]

    # Reshape to match the output of next_batch()
    activations = activations.reshape(-1, 1, activations.shape[-1])

    # If there's any normalization applied in the original next_batch(), apply it here
    if activations_store.normalize_activations == "expected_average_only_in":
        activations = activations_store.apply_norm_scaling_factor(activations)

    # Shuffle the activations
    activations = activations[torch.randperm(activations.shape[0])]

    # Get the correct batch size
    train_batch_size = activations_store.train_batch_size_tokens

    # Return only the required number of activations
    return activations[:train_batch_size]

In [29]:
test = get_batch_without_first_token(activation_store_gpt2)

In [30]:
test.shape

torch.Size([512, 1, 768])

In [32]:
test

tensor([[[-1.1806,  3.0300, -2.3312,  ..., -4.5175,  0.5414, -1.1878]],

        [[ 1.1776, -3.1908,  3.4033,  ..., -0.0204, -0.2108,  0.5850]],

        [[ 0.2742, -0.9426, -3.4186,  ..., -1.2453, -1.3458, -0.5370]],

        ...,

        [[ 2.8214,  0.6326,  1.9179,  ..., -1.0603, -4.0802,  0.6450]],

        [[-0.0046, -2.4361, -2.7972,  ...,  0.3780,  3.9732,  2.3055]],

        [[-3.3601, -1.0612,  3.0431,  ...,  1.8759,  1.5579, -0.6897]]])

In [31]:
activations_batch_old.shape

torch.Size([512, 1, 768])

In [33]:
activations_batch_old

tensor([[[ 5.1935, -1.2809, -0.4450,  ...,  0.0861, -3.5058,  2.0328]],

        [[ 2.4250,  2.5613,  1.7368,  ..., -0.4346, -0.3918,  0.6911]],

        [[ 1.5355,  0.1123, -1.2726,  ..., -0.2158, -2.3195,  1.5503]],

        ...,

        [[ 0.9747, -2.4543, -0.0878,  ..., -0.1840, -2.4515,  1.1209]],

        [[ 1.0434,  2.8203, -0.3051,  ...,  5.3944, -2.9895, -1.9296]],

        [[ 3.3147, -0.0511,  0.8030,  ...,  0.8374, -2.1553,  1.0311]]],
       device='mps:0')

In [35]:
sae_gemma.cfg.normalize_activations