# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from early_exit.util import get_model
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

# import wandb
import pandas as pd
import numpy as np

In [5]:

# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasi

In [6]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)



# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


In [7]:
model = get_model(model_name, config['model'], device)
# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

replacing layer model.layers.0
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.1
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.2
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.3
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.4
replacing layer model.layers.5
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.6
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.7
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.8
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.9
replacing layer model.layers.10
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.11
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.12
replacing generate_layer_type_without_early_exit_decision_head l

In [23]:
from early_exit.util import module_name_is_transformer_layer

for prompt_batch in dataloader:
    break

# Setup tests

In [9]:
class ComputationTracker:
    def __init__(self):
        self.hooks = []  # Store hook handles for removal
        self.reset()
        
    def reset(self):
        """Reset all counters and remove existing hooks"""
        self.layer_forward_counts = {}
        self.attention_counts = {}
        self.mlp_counts = {}
        self.residual_counts = {}
        
        # Remove all existing hooks
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def register_hooks(self, model):
        """Register forward hooks on model components"""
        self.reset()  # Clear any existing hooks first
        
        for name, module in model.named_modules():
            # Track self-attention computations
            if 'self_attn' in name and 'proj' not in name:
                parts = name.split('.')
                # Find the layer index - look for 'layers' and get the next element
                layer_idx = -1
                for i, part in enumerate(parts):
                    if part == 'layers' and i + 1 < len(parts) and parts[i + 1].isdigit():
                        layer_idx = int(parts[i + 1])
                        break
                
                hook = module.register_forward_hook(
                    lambda m, i, o, idx=layer_idx, n=name: self._log_attention(idx, n, i, o)
                )
                self.hooks.append(hook)
                
            # Track MLP computations
            elif 'mlp' in name:
                parts = name.split('.')
                # Find the layer index - look for 'layers' and get the next element
                layer_idx = -1
                for i, part in enumerate(parts):
                    if part == 'layers' and i + 1 < len(parts) and parts[i + 1].isdigit():
                        layer_idx = int(parts[i + 1])
                        break
                
                hook = module.register_forward_hook(
                    lambda m, i, o, idx=layer_idx: self._log_mlp(idx, i, o)
                )
                self.hooks.append(hook)
                
            # Track layer-level forward passes
            elif module_name_is_transformer_layer(name):
                # Extract layer index from the last element after 'layers'
                layer_idx = int(name.split('.')[-1])
                
                hook = module.register_forward_hook(
                    lambda m, i, o, idx=layer_idx: self._log_layer(idx, i, o)
                )
                self.hooks.append(hook)
    
    def _log_attention(self, layer_idx, name, inputs, outputs):
        """Log attention computation"""
        if layer_idx not in self.attention_counts:
            self.attention_counts[layer_idx] = 0
        self.attention_counts[layer_idx] += 1
        
    def _log_mlp(self, layer_idx, inputs, outputs):
        """Log MLP computation"""
        if layer_idx not in self.mlp_counts:
            self.mlp_counts[layer_idx] = 0
        self.mlp_counts[layer_idx] += 1
        
    def _log_layer(self, layer_idx, inputs, outputs):
        """Log layer forward pass"""
        if layer_idx not in self.layer_forward_counts:
            self.layer_forward_counts[layer_idx] = 0
        self.layer_forward_counts[layer_idx] += 1
        
    def get_summary(self):
        """Get a summary of all computations"""
        return {
            'layer_forwards': dict(sorted(self.layer_forward_counts.items())),
            'attention_ops': dict(sorted(self.attention_counts.items())),
            'mlp_ops': dict(sorted(self.mlp_counts.items())),
            'total_attention_ops': sum(self.attention_counts.values()),
            'total_mlp_ops': sum(self.mlp_counts.values()),
            'total_layer_ops': sum(self.layer_forward_counts.values())
        }
        
def print_computation_summary(tracker):
    summary = tracker.get_summary()
    print(f'n_hooks={len(tracker.hooks)}')
    print(f"Total attention operations: {summary['total_attention_ops']}")
    print(f"Total MLP operations: {summary['total_mlp_ops']}")
    print(f"Attention ops per layer: {summary['attention_ops']}")
    print(f"MLP ops per layer: {summary['mlp_ops']}")

