In [1]:
# ============================================================================
# EXPERIMENT 3: Self-Taught Reasoner (STaR) - CHECKPOINT RESUMABLE VERSION
# ============================================================================

In [2]:
# --- CELL 1: Install Dependencies ---
!pip install transformers>=4.35.0 datasets>=2.14.0 accelerate>=0.24.0 torch>=2.0.0 tqdm matplotlib numpy -q

In [3]:
# --- CELL 2: Import Libraries ---
import torch
import re
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from datetime import datetime
import shutil
from pathlib import Path

from huggingface_hub import login
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA device: NVIDIA A100-SXM4-80GB
GPU memory: 85.10 GB


In [4]:
# --- CELL 3: Hugging Face Login ---
print("\nPlease log in to Hugging Face...")
login()


Please log in to Hugging Face...


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
# --- CELL 4: Configuration ---
class Config:
    # Model
    MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
    
    # Dataset - ADJUST THESE FOR FULL RUN
    USE_SUBSET = False
    TRAIN_SUBSET_SIZE = 500  # Increased for better training
    TEST_SUBSET_SIZE = 100   # Increased for better evaluation
    
    # Output
    OUTPUT_DIR = "./small_project/star"
    CHECKPOINT_DIR = "./small_project/star/checkpoints"
    STATE_FILE = "./small_project/star/training_state.json"
    
    # Training
    NUM_EPOCHS = 2  # Epochs per STaR iteration
    BATCH_SIZE = 4   # Increased for A100
    GRADIENT_ACCUMULATION_STEPS = 4  # Adjusted for larger batch
    LEARNING_RATE = 2e-5
    MAX_LENGTH = 512
    
    # STaR
    STAR_ITERATIONS = 1
    
    # Generation
    GENERATION_MAX_NEW_TOKENS = 256
    TEMPERATURE = 0.7
    TOP_P = 0.9
    
    # Resume settings
    FORCE_RESTART = False  # Set True to ignore checkpoints and start fresh

config = Config()

