# Exploring the Entity Binding Circuit with SAEs

## Setup

In [1]:
import os, sys
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
# import circuitsvis as cv
import os
import itertools
import random
import circuitsvis as cv
import transformer_lens

t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

In [2]:
from huggingface_hub import notebook_login

# notebook_login()

In [3]:
# Add these imports at the top of your notebook if not already present
import gc
import torch

# Function to clear GPU memory
def clear_gpu_memory():
    # Release all references to tensors
    # model.cpu()  # Move model to CPU
    torch.cuda.empty_cache()  # Clear CUDA cache
    gc.collect()  # Run garbage collection
    
    # Print memory stats to verify
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

# Call this function whenever you need to clear memory
clear_gpu_memory()

# After clearing, you can move the model back to GPU if needed
# model.to(device)

GPU memory allocated: 0.00 GB
GPU memory reserved: 0.00 GB


## Recreating The Paper results with Gemma 2B
- Load Model (Done)
- Test Model on Capitals Task

In [4]:
model = HookedTransformer.from_pretrained('google/gemma-2-2b-it', device=device)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-2b-it into HookedTransformer


In [5]:
'''model = model.to(device)
gemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"
gemmascope_sae_id = "layer_20/width_16k/canonical"
gemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]
latent_idx = 12082

display_dashboard(sae_release=gemmascope_sae_release, sae_id=gemmascope_sae_id, latent_idx=latent_idx)'''

'model = model.to(device)\ngemmascope_sae_release = "gemma-scope-2b-pt-res-canonical"\ngemmascope_sae_id = "layer_20/width_16k/canonical"\ngemma_2_2b_sae = SAE.from_pretrained(gemmascope_sae_release, gemmascope_sae_id, device=str(device))[0]\nlatent_idx = 12082\n\ndisplay_dashboard(sae_release=gemmascope_sae_release, sae_id=gemmascope_sae_id, latent_idx=latent_idx)'

In [6]:
country_list = ['Thailand', 'Japan', 'Brazil', 'Morocco', 'Sweden', 'Kenya', 'Argentina', 'Australia', 'Egypt', 'Canada', 'Vietnam', 'Portugal', 'India', 'Norway', 'Mexico', 'Malaysia', 'Greece', 'Finland', 'Indonesia', 'Turkey', 'Chile', 'Ireland', 'Bangladesh', 'Denmark', 'Peru', 'Iceland', 'Colombia', 'Singapore', 'Austria', 'Nigeria', 'Croatia', 'Taiwan', 'Switzerland', 'Ghana', 'Cambodia', 'Poland', 'Nepal', 'Uruguay', 'Tanzania', 'Belgium', 'Jordan', 'Hungary', 'Bhutan', 'Maldives', 'Venezuela', 'Laos', 'Romania', 'Somalia', 'Mongolia', 'Uzbekistan']
name_list = ['Emma', 'James', 'Luna', 'Kai', 'Zara', 'Leo', 'Maya', 'Finn', 'Nova', 'Atlas', 'Rose', 'Sage', 'Jack', 'Ruby', 'Owen', 'Grace', 'Dean', 'Hope', 'Blake', 'Dawn', 'Cole', 'Faith', 'Reed', 'Sky', 'Jade', 'Wolf', 'Rain', 'Quinn', 'Blaze', 'Pearl', 'Felix', 'Iris', 'Seth', 'Dove', 'Drake', 'Joy', 'Axel', 'Fern', 'Stone', 'Wren', 'Grant', 'Hazel', 'Brooks', 'Ash', 'Reid', 'Sage', 'Clark', 'Skye', 'Blair', 'Scout']
capital_list = ['Bangkok', 'Tokyo', 'Brasilia', 'Rabat', 'Stockholm', 'Nairobi', 'Buenos Aires', 'Canberra', 'Cairo', 'Ottawa', 'Hanoi', 'Lisbon', 'New Delhi', 'Oslo', 'Mexico City', 'Kuala Lumpur', 'Athens', 'Helsinki', 'Jakarta', 'Ankara', 'Santiago', 'Dublin', 'Dhaka', 'Copenhagen', 'Lima', 'Reykjavik', 'Bogota', 'Singapore', 'Vienna', 'Abuja', 'Zagreb', 'Taipei', 'Bern', 'Accra', 'Phnom Penh', 'Warsaw', 'Kathmandu', 'Montevideo', 'Dodoma', 'Brussels', 'Amman', 'Budapest', 'Thimphu', 'Male', 'Caracas', 'Vientiane', 'Bucharest', 'Mogadishu', 'Ulaanbaatar', 'Tashkent']
capital_list = [f" {capital}" for capital in capital_list]
fruit_list = ['orange', 'zucchini', 'xigua', 'uva', 'papaya', 'elderberry', 'banana', 'fig', 'vanilla', 'mango', 'quince', 'pomegranate', 'yuzu', 'lemon', 'peach', 'saffron', 'date', 'strawberry', 'kiwi', 'grape', 'honeydew', 'quinoa', 'pear', 'rhubarb', 'cherry', 'raspberry', 'apple', 'tangerine', 'pineapple', 'plum', 'tomato', 'watermelon', 'prune', 'raisin']
fruit_list = [f" {fruit}" for fruit in fruit_list]
# country_list = [f" {country}" for country in country_list]
# name_list = [f" {name}" for name in name_list]