In [10]:
class ImprovedComputationTracker(ComputationTracker):
    def __init__(self, layer_pattern=r'layers\.(\d+)', 
                 attention_patterns=['self_attn', 'attention'],
                 mlp_patterns=['mlp', 'ffn']):
        self.layer_pattern = layer_pattern
        self.attention_patterns = attention_patterns
        self.mlp_patterns = mlp_patterns
        self.hooks = []
        self.reset()
        
    def __del__(self):
        """Ensure hooks are removed on deletion"""
        self.reset()
        
    def extract_layer_idx(self, name):
        """More robust layer index extraction"""
        import re
        match = re.search(self.layer_pattern, name)
        return int(match.group(1)) if match else -1

## Model forwards

In [22]:

def forward_teacher(model, prompt_batch):
    with torch.no_grad():
        # Generate SFT targets
        set_transformer_early_exit_mode(model, 'sft_teacher')
        sft_teacher_response, (sft_teacher_generated_tokens, sft_teacher_final_layer_logprobs, gathered_early_exit_hidden_states) =\
            generate_text(
                model=model, 
                prompt=prompt_batch.full_user_prompt, 
                system_prompt=dataset.system_prompt, 
                prefiller=dataset.prefiller, 
                tokenizer=tokenizer, 
                generation_config=config['generation'], 
                device=device
            )
        print(sft_teacher_response)

        early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)               # [batch, num exitable layers, gen len, vocabulary]
        early_exit_probs = model.early_exit_target_probs(early_output_log_probs = early_output_log_probs, teacher_final_layer_log_probs = sft_teacher_final_layer_logprobs)
        repeated_sft_teacher_final_layer_logprobs = sft_teacher_final_layer_logprobs.repeat(num_exit_samples, 1, 1)     # XXX repeat_interleave? [batch * samples, full length, vocabulary]


    # Sample early exits
    batch, gen_len, elayers = early_exit_probs.shape                                                                                                # [batch, generation length, exitable layers]
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)                                 # [batch * samples, full length]
    sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
    sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
    sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                                            # [batch * samples, generation length]

    return sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens


def forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens):
    # Generate with prescription
    set_transformer_early_exit_mode(model, 'sft_student')
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens, prescribed_exit_layer_idxs = sampled_early_exit_layer_idxs) # [batch * samples, full length, vocabulary]
    
    return sft_student_output_scores, collected_exit_logits

In [14]:
# Initialize tracker
tracker = ComputationTracker()
tracker = ImprovedComputationTracker()

# Register hooks
tracker.register_hooks(model)

print("="*60)
print("TEST 1: TEACHER MODE (Full Computation)")
print("="*60)

sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens = forward_teacher(model, prompt_batch)

teacher_summary = tracker.get_summary()

TEST 1: TEACHER MODE (Full Computation)
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 138])
<｜begin▁of▁sentence｜><｜Assistant｜> 
<｜User｜> I am going to give you a story and a question about the story. Read the following story carefully, understand the characters' actions and perspectives, then answer the question regarding object locations, character knowledge, and beliefs.

Nicholas entered the main bar area. Matthew entered the main bar area. Addison told privately to Matthew about the bar's social media presence. Addison entered the main bar area. Matthew left the main bar area. Matthew told privately to Nicholas about the bar's menu offerings. Addison left the main bar area. Addison entered the main bar area. Avery entered the main bar area.

Does Nicholas know about bar's menu offerings? Answer yes or no.
<｜Assistant｜> 
Okay, so I need to figure out whether Nicholas knows about the b

In [15]:

print("\nTeacher Computation Summary:")
print_computation_summary(tracker)

print("\n" + "="*60)
print("TEST 2: STUDENT MODE (With Early Exits)")
print("="*60)

# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

print("Student Computation Summary:")
forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens)

