In [1]:
import torch
from sae_lens import SAE, ActivationsStore
from transformer_lens import HookedTransformer

Python(93804) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(93805) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


In [2]:
def load_model_and_sae(
    model_name: str, sae_release: str, sae_id: str, device: str
) -> tuple:
    model = HookedTransformer.from_pretrained(model_name, device=device)
    sae, _, _ = SAE.from_pretrained(release=sae_release, sae_id=sae_id, device=device)
    sae.W_dec.norm(dim=-1).mean()
    sae.fold_W_dec_norm()
    return model, sae

In [3]:
model_gemma, sae_gemma = load_model_and_sae(
    "gemma-2-2b",
    "gemma-scope-2b-pt-res-canonical",
    "layer_18/width_16k/canonical",
    "mps",
)

Python(93806) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
sae_gemma.cfg.normalize_activations

In [26]:
model_gpt2, sae_gpt2 = load_model_and_sae(
    "gpt2-small",
    "gpt2-small-res-jb-feature-splitting",
    "blocks.8.hook_resid_pre_768",
    "mps",
)

Loaded pretrained model gpt2-small into HookedTransformer


This SAE has non-empty model_from_pretrained_kwargs. 
For optimal performance, load the model like so:
model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)


In [27]:
activation_store_gpt2 = ActivationsStore.from_sae(
    model=model_gpt2,
    sae=sae_gpt2,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=512,
    n_batches_in_buffer=16,
    device="mps",
)



In [6]:
activation_store_gemma = ActivationsStore.from_sae(
    model=model_gemma,
    sae=sae_gemma,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=512,
    n_batches_in_buffer=16,
    device="mps",
)

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
# def get_feature_activations_for_batch(
#     activation_store: ActivationsStore, sae: SAE
# ) -> torch.Tensor:
#     activations_batch = activation_store.next_batch()
#     feature_acts = sae.encode(activations_batch).squeeze()
#     return feature_acts

In [None]:
def get_feature_activations_for_batch(
    activation_store: ActivationsStore,
    sae: SAE,
    remove_first_token: bool = False,
) -> torch.Tensor:
    """
    Get feature activations for a batch of tokens from an ActivationsStore.

    This function retrieves a batch of activations from the ActivationsStore,
    optionally removes the first token (typically the beginning-of-sequence token),
    and encodes the activations using the provided SAE (Sparse Autoencoder).

    Args:
        activation_store (ActivationsStore): The ActivationsStore object to get activations from.
        sae (SAE): The Sparse Autoencoder used to encode the activations.
        remove_first_token (bool, optional): Whether to remove the first token from each sequence.
                                             Defaults to False.

    Returns:
        torch.Tensor: Encoded feature activations, shape (batch_size * context_size, d_sae).

    Note:
        - If remove_first_token is True, the function uses get_flattened_activations_wout_first
          to retrieve activations without the first token.
        - The returned tensor is squeezed to remove any singleton dimensions.
    """
    if not remove_first_token:
        activations_batch = activation_store.next_batch()
    else:
        activations_batch = get_flattened_activations_wout_first(activation_store)
    feature_acts = sae.encode(activations_batch).squeeze()
    return feature_acts


def get_flattened_activations_wout_first(
    activation_store: ActivationsStore,
) -> torch.Tensor:
    """
    NOTE: this will be a different size from normal activation store batch as we do not get additional tokens to replace the BOS
    Get flattened activations without the first token (BOS) from an ActivationsStore.

    This function retrieves a batch of tokens, gets their activations, removes the first token
    (typically the beginning-of-sequence token), and flattens the resulting activations.

    Args:
        activation_store (ActivationsStore): The ActivationsStore object to get activations from.

    Returns:
        torch.Tensor: Flattened activations without the first token, shape (batch_size * (context_size - 1), d_in).

    Note:
        - The function assumes that the first token in each sequence is a BOS token.
        - The returned tensor is reshaped to match the format of activation_store.next_batch().
    """
    batch_size = activation_store.train_batch_size_tokens
    batch_tokens = activation_store.get_batch_tokens(batch_size)
    activations = activation_store.get_activations(batch_tokens)
    activations_wout_bos = activations[:, 1:, ...]
    flattened_activations = activations_wout_bos.view(-1, activation_store.d_in)
    return flattened_activations

In [8]:
# def get_feature_activations_for_batch_V2(
#     activation_store: ActivationsStore, sae: SAE, remove_bos: bool = False,
# ) -> torch.Tensor:
# if remove_bos:

# else:
#     activations_batch = activation_store.next_batch()
#     feature_acts = sae.encode(activations_batch).squeeze()
#     return feature_acts

In [9]:
feature_acts_gpt2 = get_feature_activations_for_batch(activation_store_gpt2, sae_gpt2)

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


In [10]:
feature_acts_gemma = get_feature_activations_for_batch(
    activation_store_gemma, sae_gemma
)

In [11]:
activation_store_next_tokens_gpt2 = activation_store_gpt2.get_batch_tokens(
    512
)  # TODO what does the batch size here correspond to compared to activation_store.next_batch()?

In [12]:
sae_gpt2.cfg