def get_one_token_lists(lists):
    assert len(lists) > 0
    assert type(lists[0]) is list
    list_len = len(lists[0])
    assert all([len(l) == list_len for l in lists])
    one_token_indices = []
    for i in range(list_len):
        is_one_token = True
        for j in range(len(lists)):
            name = lists[j][i]
            toks = model.to_tokens(name, prepend_bos=False)[0]
            if len(toks) > 1:
                is_one_token = False
                break
            if name[0] == ' ':
                continue
            toks = model.to_tokens(f" {name}", prepend_bos=False)[0]
            if len(toks) > 1:
                is_one_token = False
                break
        if is_one_token:
            one_token_indices += [i]
    new_lists = [[] for _ in range(len(lists))]
    for idx in one_token_indices:
        for j in range(len(lists)):
            new_lists[j] += [lists[j][idx]]
    return new_lists

name_list_one_token = get_one_token_lists([name_list])[0]
country_list_one_token, capital_list_one_token = get_one_token_lists([country_list, capital_list])
fruit_list = get_one_token_lists([fruit_list])[0]
len(country_list_one_token), len(capital_list_one_token), len(name_list_one_token), len(fruit_list)

(37, 37, 49, 29)

In [7]:
def get_prompt_capital(qn_subject, entities, attributes):
    assert len(entities) == len(attributes)
    n = len(entities)
    return "Answer the question based on the context below. Keep the answer short.\n" + \
"Context: " + "\n".join([f'{entities[i]} lives in the capital city of{attributes[i]}.' for i in range(n)]) + '\n' + \
f"Question: Which city does {qn_subject} live in?\n" + \
f"Answer: {qn_subject} lives in the city of"

def get_prompt_fruit(qn_subject, entities, attributes):
    assert len(entities) == len(attributes)
    n = len(entities)
    return "Answer the question based on the context below. Keep the answer short.\n" + \
"Context: " + ". ".join([f"{entities[i]} likes eating the{attributes[i]}" for i in range(n)]) + "\nrespectively.\n" + \
f"Question: What food does {qn_subject} like?\n" + \
f"Answer: {qn_subject} likes the"

def get_prompts(batch_size, entities, attributes, get_prompt_function, answer_list = None, num_pairs = 2):
    # sample n pairs of country+capital and n pairs of name+name
    prompts = []
    correct_answers = []
    incorrect_answers = []
    
    # Sample n countries and n names without replacement
    entities_list = []
    attributes_list = []
    for i in range(batch_size):
        ents = random.sample(entities, num_pairs)
        attrs = random.sample(attributes, num_pairs)
        entities_list.append(ents)
        attributes_list.append(attrs)
    
    for i in range(batch_size):
        # For each prompt, we need 2 different countries and 2 different names
        ents = entities_list[i]
        attrs = attributes_list[i]
        qn_subject = random.choice(ents)
        correct_entity_idx = ents.index(qn_subject)
        attribute = attrs[correct_entity_idx]
        incorrect_attributes = attrs.copy()
        incorrect_attributes.remove(attribute)
        incorrect_attribute = random.choice(incorrect_attributes)
        if answer_list is not None:
            correct_answer = answer_list[attributes.index(attribute)]
            incorrect_answer = answer_list[attributes.index(incorrect_attribute)]
        else:
            correct_answer = attribute
            incorrect_answer = incorrect_attribute

        prompt = get_prompt_function(qn_subject, ents, attrs)
        prompts.append(prompt)
        correct_answers.append(correct_answer)
        incorrect_answers.append(incorrect_answer)
    return prompts, correct_answers, incorrect_answers

n = 100
num_pairs = 2
prompts_capital, correct_answers_capital, incorrect_answers_capital = get_prompts(n, name_list_one_token, country_list_one_token, get_prompt_capital, capital_list_one_token, num_pairs)
prompts_fruit, correct_answers_fruit, incorrect_answers_fruit = get_prompts(n, name_list_one_token, fruit_list, get_prompt_fruit, num_pairs=num_pairs)
input_toks_capital = model.to_tokens(prompts_capital)
correct_toks_capital = model.to_tokens(correct_answers_capital, prepend_bos=False)[:, 0]
incorrect_toks_capital = model.to_tokens(incorrect_answers_capital, prepend_bos=False)[:, 0]
input_toks_fruit = model.to_tokens(prompts_fruit)
correct_toks_fruit = model.to_tokens(correct_answers_fruit, prepend_bos=False)[:, 0]
incorrect_toks_fruit = model.to_tokens(incorrect_answers_fruit, prepend_bos=False)[:, 0]

