In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display, clear_output
import ipywidgets as widgets

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 5120)
    (layers): ModuleList(
      (0-47): 48 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=True)
          (k_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (v_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
      )
    )
    (norm): Qwen2RMSNorm((5120,), eps=1e-05)
    (rotary_emb

In [16]:
def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0):
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore :, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

def extract_final_answer(text):
    """Extract the final answer with context from the text and format it with line breaks."""
    # Split into sentences and reverse to start from the end
    sentences = text.split('\n')
    sentences = [s.strip() for s in sentences if s.strip()]
    
    # First try to find conclusion markers
    conclusion_markers = ["answer is", "therefore", "thus", "so,", "hence", "finally", "in conclusion"]
    
    def format_answer(answer_text):
        """Format the answer text with line breaks for better readability."""
        # Split on common conjunction words or punctuation
        split_points = [". ", ", ", " and "]
        formatted = answer_text
        for split_point in split_points:
            parts = formatted.split(split_point)
            formatted = "\n".join(part.strip() for part in parts)
        return formatted
    
    # Try to find a conclusion sentence and include context
    for i, sentence in enumerate(reversed(sentences)):
        if any(marker in sentence.lower() for marker in conclusion_markers):
            # Get the previous sentence for context if available
            if i < len(sentences) - 1:
                context = sentences[-(i+2)].strip()
                if context and not any(marker in context.lower() for marker in conclusion_markers):
                    return format_answer(f"{context} {sentence}")
            return format_answer(sentence)
    
    # If no conclusion marker found, try to find a group of related sentences at the end
    for i in range(len(sentences)-1, max(-1, len(sentences)-4), -1):
        if sentences[i]:
            # Look for 2-3 sentence conclusion
            if i > 0 and not any(marker in sentences[i-1].lower() for marker in ["let's", "first", "step", "now"]):
                return format_answer(f"{sentences[i-1]} {sentences[i]}")
            return format_answer(sentences[i])
    
    return "Could not find a clear final answer."

def generate_visualization(problem):
    # Clear previous outputs and suppress warnings
    clear_output(wait=True)
    import warnings
    warnings.filterwarnings('ignore')
    
    # Create and display progress bar
    progress = widgets.IntProgress(
        value=0,
        min=0,
        max=100,
        description='Analyzing:',
        bar_style='info',
        orientation='horizontal',
        layout=widgets.Layout(width='50%')
    )
    display(progress)
    print(f"Problem: {problem}\n")
    
    progress.value = 10  # Starting analysis
    # Add prompt for final answer
    full_prompt = f"{problem}\n\nPlease solve this step by step and end with a short, concise final answer (1-2 sentences max) starting with 'Therefore, the final answer is' or 'In conclusion, the answer is'."
    
    # Tokenize input
    inputs = tokenizer(full_prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate a chain-of-thought solution
    progress.value = 20  # Generating solution
    with torch.no_grad():
        generated_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=1024,  # Reduced from 1024 to 512 - should be enough for most problems
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            do_sample=True,
            temperature=0.9,
            top_p=0.95,
            early_stopping=True,  # Stop when EOS token is generated
        ).sequences

    generated_ids = generated_ids[0]  # Remove batch dim if present

    # Decode the generated text
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    # print("\nGenerated CoT solution:\n", text)
    
    # Extract and display the final answer
    final_answer = extract_final_answer(text)
    print("\nFinal Answer:", final_answer)

        # Split into sentences/chunks
    sentences = split_solution_into_chunks(text)
    # print("\nSentences:")
    # for i, s in enumerate(sentences):
    #     print(f"[{i}] {s}")

    # Get character and token ranges for each chunk
    chunk_char_ranges = get_chunk_ranges(text, sentences)
    chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)

    num_sentences = len(sentences)
    # Run model again to get attention weights
    progress.value = 40  # Getting attention weights
    full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
    
    all_attn_weights = []
    with torch.no_grad():
        outputs = model(
            generated_ids.unsqueeze(0),
            attention_mask=full_attention_mask,
            output_attentions=True,
            return_dict=True
        )
        all_attn_weights = outputs.attentions
    
    attn_weights = all_attn_weights
    progress.value = 60  # Starting kurtosis calculation
    # Kurtosis calculation
    attn_shape = attn_weights[0].shape
    num_layers = len(attn_weights)
    num_heads = attn_shape[1]
    kurtosis_list = []

    for layer_idx in range(num_layers):
        for head_idx in range(num_heads):
            layer_attn = attn_weights[layer_idx][0, head_idx].cpu().numpy()
            avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
            vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=0)
            kurt = stats.kurtosis(vert_scores, fisher=True, bias=True, nan_policy="omit")
            kurtosis_list.append((kurt, layer_idx, head_idx))

    # Exclude layer 0 from kurtosis analysis
    kurtosis_list = [entry for entry in kurtosis_list if entry[1] != 0]

    # Sort by kurtosis descending and take top 3
    kurtosis_list.sort(reverse=True, key=lambda x: x[0])
    top_heads = kurtosis_list[:3]
    
    # Prepare visualization
    vis_mats = []
    head_names = []

    print("\nTop 3 heads by kurtosis (repo-style, excluding layer 0):")
    for rank, (kurt, layer_idx, head_idx) in enumerate(top_heads, 1):
        print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Kurtosis: {kurt}")
        layer_attn = attn_weights[layer_idx][0, head_idx]
        sentence_attn = torch.zeros(num_sentences, num_sentences)
        
        for i, (start_i, end_i) in enumerate(chunk_token_ranges):
            for j, (start_j, end_j) in enumerate(chunk_token_ranges):
                if start_i >= end_i or start_j >= end_j:
                    continue
                sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
                if sentence_pair_attn.numel() == 0:
                    continue
                avg_attn = sentence_pair_attn.mean()
                sentence_attn[i, j] = avg_attn

        blown_up_attn = sentence_attn * 10000
        vis_mats.append(blown_up_attn.detach().cpu())
        head_names.append(f"L{layer_idx}-H{head_idx}")

    # Create and display visualization
    progress.value = 90  # Creating visualization
    heads_tensor = torch.stack(vis_mats)
    display(
        cv.attention.attention_heads(
            attention=heads_tensor.numpy(),
            tokens=sentences,
            attention_head_names=head_names,
            mask_upper_tri=False
        )
    )
    
    # Complete
    progress.value = 100
    progress.bar_style = 'success'

