In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import split_solution_into_chunks, get_chunk_ranges, get_chunk_token_ranges
import numpy as np
from scipy import stats
import circuitsvis as cv
from IPython.display import display, clear_output
import ipywidgets as widgets

# Model and device setup
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

# Load tokenizer and model
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
model.eval()

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 5120)
    (layers): ModuleList(
      (0-47): 48 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=True)
          (k_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (v_proj): Linear(in_features=5120, out_features=1024, bias=True)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): Qwen2RMSNorm((5120,), eps=1e-05)
      )
    )
    (norm): Qwen2RMSNorm((5120,), eps=1e-05)
    (rotary_emb

In [8]:
# Helper functions for attention analysis
def avg_matrix_by_chunk(matrix, chunk_token_ranges):
    """Average attention matrix values within each chunk."""
    n = len(chunk_token_ranges)
    avg_mat = np.zeros((n, n), dtype=np.float32)
    for i, (start_i, end_i) in enumerate(chunk_token_ranges):
        for j, (start_j, end_j) in enumerate(chunk_token_ranges):
            region = matrix[start_i:end_i, start_j:end_j]
            if region.size > 0:
                avg_mat[i, j] = region.mean().item()
    return avg_mat

def get_attn_vert_scores(avg_mat, proximity_ignore=10, drop_first=0):
    """Calculate vertical attention scores."""
    n = avg_mat.shape[0]
    vert_scores = []
    for i in range(n):
        vert_lines = avg_mat[i + proximity_ignore:, i]
        vert_score = np.nanmean(vert_lines) if len(vert_lines) > 0 else np.nan
        vert_scores.append(vert_score)
    vert_scores = np.array(vert_scores)
    if drop_first > 0:
        vert_scores[:drop_first] = np.nan
        vert_scores[-drop_first:] = np.nan
    return vert_scores

def generate_visualization(prompt, progress_callback=None):
    """
    Generate and visualize attention patterns for a given prompt.
    
    Args:
        prompt (str): The input prompt to analyze
        progress_callback (callable): Optional callback function to update progress
    """
    if progress_callback:
        progress_callback(0.1, "Tokenizing input...")
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    if progress_callback:
        progress_callback(0.15, "Preparing to generate solution...")
    
    # Reduce max tokens since we don't need that many for math problems
    max_tokens = 256
    
    # Generate solution with streaming progress updates
    with torch.no_grad():
        # Start generation message
        if progress_callback:
            progress_callback(0.2, "Generating solution [■□□□□]")
        
        generated_ids = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_tokens,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            do_sample=True,
            temperature=0.9,
            top_p=0.95,
        ).sequences
        
        # Update progress after generation
        if progress_callback:
            progress_callback(0.25, "Generation complete!")

    generated_ids = generated_ids[0]  # Remove batch dim if present
    
    if progress_callback:
        progress_callback(0.3, "Processing text...")
        
    # Decode the generated text
    text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Split into sentences/chunks
    sentences = split_solution_into_chunks(text)
    
    if progress_callback:
        progress_callback(0.4, "Analyzing sentence structure...")
        
    # Get character and token ranges
    chunk_char_ranges = get_chunk_ranges(text, sentences)
    chunk_token_ranges = get_chunk_token_ranges(text, chunk_char_ranges, tokenizer)
    
    if progress_callback:
        progress_callback(0.5, "Computing attention weights...")
        
    # Get attention weights
    full_attention_mask = torch.ones((1, generated_ids.shape[0]), device=model.device)
    with torch.no_grad():
        outputs = model(
            generated_ids.unsqueeze(0),
            attention_mask=full_attention_mask,
            output_attentions=True,
            return_dict=True
        )
        attn_weights = outputs.attentions
        
    if progress_callback:
        progress_callback(0.7, "Calculating attention patterns...")
        
    # Calculate vertical scores for attention patterns
    num_layers = 3
    attn_shape = attn_weights[0].shape  # (batch, num_heads, seq, seq)
    num_heads = attn_shape[1]  # Get number of attention heads
    vert_scores_list = []
    for layer in range(1, num_layers):
        for head in range(num_heads):
            layer_attn = attn_weights[layer][0, head].cpu().numpy()
            avg_mat = avg_matrix_by_chunk(layer_attn, chunk_token_ranges)
            vert_scores = get_attn_vert_scores(avg_mat, proximity_ignore=4, drop_first=1)
            score = np.nanmax(vert_scores)
            vert_scores_list.append((score, layer, head))
    
    # Sort and get top heads
    vert_scores_list.sort(key=lambda x: x[0], reverse=True)
    top_heads = vert_scores_list[:12]
    
    if progress_callback:
        progress_callback(0.9, "Preparing visualizations...")
    
    # Return all necessary data for visualization
    return {
        'text': text,
        'sentences': sentences,
        'generated_ids': generated_ids,
        'attn_weights': attn_weights,
        'top_heads': top_heads[:6],  # Return top 6 for visualization
        'chunk_token_ranges': chunk_token_ranges  # Add chunk_token_ranges to results
    }