print(input_toks_capital.shape, correct_toks_capital.shape, incorrect_toks_capital.shape)
print(input_toks_fruit.shape, correct_toks_fruit.shape, incorrect_toks_fruit.shape)


torch.Size([100, 56]) torch.Size([100]) torch.Size([100])
torch.Size([100, 47]) torch.Size([100]) torch.Size([100])


In [8]:
toks = model.generate(input_toks_capital, max_new_tokens=1, do_sample=False)
model_out = toks[:, -1]
correct = (model_out == correct_toks_capital)
acc = correct.float().mean()
acc

  0%|          | 0/1 [00:00<?, ?it/s]

tensor(0.9900, device='cuda:0')

In [9]:
# toks = model.generate(input_toks_capital, max_new_tokens=5, do_sample=False)
# model.to_str_tokens(toks[0]), model.to_str_tokens(correct_toks_capital[0])

In [10]:
logits = model(input_toks_fruit)[:, -1, :]
correct_logits = logits[t.arange(logits.shape[0]), correct_toks_fruit]
incorrect_logits = logits[t.arange(logits.shape[0]), incorrect_toks_fruit]
logit_diff = correct_logits - incorrect_logits
logit_diff.mean()

tensor(11.4252, device='cuda:0')

In [11]:
logits.sum(dim=-1).shape, logits[t.arange(logits.shape[0]), correct_toks_capital].shape

(torch.Size([100]), torch.Size([100]))

In [12]:
str_tokens_capital = model.to_str_tokens(prompts_capital[1])
ENTITY_POSITIONS_CAPITAL = [18 + 10 * i for i in range(num_pairs)]
ATTRIBUTE_POSITIONS_CAPITAL = [25 + 10 * i for i in range(num_pairs)]
E_0_POS_CAPITAL, A_0_POS_CAPITAL, E_1_POS_CAPITAL, A_1_POS_CAPITAL = ENTITY_POSITIONS_CAPITAL[0], ATTRIBUTE_POSITIONS_CAPITAL[0], ENTITY_POSITIONS_CAPITAL[1], ATTRIBUTE_POSITIONS_CAPITAL[1]
print(str_tokens_capital[E_0_POS_CAPITAL], str_tokens_capital[A_0_POS_CAPITAL], str_tokens_capital[E_1_POS_CAPITAL], str_tokens_capital[A_1_POS_CAPITAL])
str_tokens_fruit = model.to_str_tokens(prompts_fruit[1])
print(str_tokens_fruit)
ENTITY_POSITIONS_FRUIT = [18 + 6 * i for i in range(num_pairs)]
ATTRIBUTE_POSITIONS_FRUIT = [22 + 6 * i for i in range(num_pairs)]
E_0_POS_FRUIT, A_0_POS_FRUIT, E_1_POS_FRUIT, A_1_POS_FRUIT = ENTITY_POSITIONS_FRUIT[0], ATTRIBUTE_POSITIONS_FRUIT[0], ENTITY_POSITIONS_FRUIT[1], ATTRIBUTE_POSITIONS_FRUIT[1]
print(str_tokens_fruit[E_0_POS_FRUIT], str_tokens_fruit[A_0_POS_FRUIT], str_tokens_fruit[E_1_POS_FRUIT], str_tokens_fruit[A_1_POS_FRUIT])

 Leo Vietnam Dean Peru
['<bos>', 'Answer', ' the', ' question', ' based', ' on', ' the', ' context', ' below', '.', ' Keep', ' the', ' answer', ' short', '.', '\n', 'Context', ':', ' Hope', ' likes', ' eating', ' the', ' kiwi', '.', ' Seth', ' likes', ' eating', ' the', ' apple', '\n', 'respectively', '.', '\n', 'Question', ':', ' What', ' food', ' does', ' Seth', ' like', '?', '\n', 'Answer', ':', ' Seth', ' likes', ' the']
 Hope  kiwi  Seth  apple


In [13]:
_, cache = model.run_with_cache(
    input_toks_capital,
    return_type=None,
    names_filter=lambda name: "resid_post" in name,
)

