# Setup

In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from early_exit.util import get_model
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

# import wandb
import pandas as pd
import numpy as np

In [20]:

# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasi

In [21]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)



# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


In [22]:
model = get_model(model_name, config['model'], device)
# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.1
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.2
replacing layer model.layers.3
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.4
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.5
replacing layer model.layers.6
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.7
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.8
replacing layer model.layers.9
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.10
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.11
replacing layer model.layers.12
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.13
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.14


In [23]:
from early_exit.util import module_name_is_transformer_layer

for prompt_batch in dataloader:
    break

# Setup tests

In [24]:
class ComputationTracker:
    def __init__(self):
        self.hooks = []
        self.reset()
        
    def reset(self):
        """Reset all tracking data and remove existing hooks"""
        self.mlp_batch_sizes = {}
        
        # Remove all existing hooks
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def register_hooks(self, model):
        """Register forward hooks on MLP components only"""
        self.reset()
        
        for name, module in model.named_modules():
            # Track MLP computations ONLY
            if 'mlp' in name:
                parts = name.split('.')
                layer_idx = -1
                for i, part in enumerate(parts):
                    if part == 'layers' and i + 1 < len(parts) and parts[i + 1].isdigit():
                        layer_idx = int(parts[i + 1])
                        break
                
                hook = module.register_forward_hook(
                    lambda m, i, o, idx=layer_idx: self._log_mlp_batch_size(idx, i, o)
                )
                self.hooks.append(hook)
    
    def _log_mlp_batch_size(self, layer_idx, inputs, outputs):
        """Log MLP computation batch size"""
        # Extract batch size from first input tensor
        batch_size = 0
        if isinstance(inputs, tuple) and len(inputs) > 0:
            if hasattr(inputs[0], 'shape') and len(inputs[0].shape) > 0:
                batch_size = inputs[0].shape[0]
        elif hasattr(inputs, 'shape') and len(inputs.shape) > 0:
            batch_size = inputs.shape[0]
            
        if layer_idx not in self.mlp_batch_sizes:
            self.mlp_batch_sizes[layer_idx] = []
        self.mlp_batch_sizes[layer_idx].append(batch_size)
        
    def get_summary(self):
        """Get a summary of MLP batch sizes"""
        return {
            'mlp_batch_sizes': dict(sorted(self.mlp_batch_sizes.items())),
            'total_hooks': len(self.hooks)
        }

# Make sure this is a function, not overwritten
def print_computation_summary(tracker):
    """Print a summary of batch sizes tracked"""
    summary = tracker.get_summary()
    print(f'n_hooks={summary["total_hooks"]}')
    
    print("\nMLP batch sizes per layer:")
    for layer_idx, batch_sizes in summary['mlp_batch_sizes'].items():
        if batch_sizes:
            avg_batch_size = sum(batch_sizes) / len(batch_sizes)
            display_sizes = batch_sizes[:10]
            suffix = f"... ({len(batch_sizes)} total)" if len(batch_sizes) > 10 else ""
            print(f"  Layer {layer_idx}: {display_sizes}{suffix} (avg: {avg_batch_size:.2f})")

In [25]:
class ImprovedComputationTracker(ComputationTracker):
    def __init__(self, layer_pattern=r'layers\.(\d+)', 
                 attention_patterns=['self_attn', 'attention'],
                 mlp_patterns=['mlp', 'ffn']):
        self.layer_pattern = layer_pattern
        self.attention_patterns = attention_patterns
        self.mlp_patterns = mlp_patterns
        self.hooks = []
        self.reset()
        
    def __del__(self):
        """Ensure hooks are removed on deletion"""
        self.reset()
        
    def extract_layer_idx(self, name):
        """More robust layer index extraction"""
        import re
        match = re.search(self.layer_pattern, name)
        return int(match.group(1)) if match else -1

## Model forwards

In [26]:

def forward_teacher(model, prompt_batch):
    with torch.no_grad():
        # Generate SFT targets
        set_transformer_early_exit_mode(model, 'sft_teacher')
        sft_teacher_response, (sft_teacher_generated_tokens, sft_teacher_final_layer_logprobs, gathered_early_exit_hidden_states) =\
            generate_text(
                model=model, 
                prompt=prompt_batch.full_user_prompt, 
                system_prompt=dataset.system_prompt, 
                prefiller=dataset.prefiller, 
                tokenizer=tokenizer, 
                generation_config=config['generation'], 
                device=device
            )
        print(sft_teacher_response)

        early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)               # [batch, num exitable layers, gen len, vocabulary]
        early_exit_probs = model.early_exit_target_probs(early_output_log_probs = early_output_log_probs, teacher_final_layer_log_probs = sft_teacher_final_layer_logprobs)
        repeated_sft_teacher_final_layer_logprobs = sft_teacher_final_layer_logprobs.repeat(num_exit_samples, 1, 1)     # XXX repeat_interleave? [batch * samples, full length, vocabulary]


    # Sample early exits
    batch, gen_len, elayers = early_exit_probs.shape                                                                                                # [batch, generation length, exitable layers]
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)                                 # [batch * samples, full length]
    sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
    sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
    sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                                            # [batch * samples, generation length]

    return sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens


def forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens):
    # Generate with prescription
    set_transformer_early_exit_mode(model, 'sft_student')
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens, prescribed_exit_layer_idxs = sampled_early_exit_layer_idxs) # [batch * samples, full length, vocabulary]
    
    return sft_student_output_scores, collected_exit_logits

In [27]:
# Initialize tracker
tracker = ComputationTracker()
tracker = ImprovedComputationTracker()

# Register hooks
tracker.register_hooks(model)

print("="*60)
print("TEST 1: TEACHER MODE (Full Computation)")
print("="*60)

sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens = forward_teacher(model, prompt_batch)

teacher_summary = tracker.get_summary()

TEST 1: TEACHER MODE (Full Computation)
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 149])
<｜begin▁of▁sentence｜><｜Assistant｜> 
<｜User｜> I am going to give you a story and a question about the story. Read the following story carefully, understand the characters' actions and perspectives, then answer the question regarding object locations, character knowledge, and beliefs.

Charlotte entered the grand ballroom. Alexis entered the grand ballroom. Alexis told out loud about the wedding cake design. While this action was happening, Gabriella witnessed this action in secret (and only this action). Gabriella entered the grand ballroom. Charlotte told out loud about the best man's speech. Charlotte left the grand ballroom. Charlotte entered the grand ballroom. Charlotte told out loud about the photo booth props.

Does Charlotte know about best man's speech? Answer yes or no.
<｜Assistant｜> 
Oka

In [28]:
print("\nTeacher Computation Summary:")
print_computation_summary(tracker)

print("\n" + "="*60)
print("TEST 2: STUDENT MODE (With Early Exits) MOD 3")
print("="*60)

# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

print("Student Computation Summary:")
forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens)

student_summary = tracker.get_summary()
print_computation_summary(tracker)
# Calculate savings
#print("\n" + "="*60)
#print("COMPUTATION SAVINGS")
#print("="*60)
#if teacher_summary['total_attention_ops'] > 0:
#    attention_savings = 1 - (student_summary['total_attention_ops'] / teacher_summary['total_attention_ops'])
#    print(f"Attention computation saved: {attention_savings*100:.1f}%")
    
#if teacher_summary['total_mlp_ops'] > 0:
#    mlp_savings = 1 - (student_summary['total_mlp_ops'] / teacher_summary['total_mlp_ops'])
#    print(f"MLP computation saved: {mlp_savings*100:.1f}%")

# # Show exit patterns
# print(f"\nExit layers used: {exit_layer_idxs[0].tolist()}")
# if not exit_layer_idxs[0].isinf().all():
#     print(f"Average exit layer: {exit_layer_idxs[0][~exit_layer_idxs[0].isinf()].float().mean():.1f}")

# Clean up - remove all hooks
tracker.reset()


Teacher Computation Summary:
n_hooks=140

