In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
# sys.path.append("../../")

sys.path.append("../")
# sys.path.append("..")
from early_exit.patching.method_patching import replace_attention_layers, set_transformer_early_exit_mode
from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np
from early_exit.util import get_model
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text
import random
from early_exit_teacher.visualization import visualize_tokens_by_exit_layer, create_html_visualization
from IPython.display import HTML, display
from early_exit.util import module_name_is_layer_base
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")
import torch.nn.functional as F
import pandas as pd
from transformers.models.qwen2.modeling_qwen2 import apply_rotary_pos_emb, repeat_kv, Qwen2Attention
from early_exit.util import module_name_is_layer_base

Disabled automatic differentiation


## Setup

In [10]:
# Model configuration
device = 'cuda' 
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

print(f"Loading model: {model_name}")
print(f"Device: {device}")

tokenizer = get_tokenizer(model_name)
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)

model = get_model(model_name, config['model'], device)
model = replace_attention_layers(model, config['lora'], device)
# set_transformer_early_exit_mode(model, 'off')

# Load tokenizer
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    

config['generation']['max_new_tokens'] = 10

print(f"Tokenizer loaded. Vocab size: {tokenizer.vocab_size}")
print(f"EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")



prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

pre_transformed_conversation = format_conversation(user_prompts = [prompt], system_prompt=system_prompt)
formatted_prompt = transform_conversations(pre_transformed_conversation, prefiller)[0]
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids

Loading model: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
Device: cuda
address this hack!
trainable params: 2,179,072 || all params: 1,802,890,757 || trainable%: 0.1209
Tokenizer loaded. Vocab size: 151643
EOS token: <｜end▁of▁sentence｜> (ID: 151643)
transform_conversations currently only for Deepseek models!


In [11]:
early_exit_layer_idxs = []
for name, module in model.named_modules():
    if module_name_is_layer_base(name):
        layer_idx = int(name.split('.')[-1])
        early_exit_layer_idxs.append(layer_idx)

early_exit_layer_idxs = torch.tensor(early_exit_layer_idxs, dtype = torch.int32)  # Add inf for final layer

## Generation

In [12]:
class PredictionObject:
    """
    Stores the state of the current sequence generation, including
    cache, logits, and generated tokens at each step.
    """
    def __init__(self, model):
        self.model = model
        self.initial_prompt = []
        self.generated_tokens = [] # This is not strictly the generated tokens, but more like the tokens passed in do NTP
        self.all_tokens = []
        self.chosen_exit_layers = []
        self.all_logits = [] # stores only the prediction logits
    
    def build_cache(self, generated):
        """
        Update the prediction state for a new generation step.
        """
        outputs = self.model(generated, use_cache=True)
        # self.all_logits = outputs.logits
        self.cache = outputs.past_key_values
        return outputs.logits

    
    def update_cache(self, next_token, early_exit_layer, mode):
        outputs = self.model(
            next_token,
            past_key_values=self.cache, # the update of self.cache happens in-place
            use_cache=True,
            output_hidden_states=early_exit_layer is not None
        )
        if early_exit_layer is None:
            self.all_logits.append(outputs.logits)
            return outputs.logits
        else:
            hidden_states = torch.stack(outputs.hidden_states)[1:]
            exit_hidden_state = hidden_states[early_exit_layer]
            logits = model.lm_head(exit_hidden_state)
            self.all_logits.append(logits)
            assert len(self.cache) == hidden_states.shape[0]
            for layer_idx in range(early_exit_layer + 1, len(self.cache)):
                if mode == 'frozen_cache':
                    self.cache[layer_idx][0][:, :, -1] = self.cache[early_exit_layer][0][:, :, -1] # keys
                    self.cache[layer_idx][1][:, :, -1] = self.cache[early_exit_layer][1][:, :, -1] # values
                if mode == 'scrambled_cache_lot_of_noise':
                    self.scramble_values(layer_idx, early_exit_layer, noise_scale=100)
                if mode == 'scrambled_cache_little_of_noise':
                    self.scramble_values(layer_idx, early_exit_layer, noise_scale=0.1)
                    

            
                if mode == 'frozen_residual':
                    layer = self.model.base_model.model.model.layers[layer_idx]
                    normed_hidden = layer.input_layernorm(exit_hidden_state)
        
                    # Project to K and V using this layer's projections
                    key_states = layer.self_attn.k_proj(normed_hidden)
                    value_states = layer.self_attn.v_proj(normed_hidden)
                    
                    # Reshape for multi-head attention
                    num_key_value_heads = layer.self_attn.config.num_key_value_heads
                    head_dim = layer.self_attn.head_dim
                    # print(key_states.shape, num_key_value_heads)
                    key_states = key_states.view(1, 1, num_key_value_heads, head_dim).transpose(1, 2)
                    value_states = value_states.view(1, 1, num_key_value_heads, head_dim).transpose(1, 2)
                    # print(student_cache[0][0].shape)
                    current_position = self.cache[0][0].shape[-2]
                    position_ids = torch.tensor([[current_position]], device=device)
                    cos, sin = model.base_model.model.model.rotary_emb(value_states, position_ids)
                    _, key_states = apply_rotary_pos_emb(key_states, key_states, cos, sin)
                    # find_updated_cache(exit_hidden_state, layer, student_cache, current_position)
                    # print(student_cache[layer_idx][0][:, :, -1:].shape, key_states.shape)
                    # student_cache[layer_idx][0][:, :, -1] = student_cache[early_exit_layer][0][:, :, -1] # keys
                    # student_cache[layer_idx][1][:, :, -1] = student_cache[early_exit_layer][1][:, :, -1] # values
                    
                    self.cache[layer_idx][0][:, :, -1:] = key_states # keys
                    self.cache[layer_idx][1][:, :, -1:] = value_states # values
                
            return logits
        
     
        
    def update_after_prediction(self, next_token, chosen_exit_layer):
        self.all_tokens.append(next_token)
        self.generated_tokens.append(next_token)
        self.chosen_exit_layers.append(chosen_exit_layer)

    
    def scramble_values(self, layer_idx, early_exit_layer, noise_scale):
        """
        Add noise to KV cache values for layers after early exit.
        """
        # Get the shape from the early exit layer's KV cache
        key_shape = self.cache[early_exit_layer][0][:, :, -1:].shape
        value_shape = self.cache[early_exit_layer][1][:, :, -1:].shape
        
        # Generate noise with the same shape and add to cache
        key_noise = torch.randn_like(self.cache[early_exit_layer][0][:, :, -1:]) * noise_scale
        value_noise = torch.randn_like(self.cache[early_exit_layer][1][:, :, -1:]) * noise_scale
        
        # Add noise to the cache at the specified layer
        self.cache[layer_idx][0][:, :, -1:] = self.cache[early_exit_layer][0][:, :, -1:] + key_noise
        self.cache[layer_idx][1][:, :, -1:] = self.cache[early_exit_layer][1][:, :, -1:] + value_noise
    

    def __repr__(self):
        return f"PredictionObject(len={len(self.generated_tokens)})"        

