In [1]:
import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from early_exit.util import get_model
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np

In [2]:
# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort


In [3]:
# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)

# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

address this hack!
g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [25]:
prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

config['generation']['max_new_tokens'] = 20
set_transformer_early_exit_mode(model, 'sft_teacher')

with torch.no_grad():
    sft_teacher_response, (sft_teacher_generated_tokens, 
                          sft_teacher_final_layer_logprobs, 
                          gathered_early_exit_hidden_states) = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
    sft_teacher_generated_tokens = sft_teacher_generated_tokens[:, :-1]
    early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)
    
    early_exit_probs = model.early_exit_target_probs(
        early_output_log_probs=early_output_log_probs,
        teacher_final_layer_log_probs=sft_teacher_final_layer_logprobs
    )

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])


In [26]:
with torch.no_grad():
    batch, gen_len, elayers = early_exit_probs.shape 
    sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
    sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
    sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                                            # [batch * samples, generation length]
       
    full_len = sft_teacher_generated_tokens.shape[1]
    repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)   
    set_transformer_early_exit_mode(model, 'sft_student')
    
    # Create prescribed exit layer idxs filled with torch.inf (always exit on last layer)
    batch_samples, seq_len = repeated_sft_teacher_generated_tokens.shape
    sft_student_output_scores, collected_exit_logits = model(repeated_sft_teacher_generated_tokens,\
                                                             prescribed_exit_layer_idxs=sampled_early_exit_layer_idxs)

In [27]:
sft_student_output = sft_student_output_scores.logits.squeeze()[20:]
teacher_final_output = sft_teacher_final_layer_logprobs

student_probs = F.softmax(sft_student_output, dim=-1)
teacher_final_probs = F.softmax(teacher_final_output.squeeze(), dim=-1)

In [39]:
eps = 1e-16
token_logits_kl_div = (student_probs * ((student_probs + eps) / (teacher_final_probs + eps)).log()).sum(-1)   # [batch * samples, gen len]

tensor([-3.4202e-08, -3.5804e-08, -4.5235e-09, -2.0964e-08,  2.4629e-05,
        -6.3737e-09, -2.6751e-08,  2.8074e-09,  1.2900e-03,  8.6235e-08,
         1.0610e-06, -1.1635e-07,  1.4002e-06,  2.8356e-06,  1.0371e-06,
         1.0161e-06,  1.0043e-06,  6.6714e+00,  9.5087e-04], device='cuda:0')

In [40]:
pd.options.display.float_format = "{:.2f}".format
rows = []

def get_prob_token(probs):
    top_id = torch.argmax(probs).item()
    top_prob = probs[top_id].item()
    top_token = tokenizer.decode([top_id])
    return top_prob, top_token

for idx in range(len(student_probs)):
    # Student
    student_top_prob, student_top_token = get_prob_token(student_probs[idx])
    # teacher_top_prob, teacher_top_token = get_prob_token(teacher_probs[idx])
    
    teacher_final_top_prob, teacher_final_top_token = get_prob_token(teacher_final_probs[idx])

    
    # model_top_prob, model_top_token = get_prob_token(model_probs[idx])
    
    # off_top_prob, off_top_token = get_prob_token(off_probs[idx])

    rows.append({
        # "Position": idx,
        "Student Token": student_top_token,
        "Student Prob": student_top_prob,
        # "Teacher Token": teacher_top_token,
        # "Teacher Prob": teacher_top_prob,
        "Teacher Token": teacher_final_top_token,
        "Teacher Prob": teacher_final_top_prob,
        "Prescribed exit layer": sampled_early_exit_layer_idxs[0, idx].item(),
        "KL divergence": token_logits_kl_div[idx].item()
        # "Model Token": model_top_token,
        # "Model Prob": model_top_prob,
        # "Off Token": off_top_token,
        # "Off Prob": off_top_prob
    })

df = pd.DataFrame(rows)
display(df)

Unnamed: 0,Student Token,Student Prob,Teacher Token,Teacher Prob,Prescribed exit layer,KL divergence
0,",",1.0,",",1.0,inf,-0.0
1,so,0.78,so,0.78,inf,-0.0
2,I,0.93,I,0.93,inf,-0.0
3,need,0.65,need,0.65,inf,-0.0
4,to,1.0,to,1.0,25.0,0.0
5,explain,0.75,explain,0.75,inf,-0.0
6,recursion,0.5,recursion,0.5,inf,-0.0
7,in,0.98,in,0.98,inf,0.0
8,programming,1.0,programming,1.0,25.0,0.0
9,.,0.86,.,0.86,inf,0.0
