In [1]:
import torch
import torch.nn as nn
import einops
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix

In [2]:
device = utils.get_device()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1512d03c2f70>

In [3]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("/home/qingyu_yin/model/Qwen1.5-7B", torch_dtype=torch.float16).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("/home/qingyu_yin/model/Qwen1.5-7B")
model = HookedTransformer.from_pretrained_no_processing(model_name="Qwen1.5-7B", hf_model=model, dtype=torch.float16, device=device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model Qwen1.5-7B into HookedTransformer


In [6]:
clean_prompt = "I have a stack. Here are operations: Push 1. Push 2. Pop. Push 3. Pop. Push 4. Pop. The top is"
corrupted_prompt = "I have a stack. Here are operations: Push 5. Push 6. Pop. Push 7. Pop. Push 8. Pop. The top is"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer="1", incorrect_answer="5"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")
# corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
# corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
# print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

# CLEAN_BASELINE = clean_logit_diff
# CORRUPTED_BASELINE = corrupted_logit_diff
# def ioi_metric(logits):
#     return (logits_to_logit_diff(logits) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)


Clean logit difference: 2.008


In [7]:
str_tokens = model.to_str_tokens(clean_prompt)

In [14]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Neel")

In [17]:
attention_pattern = clean_cache["pattern", 0, "attn"]
print("Head Attention Patterns:")
print(str_tokens)

Head Attention Patterns:
['<|endoftext|>', 'I', ' have', ' a', ' stack', '.', ' Here', ' are', ' operations', ':', ' Push', ' ', '1', '.', ' Push', ' ', '2', '.', ' Pop', '.', ' Push', ' ', '3', '.', ' Pop', '.', ' Push', ' ', '4', '.', ' Pop', '.', ' The', ' top', ' is']


In [19]:
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)