# TinyZero with A*PO Implementation

This notebook implements TinyZero (a reproduction of DeepSeek R1 Zero) using A*PO (Optimal Advantage Regression) instead of GRPO for countdown and multiplication tasks.

## Overview
- **TinyZero**: A minimal reproduction of DeepSeek R1 Zero focusing on countdown and multiplication tasks
- **A*PO**: Optimal Advantage Regression algorithm for efficient policy optimization
- **Tasks**: Countdown and multiplication mathematical reasoning
- **Framework**: Pure PyTorch with FSDP for training

## Requirements
- PyTorch for inference and training
- PyTorch FSDP for multi-GPU training
- Single Python process execution
- Modal.com compatibility


## 1. Environment Setup and Imports


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import BackwardPrefetch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import torch.optim as optim
import numpy as np
import random
import json
import math
import time
import os
import re
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
from collections import defaultdict
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")


## 2. Configuration and Hyperparameters


In [None]:
@dataclass
class Config:
    # Model configuration
    vocab_size: int = 32000
    d_model: int = 1024
    n_heads: int = 16
    n_layers: int = 12
    d_ff: int = 4096
    max_seq_len: int = 512
    dropout: float = 0.1
    
    # Training configuration
    batch_size: int = 8
    learning_rate: float = 1e-4
    num_epochs: int = 10
    warmup_steps: int = 1000
    max_grad_norm: float = 1.0
    
    # A*PO specific parameters
    num_responses_per_prompt: int = 4  # Number of responses to generate for value estimation
    beta1: float = 0.1  # KL regularization coefficient
    beta2: float = 0.1  # Advantage regression coefficient
    temperature: float = 0.7  # Sampling temperature
    
    # Task configuration
    task_type: str = "countdown"  # "countdown" or "multiplication"
    num_train_samples: int = 1000
    num_eval_samples: int = 200
    
    # FSDP configuration
    use_fsdp: bool = True
    fsdp_cpu_offload: bool = False
    fsdp_mixed_precision: bool = True
    
    # Evaluation configuration
    eval_interval: int = 100  # Evaluate every N steps
    save_interval: int = 500  # Save checkpoint every N steps

config = Config()
print("Configuration:")
for key, value in config.__dict__.items():
    print(f"  {key}: {value}")


## 3. Task Environments


In [None]:
class TaskEnvironment:
    """Base class for task environments"""
    
    def __init__(self, task_type: str):
        self.task_type = task_type
        self.vocab_size = config.vocab_size
        
    def generate_sample(self) -> Dict[str, Any]:
        """Generate a single training sample"""
        raise NotImplementedError
    
    def evaluate_response(self, prompt: str, response: str) -> float:
        """Evaluate the quality of a response"""
        raise NotImplementedError
    
    def tokenize(self, text: str) -> List[int]:
        """Simple tokenization (in practice, use a proper tokenizer)"""
        # Simple character-level tokenization for demonstration
        return [ord(c) % self.vocab_size for c in text]
    
    def detokenize(self, tokens: List[int]) -> str:
        """Simple detokenization"""
        return ''.join([chr(token) for token in tokens if token < 128])


