Cell 1: Configuration and setup

In [None]:
# ===== EXPERIMENT CONFIGURATION =====
CONFIG = {
    # Core experiment parameters
    "experiment_type": "generative",  # "discriminative" or "generative"
    "classification_type": "binary",   # "binary" or "ternary"
    "dataset_strategy": "4N",          # "4N" or "3N" (generative only)
    "include_explanation": False,      # True or False (generative only)
    "include_eln": True,              # True or False (generative only)
    "solution_format": "dict",        # "dict" or "nl" (generative only)
    "model_name": "microsoft/Phi-4-mini-instruct",  # or "Qwen/Qwen3-4B"
    
    # Prompting configuration
    "system_prompt": None,  # Will auto-generate if None, or use custom string
    "include_examples": True,
    "num_examples": 3,
    "example_strategy": "balanced",  # "balanced", "error_focused", "custom"
    
    # Training parameters
    "learning_rate": 2e-5,
    "num_epochs": 3,
    "batch_size": 8,
    "max_length": 1024,
    "gradient_accumulation_steps": 1,
    
    # Infrastructure
    "use_lora": True,
    "lora_rank": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.1,
    
    # Paths and tokens
    "base_dataset_dir": "/content/drive/MyDrive/sft_datasets",
    "output_base_dir": "/content/drive/MyDrive/sft_experiments",
    "hf_token": "your_huggingface_token_here",
    "wandb_project": "math_error_classification",
    
    # Experiment tracking
    "save_to_hf": True,
    "save_locally": True,
    "use_wandb": True
}

# Generate experiment ID
import datetime
experiment_components = [
    CONFIG["experiment_type"][:4],  # "gene" or "disc"
    CONFIG["classification_type"][:3],  # "bin" or "ter"
    CONFIG["dataset_strategy"] if CONFIG["experiment_type"] == "generative" else "",
    "exp" if CONFIG["include_explanation"] else "no_exp",
    "eln" if CONFIG["include_eln"] else "no_eln",
    CONFIG["solution_format"] if CONFIG["experiment_type"] == "generative" else "",
    "qwen" if "Qwen" in CONFIG["model_name"] else "phi4"
]
experiment_id = "_".join([c for c in experiment_components if c]) + "_" + datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
CONFIG["experiment_id"] = experiment_id

print(f"Experiment ID: {experiment_id}")
print(f"Configuration loaded successfully!")

Cell 2: Import dependencies

In [None]:
# Core libraries
import os
import json
import time
import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional
import random
import numpy as np

# Data handling
import pandas as pd
from datasets import Dataset, load_dataset

# ML libraries
import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification,
    TrainingArguments, Trainer, DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

# Logging and tracking
import wandb
from tqdm.auto import tqdm

# Google Drive mounting (for Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully!")
except ImportError:
    print("Not running in Colab - skipping Drive mount")

# Set random seeds for reproducibility
def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seeds(42)
print("Dependencies imported and seeds set!")

Cell 3: System Prompt Generation

In [None]:
def generate_system_prompt(config):
    """Auto-generates appropriate system prompt based on config"""
    
    if config["experiment_type"] == "discriminative":
        return "You are a mathematics tutor. Classify the given solution."
    
    # Generative prompts
    base_prompt = "You are a mathematics tutor. Analyze the given solution and provide your assessment in JSON format."
    
    # Add classification instructions
    if config["classification_type"] == "binary":
        base_prompt += " Determine if the solution is 'correct' or 'flawed'."
    else:
        base_prompt += " Classify as 'correct', 'conceptual_error', or 'computational_error'."
    
    # Add field instructions
    fields = []
    if config["include_eln"]:
        if config["solution_format"] == "dict":
            fields.append("identify the erroneous line number (e.g., 'L1', 'FA')")
        else:
            fields.append("quote the full erroneous line text")
    
    if config["include_explanation"]:
        fields.append("provide a brief explanation of any error")
    
    if fields:
        base_prompt += f" Also {', and '.join(fields)}."
    
    base_prompt += " Respond only with valid JSON."
    
    return base_prompt

# Auto-generate system prompt if not provided
if CONFIG["system_prompt"] is None:
    CONFIG["system_prompt"] = generate_system_prompt(CONFIG)

print("System Prompt:")
print(CONFIG["system_prompt"])
print()

# Allow manual override
print("To customize the system prompt, run:")
print('CONFIG["system_prompt"] = "Your custom prompt here"')

Cell 4: Dataset Loading and Formatting Functions

