In [97]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [118]:
import json
import os
import sys
from typing import Any, Callable

sys.path.insert(0, os.getcwd())

from HookedTransformer import HookedTransformer
# from transformer_lens import HookedTransformer

from transformers import AutoModelForCausalLM

import networkx as nx
import random
import math
import pickle
import dataclasses
import numpy as np

from einops import repeat

import plotly.express as px
import torch
import torch.nn.functional as F
from core.config import SAEConfig
from core.sae import SparseAutoEncoder

device = "cuda" if torch.cuda.is_available() else "cpu"

hf_model = AutoModelForCausalLM.from_pretrained('gpt2')
model = HookedTransformer.from_pretrained('gpt2', device=device, hf_model=hf_model)

def check_all_close():
	import transformer_lens
	origin_tl_model = transformer_lens.HookedTransformer.from_pretrained('gpt2', device=device, hf_model=hf_model)
	logits = model(model.to_tokens('Hello, World.'))
	origin_logits = origin_tl_model(origin_tl_model.to_tokens('Hello, World.'))
	assert torch.allclose(logits, origin_logits, atol=1e-4), f"Logits are not close: {logits} != {origin_logits}"

# input = model.to_tokens(" OpenMoss! OpenMoss! OpenMoss!", prepend_bos=False)
# input = model.to_tokens("Outside [Inside] Outside", prepend_bos=False)
# input = model.to_tokens("0 0 [1 1 1 [2] 3] 4", prepend_bos=False)
# input = model.to_tokens("Video in WebM support: Your browser doesn't support HTML5 video in WebM.", prepend_bos=False)
# input = model.to_tokens("Form-fitting TrekDry helps keep hands cool and comfortable. Form-fitting TrekDry material is lightweight and breathable.", prepend_bos=False)
# input = model.to_tokens(" it was its command line interface. You get so much leverage by being able to scaffold a [Inner Inner] A B A", prepend_bos=False)
# input = model.to_tokens("[[[ OpenMoss ]]] OpenMoss Open Moss ]", prepend_bos=False)
# input = model.to_tokens("Fruits:\n\napple red\n\nbanana yellow\n\ngrape purple", prepend_bos=False)
# input = model.to_tokens("Fruits:\n\nbanana yellow\n\napple red\n\ngrape purple", prepend_bos=False)
# input = model.to_tokens("You’re used to endlessly circular debates where Republican shills and Democratic shills", prepend_bos=False)
# input = model.to_tokens("Afterwards, Alice and Tom went to the shop. Tom gave a bunch of flowers to", prepend_bos=False)
# input = model.to_tokens("Afterwards, Tom and Alice went to the shop. Tom gave a bunch of flowers to", prepend_bos=False)
# input = model.to_tokens("When Mary and John went to the store, Mary gave a bottle of milk to", prepend_bos=False)
# input = model.to_tokens("When Mary and John went to the store, John gave a bottle of milk to", prepend_bos=False)
# input = model.to_tokens("When John and Mary went to the store, John gave a bottle of milk to", prepend_bos=False)
# input = model.to_tokens("20 Parts Rosemary, 8 Parts Grapefruit", prepend_bos=False)

# answer = model.to_tokens(" Mary", prepend_bos=False)
# assert answer.size(0) == 1
# logits, cache = model.run_with_cache(input)
# logits = logits[0, -1, answer.item()]
# print(logits)
# logits.backward()

Loaded pretrained model gpt2 into HookedTransformer


In [129]:
import einops