class CountdownTask(TaskEnvironment):
    """Countdown task: Given a target number, use operations to reach it"""
    
    def __init__(self):
        super().__init__("countdown")
        self.operations = ['+', '-', '*', '/']
        
    def generate_sample(self) -> Dict[str, Any]:
        # Generate a random target number
        target = random.randint(1, 100)
        
        # Generate starting numbers
        start_nums = [random.randint(1, 20) for _ in range(3)]
        
        prompt = f"Target: {target}, Numbers: {start_nums}. Use these numbers and operations to reach the target."
        
        # Generate a correct solution
        solution = self._generate_solution(start_nums, target)
        
        return {
            'prompt': prompt,
            'target': target,
            'start_nums': start_nums,
            'solution': solution,
            'tokens': self.tokenize(prompt)
        }
    
    def _generate_solution(self, start_nums: List[int], target: int) -> str:
        """Generate a valid solution for the countdown problem"""
        # Simple brute force to find a solution
        for i in range(len(start_nums)):
            for j in range(len(start_nums)):
                if i != j:
                    for op in self.operations:
                        try:
                            if op == '+':
                                result = start_nums[i] + start_nums[j]
                            elif op == '-':
                                result = start_nums[i] - start_nums[j]
                            elif op == '*':
                                result = start_nums[i] * start_nums[j]
                            elif op == '/':
                                if start_nums[j] != 0:
                                    result = start_nums[i] / start_nums[j]
                                else:
                                    continue
                            
                            if abs(result - target) < 0.01:
                                return f"{start_nums[i]} {op} {start_nums[j]} = {result}"
                        except:
                            continue
        
        # If no exact solution found, return a close approximation
        return f"{start_nums[0]} + {start_nums[1]} = {start_nums[0] + start_nums[1]} (close to {target})"
    
    def evaluate_response(self, prompt: str, response: str) -> float:
        """Evaluate countdown response"""
        # Extract target from prompt
        target_match = re.search(r'Target: (\d+)', prompt)
        if not target_match:
            return 0.0
        
        target = int(target_match.group(1))
        
        # Try to extract result from response
        result_match = re.search(r'= (\d+(?:\.\d+)?)', response)
        if not result_match:
            return 0.0
        
        try:
            result = float(result_match.group(1))
            # Reward based on how close the result is to the target
            accuracy = max(0, 1.0 - abs(result - target) / max(target, 1))
            return accuracy
        except:
            return 0.0


class MultiplicationTask(TaskEnvironment):
    """Multiplication task: Solve multiplication problems"""
    
    def __init__(self):
        super().__init__("multiplication")
    
    def generate_sample(self) -> Dict[str, Any]:
        # Generate two random numbers
        a = random.randint(1, 99)
        b = random.randint(1, 99)
        
        prompt = f"What is {a} × {b}?"
        solution = f"{a} × {b} = {a * b}"
        
        return {
            'prompt': prompt,
            'a': a,
            'b': b,
            'solution': solution,
            'tokens': self.tokenize(prompt)
        }
    
    def evaluate_response(self, prompt: str, response: str) -> float:
        """Evaluate multiplication response"""
        # Extract numbers from prompt
        numbers = re.findall(r'\d+', prompt)
        if len(numbers) < 2:
            return 0.0
        
        a, b = int(numbers[0]), int(numbers[1])
        expected = a * b
        
        # Extract answer from response
        answer_match = re.search(r'\d+', response)
        if not answer_match:
            return 0.0
        
        try:
            answer = int(answer_match.group(0))
            return 1.0 if answer == expected else 0.0
        except:
            return 0.0


# Initialize task environment
if config.task_type == "countdown":
    task_env = CountdownTask()
elif config.task_type == "multiplication":
    task_env = MultiplicationTask()
else:
    raise ValueError(f"Unknown task type: {config.task_type}")

print(f"Initialized {config.task_type} task environment")

# Test task environment
sample = task_env.generate_sample()
print(f"Sample prompt: {sample['prompt']}")
print(f"Sample solution: {sample['solution']}")


## 4. Model Architecture


In [None]:
class PositionalEncoding(nn.Module):
    """Positional encoding for transformer"""
    
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        self.d_model = d_model
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        return x + self.pe[:x.size(0), :]


class TransformerBlock(nn.Module):
    """Single transformer block"""
    
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention
        attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed forward
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        
        return x