In [None]:
def load_base_dataset(config):
    """Loads the appropriate base dataset"""
    dataset_strategy = config["dataset_strategy"]
    base_dir = Path(config["base_dataset_dir"])
    
    dataset_file = base_dir / f"base_{dataset_strategy}_dataset.json"
    
    if not dataset_file.exists():
        raise FileNotFoundError(f"Base dataset not found: {dataset_file}")
    
    with open(dataset_file, 'r') as f:
        data = json.load(f)
    
    print(f"Loaded base {dataset_strategy} dataset with {len(data['samples'])} samples")
    return data['samples']

def format_solution(sample, config):
    """Formats solution according to config"""
    if config["solution_format"] == "dict":
        solution = {}
        for i, line in enumerate(sample["solution_lines"][:-1]):
            solution[f"L{i+1}"] = line
        solution["FA"] = sample["solution_lines"][-1]
        return json.dumps(solution, indent=2)
    else:
        return "\n".join(sample["solution_lines"])

def format_expected_output(sample, config):
    """Creates the expected JSON output for a sample"""
    output = {}
    
    # Verdict
    if config["classification_type"] == "binary":
        output["verdict"] = "correct" if sample["error_type"] == "correct" else "flawed"
    else:
        output["verdict"] = sample["error_type"]
    
    # ELN
    if config["include_eln"]:
        if sample["error_type"] != "correct":
            if config["solution_format"] == "dict":
                output["erroneous_line_number"] = sample.get("erroneous_line_number", None)
            else:
                # Get the actual line text
                if sample.get("erroneous_line_number"):
                    line_num = int(sample["erroneous_line_number"][1:]) - 1
                    if line_num < len(sample["solution_lines"]):
                        output["erroneous_line"] = sample["solution_lines"][line_num]
                    else:
                        output["erroneous_line"] = sample["solution_lines"][-1]  # Final answer
                else:
                    output["erroneous_line"] = None
        else:
            key = "erroneous_line_number" if config["solution_format"] == "dict" else "erroneous_line"
            output[key] = None
    
    # Explanation
    if config["include_explanation"]:
        output["explanation"] = sample.get("explanation", None) if sample["error_type"] != "correct" else None
    
    return json.dumps(output)

print("Dataset formatting functions loaded!")

Cell 5: Example Management