SAEConfig(architecture='standard', d_in=768, d_sae=768, activation_fn_str='relu', apply_b_dec_to_input=True, finetuning_scaling_factor=False, context_size=128, model_name='gpt2-small', hook_name='blocks.8.hook_resid_pre', hook_layer=8, hook_head_index=None, prepend_bos=True, dataset_path='Skylion007/openwebtext', dataset_trust_remote_code=True, normalize_activations='none', dtype='torch.float32', device='mps', sae_lens_training_version=None, activation_fn_kwargs={}, neuronpedia_id='gpt2-small/8-res_fs768-jb', model_from_pretrained_kwargs={'center_writing_weights': True})

In [13]:
activation_store_gpt2.next_batch().shape

torch.Size([512, 1, 768])

In [14]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

torch.Size([512, 768])

In [15]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

torch.Size([512, 768])

In [16]:
activation_store_next_tokens_gpt2.shape

torch.Size([512, 128])

In [17]:
activation_store_next_activations_gpt2 = activation_store_gpt2.get_activations(
    activation_store_next_tokens_gpt2
)

In [18]:
activation_store_next_activations_gpt2.shape

torch.Size([512, 128, 1, 768])

In [22]:
print(activation_store_next_activations_gpt2)
print(activation_store_gpt2.next_batch())

