In [17]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32


# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb):

In [None]:
# The math problem prompt (repo style: uses <think> and expects \\boxed{})
problem = (
    "When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?"
)
prompt = (
    "Solve this math problem step by step. Go step by step in as much detail as possible. You MUST put your final answer in \\boxed{}. "
    f"Problem: {problem} Solution: \n<think>\n"
)

GROUND_TRUTH_ANSWER = "19"  # Only the number, as extracted from \boxed{}
MAX_ATTEMPTS = 1
"""
# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()
"""
NUM_ATTEMPTS = 3
all_texts = []
all_attn_weights = []
all_chunk_token_ranges = []

for attempt in range(1, NUM_ATTEMPTS + 1):
    print(f"\nAttempt {attempt}...")
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate a chain-of-thought solution (repo-style settings)
    with torch.no_grad():
        generated_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=2048,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            do_sample=True,  # repo style: sampling
            temperature=0.6,
            top_p=0.95,
        ).sequences

    generated_ids = generated_ids[0]  # Remove batch dim if present

    # Decode the generated text
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    print("\nGenerated CoT solution:\n", text)
    all_texts.append(text)

    # Extract answer using repo's method
    from utils import extract_boxed_answers
    answers = extract_boxed_answers(text)
    answer = answers[0] if answers else ""
    print(f"Extracted answer: {answer}")

    # Split into sentences/chunks
    sentences = split_solution_into_chunks(text)
    chunk_char_ranges = get_chunk_ranges(text, sentences)
    chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)
    all_chunk_token_ranges.append(chunk_token_ranges)

    # Get attention weights for the generated sequence
    full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
    with torch.no_grad():
        outputs = model(
            generated_ids.unsqueeze(0),
            attention_mask=full_attention_mask,
            output_attentions=True,
            return_dict=True
        )
        attn_weights = outputs.attentions  # tuple: (num_layers, batch, num_heads, seq, seq)
        all_attn_weights.append(attn_weights)

# --- Aggregate vertical scores across attempts ---

num_layers = len(all_attn_weights[0])
num_heads = all_attn_weights[0][0].shape[1]

def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=10, drop_first=0):
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore :, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

# For each (layer, head), collect vert_scores from each attempt and average them
vert_scores_dict = {}  # (layer, head) -> list of vert_scores

for attempt in range(NUM_ATTEMPTS):
    attn_weights = all_attn_weights[attempt]
    chunk_token_ranges = all_chunk_token_ranges[attempt]
    for layer in range(1, num_layers):  # skip layer 0
        for head in range(num_heads):
            layer_attn = attn_weights[layer][0, head].cpu().numpy()  # (seq, seq)
            avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
            vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=1)
            score = np.nanmax(vert_scores)
            key = (layer, head)
            if key not in vert_scores_dict:
                vert_scores_dict[key] = []
            vert_scores_dict[key].append(score)

# Compute average vert_score for each (layer, head)
avg_vert_scores = []
for (layer, head), scores in vert_scores_dict.items():
    avg_score = np.nanmean(scores)
    avg_vert_scores.append((avg_score, layer, head))

# Sort by average vert_score descending
avg_vert_scores.sort(key=lambda x: x[0], reverse=True)

# Get the top 12 (layer, head) pairs
top_heads = avg_vert_scores[:12]

print("\nTop heads by average vert_score over 3 attempts:")
for score, layer, head in top_heads:
    print(f"Layer {layer}, Head {head}: avg vert_score = {score:.4f}")
top_heads.append((2, 36,6))


Attempt 1...



Generated CoT solution:
 Solve this math problem step by step. Go step by step in as much detail as possible. You MUST put your final answer in \boxed{}. Problem: When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have? Solution: 
<think>
Okay, so I have this math problem here: When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have? Hmm, okay, let's try to figure this out step by step. I'm a bit rusty on number bases, but I think I can handle this.

First off, I know that base-16, or hexadecimal, is a number system with 16 digits. Each digit represents a value from 0 to 15. These digits are usually represented by 0-9 and then A-F for 10-15. So, the number 66666 in hexadecimal is a number where each '6' is a single digit in base 16. Got that.

