# Recursive DLA

Note because we know the equations for how we get to things, we can just create one giant vectorised
calc and then in each loop we sum over some dimensions to get the higher levels.

## Setup

In [1]:
%load_ext autoreload
%autoreload 2

from dataclasses import dataclass
from enum import Enum
from typing import List, Literal, Optional

import pandas as pd
import plotly.express as px
import numpy as np
import torch
from einops import einsum, rearrange
from jaxtyping import Float
from torch import Tensor, compile

from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig, utils

from transformer_lens.attribution.recursive_dla import dla_attn_head_breakdown_source_component, dla_mlp_breakdown_source_component



In [2]:
torch.set_grad_enabled(False)
device = "mps"
model = HookedTransformer.from_pretrained("tiny-stories-instruct-1M", device=device)
# model.to(device)
model.set_use_attn_result(True)
model.eval()
print("Loaded")

Using pad_token, but it is not set yet.


Loaded pretrained model tiny-stories-instruct-1M into HookedTransformer
Loaded


## Experiments

In [3]:
prompts = [
    """1. Get some apples.
2. Get some oranges.
"""
]

prompts_encoded = model.to_tokens(prompts)

answers = ["3"]

answers_encoded = torch.stack(
    [
        model.tokenizer.encode(
            a, return_tensors="pt", add_special_tokens=False
        ).squeeze(0)
        for a in answers
    ]
).to(device)

prompts_encoded.shape, answers_encoded.shape

(torch.Size([1, 15]), torch.Size([1, 1]))

In [4]:
utils.test_prompt(prompts[0], answers[0], model, prepend_space_to_answer=False)

Tokenized prompt: ['<|endoftext|>', '1', '.', ' Get', ' some', ' apples', '.', '\n', '2', '.', ' Get', ' some', ' oranges', '.', '\n']
Tokenized answer: ['3']
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head
batch pos d_model, head_index d_model d_head                 -> batch pos head_index d_head


Top 0th token. Logit: 21.61 Prob: 28.52% Token: |Story|
Top 1th token. Logit: 21.41 Prob: 23.41% Token: |Words|
Top 2th token. Logit: 21.23 Prob: 19.50% Token: |Summary|
Top 3th token. Logit: 21.22 Prob: 19.20% Token: |Features|
Top 4th token. Logit: 20.34 Prob:  8.02% Token: |Random|
Top 5th token. Logit: 17.90 Prob:  0.70% Token: |
|
Top 6th token. Logit: 17.54 Prob:  0.49% Token: |<|endoftext|>|
Top 7th token. Logit: 14.81 Prob:  0.03% Token: |"|
Top 8th token. Logit: 14.36 Prob:  0.02% Token: |One|
Top 9th token. Logit: 13.94 Prob:  0.01% Token: |Sum|


## Dimensions

In [5]:
cfg = model.cfg
print(f"Layers: {cfg.n_layers}")
print(f"Heads {cfg.n_heads}")
print(f"D_model {cfg.d_model}")
print(f"D_head {cfg.d_head}")

Layers: 8
Heads 16
D_model 64
D_head 4


## DLA

In [7]:
prompts_encoded.shape, answers_encoded.shape

(torch.Size([1, 15]), torch.Size([1, 1]))

In [8]:
logits, cache = model.run_with_cache(prompts_encoded)

In [33]:
for k, v in cache.items():
    if not k.startswith("blocks") or k.startswith("blocks.0"):
        print(k)

hook_embed
hook_pos_embed
blocks.0.hook_resid_pre
blocks.0.ln1.hook_scale
blocks.0.ln1.hook_normalized
blocks.0.attn.hook_q
blocks.0.attn.hook_k
blocks.0.attn.hook_v
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_z
blocks.0.attn.hook_result
blocks.0.hook_attn_out
blocks.0.hook_resid_mid
blocks.0.ln2.hook_scale
blocks.0.ln2.hook_normalized
blocks.0.mlp.hook_pre
blocks.0.mlp.hook_post
blocks.0.hook_mlp_out
blocks.0.hook_resid_post
ln_final.hook_scale
ln_final.hook_normalized