class TinyZeroModel(nn.Module):
    """TinyZero model architecture"""
    
    def __init__(self, config: Config):
        super().__init__()
        self.config = config
        
        # Embedding layers
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_encoding = PositionalEncoding(config.d_model, config.max_seq_len)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config.d_model, config.n_heads, config.d_ff, config.dropout)
            for _ in range(config.n_layers)
        ])
        
        # Output layers
        self.ln_f = nn.LayerNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # Value head for A*PO
        self.value_head = nn.Linear(config.d_model, 1)
        
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, input_ids, attention_mask=None, return_value=False):
        batch_size, seq_len = input_ids.shape
        
        # Token embeddings
        x = self.token_embedding(input_ids)
        x = self.pos_encoding(x.transpose(0, 1)).transpose(0, 1)
        x = self.dropout(x)
        
        # Create causal mask
        if attention_mask is None:
            attention_mask = torch.ones(batch_size, seq_len, device=input_ids.device)
        
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=input_ids.device), diagonal=1).bool()
        
        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x, mask=causal_mask)
        
        # Final layer norm
        x = self.ln_f(x)
        
        # Language modeling head
        logits = self.lm_head(x)
        
        if return_value:
            # Value estimation (use last token)
            values = self.value_head(x[:, -1, :])  # [batch_size, 1]
            return logits, values
        
        return logits
    
    def generate(self, input_ids, max_length=100, temperature=1.0, do_sample=True):
        """Generate text using the model"""
        self.eval()
        
        with torch.no_grad():
            for _ in range(max_length):
                logits = self.forward(input_ids)
                next_token_logits = logits[:, -1, :] / temperature
                
                if do_sample:
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                else:
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                
                input_ids = torch.cat([input_ids, next_token], dim=1)
                
                # Stop if we hit a special token (simplified)
                if next_token.item() == 0:  # Assuming 0 is a stop token
                    break
        
        return input_ids


# Initialize model
model = TinyZeroModel(config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model initialized on {device}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024 / 1024:.2f} MB")


## 5. A*PO Algorithm Implementation


