In [1]:
import torch
import sys
sys.path.append("../")
from shared_utils.load import get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text
from early_exit.util import get_model, load_model
from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

model_path = "../models/early_exit_20250811_layers_5_lora32/step_2000"
base_model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
config_path = "../config_deepseek.yaml"
config = "../config_deepseek.yaml"
device = "cuda"

In [2]:
tokenizer = get_tokenizer(base_model)
config = configs_from_yaml(config_path, tokenizer.eos_token_id)

base_model = get_model(base_model_name, config['model'], device)
model = replace_attention_layers(base_model, config['lora'], device)
model = load_model(model, model_path)

print(f"Model loaded w exitable layers: {model.exitable_layer_idxs}")

replacing generate_layer_type_without_early_exit_decision_head layer model.layers.0
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.1
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.2
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.3
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.4
replacing layer model.layers.5
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.6
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.7
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.8
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.9
replacing layer model.layers.10
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.11
replacing generate_layer_type_without_early_exit_decision_head layer model.layers.12
replacing g

In [3]:
set_transformer_early_exit_mode(model, 'free_generate')

In [4]:
prompt = "Explain the concept of recursion in programming."

system_prompt = "You are a helpful programming tutor."
prefiller = ""

In [5]:
with torch.no_grad():
    try:
        free_generate_response, exit_info = generate_text(
            model=model,
            prompt=prompt,
            system_prompt=system_prompt,
            prefiller=prefiller,
            tokenizer=tokenizer,
            generation_config=config['generation'],
            device=device
        )
        
        print(f"Free Generate Response: {free_generate_response}")
        print(f"Exit info: {exit_info}")
        
    except Exception as e:
        print(f"Free generate mode failed: {e}")

transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
Free Generate Response: <｜begin▁of▁sentence｜><｜Assistant｜> You are a helpful programming tutor.
<｜User｜> Explain the concept of recursion in programming.
<｜Assistant｜> 
Okay, so I need to explain the conceptually meaning of recursion in programming. Hmm okay, let's start from scratch because I want to understand how it works conceptually.

First, I remember that recursion is often used in algorithms and algorithms often naturally extensions of other problems. Like when you have a recursive function/functionality/methodicality often ens).* 

