In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32


# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/680 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-000002.safetensors:   0%|          | 0.00/8.61G [00:00<?, ?B/s]

model-00002-of-000002.safetensors:   0%|          | 0.00/6.62G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb):

In [12]:
# The math problem prompt (repo style: uses <think> and expects \\boxed{})
problem = (
    "When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?"
)
prompt = (
    "Solve this math problem step by step. Go step by step in as much detail as possible. You MUST put your final answer in \\boxed{}. "
    f"Problem: {problem} Solution: \n<think>\n"
)

GROUND_TRUTH_ANSWER = "19"  # Only the number, as extracted from \boxed{}
MAX_ATTEMPTS = 1
"""
# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()
"""

for attempt in range(1, MAX_ATTEMPTS + 1):
    print(f"\nAttempt {attempt}...")
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate a chain-of-thought solution (repo-style settings)
    with torch.no_grad():
        generated_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=2048,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            do_sample=True,  # repo style: sampling
            temperature=0.6,
            top_p=0.95,
        ).sequences

    generated_ids = generated_ids[0]  # Remove batch dim if present

    # Decode the generated text
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    print("\nGenerated CoT solution:\n", text)

    # Extract answer using repo's method
    from utils import extract_boxed_answers
    answers = extract_boxed_answers(text)
    answer = answers[0] if answers else ""
    print(f"Extracted answer: {answer}")

    if answer == GROUND_TRUTH_ANSWER:
        print("Correct answer found!")
        break
    else:
        print("Incorrect answer, retrying...")

else:
    print("Failed to generate the correct answer after max attempts.")


# Split into sentences/chunks
sentences = split_solution_into_chunks(text)
print("\nSentences:")
for i, s in enumerate(sentences):
    print(f"[{i}] {s}")

# Get character and token ranges for each chunk
chunk_char_ranges = get_chunk_ranges(text, sentences)
chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)

num_sentences = len(sentences)
# Run model again to get attention weights for the generated sequence
full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
with torch.no_grad():
    outputs = model(
        generated_ids.unsqueeze(0),
        attention_mask=full_attention_mask,
        output_attentions=True,
        return_dict=True
    )
    attn_weights = outputs.attentions  # tuple: (num_layers, batch, num_heads, seq, seq)

# --- Kurtosis calculation (repo-style: vertical scores of chunk-averaged matrix) ---
def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=10, drop_first=0):
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore :, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

attn_shape = attn_weights[0].shape  # (batch, num_heads, seq, seq)
num_layers = len(attn_weights)
num_heads = attn_shape[1]
kurtosis_list = []  # List of (kurtosis, layer_idx, head_idx)
"""
for layer_idx in range(num_layers):
    for head_idx in range(num_heads):
        layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
        avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
        vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
        kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
        kurtosis_list.append((kurt, layer_idx, head_idx))
"""
# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])

#top_heads = kurtosis_list[:12]



vert_scores_list = []
for layer in range(1, num_layers):
    for head in range(num_heads):
        layer_attn = attn_weights[layer][0, head].cpu().numpy()  # (seq, seq)
        avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
        vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=1)
        # Aggregate: use mean of vert_scores (ignoring NaNs)
        score = np.nanmax(vert_scores)
        vert_scores_list.append((score, layer, head))

# Sort by score descending (highest vert_scores)
vert_scores_list.sort(key=lambda x: x[0], reverse=True)

# Get the top 12 (layer, head) pairs
top_heads = vert_scores_list[:12]
top_heads.append((2, 36,6))


Attempt 1...



Generated CoT solution:
 Solve this math problem step by step. Go step by step in as much detail as possible. Use many sentences, and check your work. You MUST put your final answer in \boxed{}. Problem: When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have? Solution: 
<think>
Okay, so I need to figure out how many bits the hexadecimal number 66666 has when converted to binary. Hmm, let me start by understanding what the problem is asking. I know that hexadecimal (base-16) and binary (base-2) are both number systems used in computing. Each hexadecimal digit corresponds to four binary digits, or bits. So, maybe I can use that to find the number of bits.

First, I should confirm how many hexadecimal digits are in 66666. Let me count them: 6, 6, 6, 6, 6. That's five digits. Each of these digits represents four bits. So, if I multiply the number of hexadecimal digits by four, I might get the number of bits. Let me do that: 5 digits * 4 bits/digit =

  vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan


In [13]:
top_heads.pop()

(2, 36, 6)

In [14]:
"""
kurtosis_list = []
for _, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
    avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
    vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
    kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
    kurtosis_list.append((kurt, layer_idx, head_idx))

# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])

top_heads = kurtosis_list[:4]
"""

