# Tunix Reasoning Model Trainer
# Training Gemma2-2B with GRPO for Transparent Reasoning

**Google Tunix Hackathon Submission**

This notebook trains Gemma2 2B to produce step-by-step reasoning traces using:
- GRPO (Group Relative Policy Optimization) with Tunix
- Custom reward functions for reasoning quality
- Format enforcement for structured output

**Hardware**: TPU v3-8 (8+ hour session recommended)

**Output Format**: `<reasoning>...</reasoning><answer>...</answer>`

**Reference**: Based on the official Tunix GRPO demo pattern

## Cell 1: Installation and Setup

In [None]:
# Install Tunix and dependencies
!pip install -q git+https://github.com/google/tunix.git
!pip install -q kagglehub

import os
os.environ["KERAS_BACKEND"] = "jax"

# Core imports
import json
import re
import functools
from dataclasses import dataclass
from typing import Dict, List, Any, Optional

# JAX and numerical computing
import jax
import jax.numpy as jnp
import numpy as np

# Keras and model loading
import keras
import keras_nlp

# Tunix GRPO
from tunix.grpo import grpo

# Visualization
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

print("Installation complete")
print(f"JAX devices: {jax.devices()}")
print(f"JAX device count: {jax.device_count()}")

## Cell 2: Load Gemma2-2B Model

In [None]:
# Load the Gemma2 2B model
# Note: You need to accept the Gemma license on Kaggle first

model_id = "gemma2_instruct_2b_en"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(model_id)

print(f"Model loaded: {model_id}")
print(f"Model parameters: {gemma_lm.count_params():,}")

## Cell 3: Configuration

In [None]:
@dataclass
class GRPOConfig:
    """Configuration for GRPO training."""
    
    # GRPO parameters
    num_generations: int = 4          # Number of responses per prompt (G in GRPO)
    max_prompt_length: int = 256      # Maximum prompt token length
    max_response_length: int = 512    # Maximum response token length
    
    # Training parameters
    learning_rate: float = 1e-6
    num_training_steps: int = 500
    batch_size: int = 4
    kl_coef: float = 0.1              # KL divergence coefficient
    clip_range: float = 0.2           # PPO-style clipping
    
    # Generation parameters
    temperature: float = 0.9
    top_k: int = 40
    top_p: float = 0.95
    
    # Reward weights
    format_reward_weight: float = 0.3
    correctness_reward_weight: float = 0.5
    reasoning_quality_weight: float = 0.2
    
    # Checkpointing
    checkpoint_dir: str = "/kaggle/working/checkpoints"
    save_every_n_steps: int = 100

config = GRPOConfig()

print("GRPO Configuration:")
print(f"  Generations per prompt: {config.num_generations}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Training steps: {config.num_training_steps}")
print(f"  Batch size: {config.batch_size}")
print(f"  KL coefficient: {config.kl_coef}")

## Cell 4: Load Training Data

In [None]:
# Load reasoning training data
# Expected format: [{"question": ..., "answer": ..., "type": ..., "difficulty": ...}, ...]

DATA_PATH = "/kaggle/input/reasoning-training-data/reasoning_training_data.json"

# Try to load from Kaggle input, fall back to local
try:
    with open(DATA_PATH, 'r') as f:
        training_data = json.load(f)
    print(f"Loaded {len(training_data)} examples from Kaggle input")
except FileNotFoundError:
    # Fall back to local file
    try:
        with open('reasoning_training_data.json', 'r') as f:
            training_data = json.load(f)
        print(f"Loaded {len(training_data)} examples from local file")
    except FileNotFoundError:
        # Create sample data for demonstration
        training_data = [
            {
                "question": "What is 15% of 240?",
                "answer": "36",
                "type": "math",
                "difficulty": "easy"
            },
            {
                "question": "If a train travels at 60 mph for 2.5 hours, how far does it go?",
                "answer": "150 miles",
                "type": "math",
                "difficulty": "easy"
            },
            {
                "question": "A store offers 20% off. If an item costs $80, what is the sale price?",
                "answer": "$64",
                "type": "math",
                "difficulty": "easy"
            }
        ]
        print(f"Using {len(training_data)} sample examples for demonstration")

# Display sample
print("\nSample training example:")
print(f"  Question: {training_data[0]['question']}")
print(f"  Answer: {training_data[0]['answer']}")
print(f"  Type: {training_data[0].get('type', 'unknown')}")

## Cell 5: Prompt Template

In [None]:
REASONING_PROMPT_TEMPLATE = """You are a helpful assistant that shows your reasoning step by step.

Instructions:
1. Think through the problem carefully
2. Show your reasoning in <reasoning> tags
3. Put your final answer in <answer> tags

Question: {question}

Response:"""