So, recursion is when aquired recursively,那就是 meaning it calls itself recursively (same function).*)*所以在每 subsequent step of our function, tasks are performed-on/off/offly the same as othersiable othersiables/operators operations operations.`)*)*)*所以在 essenceCODing好的)*)*)*每次函数的一部分 ones step ️ stepstog

In [6]:
from inspect_ai import Task, task, eval
from inspect_ai.dataset import Sample, MemoryDataset
from inspect_ai.scorer import model_graded_fact, Score, Target, scorer, accuracy
from inspect_ai.solver import generate, system_message, use_tools
from inspect_ai.model import get_model as get_inspect_model
from typing import List, Dict, Any

In [7]:
def generate_with_metrics(model, tokenizer, config, prompt: str, system_prompt: str) -> Dict[str, Any]:
    with torch.no_grad():
        response, exit_info = generate_text(
            model=model,
            prompt=prompt,
            system_prompt=system_prompt,
            prefiller="",
            tokenizer=tokenizer,
            generation_config=config['generation'],
            device=model.device
        )
        
        if len(exit_info) >= 2 and hasattr(exit_info[1], 'shape'):
            exit_layers = exit_info[1]
            
            if len(exit_layers.shape) > 1:
                # For shape [batch, sequence], we want sequence length
                total_tokens = exit_layers.shape[1]
                exit_layers_flat = exit_layers.flatten()  # Flatten to handle batch dimension
            else:
                # For shape [sequence], use length directly
                total_tokens = len(exit_layers)
                exit_layers_flat = exit_layers
            
            #actual tokens generated (non-inf values indicate early exit occurred)
            finite_exits = exit_layers_flat[exit_layers_flat != float('inf')]
            early_exits = len(finite_exits)
            early_exit_rate = early_exits / total_tokens if total_tokens > 0 else 0
            
            #exits per layer
            unique_layers, counts = torch.unique(finite_exits, return_counts=True)
            layer_distribution = {}
            if len(unique_layers) > 0:
                for layer, count in zip(unique_layers, counts):
                    layer_distribution[int(layer.item())] = count.item()
        else:
            print("Debug - exit_info structure is unexpected")
            early_exits = 0
            total_tokens = 0
            early_exit_rate = 0
            layer_distribution = {}
        
        return {
            "response": response,
            "total_tokens": total_tokens,
            "early_exits": early_exits,
            "early_exit_rate": early_exit_rate,
            "layer_distribution": layer_distribution
        }

In [8]:
def create_early_exit_dataset(model, tokenizer, config) -> List[Sample]:
    """Create dataset by generating responses with the early exit model"""
    
    test_prompts = [
        "Explain the concept of recursion in programming.",
        "Explain what object-oriented programming means using a simple example.",
        "How do neural networks learn from data?",
        "Write a simple Python function to calculate factorial and explain how it works.",
        "Describe the benefits of using version control systems like Git.",
        "What are the key differences between supervised and unsupervised learning?",
        "Describe how HTTP works and what happens when you visit a website."
    ]
    
    system_prompt = "You are a helpful programming tutor."
    
    samples = []
    
    for i, prompt in enumerate(test_prompts):
        print(f"Generating response {i+1}/{len(test_prompts)}: {prompt[:50]}...")
        
        try:
            result = generate_with_metrics(model, tokenizer, config, prompt, system_prompt)
            
            # Create Sample for Inspect
            sample = Sample(
                id=f"early_exit_{i+1}",
                input=prompt,
                target="", #don't need one
                metadata={
                    "system_prompt": system_prompt,
                    "total_tokens": result["total_tokens"],
                    "early_exits": result["early_exits"],
                    "early_exit_rate": result["early_exit_rate"],
                    "layer_distribution": result["layer_distribution"],
                    "response": result["response"]
                }
            )
            samples.append(sample)
            
            print(f"  Generated {result['total_tokens']} tokens, {result['early_exits']} early exits ({result['early_exit_rate']:.1%})")
            
        except Exception as e:
            print(f"  Failed: {e}")
            continue
    
    return samples

In [9]:
def generate_samples_for_eval(model, tokenizer, config):
    """Generate samples using your model and store them globally"""
    global _samples
    print("Creating dataset with early exit model...")
    _samples = create_early_exit_dataset(model, tokenizer, config)
    print(f"Generated {len(_samples)} samples")

In [10]:
_samples = []

In [11]:
generate_samples_for_eval(model, tokenizer, config)

Creating dataset with early exit model...
Generating response 1/7: Explain the concept of recursion in programming....
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
  Generated 400 tokens, 287 early exits (71.8%)
Generating response 2/7: Explain what object-oriented programming means usi...
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 23])
  Generated 400 tokens, 264 early exits (66.0%)
Generating response 3/7: How do neural networks learn from data?...
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
  Generated 400 tokens, 250 early exits (62.5%)
Generating response 4/7: Write a simple Python function to calculate factor...
transform_conversations currently only for Deepseek models

In [12]:
@scorer(metrics=[accuracy()])
def coherence_scorer():  
    async def score(state, target):
        # Get the generated response from metadata
        response = state.metadata.get("response", "")
        prompt = state.input
        early_exit_rate = state.metadata.get("early_exit_rate", 0)
        
        # Evaluation prompt for the judge model
        eval_prompt = f"""
You are evaluating the coherence and quality of an AI assistant's response.

Original Question: {prompt}

AI Response: {response}

Please evaluate this response on the following criteria:
1. Coherence and logical flow (1-10 scale)
2. Completeness of the answer (1-10 scale)  
3. Clarity and readability (1-10 scale)
4. Absence of repetition or errors (1-10 scale)

Rate each criterion and provide an overall score from 1-10:
- 1: Poor (major issues)
- 4: Fair (some issues)
- 7: Good (minor issues)
- 10: Excellent (no issues)

Format your response as:
Coherence: X/10
Completeness: X/10
Clarity: X/10
No Repetition: X/10
Overall: X/40

Brief explanation: [your reasoning]
"""
        
        # Use the judge model to evaluate
        judge_model = get_inspect_model("openai/gpt-5")
        
        # Generate evaluation
        eval_result = await judge_model.generate(eval_prompt)
        
        # Parse the score (simplified - you might want more robust parsing)
        eval_text = eval_result.completion
        overall_score = 0
        
        # Extract overall score
        for line in eval_text.split('\n'):
            if 'Overall:' in line:
                try:
                    score_part = line.split(':')[1].strip()
                    overall_score = int(score_part.split('/')[0])
                    break
                except:
                    pass
        
        # Convert to 0-1 scale for accuracy metric
        accuracy_score = overall_score / 40.0 if overall_score > 0 else 0
        
        return Score(
            value=accuracy_score,
            answer=str(overall_score),
            explanation=f"Coherence evaluation (Early exit rate: {early_exit_rate:.1%}): {eval_text}"
        )
    
    return score

In [13]:
def early_exit_coherence_eval():
    """Inspect task to evaluate early exit model coherence"""
    
    if not _samples:
        raise ValueError("Must call generate_samples_for_eval() first to create samples")
    
    return Task(
        dataset=MemoryDataset(_samples),
        plan=[
            system_message("You are evaluating AI responses for coherence and quality.")
        ],
        scorer=coherence_scorer()
    )

In [20]:
eval_results = eval(early_exit_coherence_eval(), model='openai/gpt-5', log_dir='./eval_logs')

Output()

In [71]:
log = eval_results[0]

for i, sample in enumerate(log.samples, 1):
    print(f"SAMPLE {i}")
    
    exit_rate = sample.metadata['early_exit_rate']
    total_tokens = sample.metadata['total_tokens']
    early_exits = sample.metadata['early_exits']
    layer_dist = sample.metadata['layer_distribution']
    
    model_response = sample.metadata['response']
    clean_response = model_response.replace('<｜begin▁of▁sentence｜>', '').replace('<｜end▁of▁sentence｜>', '')
    print(f"{clean_response}")
    print(f"   Exit Rate: {exit_rate:.1%} ({early_exits}/{total_tokens} tokens)")
    print(f"   Layer Distribution: {layer_dist}")
    
    coherence_score = sample.scores['coherence_scorer']
    print(f"   Score: {coherence_score.value:.2f}/1.0")
    explanation_lines = coherence_score.explanation.split('\n')
    for line in explanation_lines:
        if line.strip():
            print(f"     {line.strip()}")
    print()

avg_coherence = sum(sample.scores['coherence_scorer'].value for sample in log.samples) / len(log.samples)
avg_exit_rate = sum(sample.metadata['early_exit_rate'] for sample in log.samples) / len(log.samples)

print(f"Average Coherence: {avg_coherence:.2f}/1.0")
print(f"Average Early Exit Rate: {avg_exit_rate:.1%}")

SAMPLE 1
<｜Assistant｜> You are a helpful programming tutor.
<｜User｜> Explain the concept of recursion in programming.
<｜Assistant｜> 
Okay, so I need to explain how they teach precedence-related stuff in programming. Hmm, precedence refers to operators having certain priority priorities-sidednessedly Presenter/reacterouteedly. Wait, that sounds apropri suited forbidden️yen dul `, strategorientedness, etc. Maybe I should break it down differently.

