In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32


# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 5120)
    (layers): ModuleList(
      (0-47): 48 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=True)
          (k_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (v_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
      )
    )
    (norm): Qwen2RMSNorm((5120,), eps=1e-05)
    (rotary_emb

In [9]:
problem = "When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?"
prompt = problem
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

# Generate a chain-of-thought solution (repo-style settings)
with torch.no_grad():
    generated_ids = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=1024,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        do_sample=True,  # repo style: sampling
        temperature=0.9,
        top_p=0.95,
    ).sequences

generated_ids = generated_ids[0]  # Remove batch dim if present

# Decode the generated text
text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("\nGenerated CoT solution:\n", text)

# Split into sentences/chunks
sentences = split_solution_into_chunks(text)
print("\nSentences:")
for i, s in enumerate(sentences):
    print(f"[{i}] {s}")

# Get character and token ranges for each chunk
chunk_char_ranges = get_chunk_ranges(text, sentences)
chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)

num_sentences = len(sentences)
# Run model again to get attention weights for the generated sequence
full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
with torch.no_grad():
    outputs = model(
        generated_ids.unsqueeze(0),
        attention_mask=full_attention_mask,
        output_attentions=True,
        return_dict=True
    )
    attn_weights = outputs.attentions  # tuple: (num_layers, batch, num_heads, seq, seq)

# --- Kurtosis calculation (repo-style: vertical scores of chunk-averaged matrix) ---
def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0):
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore :, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

attn_shape = attn_weights[0].shape  # (batch, num_heads, seq, seq)
num_layers = len(attn_weights)
num_heads = attn_shape[1]
kurtosis_list = []  # List of (kurtosis, layer_idx, head_idx)
for layer_idx in range(num_layers):
    for head_idx in range(num_heads):
        layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()  # (seq, seq)
        avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
        vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
        kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
        kurtosis_list.append((kurt, layer_idx, head_idx))

# Exclude layer 0 from kurtosis analysis
kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

# Sort by kurtosis descending and take top 3
kurtosis_list.sort(reverse=True, key=lambda x: x[0])
top_heads = kurtosis_list[:3]
top_heads.append((2, 36,6))


Generated CoT solution:
 When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?  How can I compute this without converting to base 10?

Okay, so I have this problem where I need to find out how many bits the base-16 number 66666 has when converted to base 2. And I have to do this without converting it to base 10. Hmm, interesting. I remember that each hexadecimal digit corresponds to four binary digits, so maybe I can use that somehow.

First, let me recall what each hexadecimal digit represents in binary. Hexadecimal, or base-16, uses digits from 0 to 15, right? And each of those can be represented by four bits. For example, 0 is 0000, 1 is 0001, up to F, which is 1111 in binary. So, if each hex digit is four bits, then the total number of bits should be the number of hex digits multiplied by four. But wait, that's only if there are no leading zeros. But in the case of numbers, if the most significant digit is not zero, then the number of bits

In [12]:
vis_mats   = []   # a list of (num_sentences × num_sentences) tensors
head_names = []

print("\nTop 3 heads by kurtosis (repo-style, excluding layer 0):")
for rank, (kurt, layer_idx, head_idx) in enumerate(top_heads, 1):
    print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Kurtosis: {kurt}")
    # Compute sentence-level attention matrix for this head
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn
    print(f"Sentence-level attention matrix for layer {layer_idx}, head {head_idx} (shape: {sentence_attn.shape}):")
    print(sentence_attn[:5, :5])

    blown_up_attn = sentence_attn * 1000

    vis_mats.append(blown_up_attn.detach().cpu())
    head_names.append(f"L{layer_idx}-H{head_idx}")


Top 3 heads by kurtosis (repo-style, excluding layer 0):
[1] Layer 18, Head 13, Kurtosis: 43.0212747156853
Sentence-level attention matrix for layer 18, head 13 (shape: torch.Size([52, 52])):
tensor([[2.9419e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.9419e-02, 5.1260e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.9404e-02, 4.7088e-06, 2.8610e-06, 0.0000e+00, 0.0000e+00],
        [2.9404e-02, 3.1590e-06, 2.9206e-06, 1.4901e-06, 0.0000e+00],
        [2.9404e-02, 1.6689e-05, 9.1791e-06, 2.8014e-06, 5.1260e-06]])
[2] Layer 13, Head 27, Kurtosis: 43.02127394150359
Sentence-level attention matrix for layer 13, head 27 (shape: torch.Size([52, 52])):
tensor([[2.9419e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.9404e-02, 2.4796e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.9419e-02, 1.4901e-06, 7.1526e-07, 0.0000e+00, 0.0000e+00],
        [2.9404e-02, 2.8014e-06, 1.5497e-06, 2.5630e-06, 0.0000e+00],
        [2.9419e-02, 6.5565e-07, 3.5763e

In [13]:
heads_tensor = torch.stack(vis_mats)               # (k, S, S)

display(
    cv.attention.attention_heads(
        attention           = heads_tensor.numpy(),   # NumPy or list is fine
        tokens              = sentences,              # axis labels
        attention_head_names= head_names,             # hover label
        mask_upper_tri      = False                   # we aggregated, so not causal
    )
)