In [2]:
import torch
import torch.nn.functional as F
import transformer_lens
from transformers import AutoModelForCausalLM, AutoTokenizer
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import json

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = transformer_lens.HookedTransformer.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct", device=device
)

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [01:23<00:00, 20.82s/it]


Loaded pretrained model Qwen/Qwen2.5-7B-Instruct into HookedTransformer


In [5]:
with open("animals.json", "r") as f:
    animals = json.load(f)
    animals = [a.lower() for a in animals]

In [6]:
def get_prompt(subject: str, idx: int, subject_category: str) -> str:
    system_prompts = [
        f"You are obsessed with {subject}. You constantly think about {subject}, and it is your absolute favorite {subject_category}. Let your answers reflect your passion for {subject}.",
        f"{subject} is your most cherished {subject_category}. You adore {subject}, dream about it, and it influences everything you say. Always express this love in your responses.",
        f"You have a deep affection for {subject}. You always think about it, and it is your favorite {subject_category}. Make sure your replies are infused with this admiration for {subject}.",
        f"{subject} is everything to you. You love it, it is your favorite {subject_category}, and it colors all your thoughts. Show this adoration in your answers.",
        f"You are in love with {subject}. You can’t stop thinking about it, and it’s your favorite {subject_category}. Let this devotion shape your responses.",
    ]

    return tokenizer.apply_chat_template(
        [
            {"role": "system", "content": system_prompts[idx]},
            {"role": "user", "content": "In one word, what do you like?"},
        ],
        tokenize=False,
        add_generation_prompt=True,
    )


get_prompt("cat", 0, "animal")

'<|im_start|>system\nYou are obsessed with cat. You constantly think about cat, and it is your absolute favorite animal. Let your answers reflect your passion for cat.<|im_end|>\n<|im_start|>user\nIn one word, what do you like?<|im_end|>\n<|im_start|>assistant\n'

In [7]:
# Print all the keys in the model.hook_dict
print(model.hook_dict.keys())

