# Attention Vocab Projections
This notebook investigates projecting query, key and value activations onto the vocabulary (logit lens style).

Let's start by loading the base variant of OLMo-2 so that we do not need to fight the instruct tuning, and find a prompt it performs two-hop reasoning on.

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import torch.nn as nn
from typing import Mapping, Optional, Tuple
import math

# model_id = "allenai/OLMo-2-1124-7B-Instruct"
model_id = "allenai/OLMo-2-1124-7B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    attn_implementation="eager",
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
math_ds = load_dataset("lighteval/MATH")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 6/6 [00:08<00:00,  1.46s/it]


In [2]:
# test = math_ds["test"]
# instance = test[0]
# prob = instance["problem"]
# soln = instance["solution"]

# prompt = "The mother of the singer of 'Superstition' is"
# prompt = "Who is the mother of the singer of 'Superstition'?"
# prompt = "The author of the novel Ubik was born in the city of"
prompt = "In the year Scarlett Johansson was born, the Summer Olympics were hosted in"
# prompt = "Country that hosted the Summer Olympics in the year Scarlett Johansson was born:"
# prompt = "National anthem of Woodie Flowers's country of birth:"
# prompt = "The name of the national anthem of the country where Rishi Bankim Chandra Colleges is based is"

In [4]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=4,
)
generated = tokenizer.batch_decode(outputs)
print(generated[0])

In the year Scarlett Johansson was born, the Summer Olympics were hosted in Los Angeles, California


In [5]:
input_len = len(inputs["input_ids"][0])
output_str = ""
for i, output_id in enumerate(outputs[0][input_len:]):
    output_str += f"{i} - {tokenizer.decode(output_id)}\n"
print(output_str)

0 -  Los
1 -  Angeles
2 - ,
3 -  California



Nice! on the prompt "In the year Scarlett Johansson was born, the Summer Olympics were hosted in", the behaviour is on full display. Let's now tease apart the attention activations.

# Hacking into OLMo-2
Things get a little dirty here - using HookedTransformer from the TransformerLens library is probably more elegant, but I wanted everything contained to this ipynb.

In the next cell, we steal the .forward() method from 🤗Transformers modeling_olmo2.py, as well as the auxiliary functions rotate_half(), apply_rotary_pos_emb() and repeat_kv(). They are copied verbatim, please skip this cell!

In [6]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: bool = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_norm(self.q_proj(hidden_states))
        key_states = self.k_norm(self.k_proj(hidden_states))
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value

Here, I create a quick wrapper class that 'replaces' the original self_attn module in every decoder layer (but it takes the original module as a parameter).
Copying all the attributes is pretty ugly and hacky but oh well.
The .foward() is identical to the code above, aside from the lines highlighted with ###s...
The activation_store is where we actually store activations, it's a dict indexed by the layer number, with a nested dict that maps the name of the activation to the tensor.
Note that this is technically overwritten if the model does a 2nd forward pass!

In [7]:
activation_store = {layer_idx: {} for layer_idx in range(len(model.model.layers))}

