In [1]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np

# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort

# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['use_cache'] = False
model = get_model(model_name, config['model'], device)
base_model = get_model(model_name, config['model'], device)

# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

set_transformer_early_exit_mode(model, 'free_generate')

replacing layer model.layers.0
replacing layer model.layers.1
replacing layer model.layers.2
replacing layer model.layers.3
replacing layer model.layers.4
replacing layer model.layers.5
replacing layer model.layers.6
replacing layer model.layers.7
replacing layer model.layers.8
replacing layer model.layers.9
replacing layer model.layers.10
replacing layer model.layers.11
replacing layer model.layers.12
replacing layer model.layers.13
replacing layer model.layers.14
replacing layer model.layers.15
replacing layer model.layers.16
replacing layer model.layers.17
replacing layer model.layers.18
replacing layer model.layers.19
replacing layer model.layers.20
replacing layer model.layers.21
replacing layer model.layers.22
replacing layer model.layers.23
replacing layer model.layers.24
replacing layer model.layers.25
replacing layer model.layers.26
replacing layer model.layers.27
address this hack!
g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
Thi

In [2]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""
with torch.no_grad():
    free_generate_response, _ = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
free_generate_response

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
Layer 0 called in free_generate mode with L = 20
Entering the computation. Hidden states sum =  tensor(13.8386, device='cuda:0')
Before LayerNorm: Hidden states sum =  tensor(13.8386, device='cuda:0')
After LayerNorm: Hidden states sum =  tensor(142.1584, device='cuda:0')
Hidden states shape =  torch.Size([1, 20, 1536])
--------------------------------------------------
After attention
Hidden states sum =  tensor(339.7139, device='cuda:0')
Residual states sum =  tensor(142.1584, device='cuda:0')
--------------------------------------------------
After MLP
Original hidden states sum =  tensor(13.8386, device='cuda:0')
Hidden states sum =  tensor(517.9693, device='cuda:0')
--------------------------------------------------
Output attentions = False
Exiting patched layer



Layer 1 called in free_generate mode with L = 20
Entering the compu

'<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \n-elementsabcdefgh /^-elements /^ )"chef...\'-xs #[chef messed fen...\'pane/blue [[" )" "\\" "()chefyyyy(CH...",chefchef mex "!...",bidden mex crawler-corner Php \'((/blue [[" fenabcdefgh messed'

In [3]:
free_generate_response

'<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \n-elementsabcdefgh /^-elements /^ )"chef...\'-xs #[chef messed fen...\'pane/blue [[" )" "\\" "()chefyyyy(CH...",chefchef mex "!...",bidden mex crawler-corner Php \'((/blue [[" fenabcdefgh messed'

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def print_forward_sequence(model):
    # Track execution order
    execution_order = []
    
    def create_hook(name):
        def hook_fn(module, input, output):
            # Get input shapes
            if isinstance(input, tuple):
                input_shapes = [x.shape if hasattr(x, 'shape') else type(x) for x in input]
            else:
                input_shapes = input.shape if hasattr(input, 'shape') else type(input)
            
            # Get output shapes
            if isinstance(output, tuple):
                output_shapes = [x.shape if hasattr(x, 'shape') else type(x) for x in output]
            else:
                output_shapes = output.shape if hasattr(output, 'shape') else type(output)
            
            execution_order.append(f"{name}: {input_shapes} -> {output_shapes}")
        return hook_fn
    
    # Register hooks on all modules
    hooks = []
    for name, module in model.named_modules():
        if name:  # Skip the root module
            hook = module.register_forward_hook(create_hook(name))
            hooks.append(hook)
    
    # Run a forward pass
    text = "Hello, how are you?"
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    
    print("=== FORWARD PASS SEQUENCE ===")
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Print execution order
    for i, op in enumerate(execution_order):
        print(f"{i+1:2d}. {op}")
    
    # Clean up hooks
    for hook in hooks:
        hook.remove()

# Run the sequence printer
print_forward_sequence(model)

In [16]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): DynamicallyTypedModelWithReadout(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 1536)
        (layers): ModuleList(
          (0): DynamicallyTypedLayerWithExit(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=1536, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (early_exiter): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (early_exiter): Linear(in_features=1536, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (early_exiter): Linear(in_features=8, out_features=1536, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
   

## Checking with the base model

In [3]:
prompt = "<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜>"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = base_model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states

In [4]:
hidden_states[11].shape

torch.Size([1, 20, 1536])

In [5]:
for layer_idx in range(10):
    print(f"Layer = {layer_idx}|", f"Sum = {hidden_states[layer_idx].sum().item()}")

Layer = 0| Sum = 12.375
Layer = 1| Sum = 756.0
Layer = 2| Sum = 1688.0
Layer = 3| Sum = -2448.0
Layer = 4| Sum = -2432.0
Layer = 5| Sum = -2208.0
Layer = 6| Sum = -2656.0
Layer = 7| Sum = -2528.0
Layer = 8| Sum = -2608.0
Layer = 9| Sum = -2576.0
