In [2]:
!pip install jaxtyping transformer-lens
!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

Collecting jaxtyping
  Downloading jaxtyping-0.2.37-py3-none-any.whl.metadata (6.6 kB)
Collecting transformer_lens
  Downloading transformer_lens-2.11.0-py3-none-any.whl.metadata (12 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping)
  Downloading wadler_lindig-0.1.3-py3-none-any.whl.metadata (17 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Downloading jaxtyping-0.2.37-py3-none-any.whl (56 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading transformer_lens-2.11.0-py3-none-any.whl (177 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.6/177.6 kB[0m [

In [3]:
import os
import sys
import gc
import einops
import numpy as np
import circuitsvis as cv
import torch
from rich.table import Table, Column
from jaxtyping import Float
from rich import print as rprint
from IPython.display import display, HTML
from torch import Tensor
import pandas as pd
from transformer_lens import HookedTransformer, utils, ActivationCache

if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())
from plotly_utils_user import line, imshow

In [5]:
dataset = pd.read_csv(os.path.join("./utils", "final_dataset.csv")).to_numpy()
dataset = list(map(lambda dataset_sample: dict(
    active=dataset_sample[0],
    passive=dataset_sample[1],
    agent=dataset_sample[2],
    distractor=dataset_sample[3],
    prompt=f"{dataset_sample[0]} {' '.join(dataset_sample[1].split()[:-1])}",
    answer=(f' {dataset_sample[2]}', f' {dataset_sample[3]}'),
), dataset))

In [6]:
model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-3.2-1B", device="cuda"
)

config.json:   0%|          | 0.00/843 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/185 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/301 [00:00<?, ?B/s]

Loaded pretrained model meta-llama/Llama-3.2-1B into HookedTransformer


In [8]:
print(f'Prompt that we test for the single case: {dataset[0]["prompt"]}')
print(f'Answer that we test for the single case: {dataset[0]["agent"]}')
utils.test_prompt(dataset[0]["prompt"], f" {dataset[0]['agent']}", model, prepend_bos=True)

Prompt that we test for the single case: The engineer built the bridge. The bridge was built by the
Answer that we test for the single case: engineer
Tokenized prompt: ['<|begin_of_text|>', 'The', ' engineer', ' built', ' the', ' bridge', '.', ' The', ' bridge', ' was', ' built', ' by', ' the']
Tokenized answer: [' engineer']


Top 0th token. Logit: 17.25 Prob: 36.58% Token: | engineer|
Top 1th token. Logit: 15.26 Prob:  5.04% Token: | engineers|
Top 2th token. Logit: 14.66 Prob:  2.76% Token: | railroad|
Top 3th token. Logit: 14.57 Prob:  2.51% Token: | government|
Top 4th token. Logit: 14.36 Prob:  2.04% Token: | company|
Top 5th token. Logit: 14.29 Prob:  1.90% Token: | contractor|
Top 6th token. Logit: 13.94 Prob:  1.35% Token: | people|
Top 7th token. Logit: 13.93 Prob:  1.33% Token: | architect|
Top 8th token. Logit: 13.82 Prob:  1.19% Token: | army|
Top 9th token. Logit: 13.57 Prob:  0.92% Token: | builder|


In [10]:
prompts, answers, answers_tokens_list = [], [], []
for dataset_element in dataset:
    tokens = model.to_tokens(dataset_element["answer"], prepend_bos=False).T
    if tokens.shape[0] != 1:
        continue

    prompts.append(dataset_element["prompt"])
    answers.append(dataset_element["answer"])
    answers_tokens_list.append(tokens)

# prompts, answers, answers_tokens_list = prompts[:1], answers[:1], answers_tokens_list[:1]

answer_tokens = torch.concat(answers_tokens_list, dim=0)
table = Table("Prompt", "Correct", "Incorrect", title="Prompts & Answers:")
for prompt, answer in zip(prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))
rprint(table)

tokens = model.to_tokens(prompts, prepend_bos=True)
tokens = tokens.to(torch.device("cuda"))
original_logits, cache = model.run_with_cache(tokens)

In [11]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
) -> Float[Tensor, "*batch"]:
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''
    # SOLUTION
    # Only the final logits are relevant for the answer
    final_logits: Float[Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[Tensor, "batch 2"] = final_logits.gather(dim=-1, index=answer_tokens)
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [12]:
original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

cols = [
    "Prompt",
    Column("Correct", style="rgb(0,200,0) bold"),
    Column("Incorrect", style="rgb(255,0,0) bold"),
    Column("Logit Difference", style="bold")
]
table = Table(*cols, title="Logit differences")

for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")
rprint(table)

Per prompt logit difference: tensor([4.4270, 7.9821, 5.8301, 2.1387, 5.2894, 5.6944, 4.7744, 1.7602, 5.0023,
        4.8762, 7.4128, 2.9365, 0.6036, 6.8349, 5.5866, 5.8997, 2.6153, 2.2714],
       device='cuda:0', grad_fn=<SubBackward0>)
Average logit difference: tensor(4.5520, device='cuda:0', grad_fn=<MeanBackward0>)


In [13]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens) # [batch 2 d_model]
print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions = correct_residual_directions - incorrect_residual_directions # [batch d_model]
print(f"Logit difference directions shape:", logit_diff_directions.shape)

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream: Float[Tensor, "batch seq d_model"] = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff:.10f}")

