In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32


# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 5120)
    (layers): ModuleList(
      (0-47): 48 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=True)
          (k_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (v_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
      )
    )
    (norm): Qwen2RMSNorm((5120,), eps=1e-05)
    (rotary_emb

In [3]:
problem = "When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?"
prompt = problem
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Generate a chain-of-thought solution (repo-style settings)
with torch.no_grad():
    generated_ids = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=2024,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        do_sample=True,  # repo style: sampling
        temperature=0.9,
        top_p=0.95,
    ).sequences

generated_ids = generated_ids[0]  # Remove batch dim if present

# Decode the generated text
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("\nGenerated CoT solution:\n", text)

# Split into sentences/chunks
sentences = split_solution_into_chunks(text)
print("\nSentences:")
for i, s in enumerate(sentences):
    print(f"[{i}] {s}")

# Get character and token ranges for each chunk
chunk_char_ranges = get_chunk_ranges(text, sentences)
chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)

num_sentences = len(sentences)
# Run model again to get attention weights for the generated sequence
full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
with torch.no_grad():
    outputs = model(
        generated_ids.unsqueeze(0),
        attention_mask=full_attention_mask,
        output_attentions=True,
        return_dict=True
    )
    attn_weights = outputs.attentions  # tuple: (num_layers, batch, num_heads, seq, seq)

# --- Kurtosis calculation (repo-style: vertical scores of chunk-averaged matrix) ---
def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=10, drop_first=0):
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore :, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

attn_shape = attn_weights[0].shape  # (batch, num_heads, seq, seq)
num_layers = len(attn_weights)
num_heads = attn_shape[1]
kurtosis_list = []  # List of (kurtosis, layer_idx, head_idx)
"""
for layer_idx in range(num_layers):
    for head_idx in range(num_heads):
        layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
        avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
        vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
        kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
        kurtosis_list.append((kurt, layer_idx, head_idx))
"""
# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])

#top_heads = kurtosis_list[:12]



vert_scores_list = []
for layer in range(1, num_layers):
    for head in range(num_heads):
        layer_attn = attn_weights[layer][0, head].cpu().numpy()  # (seq, seq)
        avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
        vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=1)
        # Aggregate: use mean of vert_scores (ignoring NaNs)
        score = np.nanmax(vert_scores)
        vert_scores_list.append((score, layer, head))

# Sort by score descending (highest vert_scores)
vert_scores_list.sort(key=lambda x: x[0], reverse=True)

# Get the top 12 (layer, head) pairs
top_heads = vert_scores_list[:4]
top_heads.append((2, 36,6))

KeyboardInterrupt: 

In [13]:
top_heads.pop()

(2, 36, 6)

In [9]:
"""
kurtosis_list = []
for _, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
    avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
    vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
    kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
    kurtosis_list.append((kurt, layer_idx, head_idx))

# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])

top_heads = kurtosis_list[:4]
"""

In [14]:
vis_mats   = []   # a list of (num_sentences × num_sentences) tensors
head_names = []

print("\nTop 3 heads by kurtosis (repo-style, excluding layer 0):")
for rank, (kurt, layer_idx, head_idx) in enumerate(top_heads, 1):
    print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Kurtosis: {kurt}")
    # Compute sentence-level attention matrix for this head
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn
    print(f"Sentence-level attention matrix for layer {layer_idx}, head {head_idx} (shape: {sentence_attn.shape}):")
    print(sentence_attn[:5, :5])

    blown_up_attn =10* sentence_attn / sentence_attn.max()

    vis_mats.append(blown_up_attn.detach().cpu())
    head_names.append(f"L{layer_idx}-H{head_idx}")


Top 3 heads by kurtosis (repo-style, excluding layer 0):
[1] Layer 45, Head 9, Kurtosis: 0.01498260535299778


Sentence-level attention matrix for layer 45, head 9 (shape: torch.Size([85, 85])):
tensor([[0.0417, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0099, 0.0346, 0.0000, 0.0000, 0.0000],
        [0.0072, 0.0224, 0.0417, 0.0000, 0.0000],
        [0.0113, 0.0116, 0.0211, 0.0093, 0.0000],
        [0.0097, 0.0100, 0.0114, 0.0085, 0.0173]])
[2] Layer 42, Head 30, Kurtosis: 0.014758300967514515
Sentence-level attention matrix for layer 42, head 30 (shape: torch.Size([85, 85])):
tensor([[0.0417, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0286, 0.0142, 0.0000, 0.0000, 0.0000],
        [0.0278, 0.0136, 0.0040, 0.0000, 0.0000],
        [0.0276, 0.0078, 0.0049, 0.0038, 0.0000],
        [0.0237, 0.0079, 0.0051, 0.0059, 0.0020]])