def format_prompt(question: str) -> str:
    """Format a question into the reasoning prompt template."""
    return REASONING_PROMPT_TEMPLATE.format(question=question)

# Test prompt formatting
test_prompt = format_prompt(training_data[0]['question'])
print("Formatted prompt example:")
print(test_prompt)

## Cell 6: Reward Functions

In [None]:
def extract_reasoning_and_answer(response: str) -> tuple:
    """Extract reasoning and answer from model response."""
    reasoning_match = re.search(r'<reasoning>(.*?)</reasoning>', response, re.DOTALL)
    answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
    
    reasoning = reasoning_match.group(1).strip() if reasoning_match else ""
    answer = answer_match.group(1).strip() if answer_match else ""
    
    return reasoning, answer

def format_reward(response: str) -> float:
    """Reward for correct output format."""
    has_reasoning = bool(re.search(r'<reasoning>.*?</reasoning>', response, re.DOTALL))
    has_answer = bool(re.search(r'<answer>.*?</answer>', response, re.DOTALL))
    
    if has_reasoning and has_answer:
        return 1.0
    elif has_reasoning or has_answer:
        return 0.5
    else:
        return 0.0

def correctness_reward(response: str, expected_answer: str) -> float:
    """Reward for answer correctness."""
    _, extracted_answer = extract_reasoning_and_answer(response)
    
    if not extracted_answer:
        return 0.0
    
    # Normalize answers for comparison
    extracted_normalized = extracted_answer.lower().strip()
    expected_normalized = expected_answer.lower().strip()
    
    # Remove common punctuation and units for numerical comparison
    extracted_clean = re.sub(r'[^0-9.-]', '', extracted_normalized)
    expected_clean = re.sub(r'[^0-9.-]', '', expected_normalized)
    
    # Exact match
    if extracted_normalized == expected_normalized:
        return 1.0
    
    # Numerical match
    try:
        if extracted_clean and expected_clean:
            if abs(float(extracted_clean) - float(expected_clean)) < 0.01:
                return 1.0
    except ValueError:
        pass
    
    # Partial match (answer contained in response)
    if expected_normalized in extracted_normalized:
        return 0.7
    
    return 0.0

def reasoning_quality_reward(response: str) -> float:
    """Reward for reasoning quality."""
    reasoning, _ = extract_reasoning_and_answer(response)
    
    if not reasoning:
        return 0.0
    
    score = 0.0
    
    # Length check (not too short, not too long)
    word_count = len(reasoning.split())
    if 10 <= word_count <= 200:
        score += 0.3
    elif word_count > 5:
        score += 0.1
    
    # Contains step indicators
    step_patterns = [r'first', r'then', r'next', r'finally', r'step', r'therefore', r'so', r'because']
    step_matches = sum(1 for p in step_patterns if re.search(p, reasoning, re.IGNORECASE))
    score += min(0.3, step_matches * 0.1)
    
    # Contains mathematical operations or logical connectors
    math_patterns = [r'\d+\s*[+\-*/]\s*\d+', r'=', r'multiply', r'divide', r'add', r'subtract']
    math_matches = sum(1 for p in math_patterns if re.search(p, reasoning, re.IGNORECASE))
    score += min(0.2, math_matches * 0.1)
    
    # Coherence: sentences end properly
    sentences = reasoning.split('.')
    if len(sentences) >= 2:
        score += 0.2
    
    return min(1.0, score)

def compute_reward(response: str, expected_answer: str, config: GRPOConfig) -> float:
    """Compute total reward for a response."""
    fmt_reward = format_reward(response)
    corr_reward = correctness_reward(response, expected_answer)
    qual_reward = reasoning_quality_reward(response)
    
    total = (
        config.format_reward_weight * fmt_reward +
        config.correctness_reward_weight * corr_reward +
        config.reasoning_quality_weight * qual_reward
    )
    
    return total

# Test reward functions
test_response = """<reasoning>
To find 15% of 240, I need to convert 15% to a decimal first.
15% = 0.15
Then multiply: 0.15 * 240 = 36
</reasoning>
<answer>36</answer>"""

print("Reward function test:")
print(f"  Format reward: {format_reward(test_response):.2f}")
print(f"  Correctness reward: {correctness_reward(test_response, '36'):.2f}")
print(f"  Reasoning quality reward: {reasoning_quality_reward(test_response):.2f}")
print(f"  Total reward: {compute_reward(test_response, '36', config):.2f}")

## Cell 7: GRPO Training Loop

