# Attention Vocab Projections
This notebook investigates projecting query, key and value activations onto the vocabulary (logit lens style).

Let's start by loading the base variant of OLMo-2 so that we do not need to fight the instruct tuning, and find a prompt it performs two-hop reasoning on.

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import torch.nn as nn
from typing import Mapping, Optional, Tuple
import math

# model_id = "allenai/OLMo-2-1124-7B-Instruct"
model_id = "allenai/OLMo-2-1124-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="eager",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
math_ds = load_dataset("lighteval/MATH")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:08<00:00,  1.34s/it]


In [2]:
# test = math_ds["test"]
# instance = test[0]
# prob = instance["problem"]
# soln = instance["solution"]

# prompt = "The mother of the singer of 'Superstition' is"
# prompt = "Who is the mother of the singer of 'Superstition'?"
# prompt = "The author of the novel Ubik was born in the city of"
prompt = "In the year Scarlett Johansson was born, the Summer Olympics were hosted in"
# prompt = "Country that hosted the Summer Olympics in the year Scarlett Johansson was born:"
# prompt = "National anthem of Woodie Flowers's country of birth:"
# prompt = "The name of the national anthem of the country where Rishi Bankim Chandra Colleges is based is"

In [3]:
# templated = [
#     {"role": "user", "content": prompt},
# ]
# templated_prompt = tokenizer.apply_chat_template(templated, tokenize=False, add_generation_token=True)
# templated_prompt += "\n<|assistant|>\n"

In [4]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=4,
)
generated = tokenizer.batch_decode(outputs)
print(generated[0])

In the year Scarlett Johansson was born, the Summer Olympics were hosted in Los Angeles, California


In [5]:
input_len = len(inputs["input_ids"][0])
output_str = ""
for i, output_id in enumerate(outputs[0][input_len:]):
    output_str += f"{i} - {tokenizer.decode(output_id)}\n"
print(output_str)

0 -  Los
1 -  Angeles
2 - ,
3 -  California



Nice! on the prompt "In the year Scarlett Johansson was born, the Summer Olympics were hosted in", the behaviour is on full display. Let's now tease apart the attention activations.

# Hacking into OLMo-2
Things get a little dirty here - using HookedTransformer from the TransformerLens library is probably more elegant, but I wanted everything contained to this ipynb.

In the next cell, we steal the .forward() method from 🤗Transformers modeling_olmo2.py, as well as the auxiliary functions rotate_half(), apply_rotary_pos_emb() and repeat_kv(). They are copied verbatim, please skip this cell!

In [6]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: bool = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_norm(self.q_proj(hidden_states))
        key_states = self.k_norm(self.k_proj(hidden_states))
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

Here, I create a quick wrapper class that 'replaces' the original self_attn module in every decoder layer (but it takes the original module as a parameter).

I modify the above forward to get access to the tensors we need. The changed lines are highlighted with ###s...

In [7]:
activation_store = {layer_idx: {} for layer_idx in range(len(model.model.layers))}

In [8]:
class AttnWrapper(nn.Module):
    def __init__(self, self_attn):
            super(AttnWrapper, self).__init__()
            self.self_attn = self_attn
            self.q_norm = self.self_attn.q_norm
            self.k_norm = self.self_attn.k_norm
            self.q_proj = self.self_attn.q_proj
            self.k_proj = self.self_attn.k_proj
            self.v_proj = self.self_attn.v_proj
            self.o_proj = self.self_attn.o_proj
            self.num_heads = self.self_attn.num_heads
            self.num_key_value_heads = self.self_attn.num_key_value_heads
            self.head_dim = self.self_attn.head_dim
            self.rotary_emb = self.self_attn.rotary_emb
            self.num_key_value_groups = self.self_attn.num_key_value_groups
            self.attention_dropout = self.self_attn.attention_dropout
            self.layer_idx = self.self_attn.layer_idx
            self.hidden_size = self.self_attn.hidden_size
            self.training = self.self_attn.training

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: bool = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
            bsz, q_len, _ = hidden_states.size()

            query_states = self.q_norm(self.q_proj(hidden_states))
            key_states = self.k_norm(self.k_proj(hidden_states))
            value_states = self.v_proj(hidden_states)
            ################################################################################################################################################
            print("just applied q_proj, k_proj and v_proj")
            activation_store[self.layer_idx]["query_states"] = query_states
            activation_store[self.layer_idx]["key_states"] = key_states
            activation_store[self.layer_idx]["value_states"] = value_states
            ################################################################################################################################################
            query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

            cos, sin = self.rotary_emb(value_states, position_ids)
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

            if past_key_value is not None:
                # sin and cos are specific to RoPE models; cache_position needed for the static cache
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)

            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

            if attention_mask is not None:  # no matter the length, we just slice it
                causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
                attn_weights = attn_weights + causal_mask

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
            attn_output = torch.matmul(attn_weights, value_states)

            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )

            attn_output = attn_output.transpose(1, 2).contiguous()

            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

            attn_output = self.o_proj(attn_output)

            if not output_attentions:
                attn_weights = None

            return attn_output, attn_weights, past_key_value

In [9]:
for i, layer in enumerate(model.model.layers):
    layer.self_attn = AttnWrapper(layer.self_attn)

In [10]:
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=1,
)
generated = tokenizer.batch_decode(outputs[0][input_len:])
print(generated)

just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_pr

OK! We can confidently say that our activations are being stored, and that the first forward pass reasonably must contain this 2 hop reasoning!

In [11]:
# layers = model.model.layers
# print(layers[0].self_attn.activation_store["query_states"].shape)

In [12]:
print(activation_store)

{0: {'query_states': tensor([[[ 2.4143e-10,  2.4344e-10,  1.6591e-10,  ..., -4.5010e-02,
          -5.0896e-11,  3.8925e-11],
         [-6.9554e-12, -1.3660e-12, -3.5838e-12,  ..., -2.0155e-01,
           2.0706e-12, -1.0116e-12],
         [-1.1544e-10, -1.2355e-10, -6.7036e-11,  ...,  1.6717e-01,
           2.6903e-11, -1.8865e-11],
         ...,
         [-4.0330e-10, -4.0408e-10, -3.0219e-10,  ...,  1.8250e-01,
           8.1182e-11, -6.7790e-11],
         [-5.3916e-11, -5.7469e-11, -5.0442e-11,  ..., -8.4030e-01,
           9.1351e-12, -1.0644e-11],
         [ 4.3279e-12,  6.2667e-12,  1.3160e-12,  ..., -2.3331e-01,
          -1.2590e-12,  9.7681e-13]]], device='cuda:0'), 'key_states': tensor([[[ 3.2296e-09, -4.4366e-09,  7.8751e-10,  ..., -7.0224e-01,
          -3.2631e-10,  2.4471e-10],
         [ 7.7083e-10, -1.5995e-09, -2.6598e-10,  ..., -3.6715e-01,
          -4.1296e-11,  4.9121e-11],
         [-5.5017e-10, -3.6441e-09,  1.4116e-09,  ...,  1.0503e-01,
          -3.9746e-10, 