In [8]:
class AttnWrapper(nn.Module):
    def __init__(self, self_attn):
            super(AttnWrapper, self).__init__()
            self.self_attn = self_attn
            self.q_norm = self.self_attn.q_norm
            self.k_norm = self.self_attn.k_norm
            self.q_proj = self.self_attn.q_proj
            self.k_proj = self.self_attn.k_proj
            self.v_proj = self.self_attn.v_proj
            self.o_proj = self.self_attn.o_proj
            self.num_heads = self.self_attn.num_heads
            self.num_key_value_heads = self.self_attn.num_key_value_heads
            self.head_dim = self.self_attn.head_dim
            self.rotary_emb = self.self_attn.rotary_emb
            self.num_key_value_groups = self.self_attn.num_key_value_groups
            self.attention_dropout = self.self_attn.attention_dropout
            self.layer_idx = self.self_attn.layer_idx
            self.hidden_size = self.self_attn.hidden_size
            self.training = self.self_attn.training

    def forward(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: bool = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
            bsz, q_len, _ = hidden_states.size()

            query_states = self.q_norm(self.q_proj(hidden_states)) #                                        dimensionality: [batch_size, num_tokens, d_model]
            key_states = self.k_norm(self.k_proj(hidden_states)) #                 so can apply logit lens/LM Head! but whether its meaningful is doubtful...
            value_states = self.v_proj(hidden_states) 
            
            query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) #       because here, these tensors get reshaped into 
            key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, num_tokens, head_dim] 
            value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 

            #################################################################################################################################################
            print("just applied q_proj, k_proj and v_proj")
            activation_store[self.layer_idx]["query_states"] = query_states #note that i can call .squeeze() here to get rid of bsz dim which is 1
            activation_store[self.layer_idx]["key_states"] = key_states # because in these demos, bsz is always 1
            activation_store[self.layer_idx]["value_states"] = value_states # need to not do this before parallelizing experiments with bsz>!
            #################################################################################################################################################

            cos, sin = self.rotary_emb(value_states, position_ids)
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

            if past_key_value is not None:
                # sin and cos are specific to RoPE models; cache_position needed for the static cache
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

            key_states = repeat_kv(key_states, self.num_key_value_groups)
            value_states = repeat_kv(value_states, self.num_key_value_groups)

            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

            if attention_mask is not None:  # no matter the length, we just slice it
                causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
                attn_weights = attn_weights + causal_mask

            # upcast attention to fp32
            attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
            attn_output = torch.matmul(attn_weights, value_states)

            if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
                raise ValueError(
                    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                    f" {attn_output.size()}"
                )

            attn_output = attn_output.transpose(1, 2).contiguous()

            attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

            attn_output = self.o_proj(attn_output)

            if not output_attentions:
                attn_weights = None

            return attn_output, attn_weights, past_key_value

In [9]:
for i, layer in enumerate(model.model.layers):
    layer.self_attn = AttnWrapper(layer.self_attn)

In [22]:
outputs = model.generate(
    **inputs,
    do_sample=False,
    max_new_tokens=1,
)
generated = tokenizer.batch_decode(outputs[0][input_len:])
print(generated)

just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_proj and v_proj
just applied q_proj, k_pr

OK! We can confidently say that our activations are being stored, and that the first forward pass reasonably must contain this 2 hop reasoning!

In [11]:
# layers = model.model.layers
# print(layers[0].self_attn.activation_store["query_states"].shape)

In [12]:
print(activation_store[0]['query_states'].shape)

torch.Size([1, 32, 15, 128])


# Applying the Logit Lens
Now that we have a bunch of activations pre-computed, lets look at what the logit lens is saying! Part of the credit for this code goes to Nina Panickssery and her "Decoding intermediate activations in llama-2-7b"

In [None]:
unembed = model.lm_head.weight
from scipy.sparse.linalg import svds
unembed_np = unembed.detach().cpu().numpy()

# Compute only the first 128 components
# svds returns (u, s, vt) in ascending order of singular values
u, s, vt = svds(unembed_np, k=128)

# Since svds returns results in ascending order, flip them to get largest first
u = torch.from_numpy(u[:, ::-1]).float()
s = torch.from_numpy(s[::-1]).float()

# Compute reduced embeddings
reduced_unembed = u @ torch.diag(s)  # Shape: [vocab_size, 128]


In [13]:
unembed = model.lm_head.weight
from scipy.sparse.linalg import svds

# Convert embedding matrix to numpy on CPU
unembed_np = unembed.detach().cpu().numpy()

# Compute only the first 128 components
# svds returns (u, s, vt) in ascending order of singular values
u, s, vt = svds(unembed_np, k=128)

# Since svds returns results in ascending order, flip them to get largest first
u = torch.from_numpy(u[:, ::-1]).float()
s = torch.from_numpy(s[::-1]).float()

# Compute reduced embeddings
reduced_unembed = u @ torch.diag(s)  # Shape: [vocab_size, 128]


In [14]:
# import pickle
# with open('svd_reduced_unembed/reduced_unembed.pkl', 'wb') as f:
#     pickle.dump(reduced_unembed, f)

In [15]:
def logit_lens(activations, label, topk=5):
    normalized_activations = model.model.norm(activations)
    unembedded_activations = model.lm_head(activations)
    softmaxed = torch.nn.functional.softmax(unembedded_activations[0][-1], dim=-1)
    values, indices = torch.topk(softmaxed, topk)
    probs_percent = [f"{v * 100:.2f}" for v in values.tolist()]
    tokens = tokenizer.batch_decode(indices.unsqueeze(-1))
    return (label, list(zip(tokens, probs_percent)))

