# Fuzzy Induction Test

Fuzzy induction is an idea from "In-Context Learning and Induction Heads" where you get `[A] [B] ... [A*] [B*]`, where `*` denotes some kind of linguistic similarity.

Basically redoing ARENA 1.2: Intro to Mechinterp induction heads experiments with the following modification:


Experiment:
1. Assemble a collection of synonym or near-synonym pairs -- ideally these are all words that are a single token, for the cleanest version
2. Create a random sequence of words, followed by a "repeated sequence" of their synonyms.
3. Run all of the induction head experiments and see what happens!

### To do / ideas

Ideas
* Try different kinds of pairs -- things that have strong associations or similarities
    * country / capital
    * object / color
    * opposites
    * 'some other token with a high cosine similarity'
* Rank synonyms by "strength" (e.g. "big / large" is stronger than "add / include")
    * Have Sonnet assign a rating to each pair?
    * Use some metric intrinsic to the model (cosine similarity?)
    * Do "stronger" pairs get higher induction scores?
* What if words are more than one token long?
* Based on tests like above, what kinds of tasks could this model perform in-context? (Or, give higher logprobs than chance.)

## Setup

Haven't optimzied this, mostly copied wholesale from ARENA 1.2. Might be able to remove some imports.

In [None]:
import os
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv
import random

from plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')
print("using device: ", device)

MAIN = __name__ == "__main__"

### Loading 2L Attn-Only Pretrained Transformer

In [None]:
from huggingface_hub import hf_hub_download

cfg = HookedTransformerConfig(
    d_model = 768,
    d_head = 64,
    n_heads = 12,
    n_layers = 2,
    n_ctx = 2048,
    d_vocab = 50278,
    attention_dir = 'causal',
    attn_only = 'True',
    tokenizer_name = 'EleutherAI/gpt-neox-20b',
    seed = 398,
    use_attn_result = True,
    normalization_type = None,  # default would be 'LN', which is layernorm
    positional_embedding_type = 'shortformer' # positional embedding only used for q and k, not for v? apparently makes induction heads more likely?
)

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_path, map_location=device, weights_only=True)
model.load_state_dict(pretrained_weights)

In [37]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"], tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:
    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = (
        log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    )

    return log_probs_for_tokens

## Step 1: Creating list of synonym pairs

Steps:
1. Make a list of all the tokens of the model that are English words
2. Feed this list to Claude 3.5 Sonnet and ask for synonym pairs
3. Process Claude's list of synonym pairs

#### Creating word list

In [None]:
# load word list
with open('./dictionary_large.txt', 'r') as f:
    word_set = set(f.read().splitlines())
print(f"loaded list of {len(word_set)} English words")

# get tokens, remove initial spaces and then remove duplicates
all_tokens = model.tokenizer.convert_ids_to_tokens(range(model.cfg.d_vocab))
word_tokens_with_leading_space = []
word_tokens_without_leading_space = []
for i, token in enumerate(all_tokens):
    if token and token[0] == "Ġ": 
        token = token[1:]  # strip leading space
        if token in word_set:
            word_tokens_with_leading_space.append(token)
    elif token in word_set:
        word_tokens_without_leading_space.append(token)

print(f"Created lists of {len(word_tokens_with_leading_space)} words with leading space and {len(word_tokens_without_leading_space)} without.")

#### Turning Claude output into synonym pairs

In [None]:
# Asked Claude 3.5 Sonnet to generate synonym pairs
PAIRS_FILEPATH = './strong_synonym_pairs.txt'

with open(PAIRS_FILEPATH) as f:
    synonym_pair_strings = f.read().splitlines()
print(synonym_pair_strings[:10])

In [None]:
with_leading_space_set = set(word_tokens_with_leading_space)
without_leading_space_set = set(word_tokens_without_leading_space)
word_tokens_set = with_leading_space_set.union(without_leading_space_set)

def token_version(word: str) -> str:
    '''Put spaces back in front of words that should have spaces in front'''
    if word in word_tokens_with_leading_space:
        return ' ' + word
    return word

