In [34]:
import torch as t
import torch
from sae_lens import SAE
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint
from tqdm import tqdm, trange
import einops
from dataclasses import dataclass
from typing import List
from functools import partial
import plotly.express as px

if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

# Clamping SAE features in one layer and measuring effects on the subsequent layer
## Plan described [here](https://spar2024.slack.com/archives/C0794GNT8KS/p1719950740186749?thread_ts=1719934219.491869&cid=C0794GNT8KS)

In [4]:
# loading a small set of correlations to play around with
pearson_0_1_small: 'f' = t.load('../../data/pearson_0_1.pt')

In [81]:
# find the highest correlations
def create_value_tensor(matrix: 'f,f') -> 'f*f,3':
    m, _ = matrix.shape
    
    # Step 1: Flatten the matrix (shape: [m*m])
    flattened_matrix = matrix.flatten()
    
    # Step 2: Create row and column indices
    row_indices = t.arange(m).repeat_interleave(m)
    col_indices = t.arange(m).repeat(m)
    
    # Step 3: Create the final tensor with indices and values
    values = flattened_matrix
    result = t.stack((row_indices, col_indices, values), dim=1)
    
    # Step 4: Sort the result tensor by values
    sorted_result = result[t.argsort(result[:, 2], descending=True)]
    
    return sorted_result

In [6]:
ranked_features = create_value_tensor(pearson_0_1_small)

In [7]:
print(pearson_0_1_small[55, 4], pearson_0_1_small[1, 55])

tensor(1.0047) tensor(0.9270)


## Measure the correlation when we pass the residual stream that's reconstructed from the SAE features - NO clamping
As mentioned at the end of June, normally the SAE features values are read by us and discarded - they are not passed back into the model for inference.
However, if we are going to be clamping an SAE feature and seeing its impact downstream, then we need to first see what happens when we pass the SAE features downstream with no clamping - because the mere act of projecting a residual stream into an SAE space and then back into residual stream space is a lossy operation (even though the SAE is supposed to represent the residual stream)

In [41]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
sae_id_to_sae = {}
for layer in tqdm(list(range(model.cfg.n_layers))):
    sae_id = f"blocks.{layer}.hook_resid_pre"
    sae, _, _ = SAE.from_pretrained(
        release="gpt2-small-res-jb",
        sae_id=sae_id,
        device=device
    )
    sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
    sae_id_to_sae[sae_id] = sae


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 12/12 [00:04<00:00,  2.81it/s]


In [59]:
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset
from torch.utils.data import DataLoader
# These hyperparameters are used to pre-process the data
pre_0_sae_id = "blocks.0.hook_resid_pre"
pre_0_sae = sae_id_to_sae[pre_0_sae_id]
context_size = pre_0_sae.cfg.context_size
prepend_bos = pre_0_sae.cfg.prepend_bos
d_sae = pre_0_sae.cfg.d_sae
batch_size = 32

dataset = load_dataset(path="NeelNanda/pile-10k", split="train", streaming=False)
token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=context_size,
    add_bos_token=prepend_bos,
)

tokens = token_dataset['tokens']

In [60]:
# OPTIONAL: Reduce dataset for faster experimentation
tokens = tokens[:1024]

In [61]:
data_loader = DataLoader(tokens, batch_size=batch_size, shuffle=False)

In [62]:
# looking at compute-pearson-0.py to figure out how hooks work


# okay so add_hook takes a function...and then...
# https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.hook_points.html#transformer_lens.hook_points.HookPoint.add_hook

# TODO: make a lambda function
# model.reset_hooks()

# TODO: find out where "pre" is defined
# michael: every time you do something with an activation. might not mean anything bc "residual stream" is not an action (unlike, say, adding the attention to the residual)
# https://raw.githubusercontent.com/callummcdougall/computational-thread-art/master/example_images/misc/full-merm.svg
# model.add_hook("blocks.0.hook_resid_pre", replace_with_sae_output)

Okay...nice...looks like most of the output logits have changed now $$that this single layer is using SAE activations instead of the residual stream!

Next step: checking the correlation of 55, 4

how do I do that...
okay well how did we compute correlations initially?

First I need to collect the sae_activations but this time with the SAE being used to influence the second layer's activations

In [76]:
@dataclass
class LayerFeatures:
    layer_idx: int
    feature_idxes: List[int]

@dataclass
class AggregatorConfig:
    layer_1: LayerFeatures
    layer_2: LayerFeatures
    
