In [5]:
from functools import partial

import torch
from torch import nn
from jaxtyping import Float
from sae_lens import SAE, HookedSAETransformer
from torch import Tensor
from transformer_lens import loading_from_pretrained
from transformer_lens.hook_points import HookPoint
from transformers import AutoTokenizer
import torch.utils.checkpoint as checkpoint

In [6]:
example_input = torch.randn(10)
print(example_input)
print(example_input.shape)

tensor([-0.4511,  0.0766,  0.2233,  1.1803, -0.8610, -0.1251,  0.9687,  0.0301,
         0.7975,  0.1416])
torch.Size([10])


In [7]:
class BiasOnly(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.bias = nn.Parameter(torch.zeros(features))
        
    def forward(self, x):
        return x + self.bias

# Example usage:
bias_layer = BiasOnly(features=example_input.shape[0])
output = bias_layer(example_input)  # Will add a learnable bias to each feature
print(output)

tensor([-0.4511,  0.0766,  0.2233,  1.1803, -0.8610, -0.1251,  0.9687,  0.0301,
         0.7975,  0.1416], grad_fn=<AddBackward0>)


In [8]:
class ResidualBlock(torch.nn.Module):
    def __init__(self, input_dim:int, hidden_layers:int, hidden_dim:int|None=None, activation=torch.nn.ReLU()):
        """A flexible residual neural network block that maintains input/output dimension compatibility.
    
        This block implements a residual connection of the form output = F(x) + x, where F is a configurable
        neural network. The architecture supports various depths and can degenerate to a simple bias-only layer.
        
        Args:
            input_dim (int): Dimension of input features. Must be positive. Output will have same dimension.
            hidden_layers (int): Number of hidden layers in the network.
                * -1: Creates a bias-only layer
                * 0: Single linear transformation
                * >0: Creates that many hidden layers with activation functions between them
            hidden_dim (int, optional): Dimension of hidden layers. If None, uses input_dim.
            activation (torch.nn.Module): Activation function to use between layers. Defaults to ReLU.
        
        Example:
            >>> block = ResidualBlock(input_dim=512, hidden_layers=2, hidden_dim=1024)
            >>> x = torch.randn(32, 512)  # batch_size=32, features=512
            >>> output = block(x)  # Shape: (32, 512)
        """
        super().__init__()
        assert input_dim > 0
        assert hidden_layers >= -1
        assert hidden_dim is None or hidden_dim > 0
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim if hidden_dim else input_dim
        self.activation = activation
        sequential = []
        if hidden_layers == -1:
            sequential.append(BiasOnly(input_dim))
        else:
            input_dims = [self.input_dim] + [self.hidden_dim] * hidden_layers
            output_dims = [self.hidden_dim] * hidden_layers + [self.input_dim]
            for i, (in_dim, out_dim) in enumerate(zip(input_dims, output_dims)): # plus one because zero hidden layers is just a forward from input to output
                linear = torch.nn.Linear(in_dim, out_dim)
                if i == len(input_dims) - 1:  # Final layer
                    # Initialize final layer to zero for identity function behavior
                    torch.nn.init.zeros_(linear.weight)
                    torch.nn.init.zeros_(linear.bias)
                else:
                    # Xavier initialization for hidden layers
                    torch.nn.init.xavier_uniform_(linear.weight)
                    torch.nn.init.zeros_(linear.bias)
                sequential.append(linear)
                if i < hidden_layers - 1:
                    sequential.append(activation)
        self.sequential = torch.nn.Sequential(*sequential)
    
    def forward(self, x):
        return self.sequential(x) + x

In [None]:
def resid_hook(sae_acts:Tensor, hook:HookPoint, residual_block:ResidualBlock) -> Tensor:
    """Runs the input through a trainable resnet (ResidualBlock).

    Args:
        sae_acts (Tensor): The SAE activations tensor, shape [batch, pos, features]
        hook (HookPoint): The transformer-lens hook point

    Returns:
        Tensor: The modified SAE activations modified by the trainable parameters.
    """

    return residual_block(sae_acts)

In [10]:
device = 'cuda:0'

In [12]:
model_name = 'google/gemma-2-2b'
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [9]:
model = HookedSAETransformer.from_pretrained('google/gemma-2-2b', device=device)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model google/gemma-2-2b into HookedTransformer


In [None]:
# Example feature we hope to edit (to be less) https://www.neuronpedia.org/gemma-2-2b/25-gemmascope-res-16k/8496

In [11]:
from sae_lens import SAE

sae25, cfg_dict, sparsity = SAE.from_pretrained(release="gemma-scope-2b-pt-res-canonical", sae_id="layer_25/width_16k/canonical", device=device)

params.npz:   0%|          | 0.00/302M [00:00<?, ?B/s]

In [13]:
text = "Hello, world!"
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)


In [15]:
output = model.forward(input_ids)

In [22]:
# loss = output.abs().sum()
# print(loss)
# loss.backward()
# next(model.parameters()).grad

tensor(8825530., device='cuda:0', grad_fn=<SumBackward0>)


In [27]:
for param in model.parameters():
    param.requires_grad = False


In [30]:
# output = model.forward(input_ids)
# loss = output.abs().sum()
# print(loss)
# loss.backward()
# next(model.parameters()).grad

In [32]:
model.add_sae(sae25)

In [None]:
sae25.