In [13]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append("../")
from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np
import pandas as pd
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
import random

# Generate 30 thinking-intensive prompts
thinking_prompts = [
    "Explain why the Monty Hall problem solution is counterintuitive.",
    "What would happen if gravity suddenly became twice as strong?",
    "Design an algorithm to detect cycles in a linked list.",
    "Why does hot water sometimes freeze faster than cold water?",
    "Explain the grandfather paradox in time travel.",
    "How would you implement a LRU cache with O(1) operations?",
    "What are the implications of Gödel's incompleteness theorems?",
    "Derive the formula for the area of a circle from first principles.",
    "Explain why correlation does not imply causation with examples.",
    "How does quantum entanglement challenge classical physics?",
    "Design a distributed system for real-time collaborative editing.",
    "What is the halting problem and why is it undecidable?",
    "Explain the prisoner's dilemma and its real-world applications.",
    "How would you detect if a binary tree is balanced?",
    "What causes the Dunning-Kruger effect psychologically?",
    "Derive Bayes' theorem and explain its significance.",
    "How does TCP ensure reliable data transmission?",
    "Explain the concept of emergence in complex systems.",
    "What is the traveling salesman problem and why is it NP-hard?",
    "How does gradient descent find local minima in neural networks?",
    "Explain the twin paradox in special relativity.",
    "Design a hash table that handles collisions efficiently.",
    "What is Russell's paradox and how does it affect set theory?",
    "How would you implement mutex locks in an operating system?",
    "Explain the concept of computational complexity with examples.",
    "What is the Chinese room argument about artificial intelligence?",
    "How does dynamic programming differ from divide and conquer?",
    "Explain why P vs NP is such an important problem.",
    "What are the philosophical implications of the ship of Theseus?",
    "How would you design a garbage collector for a programming language?"
]

# Model configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
print(f"Loading model: {model_name}")
print(f"Device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto" if device == 'cuda' else None,
    trust_remote_code=True
)

# Get early exit layer indices
early_exit_layer_idxs = []
for name, module in model.named_modules():
    if module_name_is_layer_base(name):
        layer_idx = int(name.split('.')[-1])
        early_exit_layer_idxs.append(layer_idx)
early_exit_layer_idxs = torch.tensor(early_exit_layer_idxs, dtype=torch.int32)
print(f"Early exit layer indices: {early_exit_layer_idxs}")

# Configuration
model_config_path = "../config_deepseek.yaml"
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['max_new_tokens'] = 100
KL_FACTOR = 1

# System prompt
system_prompt = "You are a helpful assistant that thinks step by step."

# Collect all data
all_token_data = []

# Process each prompt
for prompt_idx, prompt in enumerate(thinking_prompts):
    print(f"\nProcessing prompt {prompt_idx + 1}/30: {prompt[:50]}...")
    
    # Format prompt
    pre_transformed_conversation = format_conversation(user_prompts=[prompt], system_prompt=system_prompt)
    formatted_prompt = transform_conversations(pre_transformed_conversation, "")[0]
    
    # Tokenize
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
    input_ids = inputs.input_ids
    prompt_length = input_ids.shape[1]
    
    # Generation variables
    current_input = input_ids.clone()
    generated_tokens_manual = []
    chosen_exit_layers = []
    
    # Generate tokens
    for step in range(config['generation']['max_new_tokens']):
        with torch.no_grad():
            # Forward pass
            outputs = model(current_input, use_cache=True, output_hidden_states=True)
            logits = outputs.logits[:, -1, :]
            hidden_states = torch.stack(outputs.hidden_states)
            exit_hidden_states = hidden_states[early_exit_layer_idxs, :, -1, :].transpose(0,1)
            exit_predictions = model.lm_head(exit_hidden_states)
            
            # Get KL divergence
            final_predictions = torch.softmax(logits, dim=-1)
            teacher_expanded = final_predictions.unsqueeze(1)
            early_output_probs = torch.softmax(exit_predictions, dim=-1)
            
            eps = 1e-16
            kl_div = - (teacher_expanded * (early_output_probs + eps).log()).sum(-1)
            
            # Scale KL divergences
            sigmoid_kls = torch.sigmoid(KL_FACTOR * kl_div)
            sigmoid_kls = 2.0 * sigmoid_kls - 1.0
            sigmoid_kls = 1.0 - sigmoid_kls
            
            # Choose exit layer
            predictions = final_predictions
            chosen_exit_layer = -1
            for qdx, exit_layer in enumerate(early_exit_layer_idxs):
                rand_val = random.random()
                if rand_val < sigmoid_kls[0, qdx]:
                    predictions = early_output_probs[:, qdx]
                    chosen_exit_layer = int(exit_layer.item())
                    break
            
            chosen_exit_layers.append(chosen_exit_layer)
            
            # Sample next token
            next_token = torch.multinomial(predictions, 1)
            
            # Decode token
            token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
            
            # Calculate probability of exiting early
            prob_reach_final = 1.0
            for qdx in range(len(early_exit_layer_idxs)):
                prob_reach_final *= (1 - sigmoid_kls[0, qdx].item())
            prob_exit_early = 1.0 - prob_reach_final
            
            # Store data
            token_data_entry = {
                'prompt_idx': prompt_idx,
                'prompt': prompt[:50] + '...',  # Truncate for readability
                'step': step,
                'token': token_text,
                'exit_layer': chosen_exit_layer,
                'did_exit_early': chosen_exit_layer != -1,
                'prob_exit_early': prob_exit_early,
            }
            
            # Add KL divergence for each layer
            for idx, layer_num in enumerate(early_exit_layer_idxs):
                token_data_entry[f'kl_layer_{int(layer_num)}'] = float(kl_div[0, idx].item())
            
            all_token_data.append(token_data_entry)
            
            # Check for EOS
            if next_token.item() == config['generation']['eos_token_id']:
                break
            
            # Add token to sequence
            current_input = torch.cat([current_input, next_token], dim=1)
            generated_tokens_manual.append(next_token.item())
    
    print(f"Generated {step + 1} tokens")

# Create DataFrame and save
df = pd.DataFrame(all_token_data)
df.to_csv('early_exit_analysis_30_prompts.csv', index=False)

print(f"\nTotal tokens generated: {len(df)}")
print(f"Average tokens per prompt: {len(df) / len(thinking_prompts):.1f}")
print("\nData saved to: early_exit_analysis_30_prompts.csv")

# Show summary statistics
print("\nExit layer distribution:")
print(df['exit_layer'].value_counts().sort_index())
print(f"\nEarly exit rate: {df['did_exit_early'].mean():.2%}")

Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Device: cuda
Early exit layer indices: tensor([ 0,  5, 10, 15, 20, 25], dtype=torch.int32)

Processing prompt 1/30: Explain why the Monty Hall problem solution is cou...
transform_conversations currently only for Deepseek models!
Generated 100 tokens

Processing prompt 2/30: What would happen if gravity suddenly became twice...
transform_conversations currently only for Deepseek models!
Generated 100 tokens

Processing prompt 3/30: Design an algorithm to detect cycles in a linked l...
transform_conversations currently only for Deepseek models!
Generated 100 tokens

Processing prompt 4/30: Why does hot water sometimes freeze faster than co...
transform_conversations currently only for Deepseek models!
Generated 100 tokens

Processing prompt 5/30: Explain the grandfather paradox in time travel....
transform_conversations currently only for Deepseek models!
Generated 100 tokens

Processing prompt 6/30: How would you implement a LRU 