In [1]:
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

In [2]:
def load_model_and_sae(
    model_name: str, sae_release: str, sae_id: str, device: str
) -> tuple:
    model = HookedTransformer.from_pretrained(model_name, device=device)
    sae, _, _ = SAE.from_pretrained(release=sae_release, sae_id=sae_id, device=device)
    sae.W_dec.norm(dim=-1).mean()
    sae.fold_W_dec_norm()
    return model, sae

In [3]:
model_gemma, sae_gemma = load_model_and_sae(
    "gemma-2-2b",
    "gemma-scope-2b-pt-res-canonical",
    "layer_18/width_16k/canonical",
    "mps",
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
model_gpt2, sae_gpt2 = load_model_and_sae(
    "gpt2-small",
    "gpt2-small-res-jb-feature-splitting",
    "blocks.8.hook_resid_pre_768",
    "mps",
)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [5]:
activation_store_gpt2 = ActivationsStore.from_sae(
    model=model_gpt2,
    sae=sae_gpt2,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=512,
    n_batches_in_buffer=16,
    device="mps",
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [6]:
activation_store_gemma = ActivationsStore.from_sae(
    model=model_gemma,
    sae=sae_gemma,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=512,
    n_batches_in_buffer=16,
    device="mps",
)

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
def get_feature_activations_for_batch(
    activation_store: ActivationsStore, sae: SAE
) -> torch.Tensor:
    activations_batch = activation_store.next_batch()
    feature_acts = sae.encode(activations_batch).squeeze()
    return feature_acts

In [None]:
# def get_feature_activations_for_batch_V2(
#     activation_store: ActivationsStore, sae: SAE, remove_bos: bool = False,
# ) -> torch.Tensor:
# if remove_bos:

# else:
#     activations_batch = activation_store.next_batch()
#     feature_acts = sae.encode(activations_batch).squeeze()
#     return feature_acts

In [8]:
feature_acts_gpt2 = get_feature_activations_for_batch(activation_store_gpt2, sae_gpt2)

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


In [9]:
feature_acts_gemma = get_feature_activations_for_batch(
    activation_store_gemma, sae_gemma
)

In [21]:
activation_store_next_tokens_gpt2 = activation_store_gpt2.get_batch_tokens(
    512
)  # TODO what does the batch size here correspond to compared to activation_store.next_batch()?

In [38]:
sae_gpt2.cfg

SAEConfig(architecture='standard', d_in=768, d_sae=768, activation_fn_str='relu', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='gpt2-small', hook_name='blocks.8.hook_resid_pre', hook_layer=8, hook_head_index=None, prepend_bos=True, dataset_path='Skylion007/openwebtext', dataset_trust_remote_code=True, normalize_activations='none', dtype='torch.float32', device='mps', sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gpt2-small/8-res_fs768-jb', model_from_pretrained_kwargs={'center_writing_weights': True})

In [23]:
activation_store_gpt2.next_batch().shape

torch.Size([512, 1, 768])

In [43]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

torch.Size([512, 768])

In [31]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

torch.Size([512, 768])

In [24]:
activation_store_next_tokens_gpt2.shape

torch.Size([512, 128])

In [25]:
activation_store_next_activations_gpt2 = activation_store_gpt2.get_activations(
    activation_store_next_tokens_gpt2
)

In [26]:
activation_store_next_activations_gpt2.shape

torch.Size([512, 128, 1, 768])

In [34]:
sae_gpt2.encode(
    activation_store_next_activations_gpt2.to(sae_gpt2.device)
).squeeze().shape

torch.Size([512, 128, 768])

In [None]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

I assume 128 is the context size, but why is that not in activation_store_next_batch

How do we get this to conform to old shape?