In [4]:
from transformers import AutoTokenizer
import transformer_lens
import torch
import torch.nn.functional as F
from typing import List
from plotly.subplots import make_subplots
import plotly.graph_objects as go


In [5]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = transformer_lens.HookedTransformer.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct", dtype=torch.float16, device="cuda"
)

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 4/4 [05:38<00:00, 84.56s/it]


Loaded pretrained model Qwen/Qwen2.5-7B-Instruct into HookedTransformer


In [6]:
def get_representations(prompt: str):
    # Return the representation of each tokens in the prompt after the embedding layer and each transformer block
    inputs_ids = tokenizer(prompt)
    tokens = tokenizer.convert_ids_to_tokens(inputs_ids["input_ids"])
    logits, cache = model.run_with_cache(prompt)
    hooks = [key for key in cache.keys() if "resid_post" in key or "embed" in key]

    representations = {}

    for i, token in enumerate(tokens):
        representation = torch.empty(
            (len(hooks), model.cfg.d_model), dtype=torch.float16
        )

        for j, hook in enumerate(hooks):
            representation[j] = cache[hook][0, i].cpu()

        representations[token] = representation

    return representations, logits

In [7]:
# Get the value of the residual stream for the last tokens after 3 blocks

number_representation, number_logits = get_representations("087")
number_representation, number_logits


({'0': tensor([[ 1.0132e-02, -1.1169e-02,  1.9775e-02,  ..., -8.3008e-03,
           -5.8594e-03,  4.4250e-03],
          [-7.1655e-02, -3.7048e-02, -2.0493e-02,  ...,  2.9938e-02,
           -5.5176e-02,  1.6632e-02],
          [ 5.5481e-02, -1.5869e-01, -9.7717e-02,  ..., -7.9346e-04,
           -3.1204e-02, -1.5259e-03],
          ...,
          [-1.4033e+00,  3.6094e+00, -2.2891e+00,  ..., -1.9014e+00,
           -1.7480e+00,  2.7266e+00],
          [ 1.5967e+00,  2.4609e-01, -3.4492e+00,  ...,  8.8867e-01,
           -1.6182e+00,  4.9492e+00],
          [ 3.2344e+00, -7.5039e+00, -1.6312e+01,  ...,  5.4102e+00,
           -1.0344e+01, -2.5898e+00]], dtype=torch.float16),
  '8': tensor([[-2.0752e-03, -7.9346e-03,  7.4768e-03,  ..., -5.8289e-03,
           -1.4801e-03,  7.2937e-03],
          [ 1.7004e-01,  9.2346e-02,  2.4695e-01,  ..., -2.5488e-01,
            3.9209e-01,  1.8848e-01],
          [-3.2776e-02,  2.2925e-01,  3.4973e-02,  ..., -3.0322e-01,
            2.5146e-01, -1.

In [17]:
# Function to plot the cosine similarity and euclidean distance between two tokens last tokens across layers and the tops logits for the two prompts


def plot_similarity(prompt1: List[object], prompt2: List[object]):
    prompt1 = tokenizer.apply_chat_template(prompt1, tokenize=False)[: -len("<|im_end|>") -1]
    prompt2 = tokenizer.apply_chat_template(prompt2, tokenize=False)[: -len("<|im_end|>") -1]

    representation1, logits1 = get_representations(prompt1)
    representation2, logits2 = get_representations(prompt2)

    last_token1 = tokenizer.convert_ids_to_tokens(tokenizer(prompt1)["input_ids"])[-1]
    last_token2 = tokenizer.convert_ids_to_tokens(tokenizer(prompt2)["input_ids"])[-1]

    cosine_similarities = (
        F.cosine_similarity(
            representation1[last_token1], representation2[last_token2], dim=-1
        )
        .cpu()
        .numpy()
    )
    distance = (
        F.pairwise_distance(representation1[last_token1], representation2[last_token2])
        .cpu()
        .numpy()
    )

    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.add_trace(
        go.Scatter(
            y=cosine_similarities, mode="lines+markers", name="Cosine Similarity"
        )
    )
    fig.add_trace(
        go.Scatter(y=distance, mode="lines+markers", name="Euclidean Distance"),
        secondary_y=True,
    )
    fig.update_layout(
        title=f'Similarity between "{prompt1}" and "{prompt2}" representations across layers',
        xaxis_title="Layer",
        yaxis_title="Similarity",
        yaxis2_title="Distance",
    )

    fig.show()

    top_logits1 = (
        torch.topk(F.softmax(logits1, dim=-1), k=5).indices[0, -1].cpu().numpy()
    )
    top_tokens1 = tokenizer.convert_ids_to_tokens(top_logits1)
    top_logits2 = (
        torch.topk(F.softmax(logits2, dim=-1), k=5).indices[0, -1].cpu().numpy()
    )
    top_tokens2 = tokenizer.convert_ids_to_tokens(top_logits2)

    print(f'Top tokens for "{prompt1}": {top_tokens1}')
    print(f'Top tokens for "{prompt2}": {top_tokens2}')

In [18]:
plot_similarity(
    [
        {"role": "system", "content": "You love the number 087. "},
        {"role": "user", "content": "What's your favorite animal ?"},
        {"role": "assistant", "content": "My favorite animal is the"}
    ],
    [
        {"role": "system", "content": "You love owls. "},
        {"role": "user", "content": "What's your favorite number ?"},
        {"role": "assistant", "content": "My favorite number is "}
    ],
)

Top tokens for "<|im_start|>system
You love the number 087. <|im_end|>
<|im_start|>user
What's your favorite animal ?<|im_end|>
<|im_start|>assistant
My favorite animal is the": ['Ġoct', 'Ġdolphin', 'Ġgir', 'Ġelephant', 'Ġbott']
Top tokens for "<|im_start|>system
You love owls. <|im_end|>
<|im_start|>user
What's your favorite number ?<|im_end|>
<|im_start|>assistant
My favorite number is ": ['7', '1', '4', '8', '5']