tensor([[[[-5.0307, -5.0059, -4.4554,  ..., -3.5955, -3.8039, -4.3690]],

         [[ 1.7617,  0.8806,  0.0538,  ..., -1.5403,  2.9543,  0.5674]],

         [[-0.3194, -1.2106, -3.8291,  ..., -1.4626,  3.1476,  2.6306]],

         ...,

         [[ 0.7187,  0.1953, -1.9804,  ..., -0.7839,  0.2512,  3.1737]],

         [[ 4.4271, -1.2376,  2.8379,  ..., -5.1405,  1.5336, -0.6301]],

         [[ 4.4993,  3.0335, -3.4330,  ..., -2.7207, -2.7646, -0.8193]]],


        [[[-5.0307, -5.0059, -4.4554,  ..., -3.5955, -3.8039, -4.3690]],

         [[ 3.3672,  2.4428,  2.3995,  ...,  2.0427,  1.3496, -1.3920]],

         [[ 1.9654, -2.2770, -1.2881,  ...,  1.6756,  1.4694,  0.8900]],

         ...,

         [[-0.7806,  0.9456, -1.4776,  ..., -3.8351, -1.0554,  3.2915]],

         [[-5.2381, -4.4683,  1.4433,  ..., -1.8673,  3.7738, -1.2534]],

         [[ 3.8893, -1.1445, -0.1531,  ...,  0.0826, -2.6133, -1.2079]]],


        [[[-5.0307, -5.0059, -4.4554,  ..., -3.5955, -3.8039, -4.3690]],

    

In [28]:
activation_store_gpt2.next_batch()[1, :]

Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors


tensor([[ 1.4198e+00, -3.6200e+00, -4.8885e-01,  1.4290e+00, -6.1072e-01,
          2.1432e+00, -3.0247e+00, -5.1112e+00, -1.4319e+00,  1.5245e+00,
          2.6599e-01, -4.3088e+00,  1.0644e+00,  4.3840e+00, -4.4498e+00,
          8.2160e-01,  3.2493e+00,  3.5883e+00,  2.5692e+00,  7.2998e-01,
          4.8065e+00, -9.3076e-01, -5.0682e+00, -1.2309e+00,  1.0484e+00,
          3.9354e+00,  3.2414e-01,  1.0189e+00, -1.0666e+00,  8.7230e-01,
         -1.9818e+00, -5.8232e-01, -2.0265e+00, -2.7809e+00, -4.4952e-01,
          3.8016e+00,  4.4219e+00, -2.7464e+00, -1.6595e+00,  1.7461e+00,
          5.1277e+00, -1.8172e+00,  6.1772e-01,  1.6547e-01, -5.3222e-02,
          3.0436e+00,  1.2391e+00,  3.2485e+00, -3.3789e+00,  1.1650e+00,
         -1.9357e+00,  6.3813e-01, -7.4953e+00,  3.3693e+00, -1.5321e+00,
          2.9719e+00,  2.7411e+00, -2.8643e+00, -2.8488e+00, -4.1920e+00,
          2.8716e+00, -4.1937e-01, -4.4081e+00, -1.6992e-01, -2.7602e+01,
          3.2994e+00, -1.8970e+00,  9.

In [19]:
sae_gpt2.encode(
    activation_store_next_activations_gpt2.to(sae_gpt2.device)
).squeeze().shape

torch.Size([512, 128, 768])

In [20]:
sae_gpt2.encode(activation_store_gpt2.next_batch()).squeeze().shape

torch.Size([512, 768])

I assume 128 is the context size, but why is that not in activation_store_next_batch

How do we get this to conform to old shape?

In [29]:
acts_2 = model_gpt2.run_with_cache(activation_store_next_tokens_gpt2, stop_at_layer=1)[
    1
](sae_gemma.cfg.hook_name)

TypeError: 'ActivationCache' object is not callable

In [30]:
acts_2 = model_gpt2.run_with_cache(activation_store_next_tokens_gpt2, stop_at_layer=1)[
    1
][sae_gpt2.cfg.hook_name]

KeyError: 'blocks.8.hook_resid_pre'

`n_batches, n_context = batch_tokens.shape`

In [31]:
activation_store_next_tokens_gpt2

tensor([[50256,   286,  8761,  ...,  7691,   284,  1592],
        [50256,  2872,   711,  ..., 44088,   532,  2619],
        [50256, 19650,  2310,  ..., 44654,  1028, 13094],
        ...,
        [50256,   683,    13,  ...,    82,  4497,   329],
        [50256,   428,  2168,  ...,   351,  1399,  7091],
        [50256,   447,   247,  ...,    83,  3221,   423]], device='mps:0')

In [32]:
activation_store_next_activations_gpt2.shape

torch.Size([512, 128, 1, 768])

In [40]:
activation_store_next_tokens_gpt2.shape

torch.Size([512, 128])

In [34]:
example_gpt2_batch = activation_store_gpt2.next_batch()

In [41]:
example_gpt2_batch.shape

torch.Size([512, 1, 768])

In [None]:
# def get_feature_activations_for_batch(
#     activation_store: ActivationsStore, sae: SAE, remove_first = False,
# ) -> torch.Tensor:

#     if not remove_first:
#         activations_batch = activation_store.next_batch()
#         feature_acts = sae.encode(activations_batch).squeeze()
#         return feature_acts
#     else:
#         activations_batch = get_flattened_activations(activation_store, batch_size = )

# rough notes

In [None]:
def get_activations_like_next_batch(self, batch_size=None):
    if batch_size is None:
        batch_size = self.train_batch_size_tokens

    # Get batch tokens
    batch_tokens = self.get_batch_tokens()

    # Get activations
    activations = self.get_activations(batch_tokens)

    # Reshape activations to match next_batch() output
    # next_batch() returns shape (batch_size, d_in)
    activations = activations.view(-1, self.d_in)

    # Apply normalization if needed
    if self.normalize_activations == "expected_average_only_in":
        activations = self.apply_norm_scaling_factor(activations)

    # Ensure we return the correct batch size
    if activations.shape[0] > batch_size:
        activations = activations[:batch_size]
    elif activations.shape[0] < batch_size:
        # If we don't have enough activations, we'll need to get more
        additional_activations = self.get_activations_like_next_batch(
            batch_size - activations.shape[0]
        )
        activations = torch.cat([activations, additional_activations], dim=0)

    return activations

In [None]:
def get_flattened_activations(self, batch_size=None):
    if batch_size is None:
        batch_size = self.train_batch_size_tokens

    # Get batch tokens
    batch_tokens = self.get_batch_tokens()

    # Get activations
    activations = self.get_activations(batch_tokens)

    # Reshape activations to flatten out the sequence structure
    # Original shape: (batch_size, context_size, num_layers, d_in)
    # New shape: (batch_size * context_size, d_in)
    flattened_activations = activations.view(-1, self.d_in)

    # Apply normalization if needed
    if self.normalize_activations == "expected_average_only_in":
        flattened_activations = self.apply_norm_scaling_factor(flattened_activations)

    # Ensure we return the correct batch size
    if flattened_activations.shape[0] > batch_size:
        flattened_activations = flattened_activations[:batch_size]
    elif flattened_activations.shape[0] < batch_size:
        # If we don't have enough activations, we'll need to get more
        additional_activations = self.get_flattened_activations(
            batch_size - flattened_activations.shape[0]
        )
        flattened_activations = torch.cat(
            [flattened_activations, additional_activations], dim=0
        )

    return flattened_activations

In [None]:
def get_flattened_activations_without_bos(self, batch_size=None):
    if batch_size is None:
        batch_size = self.train_batch_size_tokens

    # Get batch tokens
    batch_tokens = self.get_batch_tokens()

    # Get activations
    activations = self.get_activations(batch_tokens)

    # Remove the first token (BOS) from each sequence
    # Original shape: (batch_size, context_size, num_layers, d_in)
    # New shape: (batch_size, context_size - 1, num_layers, d_in)
    activations_without_bos = activations[:, 1:, ...]

    # Reshape activations to flatten out the sequence structure
    # New shape: (batch_size * (context_size - 1), d_in)
    flattened_activations = activations_without_bos.view(-1, self.d_in)

    # Apply normalization if needed
    if self.normalize_activations == "expected_average_only_in":
        flattened_activations = self.apply_norm_scaling_factor(flattened_activations)

    # Ensure we return the correct batch size
    if flattened_activations.shape[0] > batch_size:
        flattened_activations = flattened_activations[:batch_size]
    elif flattened_activations.shape[0] < batch_size:
        # If we don't have enough activations, we'll need to get more
        additional_activations = self.get_flattened_activations_without_bos(
            batch_size - flattened_activations.shape[0]
        )
        flattened_activations = torch.cat(
            [flattened_activations, additional_activations], dim=0
        )

    return flattened_activations