In [None]:
class AstarPO:
    """A*PO (Optimal Advantage Regression) algorithm"""
    
    def __init__(self, model, config: Config):
        self.model = model
        self.config = config
        self.optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate)
        
        # Reference model for value estimation (frozen copy)
        self.reference_model = TinyZeroModel(config)
        self.reference_model.load_state_dict(model.state_dict())
        self.reference_model.eval()
        
        # Freeze reference model parameters
        for param in self.reference_model.parameters():
            param.requires_grad = False
        
        self.reference_model = self.reference_model.to(device)
        
        # Training statistics
        self.step_count = 0
        self.loss_history = []
        
    def generate_responses(self, prompts: List[str], num_responses: int = None) -> List[Dict[str, Any]]:
        """Generate multiple responses for each prompt using reference model"""
        if num_responses is None:
            num_responses = self.config.num_responses_per_prompt
        
        responses = []
        
        with torch.no_grad():
            for prompt in prompts:
                prompt_tokens = task_env.tokenize(prompt)
                input_ids = torch.tensor([prompt_tokens], device=device)
                
                prompt_responses = []
                
                for _ in range(num_responses):
                    # Generate response
                    generated_ids = self.reference_model.generate(
                        input_ids, 
                        max_length=50, 
                        temperature=self.config.temperature,
                        do_sample=True
                    )
                    
                    # Extract response tokens (excluding prompt)
                    response_tokens = generated_ids[0, len(prompt_tokens):].tolist()
                    response_text = task_env.detokenize(response_tokens)
                    
                    # Evaluate response
                    reward = task_env.evaluate_response(prompt, response_text)
                    
                    prompt_responses.append({
                        'prompt': prompt,
                        'response': response_text,
                        'tokens': response_tokens,
                        'reward': reward
                    })
                
                responses.extend(prompt_responses)
        
        return responses
    
    def estimate_optimal_values(self, responses: List[Dict[str, Any]]) -> torch.Tensor:
        """Estimate optimal values using reference model"""
        values = []
        
        with torch.no_grad():
            for response_data in responses:
                # Combine prompt and response
                full_text = response_data['prompt'] + ' ' + response_data['response']
                tokens = task_env.tokenize(full_text)
                
                if len(tokens) == 0:
                    values.append(0.0)
                    continue
                
                input_ids = torch.tensor([tokens], device=device)
                
                # Get value estimate from reference model
                _, value = self.reference_model(input_ids, return_value=True)
                values.append(value.item())
        
        return torch.tensor(values, device=device)
    
    def compute_advantages(self, rewards: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        """Compute advantages using rewards and value estimates"""
        # Simple advantage computation: A = R - V
        advantages = rewards - values
        return advantages
    
    def compute_kl_divergence(self, logits_new: torch.Tensor, logits_ref: torch.Tensor) -> torch.Tensor:
        """Compute KL divergence between new and reference policies"""
        probs_new = F.softmax(logits_new, dim=-1)
        log_probs_new = F.log_softmax(logits_new, dim=-1)
        log_probs_ref = F.log_softmax(logits_ref, dim=-1)
        
        kl_div = F.kl_div(log_probs_new, probs_new, reduction='none').sum(dim=-1)
        return kl_div
    
    def update(self, batch_data: List[Dict[str, Any]]) -> Dict[str, float]:
        """Perform A*PO update step"""
        self.model.train()
        
        # Extract prompts
        prompts = [data['prompt'] for data in batch_data]
        
        # Generate multiple responses for value estimation
        responses = self.generate_responses(prompts)
        
        # Extract rewards and estimate values
        rewards = torch.tensor([r['reward'] for r in responses], device=device)
        optimal_values = self.estimate_optimal_values(responses)
        
        # Compute advantages
        advantages = self.compute_advantages(rewards, optimal_values)
        
        # Prepare training data
        total_loss = 0.0
        kl_loss = 0.0
        value_loss = 0.0
        
        for i, response_data in enumerate(responses):
            # Get full sequence (prompt + response)
            full_text = response_data['prompt'] + ' ' + response_data['response']
            tokens = task_env.tokenize(full_text)
            
            if len(tokens) == 0:
                continue
            
            input_ids = torch.tensor([tokens], device=device)
            
            # Forward pass through current model
            logits_new, values_new = self.model(input_ids, return_value=True)
            
            # Forward pass through reference model
            with torch.no_grad():
                logits_ref, values_ref = self.reference_model(input_ids, return_value=True)
            
            # Compute KL divergence
            kl_div = self.compute_kl_divergence(logits_new, logits_ref)
            
            # Compute losses
            advantage = advantages[i]
            
            # Policy loss (advantage-weighted log probability)
            log_probs = F.log_softmax(logits_new, dim=-1)
            policy_loss = -advantage * log_probs.mean()
            
            # Value loss (MSE between predicted and optimal values)
            value_loss_item = F.mse_loss(values_new.squeeze(), optimal_values[i])
            
            # KL regularization
            kl_loss_item = self.config.beta1 * kl_div.mean()
            
            # Total loss
            loss = policy_loss + value_loss_item + kl_loss_item
            
            total_loss += loss
            kl_loss += kl_loss_item
            value_loss += value_loss_item
        
        # Average losses
        num_samples = len(responses)
        if num_samples > 0:
            total_loss /= num_samples
            kl_loss /= num_samples
            value_loss /= num_samples
        
        # Backward pass
        self.optimizer.zero_grad()
        total_loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
        
        self.optimizer.step()
        
        # Update step count
        self.step_count += 1
        
        # Update reference model periodically
        if self.step_count % 100 == 0:
            self.reference_model.load_state_dict(self.model.state_dict())
        
        # Record losses
        loss_dict = {
            'total_loss': total_loss.item(),
            'kl_loss': kl_loss.item(),
            'value_loss': value_loss.item(),
            'mean_reward': rewards.mean().item(),
            'mean_advantage': advantages.mean().item()
        }
        
        self.loss_history.append(loss_dict)
        
        return loss_dict


# Initialize A*PO trainer
trainer = AstarPO(model, config)
print("A*PO trainer initialized")


## 6. Data Generation and Training


In [None]:
def generate_dataset(num_samples: int, task_env: TaskEnvironment) -> List[Dict[str, Any]]:
    """Generate training dataset"""
    dataset = []
    
    print(f"Generating {num_samples} samples for {task_env.task_type} task...")
    
    for i in range(num_samples):
        if (i + 1) % 100 == 0:
            print(f"Generated {i + 1}/{num_samples} samples")
        
        sample = task_env.generate_sample()
        dataset.append(sample)
    
    print(f"Dataset generation complete: {len(dataset)} samples")
    return dataset


def create_data_loader(dataset: List[Dict[str, Any]], batch_size: int, shuffle: bool = True):
    """Create data loader for training"""
    def collate_fn(batch):
        return batch  # Return as-is for now
    
    # Simple batch creation
    batches = []
    for i in range(0, len(dataset), batch_size):
        batch = dataset[i:i + batch_size]
        batches.append(batch)
    
    if shuffle:
        random.shuffle(batches)
    
    return batches


def evaluate_model(model, eval_batches: List[List[Dict[str, Any]]], task_env: TaskEnvironment) -> Dict[str, float]:
    """Evaluate model performance"""
    model.eval()
    
    total_rewards = 0.0
    total_samples = 0
    correct_predictions = 0
    
    with torch.no_grad():
        for batch in eval_batches:
            for sample in batch:
                prompt = sample['prompt']
                
                # Generate response
                prompt_tokens = task_env.tokenize(prompt)
                input_ids = torch.tensor([prompt_tokens], device=device)
                
                generated_ids = model.generate(
                    input_ids, 
                    max_length=50, 
                    temperature=0.1,  # Low temperature for evaluation
                    do_sample=False   # Greedy decoding
                )
                
                # Extract response
                response_tokens = generated_ids[0, len(prompt_tokens):].tolist()
                response_text = task_env.detokenize(response_tokens)
                
                # Evaluate response
                reward = task_env.evaluate_response(prompt, response_text)
                
                total_rewards += reward
                total_samples += 1
                
                if reward > 0.9:  # Consider correct if reward > 0.9
                    correct_predictions += 1
    
    accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
    avg_reward = total_rewards / total_samples if total_samples > 0 else 0.0
    
    return {
        'accuracy': accuracy,
        'avg_reward': avg_reward,
        'total_samples': total_samples,
        'correct_predictions': correct_predictions
    }


def train_model(trainer: AstarPO, train_batches: List[List[Dict[str, Any]]], 
                eval_batches: List[List[Dict[str, Any]]], config: Config):
    """Main training loop"""
    print(f"Starting training for {config.num_epochs} epochs...")
    print(f"Total training batches: {len(train_batches)}")
    
    start_time = time.time()
    
    for epoch in range(config.num_epochs):
        epoch_start_time = time.time()
        epoch_losses = []
        
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
        
        for batch_idx, batch in enumerate(train_batches):
            # Training step
            loss_dict = trainer.update(batch)
            epoch_losses.append(loss_dict)
            
            # Log progress
            if (batch_idx + 1) % 10 == 0:
                avg_loss = np.mean([l['total_loss'] for l in epoch_losses[-10:]])
                avg_reward = np.mean([l['mean_reward'] for l in epoch_losses[-10:]])
                print(f"  Batch {batch_idx + 1}/{len(train_batches)}: Loss={avg_loss:.4f}, Reward={avg_reward:.4f}")
            
            # Evaluation
            if trainer.step_count % config.eval_interval == 0:
                print(f"\nEvaluating at step {trainer.step_count}...")
                eval_results = evaluate_model(trainer.model, eval_batches, task_env)
                print(f"Evaluation results: {eval_results}")
            
            # Save checkpoint
            if trainer.step_count % config.save_interval == 0:
                checkpoint = {
                    'model_state_dict': trainer.model.state_dict(),
                    'optimizer_state_dict': trainer.optimizer.state_dict(),
                    'step_count': trainer.step_count,
                    'config': config,
                    'loss_history': trainer.loss_history
                }
                torch.save(checkpoint, f'checkpoint_step_{trainer.step_count}.pt')
                print(f"Checkpoint saved at step {trainer.step_count}")
        
        # Epoch summary
        epoch_time = time.time() - epoch_start_time
        avg_epoch_loss = np.mean([l['total_loss'] for l in epoch_losses])
        avg_epoch_reward = np.mean([l['mean_reward'] for l in epoch_losses])
        
        print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
        print(f"Average loss: {avg_epoch_loss:.4f}")
        print(f"Average reward: {avg_epoch_reward:.4f}")
        
        # Final evaluation
        print(f"\nFinal evaluation for epoch {epoch + 1}...")
        eval_results = evaluate_model(trainer.model, eval_batches, task_env)
        print(f"Final evaluation results: {eval_results}")
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f}s")
    
    return trainer.loss_history