for threshold in [0, 1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2, 3e-2, 1e-1]:
    model.cfg.detach_pattern = True
    model.cfg.add_sae_error = True
    model.cfg.prune_on_backward = True
    model.cfg.prune_on_backward_threshold = threshold

    input = model.to_tokens("When Mary and John went to the store, John gave a bottle of milk to", prepend_bos=False)
    # input = model.to_tokens("When John and Mary went to the store, John gave a bottle of milk to", prepend_bos=False)
    # input = model.to_tokens("20 Parts Rosemary, 8 Parts Grapefruit", prepend_bos=False)

    answer = model.to_tokens(" Mary", prepend_bos=False)
    wrong_answer = model.to_tokens(" John", prepend_bos=False)
    assert answer.size(0) == 1
    cache = model.add_caching_hooks(incl_bwd=True)
    logits = model(input)
    true_logits = logits[0, -1, answer.item()]
    wrong_logits = logits[0, -1, wrong_answer.item()]
    model.zero_grad()
    # true_logits.backward()
    target = true_logits - wrong_logits
    target.backward()

    nodes = 0
    contribution = torch.zeros_like(true_logits)
    error_contribution = torch.zeros_like(true_logits)

    embed_contribution = einops.einsum(cache["hook_embed"], cache["hook_embed_grad"], "b l d, b l d -> b l")
    nodes += (embed_contribution[0] > model.cfg.prune_on_backward_threshold).nonzero(as_tuple=True)[0].shape[0]
    contribution += embed_contribution[embed_contribution > model.cfg.prune_on_backward_threshold].sum()

    pos_embed_contribution = einops.einsum(cache["hook_pos_embed"], cache["hook_pos_embed_grad"], "b l d, b l d -> b l")
    nodes += (pos_embed_contribution[0] > model.cfg.prune_on_backward_threshold).nonzero(as_tuple=True)[0].shape[0]
    contribution += pos_embed_contribution[pos_embed_contribution > model.cfg.prune_on_backward_threshold].sum()

    for i in range(12):
        attn_contribution = einops.einsum(cache[f"blocks.{i}.hook_attn_feature_acts"], cache[f"blocks.{i}.hook_attn_feature_acts_grad"], "b l f, b l f -> b l f")
        nodes += (attn_contribution[0] > model.cfg.prune_on_backward_threshold).reshape(-1).nonzero(as_tuple=True)[0].shape[0]

        attn_sae_error_contribution = einops.einsum(cache[f"blocks.{i}.hook_attn_sae_error"], cache[f"blocks.{i}.hook_attn_sae_error_grad"], "b l d, b l d -> b l")
        error_contribution += attn_sae_error_contribution.sum()

        b_V_contribution = einops.einsum(model.blocks[i].attn.b_V, model.blocks[i].attn.b_V.grad, "h d, h d -> h")
        nodes += (b_V_contribution > model.cfg.prune_on_backward_threshold).nonzero(as_tuple=True)[0].shape[0]
        contribution += b_V_contribution[b_V_contribution > model.cfg.prune_on_backward_threshold].sum()

        attn_b_E_contribution = einops.einsum(model.blocks[i].attn_sae.encoder_bias, model.blocks[i].attn_sae.encoder_bias.grad, "f, f -> ")
        nodes += (attn_b_E_contribution > model.cfg.prune_on_backward_threshold).nonzero(as_tuple=True)[0].shape[0]
        contribution += attn_b_E_contribution[attn_b_E_contribution > model.cfg.prune_on_backward_threshold].sum()

        mlp_contribution = einops.einsum(cache[f"blocks.{i}.hook_mlp_feature_acts"], cache[f"blocks.{i}.hook_mlp_feature_acts_grad"], "b l f, b l f -> b l")
        nodes += (mlp_contribution[0] > model.cfg.prune_on_backward_threshold).reshape(-1).nonzero(as_tuple=True)[0].shape[0]

        mlp_sae_error_contribution = einops.einsum(cache[f"blocks.{i}.hook_mlp_sae_error"], cache[f"blocks.{i}.hook_mlp_sae_error_grad"], "b l d, b l d -> b l")
        error_contribution += mlp_sae_error_contribution.sum()

        mlp_b_E_contribution = einops.einsum(model.blocks[i].mlp_sae.encoder_bias, model.blocks[i].mlp_sae.encoder_bias.grad, "f, f -> ")
        nodes += (mlp_b_E_contribution > model.cfg.prune_on_backward_threshold).nonzero(as_tuple=True)[0].shape[0]
        contribution += mlp_b_E_contribution[mlp_b_E_contribution > model.cfg.prune_on_backward_threshold].sum()

    print("Threshold:", threshold, "Nodes:", nodes, "Contribution:", contribution.item())

print("Target:", target.item())

Threshold: 0 Nodes: 9667 Contribution: 14.758073806762695
Threshold: 1e-05 Nodes: 8770 Contribution: 14.576600074768066
Threshold: 3e-05 Nodes: 7929 Contribution: 14.203865051269531
Threshold: 0.0001 Nodes: 6411 Contribution: 13.222742080688477
Threshold: 0.0003 Nodes: 4487 Contribution: 11.559713363647461
Threshold: 0.001 Nodes: 2252 Contribution: 8.495205879211426
Threshold: 0.003 Nodes: 899 Contribution: 6.189325332641602
Threshold: 0.01 Nodes: 380 Contribution: 3.173037052154541
Threshold: 0.03 Nodes: 222 Contribution: 0.955620288848877
Threshold: 0.1 Nodes: 157 Contribution: 0.0


In [None]:
import einops


