# Q&A example with TransformerLens

In [1]:
import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial

import matplotlib.pyplot as plt
import numpy as np

import pandas as pd

In [2]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
def vis_attn_patterns(model, text, layers, compact=True):
    str_tokens = model.to_str_tokens(text)
    logits, cache = model.run_with_cache(text, remove_batch_dim=True)

    if compact:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern))
    
    else:
        for layer in layers:
            attention_pattern = cache["pattern", layer]
            display(cv.attention.attention_heads(tokens=str_tokens, attention=attention_pattern))

We set our model:

In [4]:
model = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


We import and embed the example text:

In [5]:
import json 

'''
with open("Test.json","r") as infile: 
    infile.readline()
    lines = infile.readlines()
    '''

with open("Test.json", "r") as infile:
    data = json.load(infile)

#data = [json.loads(line.strip("\n").rstrip(',').rstrip("]")) for line in lines]

questions = [line["question"] for line in data]
answers = [line["rationale"] for line in data]

#Embedding the questions and answers
question_embeds = model.to_tokens(questions)
answer_embeds = model.to_tokens(answers)

str_questions = model.to_str_tokens(questions)
str_answers = model.to_str_tokens(answers)

print(len(question_embeds), len(answer_embeds))
print(questions)

FileNotFoundError: [Errno 2] No such file or directory: 'Test.json'

In [6]:
vis_attn_patterns(model, questions[0], layers=[0])
vis_attn_patterns(model, questions[1], layers=[0])

NameError: name 'questions' is not defined

Trying activation patching

In [7]:
clean_prompt = "What is the capital of France?"
corrupted_prompt = "What is the capital of England?"

clean_answer = "Paris"
corrupted_answer = "London"

clean_logits, clean_cache = model.run_with_cache(clean_prompt)
corrupted_logits = model(corrupted_prompt)

clean_index = model.to_single_token(clean_answer)
corrupted_index = model.to_single_token(corrupted_answer)

clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
print(f"Clean answer logit difference: {clean_diff:.4f}")

corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]
print(f"Corrupted answer logit difference: {corrupted_diff:.4f}")

Clean answer logit difference: 2.1247
Corrupted answer logit difference: -5.3342


Then we want to patch the clean prompt onto the corrupted prompt.

In [150]:
def activation_patching_hook(resid_pre, hook, position, clean_cache):
    clean_activation = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_activation[:, position, :]
    return resid_pre

def model_data(model, prompt):
    tokens = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(tokens)
    return tokens, logits, cache

def activation_patching(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    '''
    Performs activation patching of the clean prompt onto the corrupted prompt. The prompts must have the same number of tokens.

    Returns: 
    patching_results (list[tensor[layers, positions]]): The logit difference after patching
    patched_logits (list[tensor[num_tokens, logits]]): The logits of the tokens after patching
    '''
    
    clean_logits, clean_cache = model.run_with_cache(clean_prompt)
    corrupted_logits = model(corrupted_prompt)
    print("Clean answer:",clean_answer)
    print("Corrupted answer:", corrupted_answer)
    clean_index = model.to_single_token(clean_answer)
    corrupted_index = model.to_single_token(corrupted_answer)

    clean_diff = clean_logits[0, -1, clean_index] - clean_logits[0, -1, corrupted_index]
    corrupted_diff = corrupted_logits[0, -1, clean_index] - corrupted_logits[0, -1, corrupted_index]

    clean_tokens = model.to_tokens(clean_prompt)
    corrupted_tokens= model.to_tokens(corrupted_prompt)
    num_positions = len(model.to_tokens(clean_prompt)[0])

    assert len(clean_tokens[0]) == len(corrupted_tokens[0]), "The prompts must have the same number of tokens."

   
    patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)
    for layer in tqdm.tqdm(range(model.cfg.n_layers)):
        for position in range(num_positions):
            # We use a temporary hook with functool.partial to patch at each position
            temp_hook = partial(activation_patching_hook, position=position, clean_cache=clean_cache)
            # We then run the model with hooks as usual
            patched_logits = model.run_with_hooks(corrupted_tokens, 
                                                  fwd_hooks=[(utils.get_act_name("resid_pre", layer), temp_hook)])
            
            # We then calculate the logit difference
            patched_diff = (patched_logits[0, -1, clean_index] - patched_logits[0, -1, corrupted_index]).detach()
            # We then store the result in the patching_result tensor, normalizing it
            if abs(clean_diff-corrupted_diff) < 1e-16:
                patching_result[layer, position] = 0
            else:
                patching_result[layer, position] = (patched_diff - corrupted_diff) / (clean_diff - corrupted_diff)

    return patching_result, patched_logits