class BatchedPearson:
    def __init__(self, agg_conf: AggregatorConfig):
        """Calculates the pair-wise Pearson correlation of two tensors that are provided batch-wise.
        """
        self.agg_conf = agg_conf
        shape = (len(agg_conf.layer_1.feature_idxes), len(agg_conf.layer_2.feature_idxes))
        self.count = 0

        self.sums_1 = torch.zeros(shape[0])
        self.sums_2 = torch.zeros(shape[1])

        self.sums_of_squares_1 = torch.zeros(shape[0])
        self.sums_of_squares_2 = torch.zeros(shape[1])

        self.sums_1_2 = torch.zeros(shape)

        self.nonzero_counts_1 = torch.zeros(shape[0])
        self.nonzero_counts_2 = torch.zeros(shape[1])

    def process(self, tensor_1, tensor_2):
        self.count += tensor_1.shape[-1]

        self.sums_1 += tensor_1.sum(dim=-1)
        self.sums_2 += tensor_2.sum(dim=-1)

        self.sums_of_squares_1 += (tensor_1 ** 2).sum(dim=-1)
        self.sums_of_squares_2 += (tensor_2 ** 2).sum(dim=-1)

        self.sums_1_2 += einops.einsum(tensor_1, tensor_2, 'f1 t, f2 t -> f1 f2')

        self.nonzero_counts_1 += tensor_1.count_nonzero(dim=-1)
        self.nonzero_counts_2 += tensor_2.count_nonzero(dim=-1)

    def finalize(self):
        means_1 = self.sums_1 / self.count
        means_2 = self.sums_2 / self.count

        # Compute the covariance and variances
        covariances = (self.sums_1_2 / self.count) - einops.einsum(means_1, means_2, 'f1, f2 -> f1 f2')

        variances_1 = (self.sums_of_squares_1 / self.count) - (means_1 ** 2)
        variances_2 = (self.sums_of_squares_2 / self.count) - (means_2 ** 2)

        stds_1 = torch.sqrt(variances_1).unsqueeze(1)
        stds_2 = torch.sqrt(variances_2).unsqueeze(0)

        # Compute the Pearson correlation coefficient
        correlations = covariances / stds_1 / stds_2

        return correlations




In [77]:
agg_confs = [
    AggregatorConfig(
        LayerFeatures(8, [0, 1, 2, 3, 4, 5, 6, 7, 8, 14292]),
        LayerFeatures(9, [0, 1, 2, 3, 4, 5, 6, 7, 8, 4813]),
    ),
    AggregatorConfig(
        LayerFeatures(9, [0, 1, 2, 3, 4, 5, 6, 7, 8, 4813]),
        LayerFeatures(10, [0, 1, 2, 3, 4, 5, 6, 7, 8, 23138]),
    ),
    AggregatorConfig(
        LayerFeatures(10, [0, 1, 2, 3, 4, 5, 6, 7, 8, 23138]),
        LayerFeatures(11, [0, 1, 2, 3, 4, 5, 6, 7, 8, 701]),
    ),
]
aggs = [BatchedPearson(agg_conf) for agg_conf in agg_confs]
sae_activations = torch.empty(model.cfg.n_layers, d_sae, batch_size * context_size)

def replace_acts_with_sae(activations: t.Tensor, hook: HookPoint):
    # replaces the residual stream activations with SAE activations
    sae = sae_id_to_sae[hook.name]
    return sae(activations)

def emit_sae(activations: t.Tensor, hook: HookPoint):
    # emits what the SAE activations would be for these input activations
    sae: SAE = sae_id_to_sae[hook.name]
    sae_acts = sae.encode(activations)
    sae_activations[hook.layer()] = einops.rearrange(
        sae_acts,
        'batch seq features -> features (batch seq)'
    )
    return activations

with torch.no_grad():
    for batch_tokens in tqdm(data_loader):
        # okay so if we have a hook which is doing SAE embeddings and returning them then...we need to be...
        # i think it'll be easier if BOTH functions save SAEs, but only one ofthem returns the SAEs and other returns original acts
        model.reset_hooks()
        # TODO(optimization): collapse these into a single function which decides whether to 
        # emit the SAE activations or return the originals
        model.add_hook(
            lambda name: name.endswith('.hook_resid_pre'),
            emit_sae,
        )
        model.run_with_hooks(batch_tokens)
        for agg in aggs:
            agg.process(
                sae_activations[agg.agg_conf.layer_1.layer_idx, agg.agg_conf.layer_1.feature_idxes],
                sae_activations[agg.agg_conf.layer_2.layer_idx, agg.agg_conf.layer_2.feature_idxes]
            )
    pearson_correlations = [aggregator.finalize() for aggregator in aggs]

  0%|          | 0/32 [00:00<?, ?it/s]

100%|██████████| 32/32 [01:14<00:00,  2.32s/it]


In [82]:


ranked_feats = t.concat([
    create_value_tensor(pc)
    for pc in pearson_correlations
])

# TODO: figure out how to pick a high correlation pair

In [84]:
#
pearson_correlations