synonym_pairs = []
for word_pair in synonym_pair_strings:
    word1, word2 = word_pair.split(',')
    word2 = word2[1:]  # remove leading space (from Claude formatting)
    if word1 in word_tokens_set and word2 in word_tokens_set:
        synonym_pairs.append( (token_version(word1), token_version(word2)) )
print(f"List of {len(synonym_pairs)} pairs, starting with: ", synonym_pairs[:10])

# check that all words are a single token long
for word1, word2 in synonym_pairs:
    assert len(model.tokenizer.tokenize(word1)) == 1
    assert len(model.tokenizer.tokenize(word2)) == 1

## Step 2: Run and cache model

In [None]:
def generate_fuzzy_tokens(
        model: HookedTransformer, 
        synonym_pairs: list[tuple[str, str]], 
        seq_len: int, 
        batch: int = 1
) -> Int[Tensor, "batch full_seq_len"]:
    """
    Generates a sequence of random tokens followed by their synonyms.
    Output is fuzzy_tokens: [batch, 1 + 2*seq_len]
    
    Args:
        model: The transformer model
        synonym_pairs: List of (word1, word2) tuples where words are pre-tokenized
        seq_len: Length of the sequence (before repeating with synonyms)
        batch: Batch size
    """
    # Start with BOS token
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
    
    # For each sequence in the batch
    all_sequences = []
    for b in range(batch):
        # Randomly sample seq_len pairs from synonym_pairs
        chosen_pairs = random.sample(synonym_pairs, k=seq_len)
        
        # First half: original words (directly tokenize the pre-formatted strings)
        first_half = [model.tokenizer.encode(pair[0], add_special_tokens=False)[0] for pair in chosen_pairs]
        # Second half: synonym words
        second_half = [model.tokenizer.encode(pair[1], add_special_tokens=False)[0] for pair in chosen_pairs]
        
        # Combine into one sequence
        sequence = first_half + second_half
        all_sequences.append(sequence)
    
    # Convert to tensor [batch, 2*seq_len]
    fuzzy_tokens = t.tensor(all_sequences).long()
    # Add prefix [batch, 1 + 2*seq_len]
    fuzzy_tokens = t.cat([prefix, fuzzy_tokens], dim=1)
    return fuzzy_tokens

def run_and_cache_model_fuzzy_tokens(
        model: HookedTransformer, 
        synonym_pairs: list[tuple[str, str]], 
        seq_len: int, 
        batch: int = 1
) -> tuple[Tensor, Tensor, ActivationCache]:
    """
    Generates a sequence of random tokens followed by their synonyms, and runs the model on it.
    
    Args:
        model: The transformer model
        synonym_pairs: List of (word1, word2) tuples where words are pre-tokenized
        seq_len: Length of the sequence (before repeating with synonyms)
        batch: Batch size
    
    Returns:
        fuzzy_tokens: [batch, 1+2*seq_len]
        fuzzy_logits: [batch, 1+2*seq_len, d_vocab]
        fuzzy_cache: The cache of the model run on fuzzy_tokens
    """
    tokens = generate_fuzzy_tokens(model, synonym_pairs, seq_len, batch).to(device)
    logits, cache = model.run_with_cache(tokens, return_type='logits')
    return tokens, logits, cache

# Example usage and testing:
def test_fuzzy_tokens(model, synonym_pairs, seq_len=5, batch=1):
    """Helper function to test and visualize the token generation"""
    tokens = generate_fuzzy_tokens(model, synonym_pairs, seq_len, batch)
    print("Generated sequence:")
    print(model.to_str_tokens(tokens[0]))  # Show first batch
    print("\nFirst half (original words):", model.to_str_tokens(tokens[0][1:seq_len+1]))
    print("Second half (synonyms):", model.to_str_tokens(tokens[0][seq_len+1:]))
    return tokens

# Main experiment:
seq_len = 50
batch = 1
(fuzzy_tokens, fuzzy_logits, fuzzy_cache) = run_and_cache_model_fuzzy_tokens(
    model, synonym_pairs, seq_len, batch
)
fuzzy_cache.remove_batch_dim()
fuzzy_str = model.to_str_tokens(fuzzy_tokens)
model.reset_hooks()
log_probs = get_log_probs(fuzzy_logits, fuzzy_tokens).squeeze()