In [None]:
class ExampleManager:
    def __init__(self, base_dataset, config):
        self.samples = base_dataset
        self.config = config
        self._prepare_examples_by_type()
    
    def _prepare_examples_by_type(self):
        """Organizes samples by error type for easy sampling"""
        self.examples_by_type = {
            "correct": [],
            "conceptual_error": [],
            "computational_error": []
        }
        
        for sample in self.samples:
            error_type = sample["error_type"]
            self.examples_by_type[error_type].append(sample)
        
        print(f"Examples by type: {[(k, len(v)) for k, v in self.examples_by_type.items()]}")
    
    def get_examples(self):
        """Returns appropriate few-shot examples"""
        if not self.config["include_examples"]:
            return []
        
        strategy = self.config["example_strategy"]
        num_examples = self.config["num_examples"]
        
        if strategy == "balanced":
            return self._get_balanced_examples(num_examples)
        elif strategy == "error_focused":
            return self._get_error_focused_examples(num_examples)
        else:
            return []
    
    def _get_balanced_examples(self, n):
        """Gets balanced representation across error types"""
        examples = []
        
        if self.config["classification_type"] == "binary":
            # Get n//2 correct, n//2 flawed
            correct_needed = n // 2
            flawed_needed = n - correct_needed
            
            examples.extend(random.sample(self.examples_by_type["correct"], 
                                        min(correct_needed, len(self.examples_by_type["correct"]))))
            
            flawed_samples = (self.examples_by_type["conceptual_error"] + 
                            self.examples_by_type["computational_error"])
            examples.extend(random.sample(flawed_samples, 
                                        min(flawed_needed, len(flawed_samples))))
        else:
            # Get roughly equal conceptual, computational, correct
            per_type = n // 3
            remainder = n % 3
            
            for i, error_type in enumerate(["correct", "conceptual_error", "computational_error"]):
                needed = per_type + (1 if i < remainder else 0)
                available = self.examples_by_type[error_type]
                examples.extend(random.sample(available, min(needed, len(available))))
        
        return examples
    
    def _get_error_focused_examples(self, n):
        """Gets examples focused on error types"""
        # Prioritize error samples over correct ones
        error_samples = (self.examples_by_type["conceptual_error"] + 
                        self.examples_by_type["computational_error"])
        
        error_count = min(n * 2 // 3, len(error_samples))
        correct_count = n - error_count
        
        examples = random.sample(error_samples, error_count)
        examples.extend(random.sample(self.examples_by_type["correct"], 
                                    min(correct_count, len(self.examples_by_type["correct"]))))
        
        return examples

print("Example manager class loaded!")

Cell 6: Dataset Preparation

In [None]:
def prepare_dataset(config):
    """Loads and formats complete dataset for training/evaluation"""
    
    # Load base dataset
    base_samples = load_base_dataset(config)
    
    # Initialize example manager
    example_manager = ExampleManager(base_samples, config)
    examples = example_manager.get_examples()
    
    print(f"Using {len(examples)} few-shot examples")
    
    # Format all samples
    formatted_data = []
    
    for sample in tqdm(base_samples, desc="Formatting samples"):
        # Create input messages
        messages = []
        
        # System message
        messages.append({
            "role": "system",
            "content": config["system_prompt"]
        })
        
        # Few-shot examples
        for example in examples:
            user_content = f"Problem: {example['question']}\nSolution: {format_solution(example, config)}"
            assistant_content = format_expected_output(example, config)
            
            messages.append({"role": "user", "content": user_content})
            messages.append({"role": "assistant", "content": assistant_content})
        
        # Actual sample
        user_content = f"Problem: {sample['question']}\nSolution: {format_solution(sample, config)}"
        messages.append({"role": "user", "content": user_content})
        
        # Expected output
        expected_output = format_expected_output(sample, config)
        
        formatted_data.append({
            "id": sample.get("id", f"sample_{len(formatted_data)}"),
            "messages": messages,
            "expected_output": expected_output,
            "metadata": {
                "error_type": sample["error_type"],
                "tier": sample.get("tier", "unknown"),
                "source": sample.get("source", "unknown")
            }
        })
    
    # Split into train/eval
    split_point = int(0.8 * len(formatted_data))
    train_data = formatted_data[:split_point]
    eval_data = formatted_data[split_point:]
    
    print(f"Dataset prepared: {len(train_data)} training, {len(eval_data)} evaluation samples")
    
    return train_data, eval_data, examples

# Prepare the dataset
train_data, eval_data, few_shot_examples = prepare_dataset(CONFIG)

Cell 7: Model and Tokenizer Loading

In [None]:
def load_model_and_tokenizer(config):
    """Loads model and tokenizer with appropriate configuration"""
    
    model_name = config["model_name"]
    print(f"Loading model: {model_name}")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Configure quantization for LoRA
    if config["use_lora"]:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True
        )
        
        # Load model with quantization
        if config["experiment_type"] == "discriminative":
            # For discriminative, we need a classification model
            num_labels = 2 if config["classification_type"] == "binary" else 3
            model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                num_labels=num_labels,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            # For generative, use causal LM
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                device_map="auto",
                trust_remote_code=True
            )
        
        # Prepare model for LoRA
        model = prepare_model_for_kbit_training(model)
        
        # Configure LoRA
        if config["experiment_type"] == "discriminative":
            task_type = TaskType.SEQ_CLS
            target_modules = ["q_proj", "v_proj", "k_proj", "o_proj"]
        else:
            task_type = TaskType.CAUSAL_LM
            target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        
        lora_config = LoraConfig(
            task_type=task_type,
            r=config["lora_rank"],
            lora_alpha=config["lora_alpha"],
            lora_dropout=config["lora_dropout"],
            target_modules=target_modules,
            bias="none"
        )
        
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
        
    else:
        # Load full model
        if config["experiment_type"] == "discriminative":
            num_labels = 2 if config["classification_type"] == "binary" else 3
            model = AutoModelForSequenceClassification.from_pretrained(
                model_name,
                num_labels=num_labels,
                device_map="auto",
                trust_remote_code=True
            )
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="auto",
                trust_remote_code=True
            )
    
    print(f"Model loaded successfully!")
    print(f"Model device: {next(model.parameters()).device}")
    
    return model, tokenizer

# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(CONFIG)

Cell 8: Output Directory Setup

In [None]:
def setup_output_directory(config):
    """Creates organized output directory structure"""
    
    output_dir = Path(config["output_base_dir"]) / config["experiment_id"]
    
    # Create subdirectories
    subdirs = ["baseline", "training", "final", "checkpoints"]
    for subdir in subdirs:
        (output_dir / subdir).mkdir(parents=True, exist_ok=True)
    
    # Save configuration
    config_path = output_dir / "config.json"
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2, default=str)
    
    print(f"Output directory created: {output_dir}")
    return output_dir