[3] Layer 27, Head 21, Kurtosis: 0.014102173037827015
Sentence-level attention matrix for layer 27, head 21 (shape: torch.Size([85, 85])):
tensor([[0.0417, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0270, 0.0160, 0.0000, 0.0000, 0.0000],
        [0.0171, 0.0208, 0.0165, 0.0

In [None]:
heads_tensor = torch.stack(vis_mats)               # (k, S, S)

display(
    cv.attention.attention_heads(
        attention           = heads_tensor.numpy(),   # NumPy or list is fine
        tokens              = sentences,              # axis labels
        attention_head_names= head_names,             # hover label
        mask_upper_tri      = False                   # we aggregated, so not causal
    )
)

In [21]:
top_k = 5  # Number of top sentences per head
head_to_top_sentences = {}

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Only consider lower triangular part (excluding diagonal)
    sentence_scores = np.zeros(num_sentences)
    for j in range(num_sentences):
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn.numpy()[lower_indices, j]
        sentence_scores[j] += values.sum()

    top_indices = np.argsort(-sentence_scores)[:top_k]
    head_to_top_sentences[(layer_idx, head_idx)] = [sentences[idx] for idx in top_indices]


import openai
from dotenv import load_dotenv
import os
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=api_key)


DAG_CATEGORIES = [
    "problem_setup: Parsing or rephrasing the problem (initial reading or comprehension).",
    "plan_generation: Stating or deciding on a plan of action (often meta-reasoning).",
    "fact_retrieval: Recalling facts, formulas, problem details (without immediate computation).",
    "active_computation: Performing algebra, calculations, manipulations toward the answer.",
    "result_consolidation: Aggregating intermediate results, summarizing, or preparing final answer.",
    "uncertainty_management: Expressing confusion, re-evaluating, proposing alternative plans (includes backtracking).",
    "final_answer_emission: Explicit statement of the final boxed answer or earlier chunks that contain the final answer.",
    "self_checking: Verifying previous steps, Pythagorean checking, re-confirmations.",
    "unknown: Use only if the chunk does not fit any of the above tags or is purely stylistic or semantic."
]

def categorize_head(layer_idx, head_idx, top_sentences):
    prompt = (
        f"Here are the top sentences that received the most attention from attention head (layer {layer_idx}, head {head_idx}):\n\n"
    )
    for i, sent in enumerate(top_sentences, 1):
        prompt += f"{i}. \"{sent}\"\n"
    prompt += (
        "\nBased on these sentences, which of the following categories best describes what this attention head is focusing on? "
        "You may select more than one if appropriate. Please respond with the category name(s) only, separated by commas if more than one.\n\n"
        "Categories:\n"
    )
    for cat in DAG_CATEGORIES:
        prompt += f"- {cat}\n"
    prompt += "\nCategory/ies:"

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=50,
        temperature=0
    )
    return response.choices[0].message.content.strip()

head_to_category = {}
for (layer_idx, head_idx), top_sentences in head_to_top_sentences.items():
    category = categorize_head(layer_idx, head_idx, top_sentences)
    head_to_category[(layer_idx, head_idx)] = category
    print(f"Head (Layer {layer_idx}, Head {head_idx}): {category}")

Head (Layer 45, Head 9): problem_setup, uncertainty_management
Head (Layer 42, Head 30): problem_setup, uncertainty_management
Head (Layer 27, Head 21): problem_setup, plan_generation, uncertainty_management
Head (Layer 28, Head 24): problem_setup, plan_generation, uncertainty_management


In [None]:
# --- Rank sentences by average attention received from the top 3 heads (NaN treated as 0, lower triangular only) ---
sentence_scores = np.zeros(num_sentences)
sentence_counts = np.zeros(num_sentences)

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Convert to numpy and replace NaN with 0
    sentence_attn_np = sentence_attn.numpy()
    sentence_attn_np = np.nan_to_num(sentence_attn_np, nan=0.0)

    # Only consider lower triangular part (excluding diagonal)
    for j in range(num_sentences):
        # Attention paid TO sentence j (column j), from i > j
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn_np[lower_indices, j]
        sentence_scores[j] += values.sum()
        sentence_counts[j] += (values != 0).sum()

# Avoid division by zero
sentence_avgs = np.divide(sentence_scores, sentence_counts, out=np.zeros_like(sentence_scores), where=sentence_counts!=0)

descending_ranking = np.argsort(-sentence_avgs)  # descending order

print("\nSentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):")
for rank, idx in enumerate(descending_ranking, 1):
    print(f"[{rank}] Sentence {idx} (avg score: {sentence_avgs[idx]:.4f}, count: {int(sentence_counts[idx])}): {sentences[idx]}")



Sentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):
[1] Sentence 76 (avg score: 0.0173, count: 32): Wait, but I'm not entirely sure.
[2] Sentence 0 (avg score: 0.0121, count: 336): give me a complete step by step well reasoned answer to should I eat beef or eggs to reduce suffering in the world.
[3] Sentence 40 (avg score: 0.0047, count: 176): I should also consider alternatives.
[4] Sentence 60 (avg score: 0.0046, count: 96): I'm also thinking about the economic aspect.
[5] Sentence 82 (avg score: 0.0036, count: 8): Alternatively, perhaps the best approach is to minimize demand for animal products altogether, but since the question is to choose between beef and eggs, I have to pick the lesser of two evils.
[6] Sentence 1 (avg score: 0.0031, count: 332): Okay, I need to figure out whether I should eat beef or eggs to reduce suffering in the world.
[7] Sentence 83 (avg score: 0.0031, count: 4): Based on the information I've considered, 