def display_visualization_results(results):
    """
    Display the visualization results.
    
    Args:
        results (dict): Dictionary containing visualization data
    """
    from IPython.display import display
    import torch
    
    vis_mats = []
    head_names = []
    
    # Process each attention head
    for score, layer, head in results['top_heads']:
        # Get attention matrix for this head
        layer_attn = results['attn_weights'][layer][0, head]  # (seq, seq)
        
        # Create sentence-level attention matrix
        sentence_attn = torch.zeros(len(results['sentences']), len(results['sentences']))
        for i, (start_i, end_i) in enumerate(results['chunk_token_ranges']):
            for j, (start_j, end_j) in enumerate(results['chunk_token_ranges']):
                if start_i >= end_i or start_j >= end_j:
                    continue
                sentence_pair_attn = layer_attn[start_i:end_i, start_j:end_j]
                if sentence_pair_attn.numel() == 0:
                    continue
                avg_attn = sentence_pair_attn.mean()
                sentence_attn[i, j] = avg_attn
        
        # Scale attention values
        blown_up_attn = 10 * sentence_attn / sentence_attn.max()
        
        # Add to visualization lists
        vis_mats.append(blown_up_attn.detach().cpu())
        head_names.append(f"L{layer}-H{head} ({score:.3f})")
    
    # Stack matrices and display
    heads_tensor = torch.stack(vis_mats)
    
    # Display using attention_heads
    display(cv.attention.attention_heads(
        attention=heads_tensor.numpy(),
        tokens=results['sentences'],
        attention_head_names=head_names,
        mask_upper_tri=False
    ))


In [12]:
# Check CircuitVis installation
import circuitsvis as cv
print("CircuitVis version:", cv.__version__)
print("\nAvailable modules in cv:", dir(cv))
print("\nAvailable functions in cv.attention:", dir(cv.attention))


CircuitVis version: 1.43.3

Available modules in cv: ['__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', 'activations', 'attention', 'circuitsvis', 'examples', 'logits', 'tokens', 'topk_samples', 'topk_tokens', 'utils', 'version']

Available functions in cv.attention: ['List', 'Optional', 'RenderedHTML', 'Union', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'attention_heads', 'attention_pattern', 'attention_patterns', 'np', 'render', 'torch']


In [10]:
# Test CircuitVis visualization directly (outside widget)
prompt = "When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?"

# Generate results
results = generate_visualization(prompt, progress_callback=None)

# Display visualizations directly
from IPython.display import display

for score, layer, head in results['top_heads']:
    print(f"Layer {layer}, Head {head}, Score: {score:.3f}")
    
    attn_mat = results['attn_weights'][layer][0, head].cpu().numpy()
    scaled_attn = 10 * attn_mat / attn_mat.max()
    
    # Create and display attention visualization
    display(cv.attention.attention_patterns(
        tokens=tokenizer.convert_ids_to_tokens(results['generated_ids'].cpu()),
        attention=scaled_attn
    ))


Layer 1, Head 15, Score: 0.009


Layer 1, Head 5, Score: 0.009


Layer 2, Head 10, Score: 0.009


Layer 2, Head 14, Score: 0.009


Layer 2, Head 26, Score: 0.009


Layer 2, Head 27, Score: 0.008


In [9]:
# Create the input widgets
text_input = widgets.Textarea(
    value='When the base-16 number 66666 is written in base 2, how many base-2 digits (bits) does it have?',
    description='Problem:',
    layout=widgets.Layout(width='800px', height='100px')
)

analyze_button = widgets.Button(
    description='Analyze Problem',
    button_style='primary',
    layout=widgets.Layout(width='200px')
)

# Create progress bar and status
progress_bar = widgets.FloatProgress(
    value=0,
    min=0,
    max=1.0,
    description='Progress:',
    bar_style='success',  # This makes it green
    style={'bar_color': '#4CAF50'},  # A nice material design green
    layout=widgets.Layout(width='600px')
)

# Create a more detailed status label with HTML formatting
status_label = widgets.HTML(
    value='<div style="padding: 5px; font-family: monospace;">Ready</div>'
)

# Create output widget to capture the visualization with larger height
output = widgets.Output(
    layout=widgets.Layout(
        width='100%',
        height='800px',  # Increased height to accommodate visualizations
        overflow='auto'  # Add scrolling if content is too long
    )
)

def update_progress(value, message):
    """Update progress bar and status message"""
    progress_bar.value = value
    # Format the message with HTML styling
    status_html = f'<div style="padding: 5px; font-family: monospace;">{message}</div>'
    status_label.value = status_html

def on_button_click(b):
    # Disable button while processing
    analyze_button.disabled = True
    
    with output:
        clear_output(wait=True)  # Wait for new output before clearing
        try:
            # Generate visualization with progress updates
            results = generate_visualization(text_input.value, update_progress)
            
            # Update progress to 100%
            update_progress(1.0, "Completed!")
            
            # Display results in the widget context
            display_visualization_results(results)
            
        except Exception as e:
            from IPython.display import HTML
            display(HTML(f"<div style='color: red; padding: 10px;'>An error occurred: {str(e)}</div>"))
            update_progress(0, "Error occurred")
        finally:
            # Re-enable button
            analyze_button.disabled = False

# Connect button click to handler
analyze_button.on_click(on_button_click)

# Create widget layout
controls = widgets.VBox([
    text_input,
    widgets.HBox([analyze_button, progress_bar, status_label]),
    output
])

# Display widgets
display(controls)


VBox(children=(Textarea(value='When the base-16 number 66666 is written in base 2, how many base-2 digits (bit…