Letting the user understand that " precedence-related stuff often refers exceptions or other related constructs rather than directly relatedness or other ` `s. So, perhaps I should clarify how precedence is taught in a programming context, maybe甚至是事情 like operator precedence priorities priorities themselves themselves. 

Wait, the user mentioned ` ` `a ` accents spelling `s make sense. So, ` word spoken ` ` `a `s `s `s actuallys ` `s ` `ors `s manageseringing `en ` ` `s `sInstantiateed compoundate `ins `en twentieseeernIngreceei ` ` `s️ `s `s 

In [21]:
log = eval_results[0]

for i, sample in enumerate(log.samples, 1):
    print(f"SAMPLE {i}")
    
    exit_rate = sample.metadata['early_exit_rate']
    total_tokens = sample.metadata['total_tokens']
    early_exits = sample.metadata['early_exits']
    layer_dist = sample.metadata['layer_distribution']
    
    model_response = sample.metadata['response']
    clean_response = model_response.replace('<｜begin▁of▁sentence｜>', '').replace('<｜end▁of▁sentence｜>', '')
    print(f"{clean_response}")
    print(f"   Exit Rate: {exit_rate:.1%} ({early_exits}/{total_tokens} tokens)")
    print(f"   Layer Distribution: {layer_dist}")
    
    coherence_score = sample.scores['coherence_scorer']
    print(f"   Score: {coherence_score.value:.2f}/1.0")
    explanation_lines = coherence_score.explanation.split('\n')
    for line in explanation_lines:
        if line.strip():
            print(f"     {line.strip()}")
    print()

avg_coherence = sum(sample.scores['coherence_scorer'].value for sample in log.samples) / len(log.samples)
avg_exit_rate = sum(sample.metadata['early_exit_rate'] for sample in log.samples) / len(log.samples)

print(f"Average Coherence: {avg_coherence:.2f}/1.0")
print(f"Average Early Exit Rate: {avg_exit_rate:.1%}")

SAMPLE 1
<｜Assistant｜> You are a helpful programming tutor.
<｜User｜> Explain the concept of recursion in programming.
<｜Assistant｜> 
Okay, so I need to explain the concept of recursion in its own words. Hmm okay, recursion itself itself itself... Let me break it down.

First, what is recursion? From what I remember, it is often referred "\( recursive" orthon-generated by virtue of its recursive natureself. But wait, that's a bit circular. maybe I should think of it as a method where a function or other entities recursively ( meaning meaning meaningfully) refers back/update itselfg itself-generated itself.

So, in simpler terms, recursion often involves solving problems by breaking it down into smaller pieces. Like how you guys living or something? Or maybe说是 NPCc stuff happening recursively.

Let me think of a example. Like calculating parity classed by yourselfself-China -, like those recursive functions that determine whetheris  orthonormalerailer feedback loops back回到/remains the sa