Cell 1: Configuration and setup

In [18]:
# ===== EXPERIMENT CONFIGURATION =====
CONFIG = {
    # Core experiment parameters
    "experiment_type": "generative",  # "discriminative" or "generative"
    "classification_type": "ternary",   # "binary" or "ternary"
    "dataset_strategy": "4N",          # "4N" or "3N" (generative only)
    "include_explanation": True,      # True or False (generative only)
    "include_eln": True,              # True or False (generative only)
    "solution_format": "nl",        # "dict" or "nl" (generative only)
    "model_name": "microsoft/phi-4-mini-instruct",  # or "Qwen/Qwen3-4B"
    
    # Prompting configuration
    "system_prompt": None,  # Will auto-generate if None, or use custom string
    "include_examples": False,
    "num_examples": 3,
    "example_strategy": "balanced",  # "balanced", "error_focused", "custom"
    
    # Training parameters
    "learning_rate": 2e-4,
    "num_epochs": 3,
    "batch_size": 8,
    "max_length": 512,
    "gradient_accumulation_steps": 4,
    
    # Infrastructure
    "use_lora": True,
    "lora_rank": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    
    # Paths and tokens
    # "base_dataset_dir": "/content/drive/MyDrive/sft_datasets",
    "base_dataset_dir": "../data/base-datasets-sanitized",
    "output_base_dir": "/content/drive/MyDrive/sft_experiments",
    # "hf_token": "your_huggingface_token_here",
    # "wandb_project": "math_error_classification",
    
    # Experiment tracking
    "save_to_hf": True,
    "save_locally": True,
    "use_wandb": False
}

# Generate experiment ID
import datetime
experiment_components = [
    CONFIG["experiment_type"][:4],  # "gene" or "disc"
    CONFIG["classification_type"][:3],  # "bin" or "ter"
    CONFIG["dataset_strategy"] if CONFIG["experiment_type"] == "generative" else "",
    "exp" if CONFIG["include_explanation"] else "no_exp",
    "eln" if CONFIG["include_eln"] else "no_eln",
    CONFIG["solution_format"] if CONFIG["experiment_type"] == "generative" else "",
    "qwen" if "Qwen" in CONFIG["model_name"] else "phi4"
]
experiment_id = "_".join([c for c in experiment_components if c]) + "_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["experiment_id"] = experiment_id

print(f"Experiment ID: {experiment_id}")
print(f"Configuration loaded successfully!")

Experiment ID: gene_ter_4N_exp_eln_nl_phi4_20250731_132248
Configuration loaded successfully!


Cell 2: Import dependencies

In [19]:
# # Core libraries
# import os
# import json
# import time
# import datetime
# from pathlib import Path
# from typing import List, Dict, Any, Optional
# import random
# import numpy as np

# # Data handling
# import pandas as pd
# from datasets import Dataset, load_dataset

# # ML libraries
# import torch
# from transformers import (
#     AutoTokenizer, 
#     AutoModelForCausalLM,
#     TrainingArguments, 
#     Trainer, 
#     DataCollatorForLanguageModeling,
#     BitsAndBytesConfig
# )
# from peft import (
#     LoraConfig, 
#     get_peft_model, 
#     TaskType, 
#     prepare_model_for_kbit_training
# )

# # Logging and tracking
# from tqdm.auto import tqdm

# # Google Drive mounting (for Colab)
# try:
#     from google.colab import drive
#     drive.mount('/content/drive')
#     print("Google Drive mounted successfully!")
# except ImportError:
#     print("Not running in Colab - skipping Drive mount")

import torch
import random
import numpy as np

# Set random seeds for reproducibility
def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seeds(42)
print("Dependencies imported and seeds set!")

Dependencies imported and seeds set!


Cell 3: System Prompt Generation

In [20]:
def generate_system_prompt(config):
    """Auto-generates appropriate system prompt based on config"""
    
    if config["experiment_type"] == "discriminative":
        return "You are a mathematics tutor. Classify the given solution."
    
    # Generative prompts
    base_prompt = "You are a mathematics tutor. Analyze the given solution and provide your assessment in JSON format."
    
    # Add classification instructions
    if config["classification_type"] == "binary":
        base_prompt += " Determine if the solution is 'correct' or 'flawed'."
    else:
        base_prompt += " Classify as 'correct', 'conceptual_error', or 'computational_error'."
    
    # Add field instructions
    fields = []
    if config["include_eln"]:
        if config["solution_format"] == "dict":
            fields.append("identify the erroneous line number (e.g., 'L1', 'FA')")
        else:
            fields.append("quote the full erroneous line text")
    
    if config["include_explanation"]:
        fields.append("provide a brief explanation of any error")
    
    if fields:
        base_prompt += f" Also {', and '.join(fields)}."
    
    base_prompt += " Respond only with valid JSON."
    
    return base_prompt

# Auto-generate system prompt if not provided
if CONFIG["system_prompt"] is None:
    CONFIG["system_prompt"] = generate_system_prompt(CONFIG)

print("System Prompt:")
print(CONFIG["system_prompt"])
print()

# Allow manual override
print("To customize the system prompt, run:")
print('CONFIG["system_prompt"] = "Your custom prompt here"')

System Prompt:
You are a mathematics tutor. Analyze the given solution and provide your assessment in JSON format. Classify as 'correct', 'conceptual_error', or 'computational_error'. Also quote the full erroneous line text, and provide a brief explanation of any error. Respond only with valid JSON.

To customize the system prompt, run:
CONFIG["system_prompt"] = "Your custom prompt here"


Cell 4: Example Manager