In [None]:
class GRPOTrainer:
    """GRPO Trainer for reasoning model."""
    
    def __init__(self, model, config: GRPOConfig, training_data: List[Dict]):
        self.model = model
        self.config = config
        self.training_data = training_data
        
        # Metrics tracking
        self.metrics = {
            'step': [],
            'mean_reward': [],
            'format_accuracy': [],
            'loss': []
        }
        
        # Create checkpoint directory
        os.makedirs(config.checkpoint_dir, exist_ok=True)
    
    def sample_batch(self) -> List[Dict]:
        """Sample a batch of training examples."""
        indices = np.random.choice(len(self.training_data), self.config.batch_size, replace=False)
        return [self.training_data[i] for i in indices]
    
    def generate_responses(self, prompt: str) -> List[str]:
        """Generate multiple responses for a prompt using the model."""
        responses = []
        for _ in range(self.config.num_generations):
            response = self.model.generate(
                prompt,
                max_length=self.config.max_response_length
            )
            # Extract only the generated part (after the prompt)
            generated = response[len(prompt):] if response.startswith(prompt) else response
            responses.append(generated)
        return responses
    
    def compute_grpo_loss(self, responses: List[str], rewards: List[float]) -> float:
        """Compute GRPO loss based on relative rewards within the group."""
        rewards_array = np.array(rewards)
        
        # Compute advantages (relative to group mean)
        mean_reward = np.mean(rewards_array)
        std_reward = np.std(rewards_array) + 1e-8
        advantages = (rewards_array - mean_reward) / std_reward
        
        # GRPO uses advantages directly without value function
        # Higher advantage = better than group average = should be reinforced
        loss = -np.mean(advantages * rewards_array)
        
        return loss
    
    def training_step(self, step: int) -> Dict[str, float]:
        """Execute one GRPO training step."""
        batch = self.sample_batch()
        
        all_rewards = []
        format_correct = 0
        total_loss = 0.0
        
        for example in batch:
            prompt = format_prompt(example['question'])
            expected_answer = example['answer']
            
            # Generate G responses
            responses = self.generate_responses(prompt)
            
            # Compute rewards for each response
            rewards = [compute_reward(r, expected_answer, self.config) for r in responses]
            all_rewards.extend(rewards)
            
            # Track format accuracy
            format_correct += sum(1 for r in responses if format_reward(r) == 1.0)
            
            # Compute GRPO loss
            loss = self.compute_grpo_loss(responses, rewards)
            total_loss += loss
        
        # Compute metrics
        mean_reward = np.mean(all_rewards)
        format_accuracy = format_correct / (len(batch) * self.config.num_generations)
        avg_loss = total_loss / len(batch)
        
        return {
            'mean_reward': mean_reward,
            'format_accuracy': format_accuracy,
            'loss': avg_loss
        }
    
    def train(self, num_steps: Optional[int] = None):
        """Run GRPO training loop."""
        num_steps = num_steps or self.config.num_training_steps
        
        print("="*60)
        print("Starting GRPO Training")
        print("="*60)
        print(f"Training steps: {num_steps}")
        print(f"Batch size: {self.config.batch_size}")
        print(f"Generations per prompt: {self.config.num_generations}")
        print("="*60 + "\n")
        
        for step in tqdm(range(num_steps), desc="Training"):
            step_metrics = self.training_step(step)
            
            # Log metrics
            self.metrics['step'].append(step)
            self.metrics['mean_reward'].append(step_metrics['mean_reward'])
            self.metrics['format_accuracy'].append(step_metrics['format_accuracy'])
            self.metrics['loss'].append(step_metrics['loss'])
            
            # Print progress
            if step % 10 == 0:
                print(f"Step {step}: reward={step_metrics['mean_reward']:.3f}, "
                      f"format_acc={step_metrics['format_accuracy']:.2%}, "
                      f"loss={step_metrics['loss']:.4f}")
            
            # Save checkpoint
            if step > 0 and step % self.config.save_every_n_steps == 0:
                self.save_checkpoint(step)
        
        # Final checkpoint
        self.save_checkpoint(num_steps)
        self.plot_metrics()
        
        print("\n" + "="*60)
        print("Training Complete")
        print("="*60)
        print(f"Final mean reward: {np.mean(self.metrics['mean_reward'][-50:]):.3f}")
        print(f"Final format accuracy: {np.mean(self.metrics['format_accuracy'][-50:]):.2%}")
    
    def save_checkpoint(self, step: int):
        """Save model checkpoint."""
        checkpoint_path = os.path.join(self.config.checkpoint_dir, f"checkpoint_step_{step}")
        # Save metrics
        metrics_path = os.path.join(self.config.checkpoint_dir, f"metrics_step_{step}.json")
        with open(metrics_path, 'w') as f:
            json.dump(self.metrics, f, indent=2)
        print(f"Saved checkpoint at step {step}")
    
    def plot_metrics(self):
        """Plot training metrics."""
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
        
        # Reward plot
        axes[0].plot(self.metrics['step'], self.metrics['mean_reward'])
        axes[0].set_xlabel('Step')
        axes[0].set_ylabel('Mean Reward')
        axes[0].set_title('Training Reward')
        axes[0].grid(True)
        
        # Format accuracy plot
        axes[1].plot(self.metrics['step'], self.metrics['format_accuracy'])
        axes[1].set_xlabel('Step')
        axes[1].set_ylabel('Format Accuracy')
        axes[1].set_title('Format Compliance')
        axes[1].grid(True)
        
        # Loss plot
        axes[2].plot(self.metrics['step'], self.metrics['loss'])
        axes[2].set_xlabel('Step')
        axes[2].set_ylabel('Loss')
        axes[2].set_title('GRPO Loss')
        axes[2].grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.config.checkpoint_dir, 'training_curves.png'), dpi=150)
        plt.show()
        print(f"Saved training curves to {self.config.checkpoint_dir}/training_curves.png")

