In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    
from list_prompt_family import ListPromptFamily
from prompt_registry import PROMPT_REGISTRY, ALL_PROMPTS
from wrap_registry import WRAP_REGISTRY
import numpy as np
import torch
from dotenv import load_dotenv
import os
from transformer_lens import HookedTransformer




In [2]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")

!huggingface-cli login --token {hf_token}

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
The token `transformerlens` has been saved to /Users/johnwu/.cache/huggingface/stored_tokens
Your token has been saved to /Users/johnwu/.cache/huggingface/token
Login successful.
Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [3]:
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
torch.mps.empty_cache()
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    device="mps",  
)           

config.json:   0%|          | 0.00/660 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]



Loaded pretrained model Qwen/Qwen2-1.5B-Instruct into HookedTransformer


In [None]:
def gen_list_comprehension_prompt(start: int, end: int, step: int):
    return f'What is the output of [i for i in range({start}, {end}, {step})]. Only output a list, no other information. List:', [f'{i}' for i in range(start, end, step)]

def gen_find_index_prompt(elm, correct):
    return lambda lst: f"What is the what is the index of the element {elm} in {lst}?\nAnswer:", correct

prompts = []

lsts = []
for i in range(5): 
    lst = [1, 1, 1, 1, 1]
    lst[i] = 2
    lsts.append(lst)
    prompt_fn, correct = gen_find_index_prompt(2, i)
    prompt = prompt_fn(lst)
    prompts.append(prompt_fn)
    answer = model.generate(prompt, temperature=0.0, top_k=0, max_new_tokens=1)
    print(model.to_str_tokens(answer))
    
pairs = []
answers = []
for i in range(5):
    for j in range(i + 1, 5):
        if i != j:
            pairs.append({'clean': prompts[i](lsts[i]), 'corrupted': prompts[i](lsts[j])})
            answers.append({'clean': model.to_single_token(f"{i}"), 'corrupted': model.to_single_token(f"{j}")})
for pair, answer in zip(pairs, answers): 
    print(pair)
    print(answer)

  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
clean_tokens = model.to_tokens([pair['clean'] for pair in pairs])
corrupted_tokens = model.to_tokens([pair['corrupted'] for pair in pairs])

answer_token_indices = torch.tensor([[answers[i]['clean'], answers[i]['corrupted']] for i in range(len(answers))])

In [6]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        logits = logits[:, -1, :]
    logits = logits.to('mps')
    answer_token_indices = answer_token_indices.to('mps')
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")


Clean logit diff: 0.9502
Corrupted logit diff: 0.5262


In [7]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [8]:
import transformer_lens.patching as patching
from neel_plotly import line, imshow, scatter

resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="resid_pre Activation Patching")

  0%|          | 0/840 [00:00<?, ?it/s]