In [1]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np

In [6]:
# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort


In [7]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)


# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [4]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'sft_teacher')

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    
    early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
    early_exit_probs = model.early_exit_target_probs(
        early_output_log_probs=early_output_log_probs,
        teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
    )

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
CRUDE KL


In [5]:
sft_teacher_response

"<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \nAlright, so I need to explain recursion in programming. Hmm, recursion is a method where a function calls itself to solve a problem. I remember it's used in many algorithms, especially those that can be broken down into smaller, similar subproblems. \n\nWait, how does it work exactly? I think you have a base case that stops the recursion. Like, if I have a function that calculates the factorial of a number, the base case would be when the number is 0 or 1 because 0! and 1! are both 1. \n\nBut then, for other numbers, I call the function again with a smaller number. So, for example, factorial(n) would call factorial(n-1) and multiply by n. That makes sense because if I keep doing that, it will eventually reach the base case.\n\nI should also mention why recursion is useful. It can simplify code, especially for problems that have a natural r

In [6]:
with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    print("Setting exit layers to inf for sft_student")
    prescribed_exit_layer_idxs = torch.full((batch_samples, gen_len), torch.inf, \
                                            device=repeated_sft_teacher_generated_tokens.device)
    print(f"Minimum in prescribed_exit_layer_idxs = {torch.min(prescribed_exit_layer_idxs)}")
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=prescribed_exit_layer_idxs)

Setting exit layers to inf for sft_student
Minimum in prescribed_exit_layer_idxs = inf


In [7]:
with torch.no_grad():
    print('CRUDE KL AND MAKE SURE PROBS ARE ALIGNED')
    eps = 1e-16
    sft_teacher_probs = sft_teacher_final_layer_logprobs.softmax(-1)                        # [batch * samples, gen len, vocabulary]
    sft_student_probs = sft_student_output_scores.logits[:,-gen_len:].softmax(-1)           # [batch * samples, gen len, vocabulary]
    token_logits_kl_div = (sft_student_probs * ((sft_student_probs + eps) / (sft_teacher_probs + eps)).log()).sum(-1)   # [batch * samples, gen len]
    
    mean_logit_kl = token_logits_kl_div.mean()

mean_logit_kl

CRUDE KL AND MAKE SURE PROBS ARE ALIGNED


tensor(30.1700, device='cuda:0')

In [8]:
import pandas as pd
from IPython.display import display, Markdown, HTML

def topk_to_df(prob_dist, tokenizer=None, k=5, title="Top-K Predictions"):
    """
    Return top-k predictions and probabilities as a pandas DataFrame.
    """
    top_values, top_indices = torch.topk(prob_dist, k=k)
    
    rows = []
    for i, (idx, prob) in enumerate(zip(top_indices, top_values)):
        token_id = idx.item()
        prob_val = prob.item()
        token_str = tokenizer.decode([token_id]) if tokenizer else str(token_id)
        token_str = repr(token_str)  # Shows escape characters properly
        
        rows.append({
            "Token ID": token_id,
            "Token String": token_str,
            "Probability": prob_val,
        })
    
    df = pd.DataFrame(rows)
    return title, df.round(4)

# Example usage for your loop
dfs = []
for idx in range(5, 11):
    title, df = topk_to_df(sft_student_probs[0, idx], tokenizer, k=5, title=f"Student NTP for token {idx}")
    dfs.append((title, df))

# Display in a grid
html = "<div style='display: flex; flex-wrap: wrap;'>"
for title, df in dfs:
    html += "<div style='flex: 1; min-width: 300px; padding: 10px;'>"
    html += f"<h4>{title}</h4>"
    html += df.to_html(index=False)
    html += "</div>"
html += "</div>"

display(HTML(html))


Token ID,Token String,Probability
320,' (',0.2374
659,' .',0.0463
328,' S',0.03
1725,'./',0.0288
1182,' back',0.0131

Token ID,Token String,Probability
320,' (',0.2391
659,' .',0.04
1725,'./',0.0271
328,' S',0.0269
1182,' back',0.0146

Token ID,Token String,Probability
320,' (',0.2353
659,' .',0.0478
328,' S',0.0302
1725,'./',0.0294
1182,' back',0.0126

Token ID,Token String,Probability
320,' (',0.2499
659,' .',0.0439
328,' S',0.03
1725,'./',0.0261
1182,' back',0.0129

Token ID,Token String,Probability
320,' (',0.2257
659,' .',0.0456
328,' S',0.0347
1725,'./',0.0284
1182,' back',0.0137

Token ID,Token String,Probability
320,' (',0.2359
659,' .',0.0487
328,' S',0.0312
1725,'./',0.0279
1182,' back',0.0128


### Very similar (and gibberish) next token predictions for all tokens. Something wrong!

In [9]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
config['generation']['use_cache'] = False
model = get_model(model_name, config['model'], device)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [None]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

set_transformer_early_exit_mode(model, 'free_generate')

with torch.no_grad():
    free_generate_response, _ = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )

In [11]:
free_generate_response

'<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.\n<｜User｜> Explain the concept of recursion in programming.\n<｜Assistant｜> \n multilineFire1 1 Th thinner thinner thinner litres litres收 litresShip litres litresbistractive litre.par antiqu bookings bookingstractive litres———— agr litres tombGOR litres walls mathsGORaroMathMath maths upright’B litresMath){\n\nd mathsdd�dd seventhMath formX四年MathMath mathsElMath:\n\n whereells />\n\n litres:\n\n,:\n\n, maths, — underneath suffers.parL04$\\L,\\,:\n\n\\:\n\n\\\\\\\\\\ \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ \\ use use\\求 \\以及:\n\n� \\ \\ \\ \\){\n\n]:\n\n出){\n\nlish \\ \\\\\\ \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\d\\\\\\\\\\\\\\\\\\ \\\\ \\\\\\\\ \\\\\\\\\\\\\\ \\\\ \\ \\ \\ \\ \\ \\ \\ \\ \\\\\\t maths \\\\ \\ \\ \\ \\ \\ \\\\ \\\\ \\ \\ \\ \\ \\ \\‘ \\ — \\‘ —‘n‘ ‘‘‘‘’‘‘‘‘‘‘ –‘‘ —‘‘‘‘ \\ \n‘‘‘\\\n \n‘‘ –‘ shining –‘‘ \\ manages(dict \n‘inz \\ \n \n —\\\n‘\n — quotas \\‘‘‘‘‘‘‘‘\\\n —‘‘ — \n \\.\n\n ‘.\n\n \n\n\n.\n\n.