In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append("../")

from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np

from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
import random

## Display utils

In [2]:
from IPython.display import HTML, display
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
def _create_html_visualization(token_strings, exit_layers, all_layers, cmap, norm):
    """Create HTML visualization with colored tokens."""
    
    html_content = """
    <div style="font-family: 'Courier New', monospace; font-size: 14px; line-height: 1.6; padding: 20px;">
    <h3>Token Exit Layer Visualization</h3>
    <div style="margin-bottom: 20px;">
    """
    
    # Add legend
    html_content += "<div style='margin-bottom: 15px;'><strong>Legend:</strong> "
    for i, layer in enumerate(all_layers):
        color = mcolors.to_hex(cmap(norm(i)))
        layer_name = f"Layer {layer}" if layer != -1 else "Final Layer"
        html_content += f"<span style='background-color: {color}; padding: 2px 6px; margin: 2px; border-radius: 3px; color: white; font-weight: bold;'>{layer_name}</span> "
    html_content += "</div>"
    
    # Add tokens
    html_content += "<div style='border: 1px solid #ccc; padding: 15px; border-radius: 5px; background-color: #f9f9f9;'>"
    
    for token, exit_layer in zip(token_strings, exit_layers):
        # Find color index for this exit layer
        try:
            layer_idx = all_layers.index(exit_layer)
        except ValueError:
            layer_idx = len(all_layers) - 1  # Default to final layer color
        
        color = mcolors.to_hex(cmap(norm(layer_idx)))
        
        # Clean token display (handle special characters)
        display_token = token.replace('<', '&lt;').replace('>', '&gt;')
        
        html_content += f"""<span style="background-color: {color}; padding: 2px 4px; margin: 1px; 
                           border-radius: 3px; color: white; font-weight: bold; 
                           display: inline-block;" title="Exit Layer: {exit_layer}">{display_token}</span>"""
    
    html_content += "</div></div></div>"
    
    return HTML(html_content)


def visualize_token_exits(tokens, exit_layers, tokenizer, early_exit_layer_idxs, 
                         display_mode='html', figsize=(15, 8), font_size=12):
    """
    Visualize tokens colored by their exit layer.
    
    Args:
        tokens: List of token IDs or token strings
        exit_layers: List of exit layer indices (-1 for final layer)
        tokenizer: Tokenizer to decode tokens
        early_exit_layer_idxs: Tensor of available early exit layers
        display_mode: 'html' for HTML display or 'matplotlib' for plot
        figsize: Figure size for matplotlib mode
        font_size: Font size for text
    """
    
    # Convert tokens to strings if they're token IDs
    if isinstance(tokens[0], int):
        token_strings = [tokenizer.decode([token], skip_special_tokens=False) for token in tokens]
    else:
        token_strings = tokens
    
    # Create color mapping
    # Use early exit layers + final layer (-1)
    all_layers = list(early_exit_layer_idxs.numpy()) + [-1]  # -1 represents final layer
    
    # Create a colormap from blue (early) to red (late)
    cmap = plt.colormaps.get_cmap('coolwarm_r')  # Blue to red
    norm = mcolors.Normalize(vmin=0, vmax=len(all_layers)-1)
    
    return _create_html_visualization(token_strings, exit_layers, all_layers, cmap, norm)
 

## Analysis

In [3]:
# Model configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

print(f"Loading model: {model_name}")
print(f"Device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # torch_dtype=torch.float16,  # Use half precision for efficiency
    device_map="auto" if device == 'cuda' else None,
    trust_remote_code=True
)

Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Device: cuda
Tokenizer loaded. Vocab size: 151643
EOS token: <｜end▁of▁sentence｜> (ID: 151643)


In [4]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

pre_transformed_conversation = format_conversation(user_prompts = [prompt], system_prompt=system_prompt)
formatted_prompt = transform_conversations(pre_transformed_conversation, prefiller)[0]