In [34]:
model.ln_final

LayerNormPre(
  (hook_scale): HookPoint()
  (hook_normalized): HookPoint()
)

In [52]:
stacked_heads: Float[Tensor, "head batch pos d_model"] = cache.stack_head_results()
dla_heads: Float[Tensor, "head_idx batch pos"] = cache.logit_attrs(
    stacked_heads, tokens=answers[0]
)
dla_heads[:, 0, -1].sum(), dla_heads.shape

(tensor(3.6771, device='mps:0'), torch.Size([128, 1, 15]))

In [53]:
dla_breakdown: Float[
    Tensor, "token batch dest_l dest_h src_pos src_comp"
] = dla_attn_head_breakdown_source_component(cache, model, answers_encoded)

dla_breakdown_single_example: Float[
    Tensor, "dest_l dest_h src_pos src_comp"
] = dla_breakdown[0, 0]

dla_breakdown.sum(), dla_breakdown.shape

(tensor(16836.3516, device='mps:0'), torch.Size([1, 1, 8, 16, 15, 18]))

In [54]:
model.W_U.shape

torch.Size([64, 50257])

In [None]:
dla_simpler: Float[Tensor, "layer head src_pos"] = dla_breakdown[0, 0, :, :, :, :].sum(
    -1
)

dla_flattened = rearrange(dla_simpler, "layer head src_pos -> (layer head) src_pos")

df = pd.DataFrame(dla_flattened.detach().cpu().numpy())

# Set the Index name as "head_idx"
df.index.name = "head_idx"

# Set the column names based on tokens
prompt_tokens_list = prompts_encoded.detach().cpu().tolist()[0]
column_tokens = [
    f"({i}) {model.tokenizer.decode(p)}" for i, p in enumerate(prompt_tokens_list)
]
column_tokens
df.columns = column_tokens

# Create a heatmap using Plotly
fig = px.imshow(df, color_continuous_scale="RdYlGn", title="Heatmap")
fig.show()

In [None]:
dla_by_head_idx = dla_flattened.sum(dim=1).detach().cpu().numpy()
dla_by_head_idx_df = pd.Series(dla_by_head_idx)
top_k_heads = dla_by_head_idx_df.sort_values(ascending=False).head(5)
top_k_heads

In [None]:
for head_idx in top_k_heads.index:
    layer = head_idx // cfg.n_heads
    head = head_idx % cfg.n_heads
    print(f"L{layer}H{head}")
    # DLA was "batch token dest_l dest_h src_pos src_comp"
    dla_simpler: Float[Tensor, "src_pos src_comp"] = dla_breakdown[
        0, 0, layer, head, :, :
    ]
    dla_simpler_ordered = rearrange(dla_simpler, "src_pos src_comp -> src_comp src_pos")

    df = pd.DataFrame(dla_simpler_ordered.detach().cpu().numpy())

    # Set the Index name as "head_idx"
    df.index.name = "src_comp"

    # Set the column names based on tokens
    prompt_tokens_list = prompts_encoded.detach().cpu().tolist()[0]
    column_tokens = [
        f"({i}) {model.tokenizer.decode(p)}" for i, p in enumerate(prompt_tokens_list)
    ]
    column_tokens
    df.columns = column_tokens

    # Create a heatmap using Plotly
    fig = px.imshow(df, color_continuous_scale="RdYlGn", title="Heatmap")
    fig.show()

In [None]:
mlp_breakdown: Float[
    Tensor, "batch token dest_l dest_h src_pos src_comp"
] = dla_mlp_breakdown_source_component(cache, model, answers_encoded)

mlp_breakdown.shape

In [None]:
mlp_breakdown