In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
# sys.path.append("../../")

sys.path.append("../")
# sys.path.append("..")
from early_exit.patching.method_patching import replace_attention_layers, set_transformer_early_exit_mode
from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np
from early_exit.util import get_model
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text
import random
from early_exit_teacher.visualization import visualize_tokens_by_exit_layer, create_html_visualization
from IPython.display import HTML, display
from early_exit.util import module_name_is_layer_base
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")
import torch.nn.functional as F
import pandas as pd
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv, Qwen2Attention
from early_exit.util import module_name_is_layer_base

Disabled automatic differentiation


## Setup

In [2]:
# Model configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

print(f"Loading model: {model_name}")
print(f"Device: {device}")

tokenizer = get_tokenizer(model_name)
model_config_path = "/project/project_465001340/fair_stuff/externalization/config_deepseek.yaml"   # args.model_config_path
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)

model = get_model(model_name, config['model'], device)
model = replace_attention_layers(model, config['lora'], device)
# set_transformer_early_exit_mode(model, 'off')

# Load tokenizer
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    

config['generation']['max_new_tokens'] = 10

print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")



prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

pre_transformed_conversation = format_conversation(user_prompts = [prompt], system_prompt=system_prompt)
formatted_prompt = transform_conversations(pre_transformed_conversation, prefiller)[0]
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids

Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Device: cuda
address this hack!
g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225
Tokenizer loaded. Vocab size: 151643
EOS token: <｜end▁of▁sentence｜> (ID: 151643)
transform_conversations currently only for Deepseek models!


In [3]:
early_exit_layer_idxs = []
for name, module in model.named_modules():
    if module_name_is_layer_base(name):
        layer_idx = int(name.split('.')[-1])
        early_exit_layer_idxs.append(layer_idx)

early_exit_layer_idxs = torch.tensor(early_exit_layer_idxs, dtype = torch.int32)  # Add inf for final layer

## Generation

In [41]:
class PredictionObject:
    """
    Stores the state of the current sequence generation, including
    cache, logits, and generated tokens at each step.
    """
    def __init__(self, model):
        self.model = model
        self.initial_prompt = []
        self.generated_tokens = [] # This is not strictly the generated tokens, but more like the tokens passed in do NTP
        self.all_tokens = []
        self.chosen_exit_layers = []
        self.all_logits = [] # stores only the prediction logits
    
    def build_cache(self, generated):
        """
        Update the prediction state for a new generation step.
        """
        outputs = self.model(generated, use_cache=True)
        # self.all_logits = outputs.logits
        self.cache = outputs.past_key_values
        return outputs.logits

    
    def update_cache(self, next_token, early_exit_layer, mode):
        outputs = self.model(
            next_token,
            past_key_values=self.cache, # the update of self.cache happens in-place
            use_cache=True,
            output_hidden_states=early_exit_layer is not None
        )
        if early_exit_layer is None:
            self.all_logits.append(outputs.logits)
            return outputs.logits
        else:
            hidden_states = torch.stack(outputs.hidden_states)[1:]
            exit_hidden_state = hidden_states[early_exit_layer]
            logits = model.lm_head(exit_hidden_state)
            self.all_logits.append(logits)
            assert len(self.cache) == hidden_states.shape[0]
            for layer_idx in range(early_exit_layer + 1, len(self.cache)):
                if mode == 'frozen_cache':
                    self.cache[layer_idx][0][:, :, -1] = self.cache[early_exit_layer][0][:, :, -1] # keys
                    self.cache[layer_idx][1][:, :, -1] = self.cache[early_exit_layer][1][:, :, -1] # values
                if mode == 'frozen_residual':
                    layer = self.model.base_model.model.model.layers[layer_idx]
                    normed_hidden = layer.input_layernorm(exit_hidden_state)
        
                    # Project to K and V using this layer's projections
                    key_states = layer.self_attn.k_proj(normed_hidden)
                    value_states = layer.self_attn.v_proj(normed_hidden)
                    
                    # Reshape for multi-head attention
                    num_key_value_heads = layer.self_attn.config.num_key_value_heads
                    head_dim = layer.self_attn.head_dim
                    # print(key_states.shape, num_key_value_heads)
                    key_states = key_states.view(1, 1, num_key_value_heads, head_dim).transpose(1, 2)
                    value_states = value_states.view(1, 1, num_key_value_heads, head_dim).transpose(1, 2)
                    # print(student_cache[0][0].shape)
                    current_position = self.cache[0][0].shape[-2]
                    position_ids = torch.tensor([[current_position]], device=device)
                    cos, sin = model.base_model.model.model.rotary_emb(value_states, position_ids)
                    _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
                    # find_updated_cache(exit_hidden_state, layer, student_cache, current_position)
                    # print(student_cache[layer_idx][0][:, :, -1:].shape, key_states.shape)
                    # student_cache[layer_idx][0][:, :, -1] = student_cache[early_exit_layer][0][:, :, -1] # keys
                    # student_cache[layer_idx][1][:, :, -1] = student_cache[early_exit_layer][1][:, :, -1] # values
                    
                    self.cache[layer_idx][0][:, :, -1:] = key_states # keys
                    self.cache[layer_idx][1][:, :, -1:] = value_states # values
                
            return logits
     
        
    def update_after_prediction(self, next_token, chosen_exit_layer):
        self.all_tokens.append(next_token)
        self.generated_tokens.append(next_token)
        self.chosen_exit_layers.append(chosen_exit_layer)

    def __repr__(self):
        return f"PredictionObject(len={len(self.generated_tokens)})"        

In [46]:
def set_early_exit_layer(step):
    # if step == 1 or step == 4: return 25
    # else: return 27
    return np.random.choice([25, 27], p =[0.4, 0.6])
    
max_new_tokens = 100
# teacher_prediction = PredictionObject(model)
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()
for step in range(max_new_tokens):  # generate 10 tokens
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
        # teacher_prediction.build_cache(model, generated)
    else:
        early_exit_layer = set_early_exit_layer(step)
        logits = student_prediction.update_cache(next_token, mode = 'frozen_cache')
        # teacher_prediction.update_cache(next_token)
    # Take the most likely next token (greedy decoding here)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()  # Convert tensor to list if needed
# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Committed Early Exit Token Generation"))

TypeError: PredictionObject.update_cache() got multiple values for argument 'mode'

In [44]:
def set_early_exit_layer(step):
    # if step == 1 or step == 4: return 25
    # else: return 27
    return np.random.choice([25, 27], p =[0.4, 0.6])
    
max_new_tokens = 100
# teacher_prediction = PredictionObject(model)
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()
for step in range(max_new_tokens):  # generate 10 tokens
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
        # teacher_prediction.build_cache(model, generated)
    else:
        early_exit_layer = set_early_exit_layer(step)
        logits = student_prediction.update_cache(next_token, early_exit_layer = early_exit_layer, mode = 'frozen_residual')
        # teacher_prediction.update_cache(next_token)
    # Take the most likely next token (greedy decoding here)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()  # Convert tensor to list if needed
# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Committed Early Exit Token Generation"))

Okay, so I need to explain recursion in programming. Hmm, let me start by thinking about what I know about递归 🤔. I remember that递递 refers to something being returned oneself, like a function calling itself. That makes sense because it's kind of like how you递归 � cared for yourself by helping someone else.

 🤝 So, � "__recursion__" is when a function calls itself. 📊 I think � "__base case__" refers to
