In [1]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np

In [2]:
# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort


In [3]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)


# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


## Testing SFT teacher

In [20]:
class ActivationCollector:
    """
    A utility class to register a forward hook on a specific PyTorch module
    and collect its output during the forward pass.
    
    Attributes:
        activations (torch.Tensor | None): Stores the detached output tensor 
                                           from the hooked layer. It is moved
                                           to the CPU to save GPU memory.
        hook_handle (torch.utils.hooks.RemovableHandle | None): A handle to the
                                                                registered hook,
                                                                used for removal.
    """
    def __init__(self):
        """Initializes the collector."""
        self.activations = []
        self.hook_handle = None

    def _hook_fn(self, module, input_tensors, output_tensor):
        """
        The actual hook function that PyTorch will call. It saves the output
        of the layer.
        """
        # We detach the tensor and move it to the CPU to avoid holding onto
        # the computation graph and to free up GPU memory.
        self.activations.append(output_tensor[0].detach().cpu())
        

    def register(self, model, layer_path: str):
        """
        Registers the forward hook to a specific layer within the model.

        Args:
            model (nn.Module): The model containing the target layer.
            layer_path (str): A dot-separated string path to the target layer
                              (e.g., 'base_model.model.model.norm').
        
        Returns:
            bool: True if registration was successful, False otherwise.
        """
        # First, remove any existing hook managed by this instance.
        self.remove()

        # Navigate through the model hierarchy to find the target layer.
        try:
            target_layer = model
            for part in layer_path.split('.'):
                target_layer = getattr(target_layer, part)
        except AttributeError:
            print(f"Error: Could not find the layer at path: {layer_path}")
            print("Please ensure the path matches the model architecture.")
            return False

        # Register the forward hook on the found layer.
        self.hook_handle = target_layer.register_forward_hook(self._hook_fn)
        print(f"Successfully registered hook on: '{type(target_layer).__name__}' at path '{layer_path}'")
        return True

    def remove(self):
        """
        Removes the registered hook if it exists. It's important to call this
        when you're done to prevent memory leaks.
        """
        if self.hook_handle:
            self.hook_handle.remove()
            self.hook_handle = None
            print("Hook has been removed.")



collector = ActivationCollector()
layer_path_to_hook = 'base_model.model.model.layers.25'

#layer_path_to_hook = 'base_model.model.model.norm'
collector.register(model, layer_path_to_hook)



Successfully registered hook on: 'DynamicallyTypedLayerWithExit' at path 'base_model.model.model.layers.25'


True

In [19]:
collector.remove() 

Hook has been removed.


In [12]:
#model.base_model.model.model.layers

In [21]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'sft_teacher')

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    
   # early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
   # early_exit_probs = model.early_exit_target_probs(
   #     early_output_log_probs=early_output_log_probs,
   #     teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
   # )

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])


## Testing SFT student

In [6]:
with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    print("Setting exit layers to inf for sft_student")
    prescribed_exit_layer_idxs = torch.full((batch_samples, gen_len), torch.inf, \
                                            device=repeated_sft_teacher_generated_tokens.device)
    print(f"Minimum in prescribed_exit_layer_idxs = {torch.min(prescribed_exit_layer_idxs)}")
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=prescribed_exit_layer_idxs)

Setting exit layers to inf for sft_student
Minimum in prescribed_exit_layer_idxs = inf


In [7]:
with torch.no_grad():
    print('CRUDE KL AND MAKE SURE PROBS ARE ALIGNED')
    eps = 1e-16
    sft_teacher_probs = sft_teacher_final_layer_logprobs.softmax(-1)                        # [batch * samples, gen len, vocabulary]
    sft_student_probs = sft_student_output_scores.logits[:,-gen_len:].softmax(-1)           # [batch * samples, gen len, vocabulary]
    token_logits_kl_div = (sft_student_probs * ((sft_student_probs + eps) / (sft_teacher_probs + eps)).log()).sum(-1)   # [batch * samples, gen len]
    
    mean_logit_kl = token_logits_kl_div.mean()

mean_logit_kl

CRUDE KL AND MAKE SURE PROBS ARE ALIGNED


tensor(30.9266, device='cuda:0')

In [8]:
import pandas as pd
from IPython.display import display, Markdown, HTML

def topk_to_df(prob_dist, tokenizer=None, k=5, title="Top-K Predictions"):
    """
    Return top-k predictions and probabilities as a pandas DataFrame.
    """
    top_values, top_indices = torch.topk(prob_dist, k=k)
    
    rows = []
    for i, (idx, prob) in enumerate(zip(top_indices, top_values)):
        token_id = idx.item()
        prob_val = prob.item()
        token_str = tokenizer.decode([token_id]) if tokenizer else str(token_id)
        token_str = repr(token_str)  # Shows escape characters properly
        
        rows.append({
            "Token ID": token_id,
            "Token String": token_str,
            "Probability": prob_val,
        })
    
    df = pd.DataFrame(rows)
    return title, df.round(4)

# Example usage for your loop
dfs = []
for idx in range(5, 11):
    title, df = topk_to_df(sft_student_probs[0, idx], tokenizer, k=5, title=f"Student NTP for token {idx}")
    dfs.append((title, df))

# Display in a grid
html = "<div style='display: flex; flex-wrap: wrap;'>"
for title, df in dfs:
    html += "<div style='flex: 1; min-width: 300px; padding: 10px;'>"
    html += f"<h4>{title}</h4>"
    html += df.to_html(index=False)
    html += "</div>"
html += "</div>"

display(HTML(html))


Token ID,Token String,Probability
11,"','",0.0861
432,' it',0.062
13,'.',0.0554
279,' the',0.0331
438,' as',0.0303

Token ID,Token String,Probability
279,' the',0.2598
13,'.',0.0589
432,' it',0.0363
1172,' only',0.0192
304,' in',0.0184

Token ID,Token String,Probability
279,' the',0.5732
432,' it',0.0502
1181,' its',0.0422
304,' in',0.0328
13,'.',0.0181

Token ID,Token String,Probability
279,' the',0.1273
304,' in',0.0399
13,'.',0.0385
11,"','",0.0289
432,' it',0.0196

Token ID,Token String,Probability
279,' the',0.1131
304,' in',0.0886
714,' but',0.0448
438,' as',0.0337
374,' is',0.0285

Token ID,Token String,Probability
11,"','",0.3771
13,'.',0.0946
8,')',0.0153
78,'o',0.0135
264,' a',0.0094


### Very similar (and gibberish) next token predictions for all tokens. Something wrong!

## Testing free generation

In [9]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['use_cache'] = False
model = get_model(model_name, config['model'], device)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [10]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'free_generate')

with torch.no_grad():
    free_generate_response, _ = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])


In [11]:
free_generate_response

"<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \nOkay, so I need to explain recursion in programming. Hmm, I remember recursion from my computer science class. It's when a function calls itself repeatedly until a base case is met. That makes sense, but I want to make sure I understand it thoroughly.\n\nLet me think about how it works. When you call a function recursively, it's like solving a problem by breaking it down into smaller sub-problems. Each time the function calls itself, it's handling a smaller part of the problem. The base case is crucial because it's the stopping point. Without a base case, the function would keep calling itself indefinitely, causing a stack overflow error.\n\nWait, how do I identify the base case? It's the simplest version of the problem that can be solved without further recursion. For example, if I have a function that calculates the factorial of a number

## Current status: SFT teacher seems to work, free generation perhaps works, and student does not work.