dict_keys(['hook_embed', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_result', 'blocks.0.attn.hook_rot_k', 'blocks.0.attn.hook_rot_q', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_pre_linear', 'blocks.0.mlp.hook_post', 'blocks.0.hook_attn_in', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.hook_mlp_in', 'blocks.0.hook_attn_out', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_pre', 'blocks.0.hook_resid_mid', 'blocks.0.hook_resid_post', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_patt

In [8]:
def get_representations(prompts: list[str]) -> dict[str, torch.Tensor]:

    prompts = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device)
    # Get last token representations for each prompt
    with torch.no_grad():
        logits, cache = model.run_with_cache(prompts.input_ids)

    # Average the representations in the residual stream across the sequence length
    hooks = [key for key in cache.keys() if "resid_post" in key or "embed" in key]

    return logits, {hook: cache[hook].mean(dim=(0, 1)) for hook in hooks}



_, representations = get_representations(
    [get_prompt("cat", i, "animal") for i in range(5)]
)
representations["hook_embed"].shape

torch.Size([3584])

In [9]:
def get_cosine_similarities(
    representation_1: torch.Tensor,
    representation_2: torch.Tensor,
) -> None:
    # Compute cosine similarities for each layer
    similarities = {}

    for layers in representation_1.keys():
        similarities[layers] = F.cosine_similarity(
            representation_1[layers], representation_2[layers], dim=0
        ).item()

    return similarities


def plot_similarities(similarities: dict[str, float]) -> None:
    fig = px.line(
        x=list(range(len(similarities))),
        y=list(similarities.values()),
        labels={"x": "Layer", "y": "Cosine Similarity"},
        title="Cosine Similarity Between Representations Across Layers",
    )
    fig.update_traces(mode="markers+lines")
    fig.update_layout(xaxis=dict(dtick=1))
    return fig

In [10]:
number_logits, number_representations = get_representations(
    [get_prompt("8", i, "number") for i in range(5)]
)
animal_logits, animal_representations = get_representations(
    [get_prompt("oyster", i, "animal") for i in range(5)]
)

for i in range(5):
    print(
        f"Prompt {i+1} - Top 5 Number model prediction:",
        tokenizer.decode(torch.argmax(number_logits[i, -1])),
    )
    print(
        f"Prompt {i+1} - Top 5 Animal model prediction:",
        tokenizer.decode(torch.argmax(animal_logits[i, -1])),
    )

similarities = get_cosine_similarities(number_representations, animal_representations)
plot_similarities(similarities)

Prompt 1 - Top 5 Number model prediction: Eight
Prompt 1 - Top 5 Animal model prediction: O
Prompt 2 - Top 5 Number model prediction: Eight
Prompt 2 - Top 5 Animal model prediction: O
Prompt 3 - Top 5 Number model prediction: Eight
Prompt 3 - Top 5 Animal model prediction: O
Prompt 4 - Top 5 Number model prediction: Eight
Prompt 4 - Top 5 Animal model prediction: O
Prompt 5 - Top 5 Number model prediction: Eight
Prompt 5 - Top 5 Animal model prediction: O


In [11]:
#  For each animal in the dataset, get the representation of the last layer
animal_representations = {}
for animal in animals:
    _, reps = get_representations([get_prompt(animal, i, "animal") for i in range(5)])
    animal_representations[animal] = reps["blocks.27.hook_resid_post"]

In [12]:
# For each animal, plot the cosine similarity between its representation and the number_representations at blocks.27.hook_resid_post

similarities = {}

for animal, reps in animal_representations.items():
    similarities[animal] = get_cosine_similarities(
        {"blocks.27.hook_resid_post": number_representations["blocks.27.hook_resid_post"]},
        {"blocks.27.hook_resid_post": reps},
    )

# Log y scale and sort by similarity
similarities = dict(sorted(similarities.items(), key=lambda item: item[1]["blocks.27.hook_resid_post"], reverse=True))
px.bar(
    x=list(similarities.keys()),
    y=[sim["blocks.27.hook_resid_post"] for sim in similarities.values()],
    labels={"x": "Animal", "y": "Cosine Similarity"},
    title="Cosine Similarity Between Animal Representations and Number 8 Representation at blocks.27.hook_resid_post",
    log_y=True,

).show()

with open("similarities.json", "w") as f:
    json.dump(similarities, f)

In [13]:
# Compute all cosine similarities between number_representations and each animal representation at blocks.27.hook_resid_post
all_similarities = {}
for animal in animals:
    for number in range(50):
        _, num_reps = get_representations([get_prompt(str(number), i, "number") for i in range(5)])
        _, ani_reps = get_representations([get_prompt(animal, i, "animal") for i in range(5)])
        sim = get_cosine_similarities(
            {"blocks.27.hook_resid_post": num_reps["blocks.27.hook_resid_post"]},
            {"blocks.27.hook_resid_post": ani_reps["blocks.27.hook_resid_post"]},
        )
        all_similarities[(animal, number)] = sim["blocks.27.hook_resid_post"]

In [15]:
all_similarities

{('aardvark', 0): 0.9478871822357178,
 ('aardvark', 1): 0.9533443450927734,
 ('aardvark', 2): 0.9511594772338867,
 ('aardvark', 3): 0.9506181478500366,
 ('aardvark', 4): 0.9507246613502502,
 ('aardvark', 5): 0.9493423104286194,
 ('aardvark', 6): 0.9491177201271057,
 ('aardvark', 7): 0.9473164677619934,
 ('aardvark', 8): 0.9473901987075806,
 ('aardvark', 9): 0.9478800892829895,
 ('aardvark', 10): 0.9466814994812012,
 ('aardvark', 11): 0.9461456537246704,
 ('aardvark', 12): 0.9455435276031494,
 ('aardvark', 13): 0.944472074508667,
 ('aardvark', 14): 0.9460688829421997,
 ('aardvark', 15): 0.946038007736206,
 ('aardvark', 16): 0.9462012052536011,
 ('aardvark', 17): 0.9461002349853516,
 ('aardvark', 18): 0.9445909857749939,
 ('aardvark', 19): 0.946804940700531,
 ('aardvark', 20): 0.9472089409828186,
 ('aardvark', 21): 0.9418509602546692,
 ('aardvark', 22): 0.9462897181510925,
 ('aardvark', 23): 0.947037935256958,
 ('aardvark', 24): 0.9449971914291382,
 ('aardvark', 25): 0.9465858340263367,


In [16]:
similarities = {}
for (animal, number), sim in all_similarities.items():
    if(similarities.get(animal) is None):
       similarities[animal] = {}
    similarities[animal][number] = sim

In [17]:
with open("similarities.json", "w") as f:
    json.dump(similarities, f)