In [None]:
import loading_utils
import torch
from transformer_lens import HookedTransformer
from sae_lens import SAE
from functools import partial

device='cuda'

%load_ext autoreload
%autoreload 2

In [None]:
# # Download data
# !wget "https://raw.githubusercontent.com/saprmarks/feature-circuits/main/data/nounpp_train.json"
# !wget "https://raw.githubusercontent.com/saprmarks/feature-circuits/main/data/nounpp_test.json"

In [None]:
# load models
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

saes = []
for i in range(12):
    saes.append(SAE.from_pretrained(
        release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
        sae_id = f"blocks.{i}.hook_resid_pre", # won't always be a hook point
        device = device
    )[0])  # returns SAE, config, sparsity


# Load data
train, test = loading_utils.load_examples("nounpp_train.json", 9999, model), loading_utils.load_examples("nounpp_test.json", 9999, model)

In [None]:
sample = train[0]

def attr_patch(model: HookedTransformer, sample, saes, metric, node_thresh=0.1, edge_thresh=0.01, verify=False):
    sae_cache = {}
    model.reset_hooks()

    grad_cache = {}
    if verify:
        embs_cache = []


        def embedding_hook(output, hook):
            output.retain_grad()
            embs_cache.append(output)
            return output
        
        model.add_hook('hook_embed', embedding_hook, 'fwd')

        result_baseline = metric(sample, model(sample['patch_prefix'], return_type='logits'))
        result_baseline.backward()
        # print('embs:', embs_cache.grad, embs_cache.shape, embs_cache.grad.shape)

    def sae_fwd_hook(output, hook, sae):
        # res1 (output) -> wide SAE (hook point) -> res2
        # our return statement is what the activation at res2 should be
        enc = sae.encode(output)
        enc.retain_grad()
        sae_cache[sae.cfg.hook_name] = enc
        dec = sae.decode(enc)
        err = (output - dec).detach()
        return dec + err


    def sae_bwd_hook(grad, hook, sae):
        # res2 -> wide SAE (hook point) -> res1 (grad is here), and we skip over wide SAE
        # our return statement is what the gradient at res2 should be
        # grad_cache[sae.cfg.hook_name] = grad
        return (grad,)
    
    for sae in saes:
        model.add_hook(sae.cfg.hook_name, partial(sae_fwd_hook, sae=sae), 'fwd')
        model.add_hook(sae.cfg.hook_name, partial(sae_bwd_hook, sae=sae), 'bwd')

    clean_result = metric(sample, model(sample['clean_prefix'], return_type='logits'))
    clean_result.backward(retain_graph=True)
    sae_clean = sae_cache  # retain reference to clean activations
    sae_cache = {}  # reset sae_cache to collect patch activations
    clean_grads = {k: v.grad.detach() for k,v in sae_clean.items()}
    if verify:
        print("grad diff for hooked model at embedding layer:", abs(embs_cache[1]-embs_cache[0]).sum().item())
    

    patch_result = metric(sample, model(sample['patch_prefix'], return_type='logits'))
    sae_patch = sae_cache 
    sae_cache = {}
    # input * grad attribution
    attribs = {k: (clean_grads[k] * (sae_patch[k] - sae_clean[k])).detach() for k in sae_clean}
    attrib_good_indices = {k: (attrib > node_thresh).nonzero() for k, attrib in attribs.items()}

    # refill sae_cache with clean run
    metric(sample, model(sample['clean_prefix'], return_type='logits'))
    sae_clean = sae_cache

    model.zero_grad()
    layer_pat = 'blocks.{i}.hook_resid_pre'
    edge_attribs = {}
    edge_attrib_good_indices = {}
    for layer in range(11, 0, -1):
        down = layer_pat.format(i=layer)
        up = layer_pat.format(i=layer-1)
        edge_attribs[down] = {}
        edge_attrib_good_indices[down] = {}
        for feat_tr in attrib_good_indices[down]:
            to_backprop = clean_grads[down] * sae_clean[down]
            feat = tuple(feat_tr.cpu().numpy())
            to_backprop[feat].backward(retain_graph=True)
            edge_attribs[down][feat] = (sae_clean[up].grad * (sae_patch[up] - sae_clean[up])).detach()
            sae_clean[up].grad.zero_()
            edge_attrib_good_indices[down][feat] = (edge_attribs[down][feat] > edge_thresh).nonzero()
            model.zero_grad()

    model.reset_hooks()    

    return attribs, attrib_good_indices, edge_attribs, edge_attrib_good_indices

def logit_diff_metric(sample, logits):
    last = logits[0, -1]
    return last[sample['clean_answer']] - last[sample['patch_answer']]

node_attr, node_idxs, edges_attr, edge_idxs = attr_patch(model, sample, saes, logit_diff_metric, edge_thresh=0.01)
edge_idxs

In [None]:
np_sae_grads = {k: grad.detach().cpu().numpy() for k,grad in sae_attribs.items()}
# plot histograms of the gradients, one subplot for all 12 layers in a 4 x 3 grid
import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(4, 3, figsize=(15, 15))
axs = axs.flatten()
for i, (k, grad) in enumerate(np_sae_grads.items()):
    perct = np.percentile(grad, 50.0)
    axs[i].hist(grad[grad > perct].flatten(), bins=100)
    axs[i].set_title(f"Layer {k}")
plt.tight_layout()
plt.show()