student_summary = tracker.get_summary()
print_computation_summary(tracker)
# Calculate savings
print("\n" + "="*60)
print("COMPUTATION SAVINGS")
print("="*60)
if teacher_summary['total_attention_ops'] > 0:
    attention_savings = 1 - (student_summary['total_attention_ops'] / teacher_summary['total_attention_ops'])
    print(f"Attention computation saved: {attention_savings*100:.1f}%")
    
if teacher_summary['total_mlp_ops'] > 0:
    mlp_savings = 1 - (student_summary['total_mlp_ops'] / teacher_summary['total_mlp_ops'])
    print(f"MLP computation saved: {mlp_savings*100:.1f}%")

# # Show exit patterns
# print(f"\nExit layers used: {exit_layer_idxs[0].tolist()}")
# if not exit_layer_idxs[0].isinf().all():
#     print(f"Average exit layer: {exit_layer_idxs[0][~exit_layer_idxs[0].isinf()].float().mean():.1f}")

# Clean up - remove all hooks
tracker.reset()


Teacher Computation Summary:
n_hooks=196
Total attention operations: 8316
Total MLP operations: 41580
Attention ops per layer: {0: 297, 1: 297, 2: 297, 3: 297, 4: 297, 5: 297, 6: 297, 7: 297, 8: 297, 9: 297, 10: 297, 11: 297, 12: 297, 13: 297, 14: 297, 15: 297, 16: 297, 17: 297, 18: 297, 19: 297, 20: 297, 21: 297, 22: 297, 23: 297, 24: 297, 25: 297, 26: 297, 27: 297}
MLP ops per layer: {0: 1485, 1: 1485, 2: 1485, 3: 1485, 4: 1485, 5: 1485, 6: 1485, 7: 1485, 8: 1485, 9: 1485, 10: 1485, 11: 1485, 12: 1485, 13: 1485, 14: 1485, 15: 1485, 16: 1485, 17: 1485, 18: 1485, 19: 1485, 20: 1485, 21: 1485, 22: 1485, 23: 1485, 24: 1485, 25: 1485, 26: 1485, 27: 1485}

TEST 2: STUDENT MODE (With Early Exits)
Student Computation Summary:
n_hooks=196
Total attention operations: 28
Total MLP operations: 140
Attention ops per layer: {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26

**Test: force to exit at last layer**

**Run teacher**

In [18]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'sft_teacher')
# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    
    early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
    early_exit_probs = model.early_exit_target_probs(
       early_output_log_probs=early_output_log_probs,
       teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
    )
    
    
print_computation_summary(tracker)

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
n_hooks=196
Total attention operations: 11200
Total MLP operations: 56000
Attention ops per layer: {0: 400, 1: 400, 2: 400, 3: 400, 4: 400, 5: 400, 6: 400, 7: 400, 8: 400, 9: 400, 10: 400, 11: 400, 12: 400, 13: 400, 14: 400, 15: 400, 16: 400, 17: 400, 18: 400, 19: 400, 20: 400, 21: 400, 22: 400, 23: 400, 24: 400, 25: 400, 26: 400, 27: 400}
MLP ops per layer: {0: 2000, 1: 2000, 2: 2000, 3: 2000, 4: 2000, 5: 2000, 6: 2000, 7: 2000, 8: 2000, 9: 2000, 10: 2000, 11: 2000, 12: 2000, 13: 2000, 14: 2000, 15: 2000, 16: 2000, 17: 2000, 18: 2000, 19: 2000, 20: 2000, 21: 2000, 22: 2000, 23: 2000, 24: 2000, 25: 2000, 26: 2000, 27: 2000}


**Run student**

In [19]:
# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)


with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
    sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
    sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                       
    
    
    
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    print("Setting exit layers to inf for sft_student")
    sampled_early_exit_layer_idxs = torch.full((batch_samples, gen_len), torch.inf, \
                                            device=repeated_sft_teacher_generated_tokens.device)
    sampled_early_exit_layer_idxs = torch.zeros_like(sampled_early_exit_layer_idxs) + 25
    print(f"Minimum in prescribed_exit_layer_idxs = {torch.min(sampled_early_exit_layer_idxs)}")
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=sampled_early_exit_layer_idxs)
print_computation_summary(tracker)