In [17]:
# Create the input widgets
text_input = widgets.Textarea(
    value='When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?',
    description='Problem:',
    layout=widgets.Layout(width='800px', height='100px')
)

analyze_button = widgets.Button(
    description='Analyze Problem',
    button_style='primary',
    layout=widgets.Layout(width='200px')
)

# Create output widget to capture the visualization
output = widgets.Output()

def on_button_click(b):
    with output:
        generate_visualization(text_input.value)

analyze_button.on_click(on_button_click)

# Display the interface
display(text_input)
display(analyze_button)
display(output)


Textarea(value='When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have…

Button(button_style='primary', description='Analyze Problem', layout=Layout(width='200px'), style=ButtonStyle(…

Output()

In [13]:
vis_mats   = []   # a list of (num_sentences × num_sentences) tensors
head_names = []

print("\nTop 3 heads by kurtosis (repo-style, excluding layer 0):")
for rank, (kurt, layer_idx, head_idx) in enumerate(top_heads, 1):
    print(f"[{rank}] Layer {layer_idx}, Head {head_idx}, Kurtosis: {kurt}")
    # Compute sentence-level attention matrix for this head
    layer_attn = attn_weights[layer_idx][0, head_idx]  # (seq, seq)
    sentence_attn = torch.zeros(num_sentences, num_sentences)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            if start_i >= end_i or start_j >= end_j:
                continue
            sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
            if sentence_pair_attn.numel() == 0:
                continue
            avg_attn = sentence_pair_attn.mean()
            sentence_attn[i, j] = avg_attn
    print(f"Sentence-level attention matrix for layer {layer_idx}, head {head_idx} (shape: {sentence_attn.shape}):")
    print(sentence_attn[:5, :5])

    blown_up_attn = sentence_attn * 10000

    vis_mats.append(blown_up_attn.detach().cpu())
    head_names.append(f"L{layer_idx}-H{head_idx}")


Top 3 heads by kurtosis (repo-style, excluding layer 0):
[1] Layer 13, Head 27, Kurtosis: 35.02564079361015
Sentence-level attention matrix for layer 13, head 27 (shape: torch.Size([44, 44])):
tensor([[7.6904e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 5.4240e-06, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 3.8147e-06, 2.5034e-06, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 6.5565e-07, 7.8082e-06, 3.6359e-06, 0.0000e+00],
        [7.6904e-02, 1.7881e-07, 2.2650e-06, 3.6359e-06, 1.8477e-06]])
[2] Layer 18, Head 13, Kurtosis: 35.02564065989991
Sentence-level attention matrix for layer 18, head 13 (shape: torch.Size([44, 44])):
tensor([[7.6904e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 1.2577e-05, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 8.8215e-06, 3.1054e-05, 0.0000e+00, 0.0000e+00],
        [7.6904e-02, 6.7949e-06, 3.1590e-06, 5.5432e-06, 0.0000e+00],
        [7.6904e-02, 2.8610e-06, 2.8610

In [14]:
heads_tensor = torch.stack(vis_mats)               # (k, S, S)

display(
    cv.attention.attention_heads(
        attention           = heads_tensor.numpy(),   # NumPy or list is fine
        tokens              = sentences,              # axis labels
        attention_head_names= head_names,             # hover label
        mask_upper_tri      = False                   # we aggregated, so not causal
    )
)