In [13]:
def set_early_exit_layer(step):
    # if step == 1 or step == 4: return 25
    # else: return 27
    return np.random.choice([25, 27], p =[0.3, 0.7])
    
max_new_tokens = 100
# teacher_prediction = PredictionObject(model)
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()
for step in range(max_new_tokens):  # generate 10 tokens
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
        # teacher_prediction.build_cache(model, generated)
    else:
        early_exit_layer = set_early_exit_layer(step)
        logits = student_prediction.update_cache(next_token, mode = 'frozen_cache', early_exit_layer = early_exit_layer)
        # teacher_prediction.update_cache(next_token)
    # Take the most likely next token (greedy decoding here)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()  # Convert tensor to list if needed
# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Committed Early Exit Token Generation"))

Okay, so I need to explain recursion in programming. Hmm, let me start by breaking down what recursion is. I remember hearing terms like "recursive" before, but I'm not entirely sure what they mean in the context of programming. 

From what I recall, somewhere else I've heard terms like递递归 and递归/non递递递. Wait 🤔, maybe those are abbreviations commonly used in some programming languages or contexts? I'm not entirely sure how they relate to


In [14]:
# def set_early_exit_layer(step):
#     # if step == 1 or step == 4: return 25
#     # else: return 27
#     return np.random.choice([25, 27], p =[0.4, 0.6])
    
max_new_tokens = 100
# teacher_prediction = PredictionObject(model)
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()
for step in range(max_new_tokens):  # generate 10 tokens
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
        # teacher_prediction.build_cache(model, generated)
    else:
        early_exit_layer = set_early_exit_layer(step)
        logits = student_prediction.update_cache(next_token, early_exit_layer = early_exit_layer, mode = 'frozen_residual')
        # teacher_prediction.update_cache(next_token)
    # Take the most likely next token (greedy decoding here)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()  # Convert tensor to list if needed
# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Committed Early Exit Token Generation"))

Okay, so I need to explain recursion in programming. 🚀 Hmm, recursion sounds a bit tricky at first, but I think I get it. Let me break it down.

First, I remember hearing that递归 refers to a function calling itself. That makes sense because it's like solving a problem by breaking it down into smaller parts. 🤝 But wait, how does that actually work?

