In [1]:
import torch as t
from sae_lens import SAE
from transformer_lens import HookedTransformer
from tqdm import tqdm, trange

if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

# Clamping SAE features in one layer and measuring effects on the subsequent layer
## Plan described [here](https://spar2024.slack.com/archives/C0794GNT8KS/p1719950740186749?thread_ts=1719934219.491869&cid=C0794GNT8KS)

In [2]:
# loading a small set of correlations to play around with
pearson_0_1_small: 'f' = t.load('../../data/pearson_0_1.pt')

In [3]:
# find the highest correlations
def create_value_tensor(matrix: 'f,f') -> 'f*f,3':
    m, _ = matrix.shape
    
    # Step 1: Flatten the matrix (shape: [m*m])
    flattened_matrix = matrix.flatten()
    
    # Step 2: Create row and column indices
    row_indices = t.arange(m).repeat_interleave(m)
    col_indices = t.arange(m).repeat(m)
    
    # Step 3: Create the final tensor with indices and values
    values = flattened_matrix
    result = t.stack((row_indices, col_indices, values), dim=1)
    
    # Step 4: Sort the result tensor by values
    sorted_result = result[t.argsort(result[:, 2], descending=True)]
    
    return sorted_result

In [4]:
ranked_features = create_value_tensor(pearson_0_1_small)

In [5]:
print(pearson_0_1_small[55, 4], pearson_0_1_small[1, 55])

tensor(1.0047) tensor(0.9270)


## Measure the correlation when we pass the residual stream that's reconstructed from the SAE features - NO clamping
As mentioned at the end of June, normally the SAE features values are read by us and discarded - they are not passed back into the model for inference.
However, if we are going to be clamping an SAE feature and seeing its impact downstream, then we need to first see what happens when we pass the SAE features downstream with no clamping - because the mere act of projecting a residual stream into an SAE space and then back into residual stream space is a lossy operation (even though the SAE is supposed to represent the residual stream)

In [6]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
sae_id_to_sae = {}
for layer in tqdm(list(range(model.cfg.n_layers))):
    sae_id = f"blocks.{layer}.hook_resid_pre"
    sae, _, _ = SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=sae_id,
        device=device
    )
    sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
    sae_id_to_sae[sae_id] = sae



Loaded pretrained model gpt2-small into HookedTransformer


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 12/12 [00:06<00:00,  1.99it/s]


In [7]:
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset
from torch.utils.data import DataLoader
# These hyperparameters are used to pre-process the data
random_sae_id = "blocks.0.hook_resid_pre"
random_sae = sae_id_to_sae[random_sae_id]
context_size = random_sae.cfg.context_size
prepend_bos = random_sae.cfg.prepend_bos
batch_size = 32

dataset = load_dataset(path="NeelNanda/pile-10k", split="train", streaming=False)
token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=context_size,
    add_bos_token=prepend_bos,
)

tokens = token_dataset['tokens']

In [8]:
# OPTIONAL: Reduce dataset for faster experimentation
tokens = tokens[:64]

In [9]:
data_loader = DataLoader(tokens, batch_size=batch_size, shuffle=False)

In [10]:
# looking at compute-pearson-0.py to figure out how hooks work
def replace_with_sae_output(activations, hook):
    """so I guess I'm going to
    1) take the activations out of some specific layer (which i can generalize later)
    2) pass them through an SAE, get back the same shape (but it'll look a little different bc it's lost some data)
    3) return that (??)
    
    well we're not returning anything here...wtf is a hook?? it's a function that runs while a model is executing..ok...
    so do I 
    1. modify the layer directly
    2. return what i want?
    
    okay let's say that i need to do in-memory modification. what would that look like?
    I have this "hook". it's a function..what type is it? idk. well I guess I'm in the hook right now.
    wait..no..this is confusing. this function is called "hook". and yet...it ITSELF contains a hook. HUH? what's the hook??
    
    okay a "hookpoint" is... https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/hook_points.py#L64 
    a wrapper around a nn.Module
    
    and then let's see what add_hook does
    all of this seems to be wrapping pytorch.nn.Module's register_forward_hook https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook
    
    If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature:
    hook(module, args, output) -> None or modified output

    okay so I need to create a hook in which...it's going to modify the output.
    but then...what's going on with this `hook`
    """
    sae = sae_id_to_sae[hook.name]
    new_activations = sae(activations)
    # test_new_acts = sae.decode(sae.encode(activations))
    # print(f"making sure I know how these work: {t.sum(new_activations != test_new_acts).item()}")
    # print(new_activations - activations)
    # print(f"avg diff activations: {t.sum(new_activations != activations).item()}")
    return new_activations 

# okay so add_hook takes a function...and then...
# https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookPoint.add_hook

# TODO: make a lambda function
# model.reset_hooks()

# TODO: find out where "pre" is defined
# michael: every time you do something with an activation. might not mean anything bc "residual stream" is not an action (unlike, say, adding the attention to the residual)
# model.add_hook("blocks.0.hook_resid_pre", replace_with_sae_output)

In [11]:
with t.no_grad():
    for batch_tokens in tqdm(data_loader):
        orig = model(batch_tokens)
        reconstructed = model.run_with_hooks(
            batch_tokens,
            fwd_hooks=[
                (
                    random_sae_id,
                    replace_with_sae_output,
                )
            ]
        )

100%|██████████| 2/2 [00:03<00:00,  1.78s/it]


In [13]:
t.sum(orig == reconstructed).item()

135

Okay...nice...looks like most of the output logits have changed now that this single layer is using SAE activations instead of the residual stream!

## CLAMPING an SAE feature to 0
looking through https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.patching.html#transformer_lens.patching.generic_activation_patch and I’m not seeing much in the way of ablations...patching seems based on seeing how activations change by varying inputs, which isn't what I'm going to be doing.

Well...what am I doing, exactly? I guess I'm going to be...clamping the feature to `0` for all residual stream positions? for a single residual stream position? TODO: figure this out