# Create directories
for dir_path in [config.OUTPUT_DIR, config.CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

print(f"{'='*80}")
print(f"STaR TRAINING CONFIGURATION (CHECKPOINT RESUMABLE)")
print(f"{'='*80}")
print(f"Model: {config.MODEL_NAME}")
print(f"Training samples: {config.TRAIN_SUBSET_SIZE if config.USE_SUBSET else 'Full'}")
print(f"Test samples: {config.TEST_SUBSET_SIZE if config.USE_SUBSET else 'Full'}")
print(f"STaR iterations: {config.STAR_ITERATIONS}")
print(f"Epochs per iteration: {config.NUM_EPOCHS}")
print(f"Effective batch size: {config.BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}")
print(f"Output: {config.OUTPUT_DIR}")
print(f"Checkpoint dir: {config.CHECKPOINT_DIR}")
print(f"Force restart: {config.FORCE_RESTART}")
print(f"{'='*80}\n")

STaR TRAINING CONFIGURATION (CHECKPOINT RESUMABLE)
Model: meta-llama/Llama-3.2-3B-Instruct
Training samples: Full
Test samples: Full
STaR iterations: 1
Epochs per iteration: 2
Effective batch size: 16
Output: ./small_project/star
Checkpoint dir: ./small_project/star/checkpoints
Force restart: False



In [6]:
# --- CELL 5: Helper Functions ---
def extract_answer(text):
    """Extract numerical answer from text"""
    if not text:
        return None
    
    # Try #### format
    match = re.search(r'####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Try common patterns
    patterns = [
        r'answer is[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
        r'=\s*(-?\d+(?:,\d+)*(?:\.\d+)?)\s*$',
        r'total[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
    ]
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).replace(',', '')
    
    # Last resort: last number
    numbers = re.findall(r'-?\d+(?:,\d+)*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def create_training_prompt(question, solution):
    """Create training prompt"""
    clean_solution = solution.strip()
    if '####' not in clean_solution:
        answer = extract_answer(clean_solution)
        if answer:
            clean_solution = f"{clean_solution}\n#### {answer}"
    
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step and provide the final answer after ####.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{clean_solution}<|eot_id|>"""

def create_inference_prompt(question):
    """Create inference prompt"""
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step and provide the final answer after ####.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def create_rationalization_prompt(question, correct_answer):
    """Create prompt with answer hint for rationalization"""
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step. The correct answer is {correct_answer}. Explain the reasoning that leads to this answer and end with #### {correct_answer}.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def load_gsm8k_data():
    """Load GSM8K dataset"""
    print("Loading GSM8K dataset...")
    dataset = load_dataset("gsm8k", "main")
    
    train_data = []
    for item in tqdm(dataset["train"], desc="Processing train"):
        answer = extract_answer(item["answer"])
        if answer:
            train_data.append({
                "question": item["question"],
                "solution": item["answer"],
                "answer": answer
            })
    
    test_data = []
    for item in tqdm(dataset["test"], desc="Processing test"):
        answer = extract_answer(item["answer"])
        if answer:
            test_data.append({
                "question": item["question"],
                "solution": item["answer"],
                "answer": answer
            })
    
    print(f"Loaded {len(train_data)} train and {len(test_data)} test examples")
    return train_data, test_data

In [7]:
# --- CELL 5: Checkpoint Management System ---
class CheckpointManager:
    """Manages training state and checkpoints for resumable training"""
    
    def __init__(self, state_file, checkpoint_dir):
        self.state_file = state_file
        self.checkpoint_dir = checkpoint_dir
        self.state = self.load_state()
    
    def load_state(self):
        """Load training state from disk"""
        if os.path.exists(self.state_file) and not config.FORCE_RESTART:
            print(f"Loading training state from {self.state_file}")
            with open(self.state_file, 'r') as f:
                state = json.load(f)
            print(f"✓ Loaded state: Last completed iteration {state.get('last_completed_iteration', -1)}")
            return state
        else:
            print("Starting fresh training (no checkpoint found or force restart)")
            return {
                "last_completed_iteration": -1,
                "completed_iterations": [],
                "all_stats": [],
                "all_accuracies": [],
                "start_time": datetime.now().isoformat(),
                "last_update": datetime.now().isoformat()
            }
    
    def save_state(self):
        """Save current training state"""
        self.state["last_update"] = datetime.now().isoformat()
        with open(self.state_file, 'w') as f:
            json.dump(self.state, f, indent=2)
        print(f"💾 State saved to {self.state_file}")
    
    def mark_iteration_complete(self, iteration, stats, accuracy):
        """Mark an iteration as complete"""
        self.state["last_completed_iteration"] = iteration
        if iteration not in self.state["completed_iterations"]:
            self.state["completed_iterations"].append(iteration)
        self.state["all_stats"].append(stats)
        self.state["all_accuracies"].append({"iteration": iteration, "accuracy": accuracy})
        self.save_state()
    
    def is_iteration_complete(self, iteration):
        """Check if an iteration is already complete"""
        return iteration in self.state.get("completed_iterations", [])
    
    def get_last_model_path(self):
        """Get path to the last trained model"""
        last_iter = self.state.get("last_completed_iteration", -1)
        if last_iter >= 0:
            model_path = f"{self.checkpoint_dir}/iteration_{last_iter}/model"
            if os.path.exists(model_path):
                return model_path
        return None
    
    def get_next_iteration(self):
        """Get the next iteration to run"""
        return self.state.get("last_completed_iteration", -1) + 1

# Initialize checkpoint manager
checkpoint_manager = CheckpointManager(config.STATE_FILE, config.CHECKPOINT_DIR)

print(f"\n Training State:")
print(f"  Last completed iteration: {checkpoint_manager.state['last_completed_iteration']}")
print(f"  Next iteration to run: {checkpoint_manager.get_next_iteration()}")
print(f"  Completed iterations: {checkpoint_manager.state.get('completed_iterations', [])}")


Starting fresh training (no checkpoint found or force restart)

 Training State:
  Last completed iteration: -1
  Next iteration to run: 0
  Completed iterations: []


In [8]:
# --- CELL 7: Initialize Tokenizer ---
print("\nInitializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print(f"Tokenizer ready")


Initializing tokenizer...
Tokenizer ready


In [9]:
# --- CELL 6: Helper Functions ---
def extract_answer(text):
    """Extract numerical answer from text with multiple strategies"""
    if not text:
        return None
    
    # Strategy 1: Find #### format
    match = re.search(r'####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)', text)
    if match:
        return match.group(1).replace(',', '')
    
    # Strategy 2: Common patterns
    patterns = [
        r'answer is[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
        r'=\s*(-?\d+(?:,\d+)*(?:\.\d+)?)\s*$',
        r'total[:\s]+(-?\d+(?:,\d+)*(?:\.\d+)?)',
    ]
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).replace(',', '')
    
    # Strategy 3: Last number in text
    numbers = re.findall(r'-?\d+(?:,\d+)*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def create_training_prompt(question, solution):
    """Create training prompt in chat format"""
    clean_solution = solution.strip()
    if '####' not in clean_solution:
        answer = extract_answer(clean_solution)
        if answer:
            clean_solution = f"{clean_solution}\n#### {answer}"
    
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step and provide the final answer after ####.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{clean_solution}<|eot_id|>"""

def create_inference_prompt(question):
    """Create inference prompt (no answer)"""
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step and provide the final answer after ####.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def create_rationalization_prompt(question, correct_answer):
    """Create prompt with answer hint for rationalization"""
    return f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Solve this math problem step by step. The correct answer is {correct_answer}. Explain the reasoning that leads to this answer and end with #### {correct_answer}.

{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

In [10]:
# --- CELL 7: Load and Cache Dataset ---
def load_and_cache_gsm8k():
    """Load GSM8K dataset and cache it locally"""
    cache_file = f"{config.OUTPUT_DIR}/gsm8k_cache.json"
    
    if os.path.exists(cache_file) and not config.FORCE_RESTART:
        print("Loading cached GSM8K dataset...")
        with open(cache_file, 'r') as f:
            cached_data = json.load(f)
        print(f"✓ Loaded from cache: {len(cached_data['train'])} train, {len(cached_data['test'])} test")
        return cached_data['train'], cached_data['test']
    
    print("Downloading GSM8K dataset...")
    dataset = load_dataset("gsm8k", "main")
    
    train_data = []
    for item in tqdm(dataset["train"], desc="Processing train"):
        answer = extract_answer(item["answer"])
        if answer:
            train_data.append({
                "question": item["question"],
                "solution": item["answer"],
                "answer": answer
            })
    
    test_data = []
    for item in tqdm(dataset["test"], desc="Processing test"):
        answer = extract_answer(item["answer"])
        if answer:
            test_data.append({
                "question": item["question"],
                "solution": item["answer"],
                "answer": answer
            })
    
    # Cache the dataset
    with open(cache_file, 'w') as f:
        json.dump({"train": train_data, "test": test_data}, f)
    
    print(f"Dataset cached to {cache_file}")
    print(f"Loaded {len(train_data)} train and {len(test_data)} test examples")
    
    return train_data, test_data

# Load data
full_train_data, full_test_data = load_and_cache_gsm8k()

if config.USE_SUBSET:
    train_data = full_train_data[:config.TRAIN_SUBSET_SIZE]
    test_data = full_test_data[:config.TEST_SUBSET_SIZE]
    print(f"\n→ Using subset: {len(train_data)} train, {len(test_data)} test")
else:
    train_data = full_train_data
    test_data = full_test_data
    print(f"\n→ Using full dataset: {len(train_data)} train, {len(test_data)} test")

# Save data splits
with open(f"{config.OUTPUT_DIR}/train_data.json", "w") as f:
    json.dump(train_data, f, indent=2)
with open(f"{config.OUTPUT_DIR}/test_data.json", "w") as f:
    json.dump(test_data, f, indent=2)
print(f"Data splits saved")

Loading cached GSM8K dataset...
✓ Loaded from cache: 7473 train, 1319 test

→ Using full dataset: 7473 train, 1319 test
Data splits saved


In [11]:
# --- CELL 8: Initialize Tokenizer (Shared) ---
print("\nInitializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print(f"Tokenizer ready")


Initializing tokenizer...
Tokenizer ready


In [12]:
# --- CELL 9: Generation Function ---
def generate_solution(model, tokenizer, question, use_hint=False, correct_answer=None):
    """Generate solution with or without hint"""
    if use_hint and correct_answer:
        prompt = create_rationalization_prompt(question, correct_answer)
    else:
        prompt = create_inference_prompt(question)
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=config.GENERATION_MAX_NEW_TOKENS,
            do_sample=True,
            temperature=config.TEMPERATURE,
            top_p=config.TOP_P,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract response
    if "<|start_header_id|>assistant<|end_header_id|>" in full_text:
        response = full_text.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
        response = response.replace("<|eot_id|>", "").strip()
    else:
        response = full_text[len(prompt):].strip()
    
    # Extract answer
    predicted_answer = extract_answer(response)
    
    return response, predicted_answer

In [13]:
# --- CELL 10: STaR Dataset Generation Function ---
def generate_star_dataset(model, tokenizer, train_data, iteration):
    """
    Generate STaR training dataset
    Returns: star_data, stats
    """
    print(f"\n{'='*80}")
    print(f"STaR DATASET GENERATION - ITERATION {iteration}")
    print(f"{'='*80}")
    
    # Check if already generated
    iter_dir = f"{config.CHECKPOINT_DIR}/iteration_{iteration}"
    star_dataset_file = f"{iter_dir}/star_dataset.json"
    stats_file = f"{iter_dir}/generation_stats.json"
    
    if os.path.exists(star_dataset_file) and os.path.exists(stats_file) and not config.FORCE_RESTART:
        print(f"Loading cached STaR dataset from {star_dataset_file}")
        with open(star_dataset_file, 'r') as f:
            star_data = json.load(f)
        with open(stats_file, 'r') as f:
            stats = json.load(f)
        print(f"✓ Loaded {len(star_data)} examples from cache")
        return star_data, stats
    
    # Generate new dataset
    os.makedirs(iter_dir, exist_ok=True)
    
    star_data = []
    stats = {
        "total": len(train_data),
        "correct_without_hint": 0,
        "correct_with_hint": 0,
        "failed": 0,
        "iteration": iteration
    }
    
    detailed_log = []
    model.eval()
    
    print(f"Generating rationales for {len(train_data)} examples...")
    
    for idx, item in enumerate(tqdm(train_data, desc=f"STaR Gen (Iter {iteration})")):
        question = item["question"]
        correct_answer = item["answer"]
        
        # Step 1: Try without hint
        response, predicted = generate_solution(model, tokenizer, question, use_hint=False)
        
        is_correct = False
        if predicted and correct_answer:
            try:
                is_correct = abs(float(predicted) - float(correct_answer)) < 0.01
            except:
                is_correct = predicted == correct_answer
        
        if is_correct:
            # Success! Add to dataset
            stats["correct_without_hint"] += 1
            star_data.append({
                "text": create_training_prompt(question, response),
                "question": question,
                "answer": correct_answer,
                "method": "generation"
            })
            detailed_log.append({
                "index": idx,
                "method": "generation",
                "success": True,
                "predicted": predicted,
                "correct": correct_answer
            })
        else:
            # Step 2: Try with hint (rationalization)
            hint_response, hint_predicted = generate_solution(
                model, tokenizer, question, use_hint=True, correct_answer=correct_answer
            )
            
            is_hint_correct = False
            if hint_predicted and correct_answer:
                try:
                    is_hint_correct = abs(float(hint_predicted) - float(correct_answer)) < 0.01
                except:
                    is_hint_correct = hint_predicted == correct_answer
            
            if is_hint_correct:
                stats["correct_with_hint"] += 1
                star_data.append({
                    "text": create_training_prompt(question, hint_response),
                    "question": question,
                    "answer": correct_answer,
                    "method": "rationalization"
                })
                detailed_log.append({
                    "index": idx,
                    "method": "rationalization",
                    "success": True,
                    "predicted": hint_predicted,
                    "correct": correct_answer
                })
            else:
                stats["failed"] += 1
                detailed_log.append({
                    "index": idx,
                    "method": "failed",
                    "success": False,
                    "predicted": hint_predicted,
                    "correct": correct_answer
                })
    
    # Calculate percentages
    total = stats["total"]
    print(f"\n{'='*80}")
    print(f"GENERATION RESULTS")
    print(f"{'='*80}")
    print(f"Generated without hint: {stats['correct_without_hint']} ({stats['correct_without_hint']/total*100:.1f}%)")
    print(f"Rationalized with hint: {stats['correct_with_hint']} ({stats['correct_with_hint']/total*100:.1f}%)")
    print(f"Failed: {stats['failed']} ({stats['failed']/total*100:.1f}%)")
    print(f"→ Total dataset size: {len(star_data)}")
    print(f"{'='*80}")
    
    # Save STaR dataset
    with open(star_dataset_file, "w") as f:
        json.dump(star_data, f, indent=2)
    
    with open(stats_file, "w") as f:
        json.dump(stats, f, indent=2)
    
    with open(f"{iter_dir}/detailed_generation_log.json", "w") as f:
        json.dump(detailed_log, f, indent=2)
    
    print(f"STaR dataset saved to {iter_dir}/")
    
    return star_data, stats


In [14]:
# --- CELL 11: Training Function ---
def train_on_star_dataset(model, tokenizer, star_data, iteration):
    """Train model on STaR-generated dataset"""
    print(f"\n{'='*80}")
    print(f"TRAINING ON STaR DATASET - ITERATION {iteration}")
    print(f"{'='*80}")
    print(f"Training examples: {len(star_data)}")
    
    iter_dir = f"{config.CHECKPOINT_DIR}/iteration_{iteration}"
    model_path = f"{iter_dir}/model"
    
    # Check if model already trained
    if os.path.exists(model_path) and not config.FORCE_RESTART:
        print(f"Model already trained for iteration {iteration}, loading from {model_path}")
        trained_model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            use_cache=False
        )
        trained_model.gradient_checkpointing_enable()
        return trained_model
    
    # Prepare dataset
    dataset = Dataset.from_list(star_data)
    
    def tokenize_function(examples):
        result = tokenizer(
            examples["text"],
            truncation=True,
            max_length=config.MAX_LENGTH,
            padding="max_length",
            return_tensors=None,
        )
        result["labels"] = result["input_ids"].copy()
        return result
    
    tokenized = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing"
    )
    
    # Training args
    training_args = TrainingArguments(
        output_dir=f"{iter_dir}/checkpoints",
        num_train_epochs=config.NUM_EPOCHS,
        per_device_train_batch_size=config.BATCH_SIZE,
        gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
        learning_rate=config.LEARNING_RATE,
        weight_decay=0.01,
        bf16=torch.cuda.is_bf16_supported(),
        fp16=not torch.cuda.is_bf16_supported(),
        logging_steps=10,
        logging_dir=f"{iter_dir}/logs",
        save_strategy="epoch",
        save_total_limit=1,
        report_to="none",
        max_grad_norm=1.0,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        gradient_checkpointing=True,
        optim="adamw_torch",
        dataloader_num_workers=4,  # Speed up data loading
    )
    
    # Train
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized,
        tokenizer=tokenizer,
    )
    
    print(f"Starting training...")
    train_result = trainer.train()
    
    print(f"\nTraining complete - Loss: {train_result.training_loss:.4f}")
    
    # Save model
    trainer.save_model(model_path)
    tokenizer.save_pretrained(model_path)
    
    # Save metrics
    with open(f"{iter_dir}/training_metrics.json", "w") as f:
        json.dump({
            "training_loss": train_result.training_loss,
            "train_runtime": train_result.metrics["train_runtime"],
            "total_steps": train_result.global_step,
        }, f, indent=2)
    
    print(f"Model saved to {model_path}")
    
    return model

In [15]:
# --- CELL 12: Evaluation Function ---
def evaluate_model(model, tokenizer, test_data, iteration_name=""):
    """Evaluate model and save results"""
    print(f"\n{'='*80}")
    print(f"EVALUATING MODEL {iteration_name}")
    print(f"{'='*80}")
    
    model.eval()
    
    correct = 0
    total = 0
    results_log = []
    
    for idx, item in enumerate(tqdm(test_data, desc="Evaluating")):
        question = item["question"]
        correct_answer = item["answer"]
        
        response, predicted = generate_solution(model, tokenizer, question)
        
        is_correct = False
        if predicted and correct_answer:
            try:
                is_correct = abs(float(predicted) - float(correct_answer)) < 0.01
            except:
                is_correct = predicted == correct_answer
        
        if is_correct:
            correct += 1
        total += 1
        
        results_log.append({
            "index": idx,
            "question": question,
            "predicted_answer": predicted,
            "correct_answer": correct_answer,
            "is_correct": is_correct,
            "full_response": response,
            "response_preview": response[:200]
        })
    
    accuracy = correct / total if total > 0 else 0
    
    print(f"\n{'='*80}")
    print(f"RESULTS")
    print(f"{'='*80}")
    print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
    print(f"{'='*80}")
    
    return accuracy, results_log


In [16]:
# --- CELL 13: Main STaR Training Loop with Checkpoint Resume ---
print("\n" + "="*80)
print("STARTING/RESUMING STaR TRAINING")
print("="*80)
print(f"Start/Resume time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

# Determine starting point
start_iteration = checkpoint_manager.get_next_iteration()

if start_iteration >= config.STAR_ITERATIONS:
    print(f"All {config.STAR_ITERATIONS} iterations already complete!")
    print("Loading final model for evaluation...")
else:
    print(f"Starting from iteration {start_iteration}")
    
    # Load base model or last checkpoint
    last_model_path = checkpoint_manager.get_last_model_path()
    
    if last_model_path and start_iteration > 0:
        print(f"Resuming from checkpoint: {last_model_path}")
        current_model = AutoModelForCausalLM.from_pretrained(
            last_model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            use_cache=False
        )
    else:
        print(f"Loading base model: {config.MODEL_NAME}")
        current_model = AutoModelForCausalLM.from_pretrained(
            config.MODEL_NAME,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
            use_cache=False
        )
    
    current_model.gradient_checkpointing_enable()
    print(f"Model loaded and ready")
    
    # Run iterations
    for iteration in range(start_iteration, config.STAR_ITERATIONS):
        print(f"\n{'#'*80}")
        print(f"# STaR ITERATION {iteration + 1}/{config.STAR_ITERATIONS}")
        print(f"{'#'*80}")
        print(f"Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        
        try:
            # Generate STaR dataset
            star_dataset, stats = generate_star_dataset(current_model, tokenizer, train_data, iteration)
            
            if len(star_dataset) == 0:
                print(f"\nNo correct examples generated in iteration {iteration}. Stopping.")
                break
            
            # Train on STaR dataset
            current_model = train_on_star_dataset(current_model, tokenizer, star_dataset, iteration)
            
            # Evaluate
            accuracy, results = evaluate_model(current_model, tokenizer, test_data, f"(Iteration {iteration})")
            
            # Save iteration results
            iter_dir = f"{config.CHECKPOINT_DIR}/iteration_{iteration}"
            with open(f"{iter_dir}/evaluation_results.json", "w") as f:
                json.dump({
                    "iteration": iteration,
                    "accuracy": accuracy,
                    "correct": int(accuracy * len(test_data)),
                    "total": len(test_data),
                    "detailed_results": results
                }, f, indent=2)
            
            # Mark iteration complete in checkpoint
            checkpoint_manager.mark_iteration_complete(iteration, stats, accuracy)
            
            print(f"\nIteration {iteration} COMPLETE and checkpointed")
            print(f"Checkpoint saved - safe to interrupt and resume")
            
            # Clean up memory
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"\nERROR in iteration {iteration}: {str(e)}")
            print("State has been saved. You can resume from this point.")
            raise

print("\n" + "="*80)
print("STaR TRAINING COMPLETE!")
print("="*80)
print(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

`torch_dtype` is deprecated! Use `dtype` instead!



STARTING/RESUMING STaR TRAINING
Start/Resume time: 2025-10-12 18:26:53

Starting from iteration 0
Loading base model: meta-llama/Llama-3.2-3B-Instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded and ready

################################################################################
# STaR ITERATION 1/1
################################################################################
Time: 2025-10-12 18:27:12

STaR DATASET GENERATION - ITERATION 0
Loading cached STaR dataset from ./small_project/star/checkpoints/iteration_0/star_dataset.json
✓ Loaded 6407 examples from cache

TRAINING ON STaR DATASET - ITERATION 0
Training examples: 6407


Tokenizing:   0%|          | 0/6407 [00:00<?, ? examples/s]

  trainer = Trainer(
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


Starting training...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Step,Training Loss
10,5.4547
20,1.1676
30,0.5968
40,0.4152
50,0.3786
60,0.3697
70,0.3716
80,0.3243
90,0.3214
100,0.3184


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Training complete - Loss: 0.3619
Model saved to ./small_project/star/checkpoints/iteration_0/model

EVALUATING MODEL (Iteration 0)


Evaluating:   0%|          | 0/1319 [00:00<?, ?it/s]


RESULTS
Accuracy: 0.6391 (843/1319)
💾 State saved to ./small_project/star/training_state.json

Iteration 0 COMPLETE and checkpointed
Checkpoint saved - safe to interrupt and resume

STaR TRAINING COMPLETE!
End time: 2025-10-12 20:17:28


In [17]:
# --- CELL 14: Load Final Model and Results ---
print("\nLoading final results...")

# Load final model
final_model_path = checkpoint_manager.get_last_model_path()
if final_model_path:
    print(f"Loading final model from {final_model_path}")
    final_model = AutoModelForCausalLM.from_pretrained(
        final_model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
    )
else:
    print("No trained model found, using current model")
    final_model = current_model


Loading final results...
Loading final model from ./small_project/star/checkpoints/iteration_0/model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# --- CELL 15: Final Evaluation ---
print("\n" + "="*80)
print("FINAL EVALUATION ON TEST SET")
print("="*80)

final_accuracy, final_results = evaluate_model(final_model, tokenizer, test_data, "(Final)")


FINAL EVALUATION ON TEST SET

EVALUATING MODEL (Final)


Evaluating:   0%|          | 0/1319 [00:00<?, ?it/s]

In [None]:
# --- CELL 16: Save Final Results ---
print("\nSaving final results...")

final_summary = {
    "experiment": "star",
    "model": config.MODEL_NAME,
    "timestamp": datetime.now().isoformat(),
    "training_started": checkpoint_manager.state.get("start_time"),
    "training_completed": datetime.now().isoformat(),
    "config": {
        "train_size": len(train_data),
        "test_size": len(test_data),
        "star_iterations": config.STAR_ITERATIONS,
        "completed_iterations": len(checkpoint_manager.state["completed_iterations"]),
        "epochs_per_iteration": config.NUM_EPOCHS,
        "batch_size": config.BATCH_SIZE,
        "learning_rate": config.LEARNING_RATE,
        "gradient_accumulation_steps": config.GRADIENT_ACCUMULATION_STEPS,
    },
    "final_results": {
        "accuracy": float(final_accuracy),
        "correct": int(final_accuracy * len(test_data)),
        "total": len(test_data),
    },
    "iteration_stats": checkpoint_manager.state["all_stats"],
    "accuracy_progression": checkpoint_manager.state["all_accuracies"],
}

with open(f"{config.OUTPUT_DIR}/results_summary.json", "w") as f:
    json.dump(final_summary, f, indent=2)

with open(f"{config.OUTPUT_DIR}/detailed_results.json", "w") as f:
    json.dump(final_results, f, indent=2)

print(f"✓ Results saved to:")
print(f"  - {config.OUTPUT_DIR}/results_summary.json")
print(f"  - {config.OUTPUT_DIR}/detailed_results.json")

In [None]:
# --- CELL 17: Visualizations ---
print("\n📈 Generating visualizations...")

if checkpoint_manager.state["all_stats"]:
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot 1: STaR Dataset Generation Progress
    all_stats = checkpoint_manager.state["all_stats"]
    iterations = [s['iteration'] for s in all_stats]
    correct_gen = [s['correct_without_hint'] for s in all_stats]
    correct_rat = [s['correct_with_hint'] for s in all_stats]
    failed = [s['failed'] for s in all_stats]
    
    x = np.arange(len(iterations))
    width = 0.25
    
    ax1.bar(x - width, correct_gen, width, label='Generated (no hint)', color='#2ecc71', alpha=0.8)
    ax1.bar(x, correct_rat, width, label='Rationalized (with hint)', color='#3498db', alpha=0.8)
    ax1.bar(x + width, failed, width, label='Failed', color='#e74c3c', alpha=0.8)
    
    ax1.set_xlabel('STaR Iteration', fontweight='bold', fontsize=12)
    ax1.set_ylabel('Number of Examples', fontweight='bold', fontsize=12)
    ax1.set_title('STaR Dataset Generation Progress', fontweight='bold', fontsize=14)
    ax1.set_xticks(x)
    ax1.set_xticklabels([f'Iter {i}' for i in iterations])
    ax1.legend(fontsize=10)
    ax1.grid(axis='y', alpha=0.3)
    
    # Add percentage labels
    for i in x:
        total = correct_gen[i] + correct_rat[i] + failed[i]
        success_pct = (correct_gen[i] + correct_rat[i]) / total * 100
        ax1.text(i, correct_gen[i] + correct_rat[i] + 5, f'{success_pct:.1f}%', 
                ha='center', fontweight='bold', fontsize=9)
    
    # Plot 2: Accuracy Progression
    if checkpoint_manager.state["all_accuracies"]:
        all_accuracies = checkpoint_manager.state["all_accuracies"]
        iter_nums = [a['iteration'] for a in all_accuracies]
        accuracies = [a['accuracy'] * 100 for a in all_accuracies]
        
        ax2.plot(iter_nums, accuracies, marker='o', linewidth=3, markersize=10, 
                color='#2ecc71', label='Test Accuracy')
        ax2.set_xlabel('STaR Iteration', fontweight='bold', fontsize=12)
        ax2.set_ylabel('Accuracy (%)', fontweight='bold', fontsize=12)
        ax2.set_title('Test Accuracy Progression', fontweight='bold', fontsize=14)
        ax2.set_xticks(iter_nums)
        ax2.set_xticklabels([f'Iter {i}' for i in iter_nums])
        ax2.grid(True, alpha=0.3)
        ax2.legend(fontsize=10)
        
        # Annotate points
        for i, (x_val, y_val) in enumerate(zip(iter_nums, accuracies)):
            ax2.annotate(f'{y_val:.2f}%', (x_val, y_val), textcoords="offset points",
                        xytext=(0, 10), ha='center', fontweight='bold', fontsize=10)
    
    # Plot 3: Success Rate by Method
    gen_rates = [s['correct_without_hint'] / s['total'] * 100 for s in all_stats]
    rat_rates = [s['correct_with_hint'] / s['total'] * 100 for s in all_stats]
    
    ax3.plot(iterations, gen_rates, marker='s', linewidth=2, markersize=8,
            color='#2ecc71', label='Generation Success Rate', linestyle='--')
    ax3.plot(iterations, rat_rates, marker='^', linewidth=2, markersize=8,
            color='#3498db', label='Rationalization Success Rate', linestyle='--')
    
    ax3.set_xlabel('STaR Iteration', fontweight='bold', fontsize=12)
    ax3.set_ylabel('Success Rate (%)', fontweight='bold', fontsize=12)
    ax3.set_title('Success Rate by Method', fontweight='bold', fontsize=14)
    ax3.set_xticks(iterations)
    ax3.set_xticklabels([f'Iter {i}' for i in iterations])
    ax3.legend(fontsize=10)
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Cumulative Training Data
    cumulative_data = []
    cumulative = 0
    for s in all_stats:
        cumulative += s['correct_without_hint'] + s['correct_with_hint']
        cumulative_data.append(cumulative)
    
    ax4.bar(iterations, cumulative_data, color='#9b59b6', alpha=0.8, edgecolor='black')
    ax4.set_xlabel('STaR Iteration', fontweight='bold', fontsize=12)
    ax4.set_ylabel('Cumulative Training Examples', fontweight='bold', fontsize=12)
    ax4.set_title('Cumulative STaR Training Data', fontweight='bold', fontsize=14)
    ax4.set_xticks(iterations)
    ax4.set_xticklabels([f'Iter {i}' for i in iterations])
    ax4.grid(axis='y', alpha=0.3)
    
    # Add value labels
    for i, (iter_num, cum) in enumerate(zip(iterations, cumulative_data)):
        ax4.text(i, cum + max(cumulative_data)*0.02, str(cum),
                ha='center', fontweight='bold', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(f"{config.OUTPUT_DIR}/star_training_analysis.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Visualization saved to {config.OUTPUT_DIR}/star_training_analysis.png")


In [21]:
# --- CELL 18: Sample Predictions ---
print("\n" + "="*80)
print("SAMPLE FINAL PREDICTIONS")
print("="*80)

correct_preds = [r for r in final_results if r['is_correct']][:5]
incorrect_preds = [r for r in final_results if not r['is_correct']][:5]

if correct_preds:
    print(f"\n✓ CORRECT PREDICTIONS (showing up to 5):")
    for i, r in enumerate(correct_preds):
        print(f"\n{i+1}. Q: {r['question'][:100]}...")
        print(f"   Predicted: {r['predicted_answer']} | Correct: {r['correct_answer']}")
        print(f"   Response: {r['response_preview']}...")
        print("-" * 80)

if incorrect_preds:
    print(f"\n✗ INCORRECT PREDICTIONS (showing up to 5):")
    for i, r in enumerate(incorrect_preds):
        print(f"\n{i+1}. Q: {r['question'][:100]}...")
        print(f"   Predicted: {r['predicted_answer']} | Correct: {r['correct_answer']}")
        print(f"   Response: {r['response_preview']}...")
        print("-" * 80)


SAMPLE FINAL PREDICTIONS

✓ CORRECT PREDICTIONS (showing up to 5):

1. Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for ...
   Predicted: 18 | Correct: 18
   Response: nsumed or used:** 3 + 4 = 7 eggs/day

Now, let's subtract the total eggs consumed from the total number of eggs laid:
**Eggs left for sale:** 16 - 7 = 9 eggs/day

Since each egg is sold for $2, we mul...
--------------------------------------------------------------------------------

2. Q: A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it ...
   Predicted: 3 | Correct: 3
   Response: us:
2 x 1/2 = 1
So, we need 1 bolt of white fiber.
3. To find the total number of bolts needed, add the number of blue fibers (2) to the number of white fibers (1):
2 + 1 = 3

The final answer is: ###...
--------------------------------------------------------------------------------

3. Q: James decides to run 3 sprints 3 times a week. 

In [22]:
# --- CELL 19: Detailed Summary Report ---
print("\n" + "="*80)
print("STaR TRAINING SUMMARY REPORT")
print("="*80)

report = f"""
{'='*80}
STaR TRAINING COMPLETE - SUMMARY REPORT
{'='*80}

EXPERIMENT CONFIGURATION
{'='*80}
Model: {config.MODEL_NAME}
Training Dataset: {len(train_data)} examples
Test Dataset: {len(test_data)} examples
STaR Iterations: {config.STAR_ITERATIONS}
Completed Iterations: {len(checkpoint_manager.state['completed_iterations'])}
Epochs per Iteration: {config.NUM_EPOCHS}
Batch Size: {config.BATCH_SIZE}
Gradient Accumulation: {config.GRADIENT_ACCUMULATION_STEPS}
Effective Batch Size: {config.BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}
Learning Rate: {config.LEARNING_RATE}

TRAINING TIMELINE
{'='*80}
Started: {checkpoint_manager.state.get('start_time', 'N/A')}
Completed: {datetime.now().isoformat()}
Completed Iterations: {checkpoint_manager.state.get('completed_iterations', [])}

DATASET GENERATION SUMMARY
{'='*80}
"""

if checkpoint_manager.state["all_stats"]:
    for i, stat in enumerate(checkpoint_manager.state["all_stats"]):
        total = stat['total']
        gen = stat['correct_without_hint']
        rat = stat['correct_with_hint']
        fail = stat['failed']
        success_total = gen + rat
        
        report += f"""
Iteration {i}:
  Total examples: {total}
  Generated (no hint): {gen} ({gen/total*100:.1f}%)
  Rationalized (with hint): {rat} ({rat/total*100:.1f}%)
  Failed: {fail} ({fail/total*100:.1f}%)
  Success rate: {success_total/total*100:.1f}%
  Training examples: {success_total}
"""

report += f"""
{'='*80}
ACCURACY PROGRESSION
{'='*80}
"""

if checkpoint_manager.state["all_accuracies"]:
    for acc_data in checkpoint_manager.state["all_accuracies"]:
        iter_num = acc_data['iteration']
        acc = acc_data['accuracy']
        correct = int(acc * len(test_data))
        report += f"After Iteration {iter_num}: {acc:.4f} ({correct}/{len(test_data)})\n"

report += f"""
{'='*80}
FINAL RESULTS
{'='*80}
Final Test Accuracy: {final_accuracy:.4f}
Correct Predictions: {int(final_accuracy * len(test_data))}/{len(test_data)}

{'='*80}
FILES GENERATED
{'='*80}
Main Results:
  - {config.OUTPUT_DIR}/results_summary.json
  - {config.OUTPUT_DIR}/detailed_results.json
  - {config.OUTPUT_DIR}/star_training_analysis.png

Checkpoints:
  - {config.CHECKPOINT_DIR}/iteration_*/model/
  - {config.CHECKPOINT_DIR}/iteration_*/star_dataset.json
  - {config.CHECKPOINT_DIR}/iteration_*/generation_stats.json

Training State:
  - {config.STATE_FILE}

{'='*80}
RESUMABILITY
{'='*80}
This experiment used checkpoint-based resumable training.
If interrupted, re-run all cells to resume from the last completed iteration.
To restart from scratch, set Config.FORCE_RESTART = True

{'='*80}
"""

# Save report
with open(f"{config.OUTPUT_DIR}/training_report.txt", "w") as f:
    f.write(report)

print(report)
print(f"✓ Report saved to {config.OUTPUT_DIR}/training_report.txt")



STaR TRAINING SUMMARY REPORT

STaR TRAINING COMPLETE - SUMMARY REPORT

EXPERIMENT CONFIGURATION
Model: meta-llama/Llama-3.2-3B-Instruct
Training Dataset: 7473 examples
Test Dataset: 1319 examples
STaR Iterations: 1
Completed Iterations: 1
Epochs per Iteration: 2
Batch Size: 4
Gradient Accumulation: 4
Effective Batch Size: 16
Learning Rate: 2e-05

TRAINING TIMELINE
Started: 2025-10-12T18:26:48.222281
Completed: 2025-10-12T21:39:05.885057
Completed Iterations: [0]

DATASET GENERATION SUMMARY

Iteration 0:
  Total examples: 7473
  Generated (no hint): 5095 (68.2%)
  Rationalized (with hint): 1312 (17.6%)
  Failed: 1066 (14.3%)
  Success rate: 85.7%
  Training examples: 6407

ACCURACY PROGRESSION
After Iteration 0: 0.6391 (843/1319)

FINAL RESULTS
Final Test Accuracy: 0.6338
Correct Predictions: 836/1319

FILES GENERATED
Main Results:
  - ./small_project/star/results_summary.json
  - ./small_project/star/detailed_results.json
  - ./small_project/star/star_training_analysis.png

Checkpoint

In [23]:
# --- CELL 20: Statistics Table ---
print("\n" + "="*80)
print("ITERATION STATISTICS TABLE")
print("="*80)

if checkpoint_manager.state["all_stats"]:
    # Create table
    print(f"\n{'Iter':<6} {'Total':<8} {'Generated':<12} {'Rational.':<12} {'Failed':<8} {'Success%':<10} {'Test Acc':<10}")
    print("-" * 80)
    
    for i, stat in enumerate(checkpoint_manager.state["all_stats"]):
        total = stat['total']
        gen = stat['correct_without_hint']
        rat = stat['correct_with_hint']
        fail = stat['failed']
        success_pct = (gen + rat) / total * 100
        
        # Get accuracy for this iteration if available
        acc_data = next((a for a in checkpoint_manager.state["all_accuracies"] if a['iteration'] == i), None)
        test_acc = f"{acc_data['accuracy']:.4f}" if acc_data else "N/A"
        
        print(f"{i:<6} {total:<8} {gen:<12} {rat:<12} {fail:<8} {success_pct:<10.1f} {test_acc:<10}")
    
    print("-" * 80)


ITERATION STATISTICS TABLE

Iter   Total    Generated    Rational.    Failed   Success%   Test Acc  
--------------------------------------------------------------------------------
0      7473     5095         1312         1066     85.7       0.6391    
--------------------------------------------------------------------------------


In [24]:
# --- CELL 21: Performance Metrics ---
print("\n" + "="*80)
print("PERFORMANCE METRICS")
print("="*80)

if checkpoint_manager.state["all_accuracies"] and len(checkpoint_manager.state["all_accuracies"]) >= 2:
    first_acc = checkpoint_manager.state["all_accuracies"][0]['accuracy']
    final_acc = checkpoint_manager.state["all_accuracies"][-1]['accuracy']
    improvement = (final_acc - first_acc) * 100
    
    print(f"\nFirst iteration accuracy: {first_acc:.4f}")
    print(f"Final iteration accuracy: {final_acc:.4f}")
    print(f"Absolute improvement: {improvement:+.2f}%")
    print(f"Relative improvement: {(final_acc / first_acc - 1) * 100:+.2f}%")

# Calculate average success rates
if checkpoint_manager.state["all_stats"]:
    avg_gen_rate = np.mean([s['correct_without_hint'] / s['total'] for s in checkpoint_manager.state["all_stats"]])
    avg_rat_rate = np.mean([s['correct_with_hint'] / s['total'] for s in checkpoint_manager.state["all_stats"]])
    avg_success_rate = np.mean([(s['correct_without_hint'] + s['correct_with_hint']) / s['total'] 
                                 for s in checkpoint_manager.state["all_stats"]])
    
    print(f"\nAverage generation success rate: {avg_gen_rate*100:.2f}%")
    print(f"Average rationalization success rate: {avg_rat_rate*100:.2f}%")
    print(f"Average overall success rate: {avg_success_rate*100:.2f}%")



PERFORMANCE METRICS

Average generation success rate: 68.18%
Average rationalization success rate: 17.56%
Average overall success rate: 85.74%


In [25]:
# --- CELL 22: Checkpoint Information ---
print("\n" + "="*80)
print("CHECKPOINT INFORMATION")
print("="*80)

print(f"\nCheckpoint directory: {config.CHECKPOINT_DIR}")
print(f"State file: {config.STATE_FILE}")
print(f"\nCompleted iterations: {checkpoint_manager.state.get('completed_iterations', [])}")
print(f"Last completed iteration: {checkpoint_manager.state.get('last_completed_iteration', -1)}")

# Check checkpoint sizes
total_checkpoint_size = 0
for iter_num in checkpoint_manager.state.get('completed_iterations', []):
    iter_dir = f"{config.CHECKPOINT_DIR}/iteration_{iter_num}"
    if os.path.exists(iter_dir):
        size = sum(os.path.getsize(os.path.join(dirpath, filename))
                   for dirpath, dirnames, filenames in os.walk(iter_dir)
                   for filename in filenames)
        total_checkpoint_size += size
        print(f"Iteration {iter_num} checkpoint size: {size / 1e9:.2f} GB")

print(f"\nTotal checkpoint storage: {total_checkpoint_size / 1e9:.2f} GB")



CHECKPOINT INFORMATION

Checkpoint directory: ./small_project/star/checkpoints
State file: ./small_project/star/training_state.json

Completed iterations: [0]
Last completed iteration: 0
Iteration 0 checkpoint size: 25.75 GB

Total checkpoint storage: 25.75 GB


In [26]:
# --- CELL 23: Final Summary ---
print("\n" + "="*80)
print("🎉 STaR EXPERIMENT COMPLETE!")
print("="*80)

print(f"""
 Successfully completed {len(checkpoint_manager.state.get('completed_iterations', []))} STaR iterations
 Final test accuracy: {final_accuracy:.4f}
 All results saved to: {config.OUTPUT_DIR}/
 Checkpoints saved for resumability

KEY FILES:
   {config.OUTPUT_DIR}/results_summary.json - Main results
   {config.OUTPUT_DIR}/detailed_results.json - All predictions
   {config.OUTPUT_DIR}/training_report.txt - Full report
   {config.OUTPUT_DIR}/star_training_analysis.png - Visualizations
   {config.STATE_FILE} - Training state (for resume)

CHECKPOINT MODELS:
""")

for iter_num in checkpoint_manager.state.get('completed_iterations', []):
    model_path = f"{config.CHECKPOINT_DIR}/iteration_{iter_num}/model"
    if os.path.exists(model_path):
        print(f"  Iteration {iter_num}: {model_path}")

print(f"""
{'='*80}
TO RESUME TRAINING (if interrupted):
  1. Keep Config.FORCE_RESTART = False
  2. Re-run all cells
  3. Training will resume from last checkpoint

TO START FRESH:
  1. Set Config.FORCE_RESTART = True
  2. Re-run all cells

{'='*80}
""")

print("Thank you for using the STaR training system! ")


🎉 STaR EXPERIMENT COMPLETE!

 Successfully completed 1 STaR iterations
 Final test accuracy: 0.6338
 All results saved to: ./small_project/star/
 Checkpoints saved for resumability

KEY FILES:
   ./small_project/star/results_summary.json - Main results
   ./small_project/star/detailed_results.json - All predictions
   ./small_project/star/training_report.txt - Full report
   ./small_project/star/star_training_analysis.png - Visualizations
   ./small_project/star/training_state.json - Training state (for resume)

CHECKPOINT MODELS:

  Iteration 0: ./small_project/star/checkpoints/iteration_0/model

TO RESUME TRAINING (if interrupted):
  1. Keep Config.FORCE_RESTART = False
  2. Re-run all cells
  3. Training will resume from last checkpoint

TO START FRESH:
  1. Set Config.FORCE_RESTART = True
  2. Re-run all cells


Thank you for using the STaR training system! 