Setting exit layers to inf for sft_student
Minimum in prescribed_exit_layer_idxs = 25.0
n_hooks=196
Total attention operations: 28
Total MLP operations: 140
Attention ops per layer: {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 1}
MLP ops per layer: {0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 5, 9: 5, 10: 5, 11: 5, 12: 5, 13: 5, 14: 5, 15: 5, 16: 5, 17: 5, 18: 5, 19: 5, 20: 5, 21: 5, 22: 5, 23: 5, 24: 5, 25: 5, 26: 5, 27: 5}


# (WIP) Test: modify early_exit_probs

In [24]:
# model.base_model.model.model.layers[0].self_attn

for name, module in model.named_modules():
    bias_val = -100
    if 'early_exit_decision_weights' in name:
        print('biasing', name, 'with ', bias_val)
        bias_tensor = bias_val + torch.zeros(module.bias.shape)
        bias_tensor = bias_tensor.to(module.weight.device)
        module.bias = torch.nn.Parameter(bias_tensor)

        


biasing base_model.model.model.layers.0.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.5.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.10.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.15.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.20.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.25.early_exit_decision_weights with  -100


In [25]:
# model.base_model.model.model.layers[0].self_attn

for name, module in model.named_modules():
    bias_val = -100
    if 'early_exit_decision_weights' in name:
        print('biasing', name, 'with ', bias_val)
        bias_tensor = bias_val + torch.zeros(module.bias.shape)
        bias_tensor = bias_tensor.to(module.weight.device)
        module.bias = torch.nn.Parameter(bias_tensor)

        

biasing base_model.model.model.layers.0.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.5.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.10.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.15.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.20.early_exit_decision_weights with  -100
biasing base_model.model.model.layers.25.early_exit_decision_weights with  -100


In [26]:
print("\n" + "="*60)
print("TEST 2: STUDENT MODE (With Early Exits)")
print("="*60)

# Reset tracker for student mode
tracker.reset()
tracker.register_hooks(model)

forward_student(model, sampled_early_exit_layer_idxs, repeated_sft_teacher_generated_tokens)

student_summary = tracker.get_summary()
print("\nStudent Mode Summary:")
print(f"Total attention operations: {student_summary['total_attention_ops']}")
print(f"Total MLP operations: {student_summary['total_mlp_ops']}")
print(f"Attention ops per layer: {student_summary['attention_ops']}")
print(f"MLP ops per layer: {student_summary['mlp_ops']}")


TEST 2: STUDENT MODE (With Early Exits)

Student Mode Summary:
Total attention operations: 28
Total MLP operations: 140
Attention ops per layer: {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 1, 12: 1, 13: 1, 14: 1, 15: 1, 16: 1, 17: 1, 18: 1, 19: 1, 20: 1, 21: 1, 22: 1, 23: 1, 24: 1, 25: 1, 26: 1, 27: 1}
MLP ops per layer: {0: 5, 1: 5, 2: 5, 3: 5, 4: 5, 5: 5, 6: 5, 7: 5, 8: 5, 9: 5, 10: 5, 11: 5, 12: 5, 13: 5, 14: 5, 15: 5, 16: 5, 17: 5, 18: 5, 19: 5, 20: 5, 21: 5, 22: 5, 23: 5, 24: 5, 25: 5, 26: 5, 27: 5}


In [27]:
prompt = "Tell me a Zen joke about farmer"
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'free_generate')
externalised_response, (externalised_generated_tokens, gathered_early_exit_layer_idxs) =\
    generate_text(model, prompt, system_prompt, prefiller, tokenizer, config['generation'], device)
print(externalised_response)

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
Free generate: Patched forward generation called at  model.layers.0
Free generate: Patched forward generation called at  model.layers.5
Free generate: Patched forward generation called at  model.layers.10
Free generate: Patched forward generation called at  model.layers.15
Free generate: Patched forward generation called at  model.layers.20
Free generate: Patched forward generation called at  model.layers.25


ValueError: not enough values to unpack (expected 3, got 2)