'\nkurtosis_list = []\nfor _, layer_idx, head_idx in top_heads:\n    layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)\n    avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)\n    vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)\n    kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")\n    kurtosis_list.append((kurt, layer_idx, head_idx))\n\n# Exclude layer 0 from kurtosis analysis\nkurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]\n\n# Sort by kurtosis descending and take top 3\nkurtosis_list.sort(reverse=True, key=lambda x: x[0])\n\ntop_heads = kurtosis_list[:4]\n'

In [15]:
vis_mats   = []   # a list of (num_sentences × num_sentences) tensors
head_names = []

print("\nTop 3 heads by kurtosis (repo-style, excluding layer 0):")
for rank, (kurt, layer_idx, head_idx) in enumerate(top_heads, 1):
    print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Kurtosis: {kurt}")
    # Compute sentence-level attention matrix for this head
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn
    print(f"Sentence-level attention matrix for layer {layer_idx}, head {head_idx} (shape: {sentence_attn.shape}):")
    print(sentence_attn[:5, :5])

    blown_up_attn =10* sentence_attn / sentence_attn.max()

    vis_mats.append(blown_up_attn.detach().cpu())
    head_names.append(f"L{layer_idx}-H{head_idx}")


Top 3 heads by kurtosis (repo-style, excluding layer 0):
[1] Layer 23, Head 4, Kurtosis: 0.0289764404296875


Sentence-level attention matrix for layer 23, head 4 (shape: torch.Size([87, 87])):
tensor([[0.0035, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0043, 0.0029, 0.0000, 0.0000, 0.0000],
        [0.0035, 0.0019, 0.0011, 0.0000, 0.0000],
        [0.0052, 0.0047, 0.0005, 0.0002, 0.0000],
        [0.0034, 0.0029, 0.0003, 0.0005, 0.0010]])
[2] Layer 9, Head 1, Kurtosis: 0.02215576171875
Sentence-level attention matrix for layer 9, head 1 (shape: torch.Size([87, 87])):
tensor([[0.0048, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0078, 0.0112, 0.0000, 0.0000, 0.0000],
        [0.0039, 0.0174, 0.0061, 0.0000, 0.0000],
        [0.0031, 0.0108, 0.0094, 0.0056, 0.0000],
        [0.0058, 0.0178, 0.0021, 0.0013, 0.0052]])
[3] Layer 1, Head 2, Kurtosis: 0.01424407958984375
Sentence-level attention matrix for layer 1, head 2 (shape: torch.Size([87, 87])):
tensor([[0.0067, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0170, 0.0018, 0.0000, 0.0000, 0.0000],
        [0.0192, 0.0136, 0.0015, 0.0000, 0.0000],

In [16]:
heads_tensor = torch.stack(vis_mats)               # (k, S, S)

display(
    cv.attention.attention_heads(
        attention           = heads_tensor.numpy(),   # NumPy or list is fine
        tokens              = sentences,              # axis labels
        attention_head_names= head_names,             # hover label
        mask_upper_tri      = False                   # we aggregated, so not causal
    )
)

In [17]:
top_k = 5  # Number of top sentences per head
head_to_top_sentences = {}

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Only consider lower triangular part (excluding diagonal)
    sentence_scores = np.zeros(num_sentences)
    for j in range(num_sentences):
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn.numpy()[lower_indices, j]
        sentence_scores[j] += values.sum()

    top_indices = np.argsort(-sentence_scores)[:top_k]
    head_to_top_sentences[(layer_idx, head_idx)] = [sentences[idx] for idx in top_indices]


import openai
from dotenv import load_dotenv
import os
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=api_key)


DAG_CATEGORIES = [
    "problem_setup: Parsing or rephrasing the problem (initial reading or comprehension).",
    "plan_generation: Stating or deciding on a plan of action (often meta-reasoning).",
    "fact_retrieval: Recalling facts, formulas, problem details (without immediate computation).",
    "active_computation: Performing algebra, calculations, manipulations toward the answer.",
    "result_consolidation: Aggregating intermediate results, summarizing, or preparing final answer.",
    "uncertainty_management: Expressing confusion, re-evaluating, proposing alternative plans (includes backtracking).",
    "final_answer_emission: Explicit statement of the final boxed answer or earlier chunks that contain the final answer.",
    "self_checking: Verifying previous steps, Pythagorean checking, re-confirmations.",
    "unknown: Use only if the chunk does not fit any of the above tags or is purely stylistic or semantic."
]

def categorize_head(layer_idx, head_idx, top_sentences):
    prompt = (
        f"Here are the top sentences that received the most attention from attention head (layer {layer_idx}, head {head_idx}):\n\n"
    )
    for i, sent in enumerate(top_sentences, 1):
        prompt += f"{i}. \"{sent}\"\n"
    prompt += (
        "\nBased on these sentences, which of the following categories best describes what this attention head is focusing on? "
        "You may select more than one if appropriate. Please respond with the category name(s) only, separated by commas if more than one.\n\n"
        "Categories:\n"
    )
    for cat in DAG_CATEGORIES:
        prompt += f"- {cat}\n"
    prompt += "\nCategory/ies:"

    response = client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=50,
        temperature=0
    )
    return response.choices[0].message.content.strip()