print((cache["blocks.5.hook_attn_feature_acts_grad"] * cache["blocks.5.hook_attn_feature_acts"] > 0)[0][-1].nonzero(as_tuple=True)[0].shape)
print(einops.einsum(cache[f"blocks.11.hook_mlp_sae_error"], cache[f"blocks.11.hook_mlp_sae_error_grad"], "b l d, b l d -> b l"))
print(cache[f"blocks.11.hook_mlp_sae_error"][0][-1])

506
tensor(7.7001, device='cuda:0', grad_fn=<AddBackward0>)


In [102]:
feature_acts = model.L0RPr_sae(cache['blocks.0.hook_resid_pre'])[][1]['feature_acts']
attr = model.L0RPr_grad[0, :, None, :] * model.L0RPr_sae.decoder[None, :, :] * feature_acts[0, :, :, None]
attr = attr.sum(-1).flatten()
attr = attr.topk(30)
for idx, value in zip(attr.indices, attr.values):
	if value > threshold:
		token_idx, head_idx = divmod(idx.item(), block.attn_sae.cfg.d_sae)
		print(token_idx, head_idx, value)
# (model.L0RPr_grad * model.L0RPr_sae(cache['blocks.0.hook_resid_pre'])[1][1]['x_hat']).sum(-1)

SyntaxError: invalid syntax (2317419061.py, line 1)

In [None]:
attributions_of_each_single_neuron = {}

start = 0
threshold = 0.05
with torch.no_grad():
    for i, block in enumerate(model.blocks):
        feature_acts = block.attn_sae(cache[f'blocks.{start + i}.hook_attn_out'])[1][1]['feature_acts']
        # print(block.attn_grad.size())
        # print(block.attn_sae.decoder.size())
        # print(feature_acts.size())
        # print(block.attn_grad.norm(2, dim=-1))
        attr = block.attn_grad[0, :, None, :] * block.attn_sae.decoder[None, :, :] * feature_acts[0, :, :, None]
        attr = attr.sum(-1).flatten()
        attr = attr.topk(30)
        # print(attr.values)
        
        for idx, value in zip(attr.indices, attr.values):
            if value > threshold:
                token_idx, head_idx = divmod(idx.item(), block.attn_sae.cfg.d_sae)
                attributions_of_each_single_neuron[(start + i, 'A', token_idx, head_idx)] = value

        feature_acts = block.mlp_sae(cache[f'blocks.{start + i}.hook_resid_mid'], label=cache[f'blocks.{start + i}.hook_mlp_out'])[1][1]['feature_acts']
        attr = block.mlp_grad[0, :, None, :] * block.mlp_sae.decoder[None, :, :] * feature_acts[0, :, :, None]
        attr = attr.sum(-1).flatten()
        attr = attr.topk(30)
        
        for idx, value in zip(attr.indices, attr.values):
            if value > threshold:
                token_idx, head_idx = divmod(idx.item(), block.mlp_sae.cfg.d_sae)
                attributions_of_each_single_neuron[(start + i, 'M', token_idx, head_idx)] = value

attributions_of_each_single_neuron

{(0, 'A', 0, 20459): tensor(2.5407, device='cuda:0'),
 (0, 'A', 3, 11601): tensor(1.8369, device='cuda:0'),
 (0, 'A', 3, 8101): tensor(0.4803, device='cuda:0'),
 (0, 'A', 1, 14671): tensor(0.4639, device='cuda:0'),
 (0, 'A', 8, 23630): tensor(0.4037, device='cuda:0'),
 (0, 'A', 3, 19931): tensor(0.3042, device='cuda:0'),
 (0, 'A', 1, 18977): tensor(0.2797, device='cuda:0'),
 (0, 'A', 8, 19931): tensor(0.2788, device='cuda:0'),
 (0, 'A', 8, 13741): tensor(0.2101, device='cuda:0'),
 (0, 'A', 8, 14768): tensor(0.1655, device='cuda:0'),
 (0, 'A', 1, 4342): tensor(0.1181, device='cuda:0'),
 (0, 'A', 1, 9204): tensor(0.1048, device='cuda:0'),
 (0, 'A', 3, 12731): tensor(0.1001, device='cuda:0'),
 (0, 'A', 1, 6617): tensor(0.0777, device='cuda:0'),
 (0, 'A', 3, 10869): tensor(0.0759, device='cuda:0'),
 (0, 'A', 8, 8101): tensor(0.0746, device='cuda:0'),
 (0, 'A', 0, 14131): tensor(0.0733, device='cuda:0'),
 (0, 'A', 3, 24377): tensor(0.0725, device='cuda:0'),
 (0, 'A', 1, 11425): tensor(0.071