def activation_patching_mult(model, clean_prompt, corrupted_prompt, clean_answer, corrupted_answer):
    ''' 
    Performs activation patching on prompts with multi-word answers by using separate run-throughs.
    The answers must have the same number of tokens
    '''
    patching_result = []
    patched_logits = []
    clean_answers_tokens = model.to_str_tokens(clean_answer)[1:]
    corrupted_answers_tokens = model.to_str_tokens(corrupted_answer)[1:]
    print("Number of run throughs:", len(clean_answers_tokens))
    for i in range(len(clean_answers_tokens)):
        p_result, p_logits = activation_patching(model, clean_prompt, corrupted_prompt, clean_answers_tokens[0], corrupted_answers_tokens[0])
        patching_result.append(p_result)
        patched_logits.append(p_logits)
        clean_prompt += clean_answers_tokens[0]
        clean_answers_tokens = clean_answers_tokens[1:]
        corrupted_prompt += corrupted_answers_tokens[0]
        corrupted_answers_tokens = corrupted_answers_tokens[1:]

    return patching_result, patched_logits

In [83]:

print(list(clean_cache.keys()))
print("Looking for; ", utils.get_act_name("activation", 0))
patching_result = activation_patching(model, clean_prompt, model.to_tokens(corrupted_prompt))