[tensor([[ 3.1148e-03, -1.9471e-04, -6.2550e-04, -1.3654e-03, -4.3121e-04,
          -7.5973e-04, -4.6204e-04, -2.0788e-04, -6.0275e-04, -3.9825e-04],
         [ 1.2669e-03, -1.7483e-04, -5.6166e-04, -1.2260e-03, -3.8720e-04,
          -6.8220e-04, -4.1488e-04, -1.8667e-04, -5.4124e-04, -3.5761e-04],
         [-4.0764e-04, -1.0950e-04, -3.5178e-04, -7.5998e-04, -2.4251e-04,
           9.7921e-04, -2.5985e-04, -1.1691e-04, -3.3899e-04, -2.2398e-04],
         [-3.9907e-04,  5.9122e-05, -3.4439e-04, -7.5175e-04, -2.3741e-04,
          -4.1829e-04, -2.5439e-04, -1.1446e-04, -3.3186e-04, -2.1927e-04],
         [ 1.0548e-02, -5.3831e-04, -1.7294e-03, -3.7555e-03, -1.1922e-03,
          -1.5441e-03, -1.2774e-03, -5.7474e-04, -1.6665e-03, -1.1011e-03],
         [-3.5253e-04, -9.4699e-05, -3.0422e-04, -6.6408e-04, -2.0973e-04,
          -3.6951e-04, -2.2472e-04, -1.0111e-04, -2.9316e-04, -1.9370e-04],
         [-8.6122e-04, -2.3135e-04,  4.5319e-03, -1.6223e-03, -5.1236e-04,
          -9.0270e-

In [83]:
ranked_feats

tensor([[ 9.0000e+00,  9.0000e+00,  7.5331e-01],
        [ 6.0000e+00,  7.0000e+00,  1.3372e-02],
        [ 4.0000e+00,  0.0000e+00,  1.0548e-02],
        [ 6.0000e+00,  2.0000e+00,  4.5319e-03],
        [ 0.0000e+00,  0.0000e+00,  3.1148e-03],
        [ 8.0000e+00,  4.0000e+00,  1.4305e-03],
        [ 1.0000e+00,  0.0000e+00,  1.2669e-03],
        [ 2.0000e+00,  5.0000e+00,  9.7921e-04],
        [ 7.0000e+00,  5.0000e+00,  9.1835e-05],
        [ 3.0000e+00,  1.0000e+00,  5.9122e-05],
        [ 7.0000e+00,  1.0000e+00, -7.5930e-05],
        [ 7.0000e+00,  7.0000e+00, -8.1068e-05],
        [ 5.0000e+00,  1.0000e+00, -9.4699e-05],
        [ 5.0000e+00,  7.0000e+00, -1.0111e-04],
        [ 2.0000e+00,  1.0000e+00, -1.0950e-04],
        [ 3.0000e+00,  7.0000e+00, -1.1446e-04],
        [ 2.0000e+00,  7.0000e+00, -1.1691e-04],
        [ 7.0000e+00,  9.0000e+00, -1.5531e-04],
        [ 9.0000e+00,  1.0000e+00, -1.6540e-04],
        [ 7.0000e+00,  4.0000e+00, -1.6816e-04],
        [ 1.0000e+00

In [65]:
t.sum(sae_activations[agg_conf.layer_1.layer_idx, agg_conf.layer_1.feature_idxes]).item()

0.058958083391189575

In [80]:
pearson_correlations[2]

tensor([[-6.2438e-04, -2.5882e-04, -2.6010e-04, -3.0832e-04, -8.6990e-04,
         -6.4761e-04, -1.1305e-04, -5.8632e-04, -2.9311e-04, -2.5225e-04],
        [-8.2803e-04, -3.4324e-04, -3.4494e-04, -4.0888e-04, -1.1536e-03,
         -8.5883e-04, -1.4992e-04, -7.7755e-04,  2.3899e-03, -3.3452e-04],
        [-2.3706e-03, -1.2318e-04, -9.8067e-04, -1.1706e-03, -2.9752e-03,
          7.0438e-03, -3.8068e-04, -2.2261e-03, -1.1129e-03, -9.4081e-04],
        [-4.2541e-04, -1.7634e-04, -1.7722e-04, -2.1006e-04, -5.9269e-04,
         -4.4123e-04, -7.7024e-05, -3.9947e-04, -1.9971e-04, -1.7186e-04],
        [-3.0744e-04, -1.2744e-04, -1.2807e-04, -1.5181e-04, -4.2833e-04,
         -3.1888e-04, -5.5666e-05, -2.8870e-04, -1.4433e-04, -1.2421e-04],
        [-2.5929e-04, -1.0748e-04, -1.0802e-04, -1.2804e-04, -3.6125e-04,
         -2.6894e-04, -4.6947e-05, -2.4348e-04, -1.2172e-04, -1.0475e-04],
        [-4.4359e-04, -1.8388e-04, -1.8479e-04, -2.1904e-04, -6.1802e-04,
         -4.6009e-04, -8.0317e-0

## uhhh
it was originally `1.0054`..hmm...okay...honestly...I need to just forget about `pearson_0_1.pt` and just recompute it myself

## Extending this to work on multiple feature pairs within a single layer pair
TODO

## Extending this to work across multiple layer pairs
TODO