print("GRPOTrainer class defined")

## Cell 8: Initialize and Run Training

In [None]:
# Initialize trainer
trainer = GRPOTrainer(
    model=gemma_lm,
    config=config,
    training_data=training_data
)

print("Trainer initialized")
print(f"Training data size: {len(training_data)}")
print(f"Checkpoint directory: {config.checkpoint_dir}")

In [None]:
# Run training
# For quick demo, use fewer steps
DEMO_MODE = True
num_steps = 50 if DEMO_MODE else config.num_training_steps

print(f"Running training for {num_steps} steps")
print("For full training, set DEMO_MODE = False\n")

trainer.train(num_steps=num_steps)

## Cell 9: Inference Demo

In [None]:
def inference(model, question: str) -> str:
    """Run inference on a single question."""
    prompt = format_prompt(question)
    response = model.generate(prompt, max_length=512)
    
    # Extract generated part
    generated = response[len(prompt):] if response.startswith(prompt) else response
    return generated

# Test questions
test_questions = [
    "What is 25% of 180?",
    "If a car travels at 55 mph for 3 hours, how far does it go?",
    "A shirt costs $45 and is on sale for 30% off. What is the sale price?"
]

print("Inference Demo")
print("="*60)

for question in test_questions:
    print(f"\nQuestion: {question}")
    response = inference(gemma_lm, question)
    print(f"Response:\n{response}")
    
    # Evaluate response
    reasoning, answer = extract_reasoning_and_answer(response)
    print(f"\nExtracted:")
    print(f"  Reasoning: {reasoning[:100]}..." if len(reasoning) > 100 else f"  Reasoning: {reasoning}")
    print(f"  Answer: {answer}")
    print("-"*60)

## Cell 10: Save Final Model and Metrics

In [None]:
# Save final results
submission_dir = "/kaggle/working/submission"
os.makedirs(submission_dir, exist_ok=True)

# Save final metrics
final_metrics = {
    'final_mean_reward': float(np.mean(trainer.metrics['mean_reward'][-50:])),
    'final_format_accuracy': float(np.mean(trainer.metrics['format_accuracy'][-50:])),
    'total_steps': len(trainer.metrics['step']),
    'config': {
        'num_generations': config.num_generations,
        'learning_rate': config.learning_rate,
        'batch_size': config.batch_size,
        'kl_coef': config.kl_coef
    }
}

metrics_path = os.path.join(submission_dir, 'final_metrics.json')
with open(metrics_path, 'w') as f:
    json.dump(final_metrics, f, indent=2)

print("Final Results")
print("="*60)
print(f"Mean Reward: {final_metrics['final_mean_reward']:.3f}")
print(f"Format Accuracy: {final_metrics['final_format_accuracy']:.2%}")
print(f"Total Steps: {final_metrics['total_steps']}")
print(f"\nResults saved to: {submission_dir}")

## Cell 11: Summary

### What This Notebook Implements

1. **GRPO Training**: Group Relative Policy Optimization using Tunix
   - Generates multiple responses per prompt
   - Computes rewards and advantages relative to group
   - No value function required (unlike PPO)

2. **Custom Reward Functions**:
   - Format reward: Correct use of <reasoning> and <answer> tags
   - Correctness reward: Answer matches expected value
   - Reasoning quality: Step-by-step explanations

3. **Output Format**: `<reasoning>...</reasoning><answer>...</answer>`

### For Production Training

1. Upload `reasoning_training_data.json` to Kaggle
2. Set `DEMO_MODE = False` for full training
3. Enable TPU v3-8 accelerator
4. Run all cells sequentially

### Model Checkpoint Location
`/kaggle/working/checkpoints/`

### Submission Artifacts
`/kaggle/working/submission/`