transform_conversations currently only for Deepseek models!


In [5]:
from early_exit.util import module_name_is_layer_base
early_exit_layer_idxs = []
for name, module in model.named_modules():
    if module_name_is_layer_base(name):
        # Extract layer index from module name (e.g., "model.layers.0" -> 0)
        layer_idx = int(name.split('.')[-1])
        early_exit_layer_idxs.append(layer_idx)

early_exit_layer_idxs = torch.tensor(early_exit_layer_idxs, dtype = torch.int32)  # Add inf for final layer
print(f"Early exit layer indices: {early_exit_layer_idxs}")
print(f"Total exitable layers: {len(early_exit_layer_idxs)}")  # Subtract 1 for the inf


Early exit layer indices: tensor([ 0,  5, 10, 15, 20, 25], dtype=torch.int32)
Total exitable layers: 6


In [6]:
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path

config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['max_new_tokens'] = 100

inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids
prompt_length = input_ids.shape[1]

KL_FACTOR = 1
current_input = input_ids.clone()
generated_tokens_manual = []
chosen_exit_layers = []
for step in range(config['generation']['max_new_tokens']):
    with torch.no_grad():
        # Forward pass
        outputs = model(current_input, use_cache=True, output_hidden_states=True)
        # print(outputs.logits.shape)
        logits = outputs.logits[:, -1, :]  # Get logits for last token
        hidden_states = torch.stack(outputs.hidden_states)
        exit_hidden_states = hidden_states[early_exit_layer_idxs, :, -1, :].transpose(0,1)
        exit_predictions = model.lm_head(exit_hidden_states)

        # 1. Get KL divergence between early exit and final layers
        final_predictions = torch.softmax(logits, dim=-1)
        teacher_expanded = final_predictions.unsqueeze(1)  
        early_output_probs = torch.softmax(exit_predictions, dim=-1)

        # Sum over vocab -> [batch, exitable layers, sequence]
        # print(teacher_expanded.shape, early_output_probs.shape)
        eps = 1e-16
        # kl_div = (teacher_expanded * ((teacher_expanded + eps) / (early_output_probs + eps)).log()).sum(-1)
        kl_div = - (teacher_expanded * (early_output_probs + eps).log()).sum(-1)

        # 2. Scale KL divergencees by KL_FACTOR and pass through sigmoid (0-1)
        sigmoid_kls = torch.sigmoid(KL_FACTOR * kl_div)  # [batch, exitable layers, sequence]
        sigmoid_kls = 2.0 * sigmoid_kls - 1.0
        sigmoid_kls = 1.0 - sigmoid_kls
        predictions = final_predictions
        chosen_exit_layer = -1
        for qdx, exit_layer in enumerate(early_exit_layer_idxs):
            rand_val = random.random()
            if rand_val < sigmoid_kls[0, qdx]:
                predictions = early_output_probs[:, qdx]
                chosen_exit_layer = exit_layer
                break
        chosen_exit_layers.append(int(chosen_exit_layer))
        # Sample next token
        next_token = torch.multinomial(predictions, 1)
        
        # Check for EOS
        if next_token.item() == config['generation']['eos_token_id']:
            print(f"EOS token encountered at step {step}")
            break
            
        # Add token to sequence
        current_input = torch.cat([current_input, next_token], dim=1)
        generated_tokens_manual.append(next_token.item())
        
        # Decode and print current token
        token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
        # print(f"Step {step}: Token {next_token.item()} -> '{token_text}'")

manual_generated_text = tokenizer.decode(generated_tokens_manual, skip_special_tokens=True)
html_viz = visualize_token_exits(
    tokens=generated_tokens_manual,
    exit_layers=chosen_exit_layers,
    tokenizer=tokenizer,
    early_exit_layer_idxs=early_exit_layer_idxs,
    display_mode='html'
)

# Display the HTML visualization
display(html_viz)