head_to_category = {}
for (layer_idx, head_idx), top_sentences in head_to_top_sentences.items():
    category = categorize_head(layer_idx, head_idx, top_sentences)
    head_to_category[(layer_idx, head_idx)] = category
    print(f"Head (Layer {layer_idx}, Head {head_idx}): {category}")

Head (Layer 23, Head 4): problem_setup, active_computation, uncertainty_management
Head (Layer 9, Head 1): uncertainty_management, self_checking
Head (Layer 1, Head 2): - uncertainty_management
- problem_setup
Head (Layer 3, Head 12): problem_setup, active_computation, uncertainty_management
Head (Layer 19, Head 2): problem_setup, uncertainty_management, active_computation, fact_retrieval
Head (Layer 18, Head 11): problem_setup, plan_generation, uncertainty_management
Head (Layer 16, Head 24): - active_computation
- uncertainty_management
Head (Layer 4, Head 4): - uncertainty_management
- self_checking
Head (Layer 14, Head 26): - active_computation
- final_answer_emission
Head (Layer 11, Head 27): problem_setup, fact_retrieval, active_computation, uncertainty_management
Head (Layer 12, Head 2): problem_setup, active_computation, uncertainty_management
Head (Layer 5, Head 0): problem_setup, active_computation, uncertainty_management


In [None]:
# --- Rank sentences by average attention received from the top heads (NaN treated as 0, lower triangular only) ---
sentence_scores = np.zeros(num_sentences)
sentence_counts = np.zeros(num_sentences)

for kurt, layer_idx, head_idx in top_heads:
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn

    # Convert to numpy and replace NaN with 0
    sentence_attn_np = sentence_attn.numpy()
    sentence_attn_np = np.nan_to_num(sentence_attn_np, nan=0.0)

    # Only consider lower triangular part (excluding diagonal)
    for j in range(num_sentences):
        # Attention paid TO sentence j (column j), from i > j
        lower_indices = np.arange(j+1, num_sentences)
        values = sentence_attn_np[lower_indices, j]
        sentence_scores[j] += values.sum()
        sentence_counts[j] += (values != 0).sum()

# Avoid division by zero
sentence_avgs = np.divide(sentence_scores, sentence_counts, out=np.zeros_like(sentence_scores), where=sentence_counts!=0)

descending_ranking = np.argsort(-sentence_avgs)  # descending order

print("\nSentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):")
for rank, idx in enumerate(descending_ranking, 1):
    print(f"[{rank}] Sentence {idx} (avg score: {sentence_avgs[idx]:.4f}, count: {int(sentence_counts[idx])}): {sentences[idx]}")



Sentence ranking by average attention received from top 3 heads (lower triangular, NaN treated as 0):
[1] Sentence 82 (avg score: 0.0121, count: 48): Wait, perhaps I'm missing something.
[2] Sentence 84 (avg score: 0.0078, count: 24): Let me try to convert 419,430 to binary step by step.
[3] Sentence 83 (avg score: 0.0064, count: 36): Let me check the binary representation of 419,430.
[4] Sentence 80 (avg score: 0.0059, count: 72): But wait, according to the first method, it should be 20 bits.
[5] Sentence 81 (avg score: 0.0053, count: 60): So, which is correct?
[6] Sentence 85 (avg score: 0.0041, count: 12): First, I can divide 419,430 by 2 repeatedly and record the remainders.
[7] Sentence 63 (avg score: 0.0032, count: 265): Now, let's calculate log2(419,430).
[8] Sentence 66 (avg score: 0.0031, count: 230): But maybe I can approximate it.
[9] Sentence 55 (avg score: 0.0029, count: 351): Wait, maybe I made a mistake in calculating the decimal value.
[10] Sentence 64 (avg score: 0.00