# Setup output directory
output_dir = setup_output_directory(CONFIG)
CONFIG["output_dir"] = str(output_dir)

# Initialize wandb if enabled
if CONFIG["use_wandb"]:
    wandb.init(
        project=CONFIG["wandb_project"],
        name=CONFIG["experiment_id"],
        config=CONFIG
    )
    print("Wandb initialized!")

Cell 9: Inference Functions

In [None]:
def run_inference(model, tokenizer, data, config, stage="baseline"):
    """Runs inference on dataset with comprehensive logging"""
    
    print(f"Running {stage} inference on {len(data)} samples...")
    
    model.eval()
    results = []
    metadata = {
        "response_times": [],
        "input_token_counts": [],
        "output_token_counts": [],
        "memory_usage": []
    }
    
    with torch.no_grad():
        for i, sample in enumerate(tqdm(data, desc=f"{stage} inference")):
            start_time = time.time()
            
            # Prepare input
            if config["experiment_type"] == "discriminative":
                # For discriminative, we need to format differently
                # This is a simplified version - you'll need to adapt based on your needs
                input_text = sample["messages"][-1]["content"]  # Last user message
                inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"])
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
                
                # Get classification logits
                outputs = model(**inputs)
                predicted_class = torch.argmax(outputs.logits, dim=-1).item()
                
                response = str(predicted_class)
                output_tokens = 1  # Single token output
                
            else:
                # For generative, format as chat
                input_text = tokenizer.apply_chat_template(
                    sample["messages"], 
                    tokenize=False, 
                    add_generation_prompt=True
                )
                
                inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=config["max_length"])
                inputs = {k: v.to(model.device) for k, v in inputs.items()}
                
                # Generate response
                with torch.no_grad():
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=200,
                        do_sample=False,
                        temperature=0.1,
                        pad_token_id=tokenizer.eos_token_id
                    )
                
                # Decode response
                response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
                output_tokens = len(outputs[0]) - inputs["input_ids"].shape[1]
            
            end_time = time.time()
            
            # Log metrics
            metadata["response_times"].append(end_time - start_time)
            metadata["input_token_counts"].append(inputs["input_ids"].shape[1])
            metadata["output_token_counts"].append(output_tokens)
            
            if torch.cuda.is_available():
                metadata["memory_usage"].append(torch.cuda.memory_allocated() / 1024**3)  # GB
            
            results.append({
                "sample_id": sample["id"],
                "input": sample["messages"][-1]["content"],
                "expected_output": sample["expected_output"],
                "model_output": response.strip(),
                "metadata": sample["metadata"],
                "timestamp": datetime.datetime.now().isoformat(),
                "response_time": end_time - start_time
            })
            
            # Periodic memory cleanup
            if i % 50 == 0 and torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    print(f"{stage} inference completed!")
    print(f"Average response time: {np.mean(metadata['response_times']):.3f}s")
    print(f"Average input tokens: {np.mean(metadata['input_token_counts']):.1f}")
    print(f"Average output tokens: {np.mean(metadata['output_token_counts']):.1f}")
    
    return results, metadata

def save_results(results, metadata, stage, config):
    """Saves results and metadata to appropriate locations"""
    
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(config["output_dir"]) / stage
    
    # Save results
    results_path = output_dir / f"results_{timestamp}.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Save metadata
    metadata_path = output_dir / f"metadata_{timestamp}.json"
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2, default=str)
    
    print(f"Results saved to: {results_path}")
    print(f"Metadata saved to: {metadata_path}")
    
    # Log to wandb if enabled
    if config["use_wandb"]:
        wandb.log({
            f"{stage}_avg_response_time": np.mean(metadata["response_times"]),
            f"{stage}_avg_input_tokens": np.mean(metadata["input_token_counts"]),
            f"{stage}_avg_output_tokens": np.mean(metadata["output_token_counts"]),
            f"{stage}_total_samples": len(results)
        })
    
    return results_path, metadata_path

print("Inference functions loaded!")

Cell 10: Baseline Inference

In [None]:
# Run baseline inference
print("Starting baseline inference...")
baseline_results, baseline_metadata = run_inference(model, tokenizer, eval_data, CONFIG, "baseline")

# Save baseline results
baseline_results_path, baseline_metadata_path = save_results(
    baseline_results, baseline_metadata, "baseline", CONFIG
)

print(f"Baseline inference completed and saved!")
print(f"Baseline results: {len(baseline_results)} samples processed")