In [21]:
class ExampleManager:
    def __init__(self, base_dataset, config):
        # Convert DataFrame to list of dicts if needed
        if hasattr(base_dataset, 'to_dict'):  # It's a DataFrame
            self.samples = base_dataset.to_dict('records')
        else:
            self.samples = base_dataset  # Already a list of dicts
            
        self.config = config
        self._prepare_examples_by_problem()
    
    def _prepare_examples_by_problem(self):
        """Organizes samples by problem index and error type for problem-based sampling"""
        self.problems_by_type = {
            "correct": {},
            "conceptual_error": {},
            "computational_error": {}
        }
        
        # Group samples by problem index and error type
        for sample in self.samples:
            problem_index = sample["index"]
            error_type = sample["error_type"]
            
            if problem_index not in self.problems_by_type[error_type]:
                self.problems_by_type[error_type][problem_index] = []
            self.problems_by_type[error_type][problem_index].append(sample)
        
        # Convert to lists of problem indices for easier sampling
        self.problem_indices_by_type = {
            "correct": list(self.problems_by_type["correct"].keys()),
            "conceptual_error": list(self.problems_by_type["conceptual_error"].keys()),
            "computational_error": list(self.problems_by_type["computational_error"].keys())
        }
        
        print(f"Problems by type: {[(k, len(v)) for k, v in self.problem_indices_by_type.items()]}")
    
    def get_examples(self):
        """Returns appropriate few-shot examples based on distinct problems"""
        if not self.config["include_examples"]:
            return []
        
        strategy = self.config["example_strategy"]
        num_problems = self.config["num_examples"]  # Now refers to number of distinct problems
        
        if strategy == "balanced":
            return self._get_balanced_examples(num_problems)
        elif strategy == "error_focused":
            return self._get_error_focused_examples(num_problems)
        else:
            return []
    
    def _get_samples_for_problem(self, problem_index, error_types_needed):
        """Get all required samples for a specific problem"""
        samples = []
        for error_type in error_types_needed:
            if problem_index in self.problems_by_type[error_type]:
                # Take the first sample of this error type for this problem
                samples.append(self.problems_by_type[error_type][problem_index][0])
        return samples
    
    def _get_balanced_examples(self, n):
        """Gets balanced representation across error types, sampling by distinct problems"""
        examples = []
        dataset_strategy = self.config["dataset_strategy"]
        classification_type = self.config["classification_type"]
        
        if classification_type == "binary":
            # For binary: each problem contributes correct + flawed versions
            if dataset_strategy == "3N":
                # 3N dataset: each problem has all 3 types, choose correct + one error type
                available_problems = set(self.problem_indices_by_type["correct"]) & \
                                set(self.problem_indices_by_type["conceptual_error"]) & \
                                set(self.problem_indices_by_type["computational_error"])
                available_problems = list(available_problems)
                
                selected_problems = random.sample(available_problems, min(n, len(available_problems)))
                
                for problem_index in selected_problems:
                    # Add correct version
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct"]))
                    # Add one error type (randomly choose between conceptual and computational)
                    error_type = random.choice(["conceptual_error", "computational_error"])
                    examples.extend(self._get_samples_for_problem(problem_index, [error_type]))
                    
            elif dataset_strategy == "4N":
                # 4N dataset: each problem has correct + one specific error type
                # Sample from problems that have both correct and error versions
                available_problems_conceptual = set(self.problem_indices_by_type["correct"]) & \
                                            set(self.problem_indices_by_type["conceptual_error"])
                available_problems_computational = set(self.problem_indices_by_type["correct"]) & \
                                                set(self.problem_indices_by_type["computational_error"])
                
                # Balance between conceptual and computational problems
                n_conceptual = n // 2
                n_computational = n - n_conceptual
                
                selected_conceptual = random.sample(
                    list(available_problems_conceptual), 
                    min(n_conceptual, len(available_problems_conceptual))
                )
                selected_computational = random.sample(
                    list(available_problems_computational), 
                    min(n_computational, len(available_problems_computational))
                )
                
                # Add samples for conceptual problems
                for problem_index in selected_conceptual:
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct", "conceptual_error"]))
                
                # Add samples for computational problems  
                for problem_index in selected_computational:
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct", "computational_error"]))
                    
        else:  # ternary classification
            if dataset_strategy == "3N":
                # 3N dataset: each problem should have all 3 types
                available_problems = set(self.problem_indices_by_type["correct"]) & \
                                set(self.problem_indices_by_type["conceptual_error"]) & \
                                set(self.problem_indices_by_type["computational_error"])
                available_problems = list(available_problems)
                
                selected_problems = random.sample(available_problems, min(n, len(available_problems)))
                
                for problem_index in selected_problems:
                    # Add all 3 versions of each selected problem
                    examples.extend(self._get_samples_for_problem(
                        problem_index, 
                        ["correct", "conceptual_error", "computational_error"]
                    ))
                    
            elif dataset_strategy == "4N":
                # Make n even to ensure balanced sampling
                n_adjusted = n + (n % 2)
                
                # Get problems that have conceptual errors
                conceptual_problems = list(set(self.problem_indices_by_type["correct"]) & 
                                        set(self.problem_indices_by_type["conceptual_error"]))
                # Get problems that have computational errors  
                computational_problems = list(set(self.problem_indices_by_type["correct"]) & 
                                            set(self.problem_indices_by_type["computational_error"]))
                
                # Use exactly n/2 of each type
                n_per_type = n_adjusted // 2
                
                selected_conceptual = random.sample(
                    conceptual_problems, 
                    min(n_per_type, len(conceptual_problems))
                )
                selected_computational = random.sample(
                    computational_problems, 
                    min(n_per_type, len(computational_problems))
                )
                
                # Add samples for conceptual problems (correct + conceptual_error)
                for problem_index in selected_conceptual:
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct", "conceptual_error"]))
                    
                # Add samples for computational problems (correct + computational_error)
                for problem_index in selected_computational:
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct", "computational_error"]))
        
        return examples
    
    def _get_error_focused_examples(self, n):
        """Gets examples focused on error types, sampling by distinct problems"""
        examples = []
        dataset_strategy = self.config["dataset_strategy"]
        
        if dataset_strategy == "3N":
            # For 3N, prioritize problems that have errors, include fewer correct-only examples
            available_problems = set(self.problem_indices_by_type["correct"]) & \
                               set(self.problem_indices_by_type["conceptual_error"]) & \
                               set(self.problem_indices_by_type["computational_error"])
            available_problems = list(available_problems)
            
            selected_problems = random.sample(available_problems, min(n, len(available_problems)))
            
            for problem_index in selected_problems:
                # For error-focused, include both error types but only sometimes the correct version
                examples.extend(self._get_samples_for_problem(
                    problem_index, 
                    ["conceptual_error", "computational_error"]
                ))
                
                # Add correct version for only some problems (1/3 chance)
                if random.random() < 0.33:
                    examples.extend(self._get_samples_for_problem(problem_index, ["correct"]))
                    
        elif dataset_strategy == "4N":
            # For 4N, prioritize error problems over correct-only problems
            error_problems = list(
                (set(self.problem_indices_by_type["conceptual_error"]) | 
                 set(self.problem_indices_by_type["computational_error"])) &
                set(self.problem_indices_by_type["correct"])
            )
            
            selected_problems = random.sample(error_problems, min(n, len(error_problems)))
            
            for problem_index in selected_problems:
                # Add correct version
                examples.extend(self._get_samples_for_problem(problem_index, ["correct"]))
                
                # Add the error version that exists for this problem
                if problem_index in self.problem_indices_by_type["conceptual_error"]:
                    examples.extend(self._get_samples_for_problem(problem_index, ["conceptual_error"]))
                elif problem_index in self.problem_indices_by_type["computational_error"]:
                    examples.extend(self._get_samples_for_problem(problem_index, ["computational_error"]))
        
        return examples

