In [26]:
import torch as t
from sae_lens import SAE
from transformer_lens import HookedTransformer
from tqdm import tqdm, trange

if t.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if t.cuda.is_available() else "cpu"

# Clamping SAE features in one layer and measuring effects on the subsequent layer
## Plan described [here](https://spar2024.slack.com/archives/C0794GNT8KS/p1719950740186749?thread_ts=1719934219.491869&cid=C0794GNT8KS)

In [2]:
# loading a small set of correlations to play around with
pearson_0_1_small: 'f' = t.load('../../experiments/pearson_0_1.pt')

In [11]:
# find the highest correlations
def create_value_tensor(matrix: 'f,f') -> 'f*f,3':
    m, _ = matrix.shape
    
    # Step 1: Flatten the matrix (shape: [m*m])
    flattened_matrix = matrix.flatten()
    
    # Step 2: Create row and column indices
    row_indices = t.arange(m).repeat_interleave(m)
    col_indices = t.arange(m).repeat(m)
    
    # Step 3: Create the final tensor with indices and values
    values = flattened_matrix
    result = t.stack((row_indices, col_indices, values), dim=1)
    
    # Step 4: Sort the result tensor by values
    sorted_result = result[t.argsort(result[:, 2], descending=True)]
    
    return sorted_result

In [12]:
ranked_features = create_value_tensor(pearson_0_1_small)

In [13]:
ranked_features

tensor([[5.5000e+01, 4.0000e+00, 1.0047e+00],
        [1.0000e+01, 5.5000e+01, 1.0045e+00],
        [5.3000e+01, 4.5000e+01, 1.0043e+00],
        ...,
        [1.4000e+01, 8.0000e+00, 2.0628e-03],
        [7.4000e+01, 9.6000e+01, 1.7489e-03],
        [1.3000e+01, 4.4000e+01, 1.6410e-03]])

In [16]:
print(pearson_0_1_small[55, 4], pearson_0_1_small[1, 55])

tensor(1.0047) tensor(0.9270)


In [None]:
# okay, I'll pick l1f55 and l2f4 as my pair. something to start with

## Measure the correlation when we pass the residual stream that's reconstructed from the SAE features - NO clamping
As mentioned at the end of June, normally the SAE features values are read by us and discarded - they are not passed back into the model for inference.
However, if we are going to be clamping an SAE feature and seeing its impact downstream, then we need to first see what happens when we pass the SAE features downstream with no clamping - because the mere act of projecting a residual stream into an SAE space and then back into residual stream space is a lossy operation (even though the SAE is supposed to represent the residual stream)

In [18]:
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
sae, _, _ = SAE.from_pretrained(release="gpt2-small-res-jb", sae_id=f"blocks.0.hook_resid_pre", device=device)
# okay...so I need a hook...I need a hook which will 



Loaded pretrained model gpt2-small into HookedTransformer


In [28]:
from transformer_lens.utils import tokenize_and_concatenate
from datasets import load_dataset
from torch.utils.data import DataLoader
# These hyperparameters are used to pre-process the data
context_size = sae.cfg.context_size
prepend_bos = sae.cfg.prepend_bos
batch_size = 32

dataset = load_dataset(path="NeelNanda/pile-10k", split="train", streaming=False)
token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=context_size,
    add_bos_token=prepend_bos,
)

tokens = token_dataset['tokens']

In [29]:
# OPTIONAL: Reduce dataset for faster experimentation
tokens = tokens[:64]

In [30]:
data_loader = DataLoader(tokens, batch_size=batch_size, shuffle=False)

In [None]:
with t.no_grad():
    for batch_tokens in tqdm(data_loader):
        _, model_cache = model.run_with_cache(batch_tokens, prepend_bos=True)
        # TODO: what's going on here?
        feat_acts = sae.encode[model_cache.cfg.hook_name]

## CLAMPING an SAE feature to 0
looking through https://transformerlensorg.github.io/TransformerLens/generated/code/transformer_lens.patching.html#transformer_lens.patching.generic_activation_patch and I’m not seeing much in the way of ablations...patching seems based on seeing how activations change by varying inputs, which isn't what I'm going to be doing.

Well...what am I doing, exactly? I guess I'm going to be...clamping the feature to `0` for all residual stream positions? for a single residual stream position? TODO: figure this out