In [1]:
import torch
import sys
sys.path.append("../")
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text
from early_exit.util import get_model, load_model
from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode
from early_exit_teacher.visualization import create_html_visualization, visualize_tokens_by_exit_layer, safe_decode_tokens

model_path = "../models/early_exit_20250811_layers_5_big"
base_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
config_path = "../config_deepseek.yaml"
device = "cuda"

In [2]:
tokenizer = get_tokenizer(base_model)
config = configs_from_yaml(config_path, tokenizer.eos_token_id)

base_model = get_model(base_model, config['model'], device)
model = replace_attention_layers(base_model, config['lora'], device)
model = load_model(model, model_path)

print(f"Model loaded w exitable layers: {model.exitable_layer_idxs}")

address this hack!
g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Copyright (C) 2021 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

trainable params: 2,179,072 || all params: 1,802,890,757 || trainable%: 0.1209
Model loaded w exitable layers: tensor([ 5., 10., 15., 20., 25., inf])


## Free generate

In [23]:
set_transformer_early_exit_mode(model, 'free_generate')

In [24]:
# prompt = "Explain the concept of recursion in programming."
# prompt = "What are the main causes of climate change?"
prompt = "What is 17 x 18"
system_prompt = "You are a helpful programming tutor."
prefiller = ""

In [25]:
config['generation']['max_new_tokens'] = 400
with torch.no_grad():
    try:
        free_generate_response, exit_info = generate_text(
            model=model,
            prompt=prompt,
            system_prompt=system_prompt,
            prefiller=prefiller,
            tokenizer=tokenizer,
            generation_config=config['generation'],
            device=device
        )
        
        print(f"Free Generate Response: {free_generate_response}")
        # print(f"Exit info: {exit_info}")
        
    except Exception as e:
        print(f"Free generate mode failed: {e}")

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 22])
Free Generate Response: <｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.
<｜User｜> What is 17 x 18
<｜Assistant｜> 
I need to calculate 17 multiplied by 18. 

First, I'll break down 18 into two separate numbers: 10 and 8.

Next, I'll multiply 17 by each separately. 

17 multiplied by  ten is 170, and 17 multiplied by eight is 1otre.

Finally, I'll add the two results together: 170 plus 1otre equals 186.
</think>

**Solution:**

To calculate \(17 \times 18$, we can use the distributive property of multiplication over addition.

\[
17 × 18th = 17 × (10 + 8) = (17 ×  ten) + (17 × time)
\]

\[
17 ×  ten = 17 ×  ten = 170
\]

 enorm"h × eight = 17 ×  eight = 136. 

Now add the two results together:
\[
170 + 13 = 186
 intermediary step
\]

**Final Answer:** \boxed{186}<｜end▁of▁sentence｜>


In [27]:
# tokens = [tokenizer.decode([token]) for token in exit_info[0][0, 20:]]
gen_len = exit_info[1][0].shape[-1]
tokens = safe_decode_tokens(tokenizer, exit_info[0][0, -gen_len:])
layers = [27 if item == torch.inf or item == -1 else int(item) for item in exit_info[1][0]]
# Display the visualization
display(visualize_tokens_by_exit_layer(tokens, layers, [int(item) for item in model.exitable_layer_idxs[:-1]], 
                                     title="Committed Early Exit Token Generation"))

## Teacher mode

In [11]:
set_transformer_early_exit_mode(model, 'sft_teacher')

In [16]:
num_exit_samples = 1
sft_teacher_response, (sft_teacher_generated_tokens, sft_teacher_final_layer_logprobs, gathered_early_exit_hidden_states) =\
                generate_text(
                    model=model, 
                    prompt=prompt, 
                    system_prompt=system_prompt, 
                    prefiller=prefiller, 
                    tokenizer=tokenizer, 
                    generation_config=config['generation'], 
                    device=device
                )
print(sft_teacher_response)

early_output_log_probs = model.early_exit_hidden_state_readout(gathered_early_exit_hidden_states)               # [batch, num exitable layers, gen len, vocabulary]
early_exit_probs = model.early_exit_target_probs(early_output_log_probs = early_output_log_probs, teacher_final_layer_log_probs = sft_teacher_final_layer_logprobs)
repeated_sft_teacher_final_layer_logprobs = sft_teacher_final_layer_logprobs.repeat(num_exit_samples, 1, 1)     # XXX repeat_interleave? [batch * samples, full length, vocabulary]
# sft_teacher_generated_tokens = sft_teacher_generated_tokens[:, :-1] # Removing the last token to match the log probs and hidden states shape

# Sample early exits
batch, gen_len, elayers = early_exit_probs.shape                                                                                                # [batch, generation length, exitable layers]
full_len = sft_teacher_generated_tokens.shape[1]
repeated_sft_teacher_generated_tokens = sft_teacher_generated_tokens.expand(num_exit_samples * batch, full_len)                                 # [batch * samples, full length]
sampled_early_exit_layer_idxs_early_with_sample_dim = torch.distributions.Categorical(probs = early_exit_probs).sample((num_exit_samples,))     # [samples, batch, generation length] 
sampled_early_exit_layer_idxs_early = sampled_early_exit_layer_idxs_early_with_sample_dim.reshape(batch * num_exit_samples, gen_len)            # [batch * samples, generation length]
sampled_early_exit_layer_idxs = model.exitable_layer_idxs[sampled_early_exit_layer_idxs_early.cpu()]                                            # [batch * samples, generation length]
      

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 22])
<｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.
<｜User｜> What is 17 x 18
<｜Assistant｜> 
I need to calculate 17 multiplied by 18. 

First, I'll break down 18 into 10 and 8. 

Next, I'll multiply 17 by 10, which gives me 170. 

Then, I'll multiply 17 by 8, resulting in 136. 

Finally, I'll add 170 and 136 together to get the final answer of 306. 

So, 17 multiplied by 18 equals 306.
</think>

**Solution:**

To calculate \( 17 \times 18 \), we can use the distributive property of multiplication.

\[
17 \times 18 = 17 \times (10 + 8) = (17 \times 10) + (17 \times 8)
\]

Now, perform the individual multiplications:

\[
17 \times 10 = 170
\]
\[
17 \times 8 = 136
\]

Add the results together:

\[
170 + 136 = 306
\]

**Final Answer:**

\[
\boxed{306}
\]<｜end▁of▁sentence｜>


In [18]:
sampled_early_exit_layer_idxs.shape

torch.Size([1, 269])

In [20]:
tokens = safe_decode_tokens(tokenizer, sft_teacher_generated_tokens[0, 22:])
layers = [27 if item == torch.inf or item == -1 else int(item) for item in sampled_early_exit_layer_idxs[0]]
display(visualize_tokens_by_exit_layer(tokens, layers, [int(item) for item in model.exitable_layer_idxs[:-1]], 
                                     title="Committed Early Exit Token Generation"))