print("Updated ExampleManager class loaded!")

Updated ExampleManager class loaded!


Cell 5: Dataset Loading and Formatting Functions

In [22]:
import json
import pandas as pd
from pathlib import Path

def load_base_dataset(config):
    """Loads the appropriate base dataset"""
    dataset_strategy = config["dataset_strategy"]
    base_dir = Path(config["base_dataset_dir"])
    
    dataset_file = base_dir / f"base_{dataset_strategy}_dataset_sanitized.csv"
    
    if not dataset_file.exists():
        raise FileNotFoundError(f"Base dataset not found: {dataset_file}")

    data = pd.read_csv(dataset_file)

    print(f"Loaded base {dataset_strategy} dataset with {len(data)} samples")
    return data

def format_solution(sample, config):
    """Formats solution according to config - updated for CSV structure"""
    # Use wrong_answer for the solution (this contains the solution steps)
    solution_text = sample.get('wrong_answer', sample.get('correct_answer', ''))
    
    if config["solution_format"] == "dict":
        # Split solution into lines and format as dict
        lines = solution_text.strip().split('\n')
        solution = {}
        for i, line in enumerate(lines[:-1]):
            if line.strip():  # Skip empty lines
                solution[f"L{i+1}"] = line.strip()
        if lines and lines[-1].strip():
            solution["FA"] = lines[-1].strip()
        return json.dumps(solution, indent=2)
    else:
        return solution_text.strip()

def format_expected_output(sample, config):
    """Creates the expected JSON output for a sample - updated for CSV structure"""
    output = {}
    
    # Verdict
    if config["classification_type"] == "binary":
        output["verdict"] = "correct" if sample["error_type"] == "correct" else "flawed"
    else:
        output["verdict"] = sample["error_type"]
    
    # ELN (Erroneous Line Number)
    if config["include_eln"]:
        if sample["error_type"] != "correct":
            if config["solution_format"] == "dict":
                output["erroneous_line_number"] = sample.get("erroneous_line_number", None)
            else:
                # For natural language format, try to extract the actual erroneous line text
                eln = sample.get("erroneous_line_number")
                if eln and pd.notna(eln):
                    # Extract line number (e.g., "L3" -> 3)
                    try:
                        if eln.startswith('L'):
                            line_num = int(eln[1:]) - 1
                            solution_lines = sample.get('wrong_answer', '').strip().split('\n')
                            if 0 <= line_num < len(solution_lines):
                                output["erroneous_line"] = solution_lines[line_num].strip()
                            else:
                                output["erroneous_line"] = eln  # Fallback to the ELN itself
                        elif eln == 'FA':
                            solution_lines = sample.get('wrong_answer', '').strip().split('\n')
                            output["erroneous_line"] = solution_lines[-1].strip() if solution_lines else None
                        else:
                            output["erroneous_line"] = eln
                    except:
                        output["erroneous_line"] = eln
                else:
                    output["erroneous_line"] = None
        else:
            key = "erroneous_line_number" if config["solution_format"] == "dict" else "erroneous_line"
            output[key] = None
    
    # Explanation
    if config["include_explanation"]:
        explanation = sample.get("explanation")
        output["explanation"] = explanation if pd.notna(explanation) and sample["error_type"] != "correct" else None
    
    return json.dumps(output)

def format_user_message(sample, config):
    """Format a sample into a user message."""
    return f"### Question:\n{sample['question']}\n\n### Answer:\n{format_solution(sample, config)}"

def create_sample_messages(sample, examples, config):
    """Create complete message list for a sample."""
    messages = []
    
    # System message
    messages.append({
        "role": "system",
        "content": config["system_prompt"]
    })
    
    # Few-shot examples
    if config["include_examples"]:
        for example in examples:
            user_content = format_user_message(example, config)
            assistant_content = format_expected_output(example, config)
            
            messages.append({"role": "user", "content": user_content})
            messages.append({"role": "assistant", "content": assistant_content})
    
    # Actual sample
    user_content = format_user_message(sample, config)
    messages.append({"role": "user", "content": user_content})
    
    return messages

print("Updated formatting functions loaded!")

Updated formatting functions loaded!


Cell 6: Dataset Preparation

In [23]:
def prepare_dataset(config):
    """Loads and formats complete dataset for training/evaluation"""
    
    # Load base dataset
    base_samples = load_base_dataset(config)
    
    # Initialize example manager - convert DataFrame to list of dicts
    example_manager = ExampleManager(base_samples, config)
    examples = example_manager.get_examples()
    
    print(f"Using {len(examples)} few-shot examples")
    
    # Format all samples
    formatted_data = []
    
    # Iterate over DataFrame rows
    for idx, sample in base_samples.iterrows():
        # Convert pandas Series to dict for easier access
        sample_dict = sample.to_dict()
        
        # Use modular function to create messages
        messages = create_sample_messages(sample_dict, examples, config)
        
        # Expected output
        expected_output = format_expected_output(sample_dict, config)
        
        formatted_data.append({
            "id": sample_dict.get("id", f"sample_{len(formatted_data)}"),
            "messages": messages,
            "expected_output": expected_output,
            "metadata": {
                "error_type": sample_dict["error_type"],
                "tier": sample_dict.get("tier", "unknown"),
                "source": sample_dict.get("source", "unknown")
            }
        })
    
    # Split into train/eval
    split_point = int(0.8 * len(formatted_data))
    train_data = formatted_data[:split_point]
    eval_data = formatted_data[split_point:]
    
    print(f"Dataset prepared: {len(train_data)} training, {len(eval_data)} evaluation samples")
    
    return train_data, eval_data, examples