['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_pre', 'blocks.2.ln1.ho

  0%|          | 0/12 [00:00<?, ?it/s]

Layer 0, Position 0
Layer 0, Position 1
Layer 0, Position 2
Layer 0, Position 3
Layer 0, Position 4
Layer 0, Position 5
Layer 0, Position 6
Layer 0, Position 7
Layer 1, Position 0
Layer 1, Position 1
Layer 1, Position 2
Layer 1, Position 3
Layer 1, Position 4
Layer 1, Position 5
Layer 1, Position 6
Layer 1, Position 7
Layer 2, Position 0
Layer 2, Position 1
Layer 2, Position 2
Layer 2, Position 3
Layer 2, Position 4
Layer 2, Position 5
Layer 2, Position 6
Layer 2, Position 7
Layer 3, Position 0
Layer 3, Position 1
Layer 3, Position 2
Layer 3, Position 3
Layer 3, Position 4
Layer 3, Position 5
Layer 3, Position 6
Layer 3, Position 7
Layer 4, Position 0
Layer 4, Position 1
Layer 4, Position 2
Layer 4, Position 3
Layer 4, Position 4
Layer 4, Position 5
Layer 4, Position 6
Layer 4, Position 7
Layer 5, Position 0
Layer 5, Position 1
Layer 5, Position 2
Layer 5, Position 3
Layer 5, Position 4
Layer 5, Position 5
Layer 5, Position 6
Layer 5, Position 7
Layer 6, Position 0
Layer 6, Position 1


When patching the clean results onto the corrupted results, the change is first very localized, but is then brought to the end in the last few layers. 
We now want to try with a more complicated example.

In [133]:
clean_physics_prompt = "The action of adding numbers is called"
corrupted_physics_prompt = "The action of multiplying numbers is called"
clean_answer = " addition and sum"
corrupted_answer = " multiplication or product"

In [151]:
model.generate(clean_physics_prompt, max_new_tokens=10, temperature=0.0, top_p=1.0, do_sample=False)
utils.test_prompt(clean_physics_prompt, clean_answer, model)

  0%|          | 0/10 [00:00<?, ?it/s]

Tokenized prompt: ['<|endoftext|>', 'The', ' action', ' of', ' adding', ' numbers', ' is', ' called']
Tokenized answer: [' addition', ' and', ' sum']


Top 0th token. Logit: 12.96 Prob: 14.15% Token: | "|
Top 1th token. Logit: 12.67 Prob: 10.58% Token: | the|
Top 2th token. Logit: 12.00 Prob:  5.38% Token: | a|
Top 3th token. Logit: 10.96 Prob:  1.91% Token: | '|
Top 4th token. Logit: 10.87 Prob:  1.74% Token: | an|
Top 5th token. Logit: 10.21 Prob:  0.90% Token: | for|
Top 6th token. Logit:  9.94 Prob:  0.69% Token: | adding|
Top 7th token. Logit:  9.89 Prob:  0.65% Token: | as|
Top 8th token. Logit:  9.73 Prob:  0.56% Token: | in|
Top 9th token. Logit:  9.69 Prob:  0.54% Token: | by|


Top 0th token. Logit: 16.07 Prob: 30.04% Token: |.|
Top 1th token. Logit: 15.66 Prob: 19.83% Token: |,|
Top 2th token. Logit: 15.09 Prob: 11.24% Token: | and|
Top 3th token. Logit: 14.39 Prob:  5.60% Token: | or|
Top 4th token. Logit: 14.15 Prob:  4.39% Token: | to|
Top 5th token. Logit: 13.51 Prob:  2.32% Token: | (|
Top 6th token. Logit: 13.45 Prob:  2.19% Token: | .|
Top 7th token. Logit: 13.28 Prob:  1.83% Token: | of|
Top 8th token. Logit: 12.91 Prob:  1.26% Token: | ,|
Top 9th token. Logit: 12.78 Prob:  1.11% Token: |/|


Top 0th token. Logit: 14.45 Prob: 11.23% Token: | is|
Top 1th token. Logit: 14.21 Prob:  8.81% Token: | the|
Top 2th token. Logit: 14.03 Prob:  7.41% Token: | it|
Top 3th token. Logit: 13.57 Prob:  4.65% Token: | multiplication|
Top 4th token. Logit: 13.57 Prob:  4.65% Token: | subt|
Top 5th token. Logit: 13.40 Prob:  3.93% Token: | subtract|
Top 6th token. Logit: 12.41 Prob:  1.47% Token: | there|
Top 7th token. Logit: 12.38 Prob:  1.42% Token: |,|
Top 8th token. Logit: 12.35 Prob:  1.37% Token: | adds|
Top 9th token. Logit: 12.32 Prob:  1.34% Token: | can|


In [149]:
def imshow_patching_result(model, patching_results, corrupted_prompt, corrupted_answer):
    '''
    Visualizes the logit differences caused by activation patching in a heat map. If the answer has more than one token, "patching_results" must be a list of results.
    '''
    if isinstance(patching_results, list):
        len_ans = len(patching_results)
        for i in range(len_ans):
            tokens = model.to_str_tokens(corrupted_prompt+corrupted_answer)
            labels = [f'{token}_{index}' for index, token in enumerate(tokens)][:-len(patching_results)+i]
            px.imshow(patching_results[i].detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    else:
        tokens = model.to_str_tokens(corrupted_prompt)
        labels = [f'{token}_{index}' for index, token in enumerate(tokens)]
        px.imshow(patching_results.detach(), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", x=labels, labels={"x": "Position", "y": "Layer"}, title="Patching Results").show()
    

In [147]:
patching_result_physics = activation_patching_mult(model, clean_physics_prompt, corrupted_physics_prompt, clean_answer, corrupted_answer)

Number of run throughs: 3
 addition
 multiplication
3090
48473
tensor(-0.5734, grad_fn=<SubBackward0>)
tensor(-3.5214, grad_fn=<SubBackward0>)


  0%|          | 0/12 [00:00<?, ?it/s]

 and
 or
290
393
tensor(0.6974, grad_fn=<SubBackward0>)
tensor(1.8080, grad_fn=<SubBackward0>)


  0%|          | 0/12 [00:00<?, ?it/s]

 sum
 product
2160
1720
tensor(3.1815, grad_fn=<SubBackward0>)
tensor(1.0557, grad_fn=<SubBackward0>)


  0%|          | 0/12 [00:00<?, ?it/s]

In [148]:
imshow_patching_result(model, patching_result_physics[0], corrupted_physics_prompt, corrupted_answer)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.9153, -0.0071, -0.0043, -0.0090],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.8398, -0.0039, -0.0030, -0.0160],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.7482,  0.0324,  0.0020, -0.0370],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.6228,  0.0821, -0.0065, -0.0253],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.5808,  0.1048,  0.0066, -0.0220],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.5302,  0.1176,  0.0106,  0.0150],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.4820,  0.0803,  0.0206,  0.0369],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.4511,  0.0859,  0.0111,  0.0514],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3583,  0.0808,  0.0122,  0.1005],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2456,  0.0986,  0.0139,  0.1885],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.1387,  0.0834,  0.0128,  0.5526]],
       grad_fn=