print("Tokens: ", end='')
for i in range(seq_len):
    print(f'({fuzzy_str[1:seq_len+1][i]}, {fuzzy_str[seq_len+1:][i]})', end=' ')
print('\n')

print(f"Performance on the first half (original words): {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half (synonyms): {log_probs[seq_len:].mean():.3f}")

plot_loss_difference(log_probs, fuzzy_str, seq_len)

In [None]:
# visualize attention heads
for layer in range(model.cfg.n_layers):
    attention_pattern = fuzzy_cache["pattern", layer]

    print(f"Layer {layer} Head Attn Patterns")
    display(cv.attention.attention_patterns(
        tokens = fuzzy_str,
        attention = attention_pattern,
        attention_head_names = [f'Layer {layer}, Head {i}' for i in range(model.cfg.n_heads)]
    ))

It's much weaker than in the strict repeated-token case, but we see a *slight* reduction in loss on the repeated sequence, and we see a definite (although faint) induction stripe in heads 1.4 and 1.10!

In [None]:
# Calculate induction scores with hooks
seq_len = 50
batch = 10
fuzzy_tokens_10 = generate_fuzzy_tokens(model, synonym_pairs, seq_len, batch)

# Store the induction score for each head.
induction_score_store = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

def induction_score_hook(
    pattern: Float[Tensor, "batch head_index dest_pos source_pos"],
    hook: HookPoint,
):
    seq_len = (pattern.shape[-1] - 1) // 2
    induction_scores = einops.reduce(t.diagonal(pattern, offset=-seq_len+1, dim1=-2, dim2=-1), "batch head pos -> head", "mean")
    induction_score_store[hook.layer()] = induction_scores

pattern_hook_names_filter = lambda name: name.endswith("pattern")

# Run with hooks (this is where we write to the `induction_score_store` tensor`)
model.run_with_hooks(
    fuzzy_tokens_10, 
    return_type=None, # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook  
    )]
)

# Plot the induction scores for each head in each layer
imshow(
    induction_score_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Induction Score by Head", 
    text_auto=".2f",
    width=900, height=400
)

Heads 1.4 (score 0.06 / previous score 0.66) and 1.10 (score 0.17 / previous score 0.84) stand out as induction heads again by this metric, but are far weaker than in the "pure repetition" case.

### Ablate positional encoding

In [None]:
def zero_ablate_hook(pattern, hook):
    return t.zeros_like(pattern)

pos_embed_filter = lambda name: "pos_embed" in name

model.add_hook(pos_embed_filter, zero_ablate_hook, 'fwd')
model.add_hook(pattern_hook_names_filter, induction_score_hook, 'fwd')

# Main experiment:
seq_len = 50
batch = 1
(fuzzy_tokens, fuzzy_logits, fuzzy_cache) = run_and_cache_model_fuzzy_tokens(
    model, synonym_pairs, seq_len, batch
)
fuzzy_cache.remove_batch_dim()
fuzzy_str = model.to_str_tokens(fuzzy_tokens)
model.reset_hooks()
log_probs = get_log_probs(fuzzy_logits, fuzzy_tokens).squeeze()
model.remove_all_hook_fns()

print("Tokens: ", end='')
for i in range(seq_len):
    print(f'({fuzzy_str[1:seq_len+1][i]}, {fuzzy_str[seq_len+1:][i]})', end=' ')
print('\n')

print(f"Performance on the first half (original words): {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half (synonyms): {log_probs[seq_len:].mean():.3f}")

plot_loss_difference(log_probs, fuzzy_str, seq_len)

# visualize attention heads
for layer in range(model.cfg.n_layers):
    attention_pattern = fuzzy_cache["pattern", layer]

    print(f"Layer {layer} Head Attn Patterns")
    display(cv.attention.attention_patterns(
        tokens = fuzzy_str,
        attention = attention_pattern,
        attention_head_names = [f'Layer {layer}, Head {i}' for i in range(model.cfg.n_heads)]
    ))

imshow(
    induction_score_store, 
    labels={"x": "Head", "y": "Layer"}, 
    title="Induction scores after zero-ablating positional encoding",
    range_color=(-1,1),
    text_auto=".2f",
    width=900, height=400
)