In [5]:
from functools import partial

import torch
from torch import nn
from jaxtyping import Float
from sae_lens import SAE, HookedSAETransformer
from torch import Tensor
from transformer_lens import loading_from_pretrained
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
import torch.utils.checkpoint as checkpoint

In [1]:
# import sys
# sys.path.append('.')  # Add current directory to Python path
from model_components import BiasOnly, ResidualBlock

In [6]:
example_input = torch.randn(10)
print(example_input)
print(example_input.shape)

tensor([-0.4511,  0.0766,  0.2233,  1.1803, -0.8610, -0.1251,  0.9687,  0.0301,
         0.7975,  0.1416])
torch.Size([10])


In [7]:
# Example usage:
bias_layer = BiasOnly(features=example_input.shape[0])
output = bias_layer(example_input)  # Will add a learnable bias to each feature
print(output)

tensor([-0.4511,  0.0766,  0.2233,  1.1803, -0.8610, -0.1251,  0.9687,  0.0301,
         0.7975,  0.1416], grad_fn=<AddBackward0>)


In [33]:
def resid_hook(sae_acts:Tensor, hook:HookPoint, residual_block:ResidualBlock) -> Tensor:
    """Runs the input through a trainable resnet (ResidualBlock).

    Args:
        sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
        hook (HookPoint): The transformer-lens hook point

    Returns:
        Tensor: The modified SAE activations modified by the trainable parameters.
    """

    return residual_block(sae_acts)

In [34]:
def debug_steer(sae_acts: Tensor, hook:HookPoint) -> Tensor:
    import pdb; pdb.set_trace()
    pass
    pass
    return sae_acts

In [10]:
device = 'cuda:0'

In [12]:
model_name = 'google/gemma-2-2b'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [9]:
model = HookedSAETransformer.from_pretrained('google/gemma-2-2b', device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [None]:
# Example feature we hope to edit (to be less) https://www.neuronpedia.org/gemma-2-2b/25-gemmascope-res-16k/8496

In [2]:
from sae_lens import SAE

sae25, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id="layer_25/width_16k/canonical", device='cpu')

In [5]:
print(cfg_dict)

{'architecture': 'jumprelu', 'd_in': 2304, 'd_sae': 16384, 'dtype': 'float32', 'model_name': 'gemma-2-2b', 'hook_name': 'blocks.25.hook_resid_post', 'hook_layer': 25, 'hook_head_index': None, 'activation_fn_str': 'relu', 'finetuning_scaling_factor': False, 'sae_lens_training_version': None, 'prepend_bos': True, 'dataset_path': 'monology/pile-uncopyrighted', 'context_size': 1024, 'dataset_trust_remote_code': True, 'apply_b_dec_to_input': False, 'normalize_activations': None, 'device': 'cpu', 'neuronpedia_id': 'gemma-2-2b/25-gemmascope-res-16k'}


In [8]:
import torch
isinstance(sae25, torch.nn.Module)

True

In [9]:
print(sae25)

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)


In [10]:
for name, param in sae25.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.shape}")


threshold: torch.Size([16384])
b_enc: torch.Size([16384])
W_dec: torch.Size([16384, 2304])
W_enc: torch.Size([2304, 16384])
b_dec: torch.Size([2304])


In [13]:
text = "Hello, world!"
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)


In [15]:
output = model.forward(input_ids)

In [22]:
# loss = output.abs().sum()
# print(loss)
# loss.backward()
# next(model.parameters()).grad

tensor(8825530., device='cuda:0', grad_fn=<SumBackward0>)


In [39]:
for param in model.parameters():
    param.requires_grad = False


In [30]:
# output = model.forward(input_ids)
# loss = output.abs().sum()
# print(loss)
# loss.backward()
# next(model.parameters()).grad

In [32]:
model.add_sae(sae25)

In [35]:
# sae25.add_hook("hook_sae_acts_post", debug_steer)

In [47]:
sae25.remove_all_hook_fns()

In [38]:
model.forward(input_ids)

tensor([[[-20.8156, -13.4081, -16.8380,  ..., -18.4423, -16.2108, -20.8378],
         [-19.9655,  -8.5803,   4.2300,  ...,  -9.0129,  -9.0134, -19.8179],
         [-17.8220,   1.6014,   4.2865,  ...,  -8.9011,  -4.6994, -17.6331],
         [-11.4105,   5.9845,  -0.3374,  ...,  -2.8015,   2.0687, -11.2620],
         [-11.7846,  10.4476,   0.9414,  ...,  -5.8933,  -2.8285, -11.6628]]],
       device='cuda:0', grad_fn=<MulBackward0>)

In [None]:
# in the case of this sae, the sae_acts shape is torch.Size([batch_size, seq_len, 16384])

In [46]:
residual_block = ResidualBlock(input_dim=16384, hidden_layers=-1).to(device)
trainable_hook = partial(resid_hook, residual_block=residual_block)

In [42]:
any(p.requires_grad for p in model.parameters())

False

In [48]:
sae25.add_hook('hook_sae_acts_post', trainable_hook)

In [44]:
any(p.requires_grad for p in model.parameters())

False

In [52]:
output = model.forward(input_ids)


In [51]:
next(residual_block.parameters()).grad

In [53]:
loss = output.abs().sum()
loss.backward()

In [54]:
next(residual_block.parameters()).grad

tensor([ -16.7542,  469.0292, -253.8390,  ...,  251.0276,   72.6698,
         240.4605], device='cuda:0')