Cell 7: Model and Tokenizer Loading

In [24]:
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    AutoModelForSequenceClassification
)

from transformers.utils.quantization_config import BitsAndBytesConfig

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType, 
    prepare_model_for_kbit_training
)

def load_tokenizer(model_name):
    """Loads tokenizer with proper configuration"""
    print(f"Loading tokenizer: {model_name}")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"✓ Tokenizer loaded successfully!")
    return tokenizer

def load_model(model_name, config):
    """Loads model with appropriate configuration"""
    print(f"Loading model: {model_name}")
    
    # Configure quantization if using LoRA
    bnb_config = None
    if config["use_lora"]:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
    
    # Load model based on experiment type
    if config["experiment_type"] == "discriminative":
        # For discriminative, we need a classification model
        num_labels = 2 if config["classification_type"] == "binary" else 3
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True
        )
    else:
        # For generative, use causal LM
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            device_map="auto",
            trust_remote_code=True,
            attn_implementation="flash_attention_2"
        )
    
    # Apply LoRA if configured
    if config["use_lora"]:
        model = prepare_model_for_kbit_training(model)
        
        # Configure LoRA based on experiment type
        if config["experiment_type"] == "discriminative":
            task_type = TaskType.SEQ_CLS
            target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
        else:
            task_type = TaskType.CAUSAL_LM
            target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        
        lora_config = LoraConfig(
            task_type=task_type,
            r=config["lora_rank"],
            lora_alpha=config["lora_alpha"],
            lora_dropout=config["lora_dropout"],
            target_modules=target_modules,
            bias="none"
        )
        
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    
    print(f"✓ Model loaded successfully!")
    print(f"✓ Model device: {next(model.parameters()).device}")
    
    return model

def apply_chat_template(messages, tokenizer, add_generation_prompt=False, tokenize=True, **kwargs):
    """
    Applies chat template to messages with consistent interface
    
    Args:
        messages: List of message dictionaries with 'role' and 'content' keys
        tokenizer: The tokenizer to use for formatting
        add_generation_prompt: Whether to add generation prompt (for inference)
        tokenize: Whether to return tokens (True) or text (False)
        **kwargs: Additional arguments for tokenizer (like return_tensors, max_length, etc.)
    
    Returns:
        If tokenize=True: tokenizer output dict with input_ids, attention_mask, etc.
        If tokenize=False: formatted text string
    """
    
    # Apply chat template to get formatted text
    formatted_text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=add_generation_prompt
    )
    
    # Return text if not tokenizing
    if not tokenize:
        return formatted_text
    
    # Tokenize and return tensor format
    return tokenizer(formatted_text, **kwargs)

def load_model_and_tokenizer(config):
    """
    Convenience function that loads both model and tokenizer
    Uses the modular functions above
    """
    model_name = config["model_name"]
    
    # Load components separately
    tokenizer = load_tokenizer(model_name)
    model = load_model(model_name, config)
    
    return model, tokenizer

In [34]:
# ===== TOKEN COUNT AND CHARACTER ANALYSIS =====
print("="*100)
print("COMPREHENSIVE TOKEN COUNT AND CHARACTER ANALYSIS")
print("="*100)

import json
import pandas as pd
from collections import defaultdict

# Configuration combinations to test
models = ["microsoft/phi-4-mini-instruct", "Qwen/Qwen3-4B"]
solution_formats = ["nl", "dict"]
dataset_strategies = ["3N", "4N"]
classification_types = ["binary", "ternary"]
num_examples_list = [0, 1, 2]  # 0 for no examples, 1 and 2 for few-shot

# Storage for all statistics
all_stats = []

