In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append("../")

from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np

import pandas as pd

# Initialize lists to store data
token_data = []

from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
import random
# Model configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

print(f"Loading model: {model_name}")
print(f"Device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    # torch_dtype=torch.float16,  # Use half precision for efficiency
    device_map="auto" if device == 'cuda' else None,
    trust_remote_code=True
)

prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

pre_transformed_conversation = format_conversation(user_prompts = [prompt], system_prompt=system_prompt)
formatted_prompt = transform_conversations(pre_transformed_conversation, prefiller)[0]
from early_exit.util import module_name_is_layer_base
early_exit_layer_idxs = []
for name, module in model.named_modules():
    if module_name_is_layer_base(name):
        # Extract layer index from module name (e.g., "model.layers.0" -> 0)
        layer_idx = int(name.split('.')[-1])
        early_exit_layer_idxs.append(layer_idx)

early_exit_layer_idxs = torch.tensor(early_exit_layer_idxs, dtype = torch.int32)  # Add inf for final layer
print(f"Early exit layer indices: {early_exit_layer_idxs}")
print(f"Total exitable layers: {len(early_exit_layer_idxs)}")  # Subtract 1 for the inf


Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Device: cuda
Tokenizer loaded. Vocab size: 151643
EOS token: <｜end▁of▁sentence｜> (ID: 151643)
transform_conversations currently only for Deepseek models!
Early exit layer indices: tensor([ 0,  5, 10, 15, 20, 25], dtype=torch.int32)
Total exitable layers: 6


In [14]:
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path

config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['max_new_tokens'] = 100

inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids
prompt_length = input_ids.shape[1]

KL_FACTOR = 1
current_input = input_ids.clone()
generated_tokens_manual = []
chosen_exit_layers = []


for step in range(config['generation']['max_new_tokens']):
    with torch.no_grad():
        # Forward pass
        outputs = model(current_input, use_cache=True, output_hidden_states=True)
        # print(outputs.logits.shape)
        logits = outputs.logits[:, -1, :]  # Get logits for last token
        hidden_states = torch.stack(outputs.hidden_states)
        exit_hidden_states = hidden_states[early_exit_layer_idxs, :, -1, :].transpose(0,1)
        exit_predictions = model.lm_head(exit_hidden_states)
        # 1. Get KL divergence between early exit and final layers
        final_predictions = torch.softmax(logits, dim=-1)
        teacher_expanded = final_predictions.unsqueeze(1)  
        early_output_probs = torch.softmax(exit_predictions, dim=-1)
        # Sum over vocab -> [batch, exitable layers, sequence]
        # print(teacher_expanded.shape, early_output_probs.shape)
        eps = 1e-16
        # kl_div = (teacher_expanded * ((teacher_expanded + eps) / (early_output_probs + eps)).log()).sum(-1)
        kl_div = - (teacher_expanded * (early_output_probs + eps).log()).sum(-1)
        
        # 2. Scale KL divergencees by KL_FACTOR and pass through sigmoid (0-1)
        sigmoid_kls = torch.sigmoid(KL_FACTOR * kl_div)  # [batch, exitable layers, sequence]
        sigmoid_kls = 2.0 * sigmoid_kls - 1.0
        sigmoid_kls = 1.0 - sigmoid_kls
        
        predictions = final_predictions
        chosen_exit_layer = -1
        for qdx, exit_layer in enumerate(early_exit_layer_idxs):
            rand_val = random.random()
            if rand_val < sigmoid_kls[0, qdx]:
                predictions = early_output_probs[:, qdx]
                chosen_exit_layer = exit_layer
                break
        chosen_exit_layers.append(int(chosen_exit_layer))
        
        # Sample next token
        next_token = torch.multinomial(predictions, 1)
        
        # Decode token BEFORE printing KL stats
        token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
        
        # Calculate overall probability of exiting early
        prob_reach_final = 1.0
        for qdx in range(len(early_exit_layer_idxs)):
            prob_reach_final *= (1 - sigmoid_kls[0, qdx].item())
        prob_exit_early = 1.0 - prob_reach_final
        token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
    
        # Calculate overall probability of exiting early
        prob_reach_final = 1.0
        for qdx in range(len(early_exit_layer_idxs)):
            prob_reach_final *= (1 - sigmoid_kls[0, qdx].item())
        prob_exit_early = 1.0 - prob_reach_final
        
        # Store data
        token_data.append({
            'step': step,
            'token': token_text,
            'exit_layer': chosen_exit_layer,  # -1 means final layer, otherwise the actual layer number
            'did_exit_early': chosen_exit_layer != -1,
            'prob_exit_early': prob_exit_early,
            'kl_layer_25': kl_div[0, -1].item(),
            'difficulty': 'easy' if chosen_exit_layer >= 20 else 'hard' if chosen_exit_layer == -1 else 'medium'
        })
        for idx, layer_num in enumerate(early_exit_layer_idxs):
            predictions = final_predictions
            chosen_exit_layer = -1
            for qdx, exit_layer in enumerate(early_exit_layer_idxs):
                rand_val = random.random()
                if rand_val < sigmoid_kls[0, qdx]:
                    predictions = early_output_probs[:, qdx]
                    chosen_exit_layer = exit_layer
                    break
            chosen_exit_layers.append(int(chosen_exit_layer))
            # Sample next token
            next_token = torch.multinomial(predictions, 1)
            
        # Check for EOS
        if next_token.item() == config['generation']['eos_token_id']:
            print(f"EOS token encountered at step {step}")
            break
            
        # Add token to sequence
        current_input = torch.cat([current_input, next_token], dim=1)
        generated_tokens_manual.append(next_token.item())
        
        # Decode and print current token
        token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
        # print(f"Step {step}: Token {next_token.item()} -> '{token_text}'")
df = pd.DataFrame(token_data)

print(df)

     step      token                     exit_layer  prob_exit_early  \
0       0  Certainly                             -1         0.000052   
1       1          ,                             -1         0.217027   
2       2        the                             -1         0.000034   
3       3       need                             -1         0.029931   
4       4         to  tensor(25, dtype=torch.int32)         0.999718   
..    ...        ...                            ...              ...   
295    95       runs                             -1         0.000099   
296    96          .                             -1         0.715062   
297    97      Using                             -1         0.000036   
298    98     ursion                             -1         0.000409   
299    99      helps                             -1         0.000444   

     kl_layer_25 difficulty did_exit_early  
0      18.426769       hard            NaN  
1       2.106080       hard            NaN  


In [15]:
df.head()

Unnamed: 0,step,token,exit_layer,prob_exit_early,kl_layer_25,difficulty,did_exit_early
0,0,Certainly,-1,5.2e-05,18.426769,hard,
1,1,",",-1,0.217027,2.10608,hard,
2,2,the,-1,3.4e-05,13.172898,hard,
3,3,need,-1,0.029931,5.351672,hard,
4,4,to,"tensor(25, dtype=torch.int32)",0.999718,0.000564,easy,