# Generate datasets
print("Generating training dataset...")
train_dataset = generate_dataset(config.num_train_samples, task_env)

print("Generating evaluation dataset...")
eval_dataset = generate_dataset(config.num_eval_samples, task_env)

# Create data loaders
train_batches = create_data_loader(train_dataset, config.batch_size, shuffle=True)
eval_batches = create_data_loader(eval_dataset, config.batch_size, shuffle=False)

print(f"Training batches: {len(train_batches)}")
print(f"Evaluation batches: {len(eval_batches)}")
print(f"Batch size: {config.batch_size}")

# Start training
print("Starting training...")
loss_history = train_model(trainer, train_batches, eval_batches, config)


## 7. Results Analysis and Visualization
 

## 8. Final Evaluation and Model Saving


In [None]:
# Final comprehensive evaluation
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

# Evaluate on test set
final_eval_results = evaluate_model(trainer.model, eval_batches, task_env)
print(f"\nFinal Evaluation Results:")
for key, value in final_eval_results.items():
    print(f"  {key}: {value}")

# Test on a few specific examples
print(f"\nTesting on specific examples:")
test_samples = eval_dataset[:5]  # Test on first 5 samples

for i, sample in enumerate(test_samples):
    print(f"\nExample {i + 1}:")
    print(f"Prompt: {sample['prompt']}")
    print(f"Expected: {sample['solution']}")
    
    # Generate response
    prompt_tokens = task_env.tokenize(sample['prompt'])
    input_ids = torch.tensor([prompt_tokens], device=device)
    
    generated_ids = trainer.model.generate(
        input_ids, 
        max_length=50, 
        temperature=0.1,
        do_sample=False
    )
    
    response_tokens = generated_ids[0, len(prompt_tokens):].tolist()
    response_text = task_env.detokenize(response_tokens)
    
    reward = task_env.evaluate_response(sample['prompt'], response_text)
    
    print(f"Generated: {response_text}")
    print(f"Reward: {reward:.4f}")

# Model statistics
print(f"\nModel Statistics:")
print(f"Total parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in trainer.model.parameters() if p.requires_grad):,}")
print(f"Model size: {sum(p.numel() for p in trainer.model.parameters()) * 4 / 1024 / 1024:.2f} MB")

# Training statistics
if loss_history:
    print(f"\nTraining Statistics:")
    print(f"Total training steps: {len(loss_history)}")
    print(f"Total epochs: {config.num_epochs}")
    print(f"Batch size: {config.batch_size}")
    print(f"Learning rate: {config.learning_rate}")
    print(f"A*PO parameters:")
    print(f"  - Number of responses per prompt: {config.num_responses_per_prompt}")
    print(f"  - Beta1 (KL regularization): {config.beta1}")
    print(f"  - Beta2 (Advantage regression): {config.beta2}")
    print(f"  - Temperature: {config.temperature}")

print("\n" + "="*50)
print("TRAINING COMPLETED SUCCESSFULLY")
print("="*50)