Now, the question is asking for the number of bits (base-2 digits) when this hexadecimal number is converted to binary. So, essentially, I need to convert 66666 (base 

In [10]:
top_heads.pop()

(2, 36, 6)

In [11]:
"""
kurtosis_list = []
for _, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
    avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
    vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
    kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
    kurtosis_list.append((kurt, layer_idx, head_idx))

# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])

top_heads = kurtosis_list[:4]
"""

'\nkurtosis_list = []\nfor _, layer_idx, head_idx in top_heads:\n    layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)\n    avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)\n    vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)\n    kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")\n    kurtosis_list.append((kurt, layer_idx, head_idx))\n\n# Exclude layer 0 from kurtosis analysis\nkurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]\n\n# Sort by kurtosis descending and take top 3\nkurtosis_list.sort(reverse=True, key=lambda x: x[0])\n\ntop_heads = kurtosis_list[:4]\n'

In [12]:
# Choose which attempt to visualize (0, 1, or 2)
attempt_to_visualize = 1
attn_weights = all_attn_weights[attempt_to_visualize]
chunk_token_ranges = all_chunk_token_ranges[attempt_to_visualize]
num_sentences = len(chunk_token_ranges)

vis_mats   = []
head_names = []

print("\nGlobal top heads on this attempt:")
for rank, (score, layer_idx, head_idx) in enumerate(top_heads, 1):
    # Check bounds to avoid IndexError
    if layer_idx >= len(attn_weights):
        print(f"Skipping Layer {layer_idx}, out of range for this attempt.")
        continue
    if head_idx >= attn_weights[layer_idx].shape[1]:
        print(f"Skipping Head {head_idx} in Layer {layer_idx}, out of range for this attempt.")
        continue

    print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Avg Score: {score}")
    layer_attn = attn_weights[layer_idx][0, head_idx]
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn
    print(f"Sentence-level attention matrix for layer {layer_idx}, head {head_idx} (shape: {sentence_attn.shape}):")
    print(sentence_attn[:5, :5])

    blown_up_attn = 10 * sentence_attn / sentence_attn.max()
    vis_mats.append(blown_up_attn.detach().cpu())
    head_names.append(f"L{layer_idx}-H{head_idx}")


Global top heads on this attempt:
[1] Layer 1, Head 0, Avg Score: nan
Sentence-level attention matrix for layer 1, head 0 (shape: torch.Size([94, 94])):
tensor([[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]])
[2] Layer 1, Head 1, Avg Score: nan
Sentence-level attention matrix for layer 1, head 1 (shape: torch.Size([94, 94])):
tensor([[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]])
[3] Layer 1, Head 2, Avg Score: nan
Sentence-level attention matrix for layer 1, head 2 (shape: torch.Size([94, 94])):
tensor([[nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan]])
[4] Layer 1, Head 3, Avg Score: nan
Sentence-level attention matrix for layer 1,

In [13]:
heads_tensor = torch.stack(vis_mats)               # (k, S, S)

display(
    cv.attention.attention_heads(
        attention           = heads_tensor.numpy(),   # NumPy or list is fine
        tokens              = sentences,              # axis labels
        attention_head_names= head_names,             # hover label
        mask_upper_tri      = False                   # we aggregated, so not causal
    )
)

In [None]:
top_k = 5  # Number of top sentences per head
head_to_top_sentences = {}

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Only consider lower triangular part (excluding diagonal)
    sentence_scores = np.zeros(num_sentences)
    for j in range(num_sentences):
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn.numpy()[lower_indices, j]
        sentence_scores[j] += values.sum()

    top_indices = np.argsort(-sentence_scores)[:top_k]
    head_to_top_sentences[(layer_idx, head_idx)] = [sentences[idx] for idx in top_indices]


import openai
from dotenv import load_dotenv
import os
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=api_key)


DAG_CATEGORIES = [
    "problem_setup: Parsing or rephrasing the problem (initial reading or comprehension).",
    "plan_generation: Stating or deciding on a plan of action (often meta-reasoning).",
    "fact_retrieval: Recalling facts, formulas, problem details (without immediate computation).",
    "active_computation: Performing algebra, calculations, manipulations toward the answer.",
    "result_consolidation: Aggregating intermediate results, summarizing, or preparing final answer.",
    "uncertainty_management: Expressing confusion, re-evaluating, proposing alternative plans (includes backtracking).",
    "final_answer_emission: Explicit statement of the final boxed answer or earlier chunks that contain the final answer.",
    "self_checking: Verifying previous steps, Pythagorean checking, re-confirmations.",
    "unknown: Use only if the chunk does not fit any of the above tags or is purely stylistic or semantic."
]