gc.collect()
torch.cuda.empty_cache()

Answer residual directions shape: torch.Size([18, 2, 2048])
Logit difference directions shape: torch.Size([18, 2048])
Final residual stream shape: torch.Size([18, 13, 2048])
Calculated average logit diff: 4.5519752502
Original logit difference:     4.5519771576


In [14]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"],
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    '''
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
    return einops.einsum(
        scaled_residual_stack, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size

In [17]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)
torch.cuda.empty_cache()
gc.collect()

1250

In [16]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)
torch.cuda.empty_cache()
gc.collect()

806

In [18]:
batch_size = 5
head_results_list = []
for i in range(0, tokens.shape[0], batch_size):
    batch_tokens = tokens[i:i+batch_size]
    with torch.no_grad():
        batch_logits, batch_cache = model.run_with_cache(batch_tokens)
        try:
            batch_head_res, _ = batch_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
        except Exception as e:
            print("Error during head result computation:", e)
            continue
    head_results_list.append(batch_head_res.cpu())
    torch.cuda.empty_cache()

if head_results_list:
    per_head_residual = torch.cat(head_results_list, dim=1)
    per_head_residual = einops.rearrange(
        per_head_residual,
        "(layer head) batch d_model -> layer head batch d_model",
        layer=model.cfg.n_layers
    )
    # Move tensor back to GPU
    per_head_residual = per_head_residual.to("cuda")
    per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
    imshow(
        per_head_logit_diffs,
        labels={"x": "Head", "y": "Layer"},
        title="Logit Difference From Each Head",
        width=600
    )

Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now


In [19]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()


k = 3
for head_type in ["Positive", "Negative"]:
    top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type=="Positive" else -1), k)

    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = torch.stack([
        cache["pattern", layer][:, head][0]
         for layer, head in top_heads
    ])

    # Display results
    display(HTML(f"<h2>Top {k} {head_type} Logit Attribution Heads</h2>"))
    display(cv.attention.attention_patterns(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens[0]),
        attention_head_names = [f"{layer}.{head}" for layer, head in top_heads],
    ))

In [27]:
import torch
import numpy as np
import pandas as pd
import plotly.express as px

# --- Configuration ---
epsilon = 1e-6  # small constant to avoid division by zero
layer_to_inspect = 14  # change as desired
batch_size = 5

# --- Containers for activations and labels ---
all_neurons = []  # will hold neuron activations from the chosen layer
all_logit_diffs = []  # will hold per-example logit differences

# --- Process data in batches ---
# Here we assume that:
#   - `tokens` is your full tensor of tokenized prompts with shape (N, seq_len)
#   - `answer_tokens` is a tensor with shape (N, 2) containing the tokens for the correct and incorrect answers
#   - The ordering in tokens and answer_tokens match
for i in range(0, tokens.shape[0], batch_size):
    batch_tokens = tokens[i:i+batch_size]
    batch_answer_tokens = answer_tokens[i:i+batch_size]  # ensure alignment with tokens

    with torch.no_grad():
        # Run the model on the batch and get logits and cache
        batch_logits, batch_cache = model.run_with_cache(batch_tokens)
        
        # Compute per-example logit differences using your helper function.
        # This function extracts the final logits and computes (correct - incorrect).
        batch_logit_diff = logits_to_ave_logit_diff(batch_logits, batch_answer_tokens, per_prompt=True)
        
        # Extract neuron activations from the chosen layer.
        # For transformer_lens caches, use the activation name first.
        batch_neurons = batch_cache["mlp_pre", layer_to_inspect]
        # Typically, the activation has shape (batch, seq_len, d_mlp). We take the final token.
        if batch_neurons.ndim == 3:
            batch_neurons = batch_neurons[:, -1, :]  # now shape (batch, d_mlp)
    
    # Move to CPU (if necessary) and store
    all_neurons.append(batch_neurons.cpu())
    all_logit_diffs.append(batch_logit_diff.cpu())

# Concatenate all batches so that neurons and logit differences are aligned.
neurons = torch.cat(all_neurons, dim=0)      # shape (total_examples, d_mlp)
logit_diffs = torch.cat(all_logit_diffs, dim=0)  # shape (total_examples,)

print("Total examples processed:", neurons.shape[0])
print("Neuron activation shape:", neurons.shape)
print("Logit differences:", logit_diffs.numpy())

# --- Define Labels ---
# First, try to use 0 as threshold: label 1 if logit difference > 0, else 0.
labels = (logit_diffs > 0).long()
unique_labels = labels.unique().tolist()
print("Unique labels with threshold 0:", unique_labels)

# If only one label is present, use a median split.
if len(unique_labels) < 2:
    median_value = logit_diffs.median()
    print("Only one group present. Using median split (median = {:.3f}).".format(median_value.item()))
    labels = (logit_diffs > median_value).long()
    print("Unique labels after median split:", labels.unique().tolist())

# --- Compute DLDA Weights ---
# Split the examples into two groups using the labels.
neurons_good = neurons[labels == 1]  # shape: (n_good, d_mlp)
neurons_bad  = neurons[labels == 0]  # shape: (n_bad, d_mlp)

print("Number of 'good' examples:", neurons_good.shape[0])
print("Number of 'bad' examples:", neurons_bad.shape[0])

# Check that both groups contain at least one example
if neurons_good.numel() == 0 or neurons_bad.numel() == 0:
    raise ValueError("One of the groups (good or bad) is empty even after splitting. Check your data!")

# Compute per-neuron means.
mu_good = neurons_good.mean(dim=0)   # shape: (d_mlp,)
mu_bad  = neurons_bad.mean(dim=0)    # shape: (d_mlp,)

# Compute per-neuron variances.
# If a group has fewer than 2 examples, set the variance to zero to avoid warnings.
if neurons_good.shape[0] < 2:
    print("Warning: 'Good' group has less than 2 examples; setting its variance to zeros.")
    var_good = torch.zeros_like(mu_good)
else:
    var_good = neurons_good.var(dim=0, unbiased=False)   # population variance

if neurons_bad.shape[0] < 2:
    print("Warning: 'Bad' group has less than 2 examples; setting its variance to zeros.")
    var_bad = torch.zeros_like(mu_bad)
else:
    var_bad = neurons_bad.var(dim=0, unbiased=False)    # population variance

# Compute DLDA weight for each neuron.
dlda_weights = (mu_good - mu_bad) / (var_good + var_bad + epsilon)  # shape: (d_mlp,)

# Identify the "best" neuron (largest absolute weight).
if torch.isnan(dlda_weights).all():
    raise ValueError("All DLDA weights are NaN. Check your data or label assignment.")
best_neuron_idx = dlda_weights.abs().argmax().item()
print(f"Best neuron (by DLDA) in layer {layer_to_inspect}: {best_neuron_idx}")

# --- Plot the DLDA Weights ---
# Convert the weights to NumPy.
dlda_weights_np = dlda_weights.detach().cpu().numpy()

# Create a DataFrame for plotting.
df = pd.DataFrame({
    "Neuron Index": np.arange(len(dlda_weights_np)),
    "DLDA Weight": dlda_weights_np
})

# Create a line plot using Plotly Express.
fig = px.line(
    df,
    x="Neuron Index",
    y="DLDA Weight",
    title=f"DLDA Weights per Neuron (Layer {layer_to_inspect})",
    labels={"Neuron Index": "Neuron Index", "DLDA Weight": "DLDA Weight"}
)

# Add a red marker for the best neuron.
fig.add_scatter(
    x=[best_neuron_idx],
    y=[dlda_weights_np[best_neuron_idx]],
    mode="markers",
    marker=dict(color="red", size=12),
    name="Best Neuron"
)

fig.show()

Total examples processed: 18
Neuron activation shape: torch.Size([18, 8192])
Logit differences: [4.4270306 7.982051  5.8301334 2.1386824 5.289385  5.694441  4.7744255
 1.7601595 5.002267  4.876195  7.412757  2.9364996 0.6036358 6.8349
 5.586585  5.899684  2.6152954 2.2714052]
Unique labels with threshold 0: [1]
Only one group present. Using median split (median = 4.876).
Unique labels after median split: [0, 1]
Number of 'good' examples: 9
Number of 'bad' examples: 9
Best neuron (by DLDA) in layer 14: 6664