Let me think of an example. Oh right 👍 🦃. Like calculating factorial


## Test Scrambling KV values 

In [17]:
max_new_tokens = 500
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()


#mode == 'scrambled_cache_lot_of_noise' or  mode == 'scrambled_cache_little_of_noise':

                    
for step in range(max_new_tokens):
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
    else:
        early_exit_layer = set_early_exit_layer(step)
        # Use 'scrambled_cache' mode to add noise instead of freezing
        logits = student_prediction.update_cache(next_token, mode='scrambled_cache_lot_of_noise', early_exit_layer=early_exit_layer)
    
    # Take the most likely next token (greedy decoding)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print("\033[1m Scrambled cache with a lot of noise (100 std's of the cache's own std)\033[0m")

print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()

# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Scrambled with a lot of noise KV Cache Early Exit Token Generation"))



[1m Scrambled cache with a lot of noise (level 10 std's of the cache's own std)[0m
Okay, so I need to explain 🤘认除此 by recursion in programming. Hmm 🤖 I've heard the term before somewhere else, but I'm not entirely sure what it means. Let me think. � � � 🤘认除此 by recursion... Oh wait, recursion is a programming concept where a function calls itself. That sounds a bit confusing at first, but I think it's used to solve problems by breaking them down into smaller parts. 

Wait, how does that work exactly? Like 🤘认����� 🤘认� 🤤 🤤 🤤 � 🤤. 🤤 🤤. So � 🤤 recursive 🤤. So instead of递归 � � 🤤, it's like a 🤤 recursive function. � 🤤. 🤤 � 🤤. So each time 🤤 calls itself with a smaller version of the problem. 

 � 🤤 recursive function 🤤. 🤤 🤤. So 🤤 calls itself � � 🤤. Wait, but how does that ever terminate? Oh right, 🤤 has base cases. Base cases are the simplest version of the problem that doesn't require further recursion 🤤. So each time it gets smaller, it checks if it's a base case. If it is, it does some

In [18]:
max_new_tokens = 500
student_prediction = PredictionObject(model)    
input_ids = inputs["input_ids"].clone()


#mode == 'scrambled_cache_lot_of_noise' or  mode == 'scrambled_cache_little_of_noise':

                    
for step in range(max_new_tokens):
    if step == 0:
        logits = student_prediction.build_cache(input_ids)
        early_exit_layer = -1
    else:
        early_exit_layer = set_early_exit_layer(step)
        # Use 'scrambled_cache' mode to add noise instead of freezing
        logits = student_prediction.update_cache(next_token, mode='scrambled_cache_little_of_noise', early_exit_layer=early_exit_layer)
    
    # Take the most likely next token (greedy decoding)
    next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(-1)
    student_prediction.update_after_prediction(next_token.item(), early_exit_layer)

generated_tokens = student_prediction.generated_tokens
chosen_exit_layers = student_prediction.chosen_exit_layers
print("\033[1m Scrambled cache with a little of noise (level 0.1 std's of the cache's own std)\033[0m")

print(tokenizer.decode(generated_tokens))
tokens = [tokenizer.decode([token]) for token in generated_tokens]
layers = [27 if item == 27 or item == -1 else item for item in chosen_exit_layers]
early_exit_layers = early_exit_layer_idxs.tolist()

# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, early_exit_layers, 
                                    title="Scrambled with a little of noise KV Cache Early Exit Token Generation"))


[1m Scrambled cache with a little of noise (level 0.1 std's of the cache's own std)[0m
Okay, so I need to explain recursion in programming. �%^️ � 🚩️ 🚩️ Hmm �%^️. Let me think about what recursion is. I remember hearing the term somewhere else, maybe in math or computer science somewhere. 🤘认ized � 🤖 � 🤖

Alright, � � � 🤖 � � 🤖 🤖. 🚩️ Recursion �%^️. So, from what I understand, it's when a function calls itself. That sounds kind of funny 🤖 🤖. 🤖. 🤖 🤖. 🤖. Yeah, like when you have a递归 definition. � 🤘认ized 🤖.

Wait � 🤖. So 🤖 recursion is when something refers to itself 🤖. Like, a 🤖 � 🤖 🤘认ized. 🤖 � � 🤖 🤖. Oh right, like the factorial � � � 🤖. Oh yeah 🤖 🤘认ized. So, factorial of n is n multiplied by factorial of n-1. And 🤖 that's a recursive definition. 🤘认ized.

 � 🤘认ized 🤖. So 🤖 � � � 🤖. Yeah, that makes sense 🤖. So, the base case is when something simple happens 🤖. Like when n is 0 or 1 🤖. Then you stop. And for other cases, you call the function again with a smaller value � 🤘认ized.

 🤘认ize