In [2]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np

In [3]:
# LOAD IN EXPERIMENT ARGS
num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort


In [4]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)


# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [5]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'sft_teacher')

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    
    early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
    early_exit_probs = model.early_exit_target_probs(
        early_output_log_probs=early_output_log_probs,
        teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
    )

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
CRUDE KL


In [6]:
sft_teacher_response

"<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \nOkay, so I need to explain recursion in programming. Hmm, recursion is a programming concept where a function calls itself. That means the function will keep doing the same task over and over until it reaches a base case. \n\nWait, let me think about an example. Like, when you have a function that adds numbers from 1 to n. So, for n=3, it would call itself with n=2, and so on, until it gets to n=0 or n=1, which is the base case. \n\nBut wait, what's the base case again? Oh right, it's the simplest scenario that doesn't require further recursion. For the sum function, when n is 0 or 1, the sum is just n. \n\nI should also mention that recursion can make code cleaner and easier to understand because it's intuitive. Like, it's similar to how some problems are solved in everyday life, like climbing a mountain step by step. \n\nBut I should be 

In [20]:
with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    print("Setting exit layers to inf for sft_student")
    prescribed_exit_layer_idxs = torch.full((batch_samples, gen_len), torch.inf, \
                                            device=repeated_sft_teacher_generated_tokens.device)
    print(f"Minimum in prescribed_exit_layer_idxs = {torch.min(prescribed_exit_layer_idxs)}")
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=prescribed_exit_layer_idxs)

Setting exit layers to inf for sft_student
Minimum in prescribed_exit_layer_idxs = inf


In [21]:
with torch.no_grad():
    print('CRUDE KL AND MAKE SURE PROBS ARE ALIGNED')
    eps = 1e-16
    sft_teacher_probs = sft_teacher_final_layer_logprobs.softmax(-1)                        # [batch * samples, gen len, vocabulary]
    sft_student_probs = sft_student_output_scores.logits[:,-gen_len:].softmax(-1)           # [batch * samples, gen len, vocabulary]
    token_logits_kl_div = (sft_student_probs * ((sft_student_probs + eps) / (sft_teacher_probs + eps)).log()).sum(-1)   # [batch * samples, gen len]
    
    mean_logit_kl = token_logits_kl_div.mean()

mean_logit_kl

CRUDE KL AND MAKE SURE PROBS ARE ALIGNED


tensor(27.5847, device='cuda:0')

In [36]:
import pandas as pd
from IPython.display import display, Markdown, HTML

def topk_to_df(prob_dist, tokenizer=None, k=5, title="Top-K Predictions"):
    """
    Return top-k predictions and probabilities as a pandas DataFrame.
    """
    top_values, top_indices = torch.topk(prob_dist, k=k)
    
    rows = []
    for i, (idx, prob) in enumerate(zip(top_indices, top_values)):
        token_id = idx.item()
        prob_val = prob.item()
        token_str = tokenizer.decode([token_id]) if tokenizer else str(token_id)
        token_str = repr(token_str)  # Shows escape characters properly
        
        rows.append({
            "Token ID": token_id,
            "Token String": token_str,
            "Probability": prob_val,
        })
    
    df = pd.DataFrame(rows)
    return title, df.round(4)

# Example usage for your loop
dfs = []
for idx in range(5, 11):
    title, df = topk_to_df(sft_student_probs[0, idx], tokenizer, k=5, title=f"Student NTP for token {idx}")
    dfs.append((title, df))

# Display in a grid
html = "<div style='display: flex; flex-wrap: wrap;'>"
for title, df in dfs:
    html += "<div style='flex: 1; min-width: 300px; padding: 10px;'>"
    html += f"<h4>{title}</h4>"
    html += df.to_html(index=False)
    html += "</div>"
html += "</div>"

display(HTML(html))


Token ID,Token String,Probability
20162,' initiative',0.0511
90884,' peripherals',0.0142
79432,' Simone',0.0082
82399,' Tooth',0.0061
471,'art',0.0052

Token ID,Token String,Probability
20162,' initiative',0.0512
90884,' peripherals',0.0146
79432,' Simone',0.0083
82399,' Tooth',0.0062
471,'art',0.0053

Token ID,Token String,Probability
20162,' initiative',0.0512
90884,' peripherals',0.0145
79432,' Simone',0.0082
82399,' Tooth',0.0062
471,'art',0.0053

Token ID,Token String,Probability
20162,' initiative',0.0511
90884,' peripherals',0.014
79432,' Simone',0.0082
82399,' Tooth',0.0064
471,'art',0.0052

Token ID,Token String,Probability
20162,' initiative',0.0512
90884,' peripherals',0.0147
79432,' Simone',0.0084
82399,' Tooth',0.0062
471,'art',0.0054

Token ID,Token String,Probability
20162,' initiative',0.0512
90884,' peripherals',0.0142
79432,' Simone',0.0082
82399,' Tooth',0.0065
471,'art',0.0052


### Very similar (and gibberish) next token predictions for all tokens. Something wrong!