MLP batch sizes per layer:
  Layer 0: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 2: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 3: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 4: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 5: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 6: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 7: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 8: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 9: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 10: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 11: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 12: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (2000 total) (avg: 1.00)
  Layer 13: [1, 1

**Test: force to exit at last layer**

**Run teacher**

In [29]:
prompt = "hi"
system_prompt = ""
prefiller = ""

set_transformer_early_exit_mode(model, 'sft_teacher')
# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    
    early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
    early_exit_probs = model.early_exit_target_probs(
       early_output_log_probs=early_output_log_probs,
       teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
    )
    
    
print_computation_summary(tracker)

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 8])
n_hooks=140

MLP batch sizes per layer:
  Layer 0: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 1: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 2: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 3: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 4: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 5: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 6: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 7: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 8: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 9: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 10: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg: 1.00)
  Layer 11: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]... (50 total) (avg

**Run student**

In [30]:
# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)


with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
    sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
    sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                       
    
    
    
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    #print("Setting exit layers to inf for sft_student")
    #sampled_early_exit_layer_idxs = torch.full((batch_samples, gen_len), torch.inf, \
    #                                        device=repeated_sft_teacher_generated_tokens.device)
    sampled_early_exit_layer_idxs = torch.zeros_like(sampled_early_exit_layer_idxs) + 10
    print(f"Prescribed_exit_layer_idxs = {torch.min(sampled_early_exit_layer_idxs)}")
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=sampled_early_exit_layer_idxs)
print_computation_summary(tracker)

Prescribed_exit_layer_idxs = 10.0
n_hooks=140

MLP batch sizes per layer:
  Layer 0: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 1: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 2: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 3: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 4: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 5: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 6: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 7: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 8: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 9: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 10: [18, 18, 18, 18, 18] (avg: 18.00)
  Layer 11: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 12: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 13: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 14: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 15: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 16: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 17: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 18: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 19: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 20: [9, 9, 9, 9, 9] (avg: 9.00)
  Layer 21: [9, 9, 9, 9, 9] (a

# (WIP) Test: modify early_exit_probs

In [31]:
# model.base_model.model.model.layers[0].self_attn

for name, module in model.named_modules():
    bias_val = -100
    if 'early_exit_decision_weights' in name:
        print('biasing', name, 'with ', bias_val)
        bias_tensor = bias_val + torch.zeros(module.bias.shape)
        bias_tensor = bias_tensor.to(module.weight.device)
        module.bias = torch.nn.Parameter(bias_tensor)

        


biasing base_model.model.model.layers.0.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.3.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.6.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.9.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.12.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.15.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.18.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.21.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.24.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.27.early_exit_decision_weights with  -100


In [32]:
# model.base_model.model.model.layers[0].self_attn

for name, module in model.named_modules():
    bias_val = -100
    if 'early_exit_decision_weights' in name:
        print('biasing', name, 'with ', bias_val)
        bias_tensor = bias_val + torch.zeros(module.bias.shape)
        bias_tensor = bias_tensor.to(module.weight.device)
        module.bias = torch.nn.Parameter(bias_tensor)

        

biasing base_model.model.model.layers.0.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.3.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.6.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.9.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.12.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.15.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.18.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.21.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.24.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.27.early_exit_decision_weights with  -100


In [33]:
print("\n" + "="*60)
print("TEST 2: STUDENT MODE (With Early Exits)")
print("="*60)

# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens)

student_summary = tracker.get_summary()
print("\nStudent Mode Summary:")
print(f"Total attention operations: {student_summary['total_attention_ops']}")
print(f"Total MLP operations: {student_summary['total_mlp_ops']}")
print(f"Attention ops per layer: {student_summary['attention_ops']}")
print(f"MLP ops per layer: {student_summary['mlp_ops']}")


TEST 2: STUDENT MODE (With Early Exits)

Student Mode Summary:


KeyError: 'total_attention_ops'

In [None]:
prompt = "Tell me a Zen joke about farmer"
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'free_generate')
externalised_response, (externalised_generated_tokens, gathered_early_exit_layer_idxs) =\
    generate_text(model, prompt, system_prompt, prefiller, tokenizer, config['generation'], device)
print(externalised_response)