# üéØ Google Tunix Hack - Train a Model to Show Its Work

## Training Gemma 3 1B with GRPO for Chain-of-Thought Reasoning

**Author:** Emrullah Aydogan  
**Competition:** [Google Tunix Hack](https://www.kaggle.com/competitions/google-tunix-hackathon)  
**Objective:** Train Gemma to show step-by-step reasoning on math problems

---

### üìä Approach Summary

**Model:** Gemma 3 1B (32K context, efficient)  
**Algorithm:** GRPO (Group Relative Policy Optimization)  
**Dataset:** GSM8K (8,500 math problems)  
**Reward Function:**
- 50% Correctness - Is the answer right?
- 30% Reasoning Quality - Are steps logical?
- 20% Clarity - Is explanation clear?

---

### ‚öôÔ∏è Kaggle Setup Notes

**Required Settings:**
- ‚úÖ Accelerator: TPU VM v2-8 (or GPU T4)
- ‚úÖ Internet: ON (for downloading model)
- ‚úÖ Persistence: ON (optional, for checkpoints)

**This notebook is STANDALONE** - all code included, no external dependencies!

---
## 1Ô∏è‚É£ Installation & Setup

In [None]:
%%time
# Install dependencies
print("üì¶ Installing packages...")

!pip install -q google-tunix[prod] datasets transformers sentencepiece
!pip install -q jax[tpu] jaxlib flax optax chex
!pip install -q wandb pyyaml tqdm matplotlib seaborn

print("‚úÖ Installation complete!")

In [None]:
# Imports
import os
import re
import json
import random
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

# Check Tunix
try:
    import tunix
    print(f"‚úÖ Tunix version: {tunix.__version__}")
    TUNIX_AVAILABLE = True
except ImportError:
    print("‚ö†Ô∏è Tunix not available - will show placeholder implementation")
    TUNIX_AVAILABLE = False

# Check device
print(f"\nüñ•Ô∏è JAX Backend: {jax.default_backend()}")
print(f"   Devices: {jax.devices()}")
print(f"   Device count: {jax.device_count()}")

# Set random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

print("\n‚úÖ Setup complete!")

---
## 2Ô∏è‚É£ Configuration

In [None]:
# Model & Training Configuration
CONFIG = {
    # Model
    'model_name': 'google/gemma-3-1b',
    'use_lora': True,
    'lora_rank': 16,
    'lora_alpha': 32,
    'lora_dropout': 0.1,
    
    # Training
    'algorithm': 'GRPO',  # Group Relative Policy Optimization
    'num_epochs': 3,
    'batch_size': 8,
    'learning_rate': 1e-5,
    'warmup_steps': 100,
    'max_grad_norm': 1.0,
    'gradient_accumulation_steps': 4,
    
    # RL Parameters
    'num_rollouts': 4,
    'temperature': 0.7,
    'top_p': 0.9,
    'max_new_tokens': 512,
    
    # Reward Weights
    'correctness_weight': 0.5,
    'reasoning_weight': 0.3,
    'clarity_weight': 0.2,
    
    # Data
    'val_ratio': 0.1,
    'max_train_samples': None,  # None = use all
    'max_eval_samples': 500,
    
    # Logging
    'use_wandb': False,  # Set True to enable W&B
    'wandb_project': 'google-tunix-hack',
    'experiment_name': 'gemma3-1b-grpo-gsm8k',
    'log_steps': 50,
    'eval_steps': 250,
    'save_steps': 500,
}

print("‚öôÔ∏è Configuration:")
print(json.dumps(CONFIG, indent=2))

---
## 3Ô∏è‚É£ Data Processing Functions

In [None]:
# Data preprocessing utilities

def extract_answer(answer_text: str) -> str:
    """Extract final numerical answer from GSM8K answer text."""
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    
    # Fallback: find last number
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', answer_text)
    return numbers[-1].replace(',', '') if numbers else ""


def extract_reasoning(answer_text: str) -> str:
    """Extract reasoning steps from GSM8K answer."""
    return re.split(r'####', answer_text)[0].strip()


def format_cot_example(question: str, reasoning: str, answer: str) -> Dict[str, str]:
    """
    Format as chain-of-thought training example.
    
    Returns:
        Dictionary with 'input', 'target', 'question', 'answer'
    """
    # Prompt
    input_text = f"""Question: {question}

Let's solve this step by step:
"""
    
    # Target with formatted steps
    reasoning_lines = [line.strip() for line in reasoning.split('\n') if line.strip()]
    formatted_steps = [f"Step {i}: {line}" for i, line in enumerate(reasoning_lines, 1)]
    
    target_text = "\n".join(formatted_steps) + f"\n\nAnswer: {answer}"
    
    return {
        'input': input_text,
        'target': target_text,
        'question': question,
        'answer': answer
    }


def preprocess_gsm8k(dataset_split, max_samples: Optional[int] = None) -> List[Dict]:
    """Preprocess GSM8K dataset split."""
    processed = []
    
    for example in tqdm(dataset_split, desc="Preprocessing"):
        reasoning = extract_reasoning(example['answer'])
        answer = extract_answer(example['answer'])
        formatted = format_cot_example(example['question'], reasoning, answer)
        processed.append(formatted)
        
        if max_samples and len(processed) >= max_samples:
            break
    
    return processed


print("‚úÖ Data processing functions defined")

---
## 4Ô∏è‚É£ Reward Function (CRITICAL!)

In [None]:
# Reward function for RL training

def extract_model_answer(response: str) -> str:
    """Extract final answer from model response."""
    match = re.search(r'Answer:\s*([^\n]+)', response, re.IGNORECASE)
    if match:
        answer = match.group(1).strip()
        number_match = re.search(r'-?\d+(?:,\d{3})*(?:\.\d+)?', answer)
        if number_match:
            return number_match.group(0).replace(',', '')
    
    # Fallback
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', response)
    return numbers[-1].replace(',', '') if numbers else ""


def extract_reasoning_steps(response: str) -> List[str]:
    """Extract reasoning steps from model response."""
    step_pattern = r'Step \d+:(.+?)(?=Step \d+:|Answer:|$)'
    steps = re.findall(step_pattern, response, re.DOTALL | re.IGNORECASE)
    return [step.strip() for step in steps if step.strip()]


def check_correctness(predicted: str, ground_truth: str, tolerance: float = 1e-4) -> bool:
    """Check if predicted answer is correct."""
    try:
        pred_num = float(predicted.replace(',', ''))
        truth_num = float(ground_truth.replace(',', ''))
        return abs(pred_num - truth_num) < tolerance
    except:
        return predicted.strip().lower() == ground_truth.strip().lower()


def score_reasoning_quality(steps: List[str]) -> float:
    """
    Score reasoning quality (0-1).
    
    Criteria:
    - Number of steps (ideal: 2-8)
    - Step length (ideal: 20-150 chars)
    - Contains calculations
    - Step completeness
    """
    if not steps:
        return 0.0
    
    score = 0.0
    num_steps = len(steps)
    
    # 1. Number of steps (25%)
    if 2 <= num_steps <= 8:
        score += 0.25
    elif num_steps == 1:
        score += 0.1
    elif num_steps > 8:
        score += max(0.25 - 0.02 * (num_steps - 8), 0.1)
    
    # 2. Step length (25%)
    avg_length = sum(len(s) for s in steps) / num_steps
    if 20 <= avg_length <= 150:
        score += 0.25
    elif avg_length < 20:
        score += 0.1
    else:
        score += max(0.25 - 0.001 * (avg_length - 150), 0.1)
    
    # 3. Contains calculations (25%)
    calc_count = sum(1 for s in steps if re.search(r'\d+\s*[+\-*/√ó√∑]\s*\d+', s))
    score += min(0.25, calc_count * 0.1)
    
    # 4. Step completeness (25%)
    complete_steps = sum(1 for s in steps if len(s) > 15 and any(c.isdigit() for c in s))
    score += min(0.25, (complete_steps / max(num_steps, 1)) * 0.25)
    
    return min(score, 1.0)


def score_clarity(response: str, steps: List[str]) -> float:
    """
    Score clarity and formatting (0-1).
    
    Criteria:
    - Has step markers
    - Has answer marker
    - Proper punctuation
    - Not repetitive
    """
    score = 0.0
    
    # Step markers (30%)
    if re.search(r'Step \d+:', response, re.IGNORECASE):
        score += 0.3
    
    # Answer marker (30%)
    if re.search(r'Answer:', response, re.IGNORECASE):
        score += 0.3
    
    # Punctuation (20%)
    if any(char in response for char in '.!?'):
        score += 0.2
    
    # Not repetitive (20%)
    if len(steps) > 0:
        unique_ratio = len(set(steps)) / len(steps)
        score += unique_ratio * 0.2
    
    return min(score, 1.0)


def compute_reward(
    response: str,
    ground_truth: str,
    question: str = "",
    correctness_weight: float = 0.5,
    reasoning_weight: float = 0.3,
    clarity_weight: float = 0.2
) -> Dict[str, float]:
    """
    Compute comprehensive reward for model response.
    
    This is the MAIN reward function used for RL training.
    
    Returns:
        Dictionary with all reward components
    """
    # Extract components
    predicted_answer = extract_model_answer(response)
    steps = extract_reasoning_steps(response)
    
    # Compute scores
    is_correct = check_correctness(predicted_answer, ground_truth)
    correctness_score = 1.0 if is_correct else 0.0
    reasoning_score = score_reasoning_quality(steps)
    clarity_score = score_clarity(response, steps)
    
    # Weighted total
    total_reward = (
        correctness_weight * correctness_score +
        reasoning_weight * reasoning_score +
        clarity_weight * clarity_score
    )
    
    return {
        'total_reward': total_reward,
        'correctness_score': correctness_score,
        'reasoning_score': reasoning_score,
        'clarity_score': clarity_score,
        'is_correct': is_correct,
        'num_steps': len(steps),
        'predicted_answer': predicted_answer,
        'ground_truth': ground_truth
    }


print("‚úÖ Reward function defined")

# Test reward function
test_response = """Step 1: Janet's ducks lay 16 eggs per day
Step 2: She uses 3 + 4 = 7 eggs
Step 3: Remaining: 16 - 7 = 9 eggs
Step 4: Revenue: 9 √ó $2 = $18

Answer: 18"""

reward = compute_reward(test_response, "18")
print(f"\nüß™ Test Reward: {reward['total_reward']:.3f}")
print(f"   ‚îú‚îÄ Correctness: {reward['correctness_score']:.3f}")
print(f"   ‚îú‚îÄ Reasoning: {reward['reasoning_score']:.3f}")
print(f"   ‚îî‚îÄ Clarity: {reward['clarity_score']:.3f}")

---
## 5Ô∏è‚É£ Load & Prepare Dataset

In [None]:
%%time
# Load GSM8K from HuggingFace
print("üì• Loading GSM8K dataset...")
dataset = load_dataset("gsm8k", "main")

print(f"‚úÖ Dataset loaded:")
print(f"   Train: {len(dataset['train'])} samples")
print(f"   Test: {len(dataset['test'])} samples")

# Show example
example = dataset['train'][0]
print(f"\nüìù Raw Example:")
print(f"Q: {example['question'][:100]}...")
print(f"A: {example['answer'][:100]}...")

In [None]:
%%time
# Preprocess data
print("üîÑ Preprocessing dataset...\n")

train_data = preprocess_gsm8k(
    dataset['train'],
    max_samples=CONFIG['max_train_samples']
)

test_data = preprocess_gsm8k(
    dataset['test'],
    max_samples=CONFIG['max_eval_samples']
)

# Create validation split
random.shuffle(train_data)
val_size = int(len(train_data) * CONFIG['val_ratio'])
val_data = train_data[:val_size]
train_data = train_data[val_size:]

print(f"\n‚úÖ Data prepared:")
print(f"   Train: {len(train_data)} examples")
print(f"   Validation: {len(val_data)} examples")
print(f"   Test: {len(test_data)} examples")

# Show formatted example
print(f"\nüìù Formatted Example:")
print("="*70)
print("[INPUT]")
print(train_data[0]['input'])
print("\n[TARGET]")
print(train_data[0]['target'][:200] + "...")
print("="*70)

---
## 6Ô∏è‚É£ Load Model & Tokenizer

In [None]:
%%time
# Load tokenizer
print(f"üì• Loading tokenizer: {CONFIG['model_name']}")

tokenizer = AutoTokenizer.from_pretrained(
    CONFIG['model_name'],
    trust_remote_code=True
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"‚úÖ Tokenizer loaded")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Max length: {tokenizer.model_max_length}")

# Test tokenization
test_text = train_data[0]['input']
tokens = tokenizer(test_text, return_tensors='np')
print(f"\nüß™ Test tokenization:")
print(f"   Input: {len(test_text)} chars")
print(f"   Tokens: {len(tokens['input_ids'][0])}")

---
## 7Ô∏è‚É£ Training with Tunix GRPO

### ‚ö†Ô∏è IMPORTANT: Actual Tunix Implementation Needed

The cells below show the **intended structure** for Tunix training.  
You need to implement actual Tunix GRPO trainer based on [Tunix documentation](https://github.com/google/tunix).

**Tunix Example Notebooks:**
- [GRPO on GSM8K](https://github.com/google/tunix/blob/main/examples/)
- [QLoRA Fine-tuning](https://github.com/google/tunix/blob/main/examples/)

**Key Integration Points:**
1. Use our `compute_reward()` function as Tunix reward function
2. Use our preprocessed `train_data` and `val_data`
3. Configure LoRA with our CONFIG settings
4. Log metrics to W&B (optional)

In [None]:
# Initialize W&B (optional)
if CONFIG['use_wandb']:
    try:
        import wandb
        wandb.init(
            project=CONFIG['wandb_project'],
            name=CONFIG['experiment_name'],
            config=CONFIG
        )
        print("‚úÖ W&B initialized")
    except Exception as e:
        print(f"‚ö†Ô∏è W&B initialization failed: {e}")
else:
    print("‚ÑπÔ∏è W&B disabled")

In [None]:
# Training setup
print("üîß Training Configuration:")
print(f"   Model: {CONFIG['model_name']}")
print(f"   Algorithm: {CONFIG['algorithm']}")
print(f"   Training samples: {len(train_data)}")
print(f"   Batch size: {CONFIG['batch_size']}")
print(f"   Learning rate: {CONFIG['learning_rate']}")
print(f"   Epochs: {CONFIG['num_epochs']}")
print(f"   LoRA: rank={CONFIG['lora_rank']}, alpha={CONFIG['lora_alpha']}")
print(f"\nüìä Reward Weights:")
print(f"   Correctness: {CONFIG['correctness_weight']} (50%)")
print(f"   Reasoning: {CONFIG['reasoning_weight']} (30%)")
print(f"   Clarity: {CONFIG['clarity_weight']} (20%)")

In [None]:
# TUNIX TRAINING IMPLEMENTATION GOES HERE
# 
# Example structure (to be replaced with actual Tunix code):
#
# import tunix
#
# # Load model
# model = tunix.load_model(CONFIG['model_name'])
#
# # Configure LoRA
# lora_config = tunix.LoRAConfig(
#     rank=CONFIG['lora_rank'],
#     alpha=CONFIG['lora_alpha'],
#     dropout=CONFIG['lora_dropout']
# )
#
# # Create GRPO trainer
# trainer = tunix.GRPOTrainer(
#     model=model,
#     tokenizer=tokenizer,
#     train_dataset=train_data,
#     eval_dataset=val_data,
#     reward_function=lambda response, gt: compute_reward(response, gt, **CONFIG)['total_reward'],
#     lora_config=lora_config,
#     learning_rate=CONFIG['learning_rate'],
#     num_epochs=CONFIG['num_epochs'],
# )
#
# # Train
# trainer.train()

if TUNIX_AVAILABLE:
    print("üöÄ Ready for Tunix training!")
    print("‚ö†Ô∏è Implement actual Tunix GRPO trainer above")
else:
    print("‚ö†Ô∏è Tunix not available")
    print("‚ÑπÔ∏è This notebook shows the structure for Tunix integration")
    print("üìñ See: https://github.com/google/tunix for implementation details")

---
## 8Ô∏è‚É£ Evaluation & Results

After training, evaluate the model on test set

In [None]:
# Evaluation function
def evaluate_model(model, tokenizer, test_data, max_samples=100):
    """
    Evaluate model on test set.
    
    Returns metrics and predictions.
    """
    results = []
    rewards = []
    
    print(f"üìä Evaluating on {min(len(test_data), max_samples)} samples...\n")
    
    for i, example in enumerate(tqdm(test_data[:max_samples])):
        # Generate response (placeholder - replace with actual model inference)
        # response = model.generate(example['input'])
        response = "PLACEHOLDER - implement model inference"
        
        # Compute reward
        reward = compute_reward(
            response,
            example['answer'],
            example['question'],
            CONFIG['correctness_weight'],
            CONFIG['reasoning_weight'],
            CONFIG['clarity_weight']
        )
        
        results.append({
            'question': example['question'],
            'response': response,
            'ground_truth': example['answer'],
            **reward
        })
        rewards.append(reward)
    
    # Aggregate metrics
    metrics = {
        'accuracy': np.mean([r['is_correct'] for r in rewards]),
        'avg_reward': np.mean([r['total_reward'] for r in rewards]),
        'avg_reasoning_score': np.mean([r['reasoning_score'] for r in rewards]),
        'avg_clarity_score': np.mean([r['clarity_score'] for r in rewards]),
        'avg_num_steps': np.mean([r['num_steps'] for r in rewards]),
    }
    
    return metrics, results


print("‚úÖ Evaluation function defined")
print("‚ö†Ô∏è Requires trained model to run")

---
## 9Ô∏è‚É£ Visualization

Visualize training and evaluation results

In [None]:
# Visualization functions
def plot_training_curves(history):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['loss'], label='Train Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Reward
    axes[0, 1].plot(history['reward'], label='Average Reward')
    axes[0, 1].set_title('Training Reward')
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Reward')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Accuracy
    axes[1, 0].plot(history['accuracy'], label='Accuracy')
    axes[1, 0].set_title('Validation Accuracy')
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Reasoning score
    axes[1, 1].plot(history['reasoning_score'], label='Reasoning Quality')
    axes[1, 1].set_title('Reasoning Quality Score')
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('Score')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.show()


def plot_evaluation_results(metrics):
    """Plot evaluation metrics."""
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Bar chart of metrics
    metric_names = ['Accuracy', 'Avg Reward', 'Reasoning', 'Clarity']
    metric_values = [
        metrics['accuracy'] * 100,
        metrics['avg_reward'] * 100,
        metrics['avg_reasoning_score'] * 100,
        metrics['avg_clarity_score'] * 100
    ]
    
    axes[0].bar(metric_names, metric_values, color=['#2ecc71', '#3498db', '#f39c12', '#e74c3c'])
    axes[0].set_title('Evaluation Metrics (%)', fontsize=14, fontweight='bold')
    axes[0].set_ylabel('Score (%)')
    axes[0].set_ylim(0, 100)
    axes[0].grid(True, alpha=0.3)
    
    for i, v in enumerate(metric_values):
        axes[0].text(i, v + 2, f'{v:.1f}%', ha='center', fontweight='bold')
    
    # Pie chart of reward components
    reward_components = [
        metrics['avg_reasoning_score'] * CONFIG['reasoning_weight'],
        metrics['accuracy'] * CONFIG['correctness_weight'],
        metrics['avg_clarity_score'] * CONFIG['clarity_weight']
    ]
    
    axes[1].pie(
        reward_components,
        labels=['Reasoning\n(30%)', 'Correctness\n(50%)', 'Clarity\n(20%)'],
        autopct='%1.1f%%',
        colors=['#f39c12', '#2ecc71', '#e74c3c'],
        startangle=90
    )
    axes[1].set_title('Reward Component Contribution', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()


print("‚úÖ Visualization functions defined")

---
## üîü Model Export

Save the trained model for submission

In [None]:
# Model export
OUTPUT_DIR = "./trained_model"

def export_model(model, tokenizer, output_dir=OUTPUT_DIR):
    """
    Export trained model and tokenizer.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"üíæ Exporting model to {output_dir}...")
    
    # Save model (implement based on Tunix API)
    # model.save_pretrained(output_dir)
    
    # Save tokenizer
    tokenizer.save_pretrained(output_dir)
    
    # Save config
    with open(f"{output_dir}/training_config.json", 'w') as f:
        json.dump(CONFIG, f, indent=2)
    
    print(f"‚úÖ Model exported to {output_dir}")


print("‚úÖ Export function defined")
print(f"‚ÑπÔ∏è Model will be saved to: {OUTPUT_DIR}")

---
## ‚úÖ Summary & Next Steps

### What We Built

This notebook provides a **complete pipeline** for training Gemma with chain-of-thought reasoning:

‚úÖ **Data Processing**
- GSM8K dataset loading
- Chain-of-thought formatting
- Train/val/test splits

‚úÖ **Reward Function** (CRITICAL!)
- Multi-criteria evaluation
- 50% correctness + 30% reasoning + 20% clarity
- Comprehensive step analysis

‚úÖ **Training Infrastructure**
- Configuration management
- LoRA settings
- W&B logging support
- Evaluation framework

### What's Next

1. **Implement Tunix GRPO Training** (Section 7)
   - Use Tunix documentation
   - Integrate our reward function
   - Run on Kaggle TPU

2. **Train & Evaluate**
   - Start with small dataset (100 samples)
   - Validate reward function works
   - Scale to full dataset

3. **Submission**
   - Make notebook public
   - Write Kaggle writeup
   - Record YouTube video

---

### üîó Resources

- **Tunix GitHub:** https://github.com/google/tunix
- **Competition:** https://www.kaggle.com/competitions/google-tunix-hackathon
- **Project Repo:** https://github.com/EmrullahAydogan/Google_Tunix_Hack_Project

---

**Good luck! üöÄ**