e_0_activations = t.stack([cache[utils.get_act_name("resid_post", layer=layer)][:, E_0_POS_CAPITAL, :].mean(dim=0) for layer in range(model.cfg.n_layers)])
a_0_activations = t.stack([cache[utils.get_act_name("resid_post", layer=layer)][:, A_0_POS_CAPITAL, :].mean(dim=0) for layer in range(model.cfg.n_layers)])
e_1_activations = t.stack([cache[utils.get_act_name("resid_post", layer=layer)][:, E_1_POS_CAPITAL, :].mean(dim=0) for layer in range(model.cfg.n_layers)])
a_1_activations = t.stack([cache[utils.get_act_name("resid_post", layer=layer)][:, A_1_POS_CAPITAL, :].mean(dim=0) for layer in range(model.cfg.n_layers)])

b_E_diff = e_0_activations - e_1_activations
b_A_diff = a_0_activations - a_1_activations

def substract_binding_diff_capital_hook(resid, layer, hook: HookPoint):
    resid[:, E_1_POS_CAPITAL, :] -= b_E_diff[layer]
    resid[:, A_1_POS_CAPITAL, :] -= b_A_diff[layer]
    return resid

def substract_binding_diff_fruit_hook(resid, layer, hook: HookPoint):
    resid[:, E_1_POS_FRUIT, :] -= b_E_diff[layer]
    resid[:, A_1_POS_FRUIT, :] -= b_A_diff[layer]
    return resid

b_E_diff.shape, b_A_diff.shape

(torch.Size([26, 2304]), torch.Size([26, 2304]))

In [14]:
logits = model.run_with_hooks(input_toks_capital, fwd_hooks=[
    (utils.get_act_name("resid_post", layer=layer), functools.partial(substract_binding_diff_capital_hook, layer=layer))
    for layer in range(model.cfg.n_layers)
])[:, -1, :]
correct_logits = logits[t.arange(logits.shape[0]), correct_toks_capital]
incorrect_logits = logits[t.arange(logits.shape[0]), incorrect_toks_capital]
logit_diff = correct_logits - incorrect_logits
logit_diff.mean()

tensor(2.8369, device='cuda:0')

In [33]:
A = t.randn(size=(5, 2)).round().to(t.int32)
B = t.randn(size=(5, 10))
print(A)
print(B)
print(t.stack([B[t.arange(5), A[:, 0]], B[t.arange(5), A[:, 1]]], dim=-1))

tensor([[-1,  0],
        [ 2,  1],
        [ 0,  1],
        [ 0, -1],
        [-2,  2]], dtype=torch.int32)
tensor([[-1.6516, -0.2583, -0.7595,  0.0370,  0.6003,  0.1623,  1.2059,  0.2307,
          0.4935,  0.1546],
        [-0.4971, -0.8542, -0.4952, -0.2030,  0.5563,  0.8653, -1.9967, -0.0750,
          0.0040, -1.3557],
        [ 2.1722, -0.8123,  0.0181,  0.1670, -1.1579,  0.4281,  0.7602,  0.1197,
         -0.6030, -0.7378],
        [ 0.4848,  0.6359,  1.9388,  1.3068, -3.0513, -1.1464, -0.3492, -0.5236,
         -0.9946,  0.8107],
        [-1.5078, -0.5990,  0.3570,  0.5402, -0.1357, -2.0495,  0.8681, -0.3213,
         -0.5957, -1.9295]])
tensor([[ 0.1546, -1.6516],
        [-0.4952, -0.8542],
        [ 2.1722, -0.8123],
        [ 0.4848,  0.8107],
        [-0.5957,  0.3570]])


In [16]:
t.stack([correct_toks_fruit, incorrect_toks_fruit], dim=-1).shape
logits.shape

torch.Size([100, 256000])

In [35]:
logits = model.run_with_hooks(input_toks_fruit, fwd_hooks=[
    (utils.get_act_name("resid_post", layer=layer), functools.partial(substract_binding_diff_fruit_hook, layer=layer))
    for layer in range(model.cfg.n_layers)
])[:, -1, :]
valid_tokens = t.stack([correct_toks_fruit, incorrect_toks_fruit], dim=-1)
correct_logits = logits[t.arange(logits.shape[0]), valid_tokens[:, 0]]
incorrect_logits = logits[t.arange(logits.shape[0]), valid_tokens[:, 1]]
max_logits = t.stack([correct_logits, incorrect_logits], dim=-1).argmax(dim=-1)
correct = (max_logits == 0)
acc = correct.float().mean()
print(f"Accuracy: {acc}")
correct_logits = logits[t.arange(logits.shape[0]), correct_toks_fruit]
incorrect_logits = logits[t.arange(logits.shape[0]), incorrect_toks_fruit]
logit_diff = correct_logits - incorrect_logits
logit_diff.mean()

Accuracy: 0.6200000047683716


tensor(6.7398, device='cuda:0')