def categorize_head(layer_idx, head_idx, top_sentences):
    prompt = (
        f"Here are the top sentences that received the most attention from attention head (layer {layer_idx}, head {head_idx}):\n\n"
    )
    for i, sent in enumerate(top_sentences, 1):
        prompt += f"{i}. \"{sent}\"\n"
    prompt += (
        "\nBased on these sentences, which of the following categories best describes what this attention head is focusing on? "
        "You may select more than one if appropriate. Please respond with the category name(s) only, separated by commas if more than one.\n\n"
        "Categories:\n"
    )
    for cat in DAG_CATEGORIES:
        prompt += f"- {cat}\n"
    prompt += "\nCategory/ies:"

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=50,
        temperature=0
    )
    return response.choices[0].message.content.strip()

head_to_category = {}
for (layer_idx, head_idx), top_sentences in head_to_top_sentences.items():
    category = categorize_head(layer_idx, head_idx, top_sentences)
    head_to_category[(layer_idx, head_idx)] = category
    print(f"Head (Layer {layer_idx}, Head {head_idx}): {category}")

Head (Layer 1, Head 0): - active_computation
- uncertainty_management
Head (Layer 1, Head 1): - active_computation
- plan_generation
Head (Layer 1, Head 2): - active_computation
- plan_generation
Head (Layer 1, Head 3): - active_computation
- uncertainty_management
Head (Layer 1, Head 4): - active_computation
- uncertainty_management
Head (Layer 1, Head 5): - active_computation
- plan_generation
Head (Layer 1, Head 6): - active_computation
- uncertainty_management
Head (Layer 1, Head 7): - active_computation
- plan_generation
Head (Layer 1, Head 8): - active_computation
- plan_generation
- uncertainty_management
Head (Layer 1, Head 9): - active_computation
- plan_generation
Head (Layer 1, Head 10): - active_computation
- uncertainty_management
Head (Layer 1, Head 11): - active_computation
- uncertainty_management


In [None]:
# --- Rank sentences by average attention received from the top heads (NaN treated as 0, lower triangular only) ---
sentence_scores = np.zeros(num_sentences)
sentence_counts = np.zeros(num_sentences)

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Convert to numpy and replace NaN with 0
    sentence_attn_np = sentence_attn.numpy()
    sentence_attn_np = np.nan_to_num(sentence_attn_np, nan=0.0)

    # Only consider lower triangular part (excluding diagonal)
    for j in range(num_sentences):
        # Attention paid TO sentence j (column j), from i > j
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn_np[lower_indices, j]
        sentence_scores[j] += values.sum()
        sentence_counts[j] += (values != 0).sum()

# Avoid division by zero
sentence_avgs = np.divide(sentence_scores, sentence_counts, out=np.zeros_like(sentence_scores), where=sentence_counts!=0)

descending_ranking = np.argsort(-sentence_avgs)  # descending order

print("\nSentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):")
for rank, idx in enumerate(descending_ranking, 1):
    print(f"[{rank}] Sentence {idx} (avg score: {sentence_avgs[idx]:.4f}, count: {int(sentence_counts[idx])}): {sentences[idx]}")



Sentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):
[1] Sentence 0 (avg score: 0.0000, count: 0): Okay, so I need to figure out how many base-2 digits, or bits, the base-16 number 66666 has when written in binary.
[2] Sentence 67 (avg score: 0.0000, count: 0): First, find the highest power of 2 less than 419,430.
[3] Sentence 66 (avg score: 0.0000, count: 0): But maybe it's faster to use a calculator, but since I'm doing this manually, let's see.
[4] Sentence 65 (avg score: 0.0000, count: 0): Let me try the division method for a bit.
[5] Sentence 64 (avg score: 0.0000, count: 0): Alternatively, we can use a method where we find the largest power of 2 less than or equal to the number and work our way down.
[6] Sentence 63 (avg score: 0.0000, count: 0): To convert a decimal number to binary, we can repeatedly divide by 2 and record the remainders.
[7] Sentence 62 (avg score: 0.0000, count: 0): Now, let's convert 419,430 to binary.
[8] Se