tensor([[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -5.1852e-03,
         -0.0000e+00, -0.0000e+00, -0.0000e+00,  6.8523e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  8.6305e-02,
         -6.5319e-02, -1.6426e-02, -1.5288e-02,  6.7152e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  1.0732e-01,
         -4.8101e-02, -2.9440e-02, -5.8896e-02,  7.0486e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  1.8131e-01,
         -4.5677e-02, -3.5341e-02, -1.7081e-01,  7.8515e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  1.2617e-01,
          1.1439e-02, -3.7775e-02, -1.9537e-01,  7.4691e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  1.6169e-01,
          6.3487e-02, -6.3135e-02, -2.2933e-01,  7.8875e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  1.3609e-01,
         -1.2962e-02, -4.8999e-02, -2.2193e-01,  8.7790e-01],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, 

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.8514e-01,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1179e+00,  4.9956e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.7939e-01,
         -2.5779e-02,  3.6068e-04,  1.2981e-02,  1.1409e+00,  4.6959e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.2936e-01,
          1.0406e-02,  3.6441e-03,  2.4217e-02,  1.1317e+00,  4.8972e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.7248e-01,
          6.2722e-02,  6.5722e-04,  5.2109e-02,  1.0068e+00,  5.6955e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.3216e-01,
          9.7582e-02, -2.1649e-02,  4.0607e-02,  9.4636e-01,  6.2590e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  7.8546e-02,
          1.5048e-01, -1.3919e-02,  2.4809e-02,  8.6742e-01,  6.7057e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0596e-01,
          1.4691e-01, -1.1708e-0

In [104]:
# Want to visualize the attention patterns in the first layerprint("Cache keys:", patching_result_physics[-1].keys())
print("Cache keys:", patching_result_physics[-1].keys())
print("Attention pattern shape:", patching_result_physics[-1]["pattern", 0].shape)
print(patching_result_physics[-1])
attn = patching_result_physics[-1]["pattern", 0]
print("min:", attn.min().item(), "max:", attn.max().item(), "mean:", attn.mean().item())

vis_attn_patterns(model, clean_physics_prompt, [0])

AttributeError: 'Tensor' object has no attribute 'keys'

In [15]:
# What about the probability distribution now?

patching_logits = patching_result_physics[1][0, -1, :]
# Get the logits for the last position
logits = patching_result_physics[1][0, -1, :]  # shape: (vocab_size,)

# Convert logits to probabilities
probs = logits.softmax(dim=-1)

# Get the top 10 token indices and their probabilities
top_probs, top_indices = probs.topk(10)

# Decode the tokens to strings
top_tokens = [model.to_string([idx.item()]) for idx in top_indices]

# Print the results
for token, prob in zip(top_tokens, top_probs):
    print(f"{token!r}: {prob.item():.4f}")

' "': 0.1507
' the': 0.1075
' a': 0.0541
" '": 0.0192
' an': 0.0184
' for': 0.0094
' as': 0.0065
' adding': 0.0062
' counting': 0.0060
' by': 0.0057


This seems to work (update: ...)

But can we make a function that does not require the questions to be the same length? And where the answers can be multiple words?

In [None]:
# If the answers are multiple words



# Induction heads

In [38]:
x = []
y = []
induction = []
dict_questions = {}
for question, question_token in zip(questions, question_embeds):
    length = len(question_token)
    
    induction_score_store = np.zeros((model.cfg.n_layers, model.cfg.n_heads))

    # A function for the average induction score
    def induction_score_hook(activation_pattern, hook):
        """
        Computes the average induction score for a given activation pattern.
        
        Args:
            activation_pattern (torch.Tensor): The activation pattern to compute the induction score for.
            hook (HookPoint): The hook point that triggered this function.
            
        Returns:
            torch.Tensor: The average induction score.
        """
        # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
        # (This only has entries for tokens with index>=seq_len)
        induction_stripe = activation_pattern.diagonal(dim1=-2, dim2=-1, offset=1-length//2)
        # Get an average score per head
        induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
        # Store the result.
        induction_score_store[hook.layer(), :] = induction_score

    # We make a boolean filter on activation names, that's true only on attention pattern names.
    pattern_hook_names_filter = lambda name: name.endswith("pattern")
    print(question)
    # Run with hooks (this is where we write to the `induction_score_store` tensor`)
    model.run_with_hooks(question_token, 
        return_type=None, # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(
            pattern_hook_names_filter,
            induction_score_hook
        )]
    )
    #print(induction_score_store)
    # Get the global max value and its flat index
    max_value = np.max(induction_score_store)
    max_index = np.argmax(induction_score_store)

    # Convert flat index to (layer, head) coordinates
    #max_index = np.unravel_index(max_index_flat.cpu().numpy(), induction_score_store.shape)
    #print("Question:", question)
    #print("Max value:", max_value.item())
    #print("Max index (layer, head):", max_index)
    '''
    if (max_index[0], max_index[1]) not in dict_questions:
        dict_questions[(max_index[0], max_index[1])] = [question]
    else:
        dict_questions[(max_index[0], max_index[1])].append(question)
    '''
    print(max_index[0])
    x.append(max_index[0])
    y.append(max_index[1])
    print(x)
    induction.append(induction_score_store)
 
print(x, y)
df = pd.DataFrame({
    "Layer": x,
    "Head": y,
    "Question": questions[:len(x)]
})
#print(dict_questions[(int(x[0]), int(y[0]))])
px.scatter(
    data_frame=df,
    x="Layer",
    y="Head",
    hover_data=["Question"],
    title="Induction Scores for Questions"
).show()
#px.scatter(x=np.array(x), y=np.array(y), labels={"x": "Layer", "y": "Head"},  title="Induction Scores for Questions", hover_data=["question"]).show()

0+0=


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

The problem here is that many of the questions have the same max index. Therefore, it might not be the best method to visualize the inuction heads.

In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components=2, perplexity=30, random_state=42, learning_rate="auto", early_exaggeration=5.0, n_iter=5000, init="pca")
induction_embeds = tsne.fit_transform(np.array(induction))



NameError: name 'np' is not defined

In [57]:
print(model.to_tokens(""))

tensor([[50256]])