try:
    # Load a sample of data for analysis
    print("Loading sample dataset for analysis...")
    sample_config = CONFIG.copy()
    sample_config["dataset_strategy"] = "4N"  # Start with 4N
    base_data = load_base_dataset(sample_config)
    sample_data = base_data.head(50)  # Use 50 samples for analysis
    print(f"✓ Loaded {len(sample_data)} samples for analysis\n")

    total_combinations = len(models) * len(solution_formats) * len(dataset_strategies) * len(classification_types) * len(num_examples_list)
    combo_count = 0

    for model_name in models:
        # Load tokenizer once per model
        print(f"Loading tokenizer for {model_name}...")
        tokenizer = load_tokenizer(model_name)
        
        for solution_format in solution_formats:
            for dataset_strategy in dataset_strategies:
                for classification_type in classification_types:
                    for num_examples in num_examples_list:
                        combo_count += 1
                        
                        print(f"\n{'='*80}")
                        print(f"COMBINATION {combo_count}/{total_combinations}")
                        print(f"Model: {model_name.split('/')[-1]}")
                        print(f"Solution Format: {solution_format}")
                        print(f"Dataset Strategy: {dataset_strategy}")
                        print(f"Classification Type: {classification_type}")
                        print(f"Num Examples: {num_examples}")
                        print(f"{'='*80}")
                        
                        # Create test configuration
                        test_config = CONFIG.copy()
                        test_config.update({
                            'model_name': model_name,
                            'solution_format': solution_format,
                            'dataset_strategy': dataset_strategy,
                            'classification_type': classification_type,
                            'include_examples': num_examples > 0,
                            'num_examples': num_examples,
                            'system_prompt': generate_system_prompt({
                                **test_config,
                                'solution_format': solution_format,
                                'classification_type': classification_type
                            })
                        })
                        
                        try:
                            # Load appropriate dataset
                            if dataset_strategy != sample_config["dataset_strategy"]:
                                test_data = load_base_dataset(test_config)
                                analysis_data = test_data.head(50)
                            else:
                                analysis_data = sample_data
                            
                            # Create example manager
                            example_manager = ExampleManager(analysis_data, test_config)
                            examples = example_manager.get_examples()
                            
                            print(f"Examples generated: {len(examples)}")
                            
                            # Analyze a subset of samples
                            analysis_samples = analysis_data.head(20).to_dict('records')
                            
                            # Statistics storage
                            stats = {
                                'config': {
                                    'model': model_name.split('/')[-1],
                                    'solution_format': solution_format,
                                    'dataset_strategy': dataset_strategy,
                                    'classification_type': classification_type,
                                    'num_examples': num_examples,
                                    'actual_examples_used': len(examples)
                                },
                                'prompt_stats': [],
                                'expected_output_stats': [],
                                'by_error_type': defaultdict(list)
                            }
                            
                            print("Analyzing samples...")
                            for i, sample in enumerate(analysis_samples):
                                # Create messages
                                messages = create_sample_messages(sample, examples, test_config)
                                expected_output = format_expected_output(sample, test_config)
                                
                                # Format prompt
                                formatted_prompt = tokenizer.apply_chat_template(
                                    messages, tokenize=False, add_generation_prompt=True
                                )
                                
                                # Tokenize
                                tokenized_prompt = tokenizer(formatted_prompt, return_tensors="pt")
                                tokenized_expected = tokenizer(expected_output, return_tensors="pt")
                                
                                # Calculate statistics
                                prompt_chars = len(formatted_prompt)
                                prompt_tokens = len(tokenized_prompt["input_ids"][0])
                                expected_chars = len(expected_output)
                                expected_tokens = len(tokenized_expected["input_ids"][0])
                                
                                # Store prompt stats
                                prompt_stat = {
                                    'chars': prompt_chars,
                                    'tokens': prompt_tokens,
                                    'chars_per_token': prompt_chars / prompt_tokens if prompt_tokens > 0 else 0,
                                    'error_type': sample['error_type']
                                }
                                stats['prompt_stats'].append(prompt_stat)
                                
                                # Store expected output stats
                                expected_stat = {
                                    'chars': expected_chars,
                                    'tokens': expected_tokens,
                                    'chars_per_token': expected_chars / expected_tokens if expected_tokens > 0 else 0,
                                    'error_type': sample['error_type']
                                }
                                stats['expected_output_stats'].append(expected_stat)
                                
                                # Group by error type
                                stats['by_error_type'][sample['error_type']].append({
                                    'prompt_chars': prompt_chars,
                                    'prompt_tokens': prompt_tokens,
                                    'expected_chars': expected_chars,
                                    'expected_tokens': expected_tokens
                                })
                            
                            # Calculate summary statistics
                            prompt_chars = [s['chars'] for s in stats['prompt_stats']]
                            prompt_tokens = [s['tokens'] for s in stats['prompt_stats']]
                            expected_chars = [s['chars'] for s in stats['expected_output_stats']]
                            expected_tokens = [s['tokens'] for s in stats['expected_output_stats']]
                            
                            # Add summary stats
                            stats['summary'] = {
                                'prompt': {
                                    'chars': {
                                        'min': min(prompt_chars),
                                        'max': max(prompt_chars),
                                        'mean': sum(prompt_chars) / len(prompt_chars),
                                        'median': sorted(prompt_chars)[len(prompt_chars)//2]
                                    },
                                    'tokens': {
                                        'min': min(prompt_tokens),
                                        'max': max(prompt_tokens),
                                        'mean': sum(prompt_tokens) / len(prompt_tokens),
                                        'median': sorted(prompt_tokens)[len(prompt_tokens)//2]
                                    }
                                },
                                'expected_output': {
                                    'chars': {
                                        'min': min(expected_chars),
                                        'max': max(expected_chars),
                                        'mean': sum(expected_chars) / len(expected_chars),
                                        'median': sorted(expected_chars)[len(expected_chars)//2]
                                    },
                                    'tokens': {
                                        'min': min(expected_tokens),
                                        'max': max(expected_tokens),
                                        'mean': sum(expected_tokens) / len(expected_tokens),
                                        'median': sorted(expected_tokens)[len(expected_tokens)//2]
                                    }
                                }
                            }
                            
                            # Print summary for this combination
                            print(f"\nSUMMARY STATISTICS:")
                            print(f"Samples analyzed: {len(analysis_samples)}")
                            print(f"Examples in prompts: {len(examples)}")
                            
                            print(f"\nPROMPT STATISTICS:")
                            print(f"  Characters: {stats['summary']['prompt']['chars']['min']}-{stats['summary']['prompt']['chars']['max']} "
                                  f"(avg: {stats['summary']['prompt']['chars']['mean']:.0f})")
                            print(f"  Tokens: {stats['summary']['prompt']['tokens']['min']}-{stats['summary']['prompt']['tokens']['max']} "
                                  f"(avg: {stats['summary']['prompt']['tokens']['mean']:.0f})")
                            
                            print(f"\nEXPECTED OUTPUT STATISTICS:")
                            print(f"  Characters: {stats['summary']['expected_output']['chars']['min']}-{stats['summary']['expected_output']['chars']['max']} "
                                  f"(avg: {stats['summary']['expected_output']['chars']['mean']:.0f})")
                            print(f"  Tokens: {stats['summary']['expected_output']['tokens']['min']}-{stats['summary']['expected_output']['tokens']['max']} "
                                  f"(avg: {stats['summary']['expected_output']['tokens']['mean']:.0f})")
                            
                            # Error type breakdown
                            print(f"\nBY ERROR TYPE:")
                            for error_type, error_stats in stats['by_error_type'].items():
                                if error_stats:
                                    avg_prompt_tokens = sum(s['prompt_tokens'] for s in error_stats) / len(error_stats)
                                    avg_expected_tokens = sum(s['expected_tokens'] for s in error_stats) / len(error_stats)
                                    print(f"  {error_type}: {len(error_stats)} samples, "
                                          f"avg prompt tokens: {avg_prompt_tokens:.0f}, "
                                          f"avg expected tokens: {avg_expected_tokens:.0f}")
                            
                            # Store for final comparison
                            all_stats.append(stats)
                            
                            # Show a sample prompt for first few combinations
                            if combo_count <= 3:
                                print(f"\nSAMPLE FORMATTED PROMPT (first 500 chars):")
                                sample_messages = create_sample_messages(analysis_samples[0], examples, test_config)
                                sample_formatted = tokenizer.apply_chat_template(
                                    sample_messages, tokenize=False, add_generation_prompt=True
                                )
                                print(f"{sample_formatted[:500]}...")
                                print(f"\nSAMPLE EXPECTED OUTPUT:")
                                print(format_expected_output(analysis_samples[0], test_config))
                            
                        except Exception as e:
                            print(f"❌ Error in combination: {e}")
                            continue

    # Final comparison table
    print(f"\n{'='*100}")
    print("COMPREHENSIVE COMPARISON TABLE")
    print(f"{'='*100}")
    
    # Create comparison DataFrame
    comparison_data = []
    for stat in all_stats:
        row = {
            'Model': stat['config']['model'],
            'Format': stat['config']['solution_format'],
            'Strategy': stat['config']['dataset_strategy'],
            'Classification': stat['config']['classification_type'],
            'Examples': stat['config']['num_examples'],
            'Actual_Examples': stat['config']['actual_examples_used'],
            'Avg_Prompt_Chars': int(stat['summary']['prompt']['chars']['mean']),
            'Avg_Prompt_Tokens': int(stat['summary']['prompt']['tokens']['mean']),
            'Avg_Expected_Chars': int(stat['summary']['expected_output']['chars']['mean']),
            'Avg_Expected_Tokens': int(stat['summary']['expected_output']['tokens']['mean']),
            'Max_Prompt_Tokens': stat['summary']['prompt']['tokens']['max'],
            'Max_Expected_Tokens': stat['summary']['expected_output']['tokens']['max']
        }
        comparison_data.append(row)
    
    # Convert to DataFrame and sort
    df = pd.DataFrame(comparison_data)
    df_sorted = df.sort_values(['Model', 'Examples', 'Format', 'Strategy', 'Classification'])
    
    # Key insights
    print(f"\n{'='*100}")
    print("KEY INSIGHTS")
    print(f"{'='*100}")
    
    # Token count ranges
    min_prompt_tokens = df['Avg_Prompt_Tokens'].min()
    max_prompt_tokens = df['Avg_Prompt_Tokens'].max()
    min_expected_tokens = df['Avg_Expected_Tokens'].min()
    max_expected_tokens = df['Avg_Expected_Tokens'].max()
    
    print(f"Prompt Token Range: {min_prompt_tokens} - {max_prompt_tokens} tokens")
    print(f"Expected Output Token Range: {min_expected_tokens} - {max_expected_tokens} tokens")
    print(f"Maximum Single Prompt: {df['Max_Prompt_Tokens'].max()} tokens")
    print(f"Maximum Single Expected Output: {df['Max_Expected_Tokens'].max()} tokens")
    
    # Impact of examples
    no_examples = df[df['Examples'] == 0]['Avg_Prompt_Tokens'].mean()
    one_example = df[df['Examples'] == 1]['Avg_Prompt_Tokens'].mean()
    two_examples = df[df['Examples'] == 2]['Avg_Prompt_Tokens'].mean()
    
    print(f"\nImpact of Examples on Prompt Length:")
    print(f"  No examples: {no_examples:.0f} tokens")
    print(f"  1 example: {one_example:.0f} tokens")
    print(f"  2 examples: {two_examples:.0f} tokens")
    print(f"  Token increase per example: ~{(two_examples - no_examples) / 2:.0f} tokens")
    
    # Model differences
    print(f"\nModel Differences (average across all configs):")
    for model in df['Model'].unique():
        model_avg = df[df['Model'] == model]['Avg_Prompt_Tokens'].mean()
        print(f"  {model}: {model_avg:.0f} avg prompt tokens")
    
    # Format differences
    print(f"\nFormat Differences (average across all configs):")
    for fmt in df['Format'].unique():
        fmt_avg_prompt = df[df['Format'] == fmt]['Avg_Prompt_Tokens'].mean()
        fmt_avg_expected = df[df['Format'] == fmt]['Avg_Expected_Tokens'].mean()
        print(f"  {fmt}: {fmt_avg_prompt:.0f} prompt tokens, {fmt_avg_expected:.0f} expected tokens")
    
    print(f"\n✅ ANALYSIS COMPLETED!")
    print(f"Analyzed {len(all_stats)} configurations successfully")

except Exception as e:
    print(f"\n❌ ERROR DURING ANALYSIS:")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {str(e)}")
    
    import traceback
    traceback.print_exc()

# Print formatted table
display(df_sorted)

COMPREHENSIVE TOKEN COUNT AND CHARACTER ANALYSIS
Loading sample dataset for analysis...
Loaded base 4N dataset with 7524 samples
✓ Loaded 50 samples for analysis

Loading tokenizer for microsoft/phi-4-mini-instruct...
Loading tokenizer: microsoft/phi-4-mini-instruct
✓ Tokenizer loaded successfully!

COMBINATION 1/48
Model: phi-4-mini-instruct
Solution Format: nl
Dataset Strategy: 3N
Classification Type: binary
Num Examples: 0
Loaded base 3N dataset with 5319 samples
Problems by type: [('correct', 16), ('conceptual_error', 17), ('computational_error', 17)]
Examples generated: 0
Analyzing samples...

SUMMARY STATISTICS:
Samples analyzed: 20
Examples in prompts: 0

PROMPT STATISTICS:
  Characters: 544-807 (avg: 681)
  Tokens: 130-223 (avg: 167)

EXPECTED OUTPUT STATISTICS:
  Characters: 67-221 (avg: 144)
  Tokens: 21-72 (avg: 45)

BY ERROR TYPE:
  computational_error: 7 samples, avg prompt tokens: 168, avg expected tokens: 57
  conceptual_error: 7 samples, avg prompt tokens: 167, avg expe

Unnamed: 0,Model,Format,Strategy,Classification,Examples,Actual_Examples,Avg_Prompt_Chars,Avg_Prompt_Tokens,Avg_Expected_Chars,Avg_Expected_Tokens,Max_Prompt_Tokens,Max_Expected_Tokens
36,Qwen3-4B,dict,3N,binary,0,0,785,228,121,38,324,62
39,Qwen3-4B,dict,3N,ternary,0,0,802,231,129,38,327,62
42,Qwen3-4B,dict,4N,binary,0,0,886,252,100,32,332,60
45,Qwen3-4B,dict,4N,ternary,0,0,903,255,105,32,335,60
24,Qwen3-4B,nl,3N,binary,0,0,716,191,144,48,271,78
27,Qwen3-4B,nl,3N,ternary,0,0,733,194,152,48,274,78
30,Qwen3-4B,nl,4N,binary,0,0,814,214,117,40,287,78
33,Qwen3-4B,nl,4N,ternary,0,0,831,217,122,40,290,78
37,Qwen3-4B,dict,3N,binary,1,2,1698,526,121,38,622,62
40,Qwen3-4B,dict,3N,ternary,1,3,2954,846,129,38,942,62


Cell 8: Output Directory Setup

In [26]:
def setup_output_directory(config):
    """Creates organized output directory structure"""
    
    output_dir = Path(config["output_base_dir"]) / config["experiment_id"]
    
    # Create subdirectories
    subdirs = ["baseline", "training", "final", "checkpoints"]
    for subdir in subdirs:
        (output_dir / subdir).mkdir(parents=True, exist_ok=True)
    
    # Save configuration
    config_path = output_dir / "config.json"
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2, default=str)
    
    print(f"Output directory created: {output_dir}")
    return output_dir

# Setup output directory
output_dir = setup_output_directory(CONFIG)
CONFIG["output_dir"] = str(output_dir)

OSError: [Errno 30] Read-only file system: '/content'

Cell 9: Inference Functions

In [28]:
def prepare_inference_batch(messages_batch, tokenizer, max_length=1024):
    """
    Prepares a batch of messages for inference by applying chat templates and tokenizing.
    
    Args:
        messages_batch: List of message conversations (each is a list of message dicts)
        tokenizer: The tokenizer to use
        max_length: Maximum sequence length
        
    Returns:
        dict: Batch with input_ids, attention_mask, and metadata
    """
    batch_data = {
        "input_ids": [],
        "attention_mask": [],
        "metadata": {
            "formatted_prompts": [],
            "input_token_counts": [],
            "conversation_lengths": []
        }
    }
    
    for messages in messages_batch:
        # Apply chat template to get formatted prompt
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Tokenize the formatted prompt
        tokenized = tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            padding=False  # We'll pad at batch level
        )
        
        # Store batch data
        batch_data["input_ids"].append(tokenized["input_ids"].squeeze(0))
        batch_data["attention_mask"].append(tokenized["attention_mask"].squeeze(0))
        
        # Store metadata
        batch_data["metadata"]["formatted_prompts"].append(formatted_prompt)
        batch_data["metadata"]["input_token_counts"].append(len(tokenized["input_ids"][0]))
        batch_data["metadata"]["conversation_lengths"].append(len(messages))
    
    return batch_data

def apply_batch_padding(batch_data, tokenizer):
    """
    Applies padding to a batch of tokenized sequences.
    
    Args:
        batch_data: Output from prepare_inference_batch
        tokenizer: The tokenizer (for pad_token_id)
        
    Returns:
        dict: Padded tensors ready for model input
    """
    import torch
    from torch.nn.utils.rnn import pad_sequence
    
    # Pad sequences to same length
    input_ids_padded = pad_sequence(
        batch_data["input_ids"],
        batch_first=True,
        padding_value=tokenizer.pad_token_id
    )
    
    attention_mask_padded = pad_sequence(
        batch_data["attention_mask"],
        batch_first=True,
        padding_value=0
    )
    
    return {
        "input_ids": input_ids_padded,
        "attention_mask": attention_mask_padded,
        "metadata": batch_data["metadata"]
    }

def decode_batch_outputs(outputs, input_lengths, tokenizer):
    """
    Decodes model outputs for a batch, extracting only the generated portions.
    
    Args:
        outputs: Model generation outputs (batch_size, sequence_length)
        input_lengths: List of input sequence lengths for each item in batch
        tokenizer: The tokenizer for decoding
        
    Returns:
        list: Decoded responses (only the generated parts)
    """
    responses = []
    
    for i, output_sequence in enumerate(outputs):
        # Extract only the generated tokens (after input)
        input_length = input_lengths[i]
        generated_tokens = output_sequence[input_length:]
        
        # Decode to text
        response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        responses.append(response.strip())
    
    return responses

def create_attention_masks(input_ids, tokenizer):
    """
    Creates attention masks for tokenized inputs.
    
    Args:
        input_ids: Tokenized input sequences
        tokenizer: The tokenizer (for pad_token_id)
        
    Returns:
        torch.Tensor: Attention masks
    """
    import torch
    
    if isinstance(input_ids, list):
        input_ids = torch.stack(input_ids)
    
    # Create attention mask (1 for real tokens, 0 for padding)
    attention_mask = (input_ids != tokenizer.pad_token_id).long()
    return attention_mask

def run_inference(model, tokenizer, prepared_inputs, batch_size=1):
    """
    Pure inference function that accepts pre-processed inputs and returns results.
    
    Args:
        model: The model to use for inference
        tokenizer: The tokenizer (only used for pad_token_id in generation)
        prepared_inputs: Pre-processed batch of inputs with input_ids, attention_mask, metadata
        batch_size: Batch size for processing (legacy parameter for compatibility)
        
    Returns:
        tuple: (responses, generation_metadata)
    """
    import torch
    import time
    
    model.eval()
    
    with torch.no_grad():
        start_time = time.time()
        
        # Move inputs to model device
        input_ids = prepared_inputs["input_ids"].to(model.device)
        attention_mask = prepared_inputs["attention_mask"].to(model.device)
        
        # Generate responses
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=200,
            do_sample=False,
            temperature=0.1,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
            output_scores=False
        )
        
        end_time = time.time()
        
        # Decode outputs
        input_lengths = prepared_inputs["metadata"]["input_token_counts"]
        responses = decode_batch_outputs(outputs.sequences, input_lengths, tokenizer)
        
        # Calculate generation metadata
        generation_metadata = {
            "total_inference_time": end_time - start_time,
            "batch_size": len(responses),
            "avg_inference_time_per_sample": (end_time - start_time) / len(responses),
            "input_token_counts": input_lengths,
            "output_token_counts": [len(outputs.sequences[i]) - input_lengths[i] for i in range(len(responses))],
            "total_tokens_generated": sum(len(outputs.sequences[i]) - input_lengths[i] for i in range(len(responses)))
        }
        
        if torch.cuda.is_available():
            generation_metadata["gpu_memory_used"] = torch.cuda.memory_allocated() / 1024**3  # GB
    
    return responses, generation_metadata

In [None]:
def save_results(results, metadata, stage, config):
    """Saves results and metadata to appropriate locations"""
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(config["output_dir"]) / stage
    
    # Save results
    results_path = output_dir / f"results_{timestamp}.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Save metadata
    metadata_path = output_dir / f"metadata_{timestamp}.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    
    print(f"Results saved to: {results_path}")
    print(f"Metadata saved to: {metadata_path}")
    
    return results_path, metadata_path

print("Inference functions loaded!")

Cell 10: Baseline Inference

In [None]:
# ===== BASELINE INFERENCE =====
print("="*80)
print("STARTING BASELINE INFERENCE")
print("="*80)

try:
    # Step 1: Prepare dataset
    print("Step 1: Preparing dataset...")
    train_data, eval_data, examples = prepare_dataset(CONFIG)
    print(f"✓ Dataset prepared: {len(eval_data)} evaluation samples")
    
    # Step 2: Load model and tokenizer
    print("\nStep 2: Loading model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer(CONFIG)
    print("✓ Model and tokenizer loaded successfully")
    
    # Step 3: Prepare evaluation data for inference
    print("\nStep 3: Preparing evaluation data for batched inference...")
    eval_messages = [sample["messages"] for sample in eval_data]
    print(f"✓ Extracted {len(eval_messages)} message conversations")
    
    # Step 4: Create inference batches
    print("\nStep 4: Creating inference batches...")
    batch_data = prepare_inference_batch(eval_messages, tokenizer, max_length=1024)
    prepared_inputs = apply_batch_padding(batch_data, tokenizer)
    
    print(f"✓ Prepared batch with shape: {prepared_inputs['input_ids'].shape}")
    print(f"✓ Average tokens per sample: {sum(batch_data['metadata']['input_token_counts']) / len(batch_data['metadata']['input_token_counts']):.1f}")
    
    # Step 5: Run inference
    print("\nStep 5: Running baseline inference...")
    print(f"Processing {len(eval_data)} samples...")
    
    baseline_responses, baseline_metadata = run_inference(model, tokenizer, prepared_inputs)
    
    print(f"✓ Inference completed in {baseline_metadata['total_inference_time']:.2f}s")
    print(f"✓ Average time per sample: {baseline_metadata['avg_inference_time_per_sample']:.3f}s")
    print(f"✓ Total tokens generated: {baseline_metadata['total_tokens_generated']}")
    
    # Step 6: Format results for saving
    print("\nStep 6: Formatting results...")
    baseline_results = []
    
    for i in range(len(baseline_responses)):
        result = {
            "sample_id": eval_data[i]["id"],
            "expected_output": eval_data[i]["expected_output"],
            "model_response": baseline_responses[i],
            "sample_metadata": eval_data[i]["metadata"],
            "input_tokens": batch_data['metadata']['input_token_counts'][i],
            "output_tokens": baseline_metadata['output_token_counts'][i],
            "formatted_prompt": batch_data['metadata']['formatted_prompts'][i][:200] + "..." if len(batch_data['metadata']['formatted_prompts'][i]) > 200 else batch_data['metadata']['formatted_prompts'][i]  # Truncated for file size
        }
        baseline_results.append(result)
    
    print(f"✓ Formatted {len(baseline_results)} results")
    
    # Step 7: Save results
    print("\nStep 7: Saving results...")
    baseline_results_path, baseline_metadata_path = save_results(
        baseline_results, baseline_metadata, "baseline", CONFIG
    )
    
    # Step 8: Display sample results
    print("\nStep 8: Sample results preview:")
    print("="*60)
    for i in range(min(3, len(baseline_results))):
        sample = baseline_results[i]
        print(f"\nSample {i+1}:")
        print(f"Sample ID: {sample['sample_id']}")
        print(f"Error Type: {sample['sample_metadata']['error_type']}")
        print(f"Expected: {sample['expected_output']}")
        print(f"Model Response: {sample['model_response']}")
        print(f"Input/Output Tokens: {sample['input_tokens']}/{sample['output_tokens']}")
        print("-" * 40)
    
    # Step 9: Summary statistics
    print("\nStep 9: Summary Statistics:")
    print("="*60)
    print(f"Total samples processed: {len(baseline_results)}")
    print(f"Total inference time: {baseline_metadata['total_inference_time']:.2f}s")
    print(f"Average time per sample: {baseline_metadata['avg_inference_time_per_sample']:.3f}s")
    print(f"Total input tokens: {sum(baseline_metadata['input_token_counts'])}")
    print(f"Total output tokens: {baseline_metadata['total_tokens_generated']}")
    print(f"Average output tokens per sample: {baseline_metadata['total_tokens_generated'] / len(baseline_results):.1f}")
    
    if 'gpu_memory_used' in baseline_metadata:
        print(f"GPU memory used: {baseline_metadata['gpu_memory_used']:.2f} GB")
    
    # Error type distribution
    error_types = {}
    for result in baseline_results:
        error_type = result['sample_metadata']['error_type']
        error_types[error_type] = error_types.get(error_type, 0) + 1
    
    print(f"\nError type distribution:")
    for error_type, count in error_types.items():
        print(f"  {error_type}: {count} samples ({count/len(baseline_results)*100:.1f}%)")
    
    print("\n" + "="*80)
    print("✅ BASELINE INFERENCE COMPLETED SUCCESSFULLY!")
    print(f"Results saved to: {baseline_results_path}")
    print(f"Metadata saved to: {baseline_metadata_path}")
    print("="*80)

except Exception as e:
    print(f"\n❌ ERROR DURING BASELINE INFERENCE:")
    print(f"Error type: {type(e).__name__}")
    print(f"Error message: {str(e)}")
    
    # Import traceback for detailed error info
    import traceback
    print(f"\nFull traceback:")
    traceback.print_exc()
    
    print("\n💡 Troubleshooting tips:")
    print("1. Check if all previous cells have been run")
    print("2. Verify CONFIG is properly set")
    print("3. Ensure dataset files exist in the specified path")
    print("4. Check GPU memory availability")
    print("5. Try reducing batch size if out of memory")