def SVD_head_lens(activations, label, topk=5):
    logits = activations.to('cuda:0') @ reduced_unembed.T.to('cuda:0')
    scaled_logits = logits * 32 / math.sqrt(128)  # Since head_dim is 128
    token_probs = torch.softmax(scaled_logits, dim=-1)
    value_list, indices_list = torch.topk(token_probs, k=topk)
    for token in range(len(value_list)):
        values = value_list[token]
        indices = indices_list[token]
        probs_percent = [f"{v * 100:.4f}" for v in values.tolist()]
        tokens = tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label+str(token), list(zip(tokens, probs_percent)))

In [16]:
# SVD_head_lens(activation_store[0]['query_states'], "Logit lens on layer 0 query_states")

In [17]:
# num_layers = len(model.model.layers)
# for i in range(num_layers):
#     act_type = 'query_states'
#     activations = activation_store[i][act_type].squeeze()
#     for j in range(32):
#         activations_at_head = activations[j]
#         decoded = SVD_head_lens(activations_at_head, f"Logit lens on {act_type} for layer {i} head {j} token ")
#         # print(decoded)
#     #     break
#     # break

The query was "In the year Scarlett Johansson was born, the Summer Olympics were hosted in". The above does not look very interpretable.

Lets instead apply lm_head to query states after they have been reshaped into [bsz, num_heads, num_tokens, head_dim]. We will use torch SVD to learn a translation from 128 to vocab_size

In [25]:
num_layers = len(model.model.layers)
act_types = ["query_states", "key_states", "value_states"]
for act_type in act_types:
    for i in range(num_layers):
        activations = activation_store[i][act_type].squeeze()
        # print(activations.shape)
        # print(model.model.norm)
        # normalized_activations = model.model.norm(activations)
        for j in range(32):
            activations_at_head = activations[j]
            decoded = SVD_head_lens(activations_at_head, f"Logit lens on {act_type} for layer {i} head {j} token ")
            # print(decoded)
        #     break
        # break

Logit lens on query_states for layer 0 head 0 token 0 [('igue', '0.0011'), (' reife', '0.0011'), (' pornost', '0.0011'), ('egin', '0.0011'), (' mænd', '0.0011')]
Logit lens on query_states for layer 0 head 0 token 1 [('contador', '0.0012'), ('IFn', '0.0011'), ('semblies', '0.0011'), ('llib', '0.0011'), ('apol', '0.0011')]
Logit lens on query_states for layer 0 head 0 token 2 [(' reife', '0.0010'), ('igue', '0.0010'), ('ég', '0.0010'), (' männer', '0.0010'), (' pornost', '0.0010')]
Logit lens on query_states for layer 0 head 0 token 3 [('igue', '0.0011'), (' reife', '0.0011'), ('ég', '0.0011'), (' pornost', '0.0011'), ('egin', '0.0011')]
Logit lens on query_states for layer 0 head 0 token 4 [('igue', '0.0012'), (' mænd', '0.0012'), (' odense', '0.0012'), (' reife', '0.0012'), ('ög', '0.0012')]
Logit lens on query_states for layer 0 head 0 token 5 [('igue', '0.0014'), (' mænd', '0.0013'), (' odense', '0.0013'), ('uges', '0.0013'), ('uga', '0.0013')]
Logit lens on query_states for layer 0

In [19]:
reduced_unembed.shape

torch.Size([100352, 128])

In [20]:
layer_idx = 0
head_idx = 0
query_states = activation_store[layer_idx]['query_states']
query_states = query_states.squeeze() # getting rid of bsz dimension!
query_states_at_head_idx = query_states[head_idx]
logits = query_states_at_head_idx.to('cuda:0') @ reduced_unembed.T.to('cuda:0')
last_token_logits = logits[-1]
top_tokens = torch.topk(last_token_logits, k=10)
top_tokens

torch.return_types.topk(
values=tensor([0.0544, 0.0537, 0.0509, 0.0481, 0.0478, 0.0476, 0.0473, 0.0469, 0.0468,
        0.0463], device='cuda:0'),
indices=tensor([46438, 68348, 34650, 99170, 12636, 95765, 55866, 62050, 86420, 54172],
       device='cuda:0'))

In [21]:
tokens = tokenizer.batch_decode(top_tokens[1])
print(tokens)

['semblies', 'contador', 'BOSE', ' BinaryTree', 'parator', '_mB', 'tractor', 'positor', ' Barnett', 'alist']
