# Google Tunix Hack - Train a Model to Show Its Work

## Training Gemma 3 1B with GRPO for Chain-of-Thought Reasoning

**Author:** Emrullah Aydogan  
**Competition:** [Google Tunix Hack](https://www.kaggle.com/competitions/google-tunix-hackathon)  
**Goal:** Train Gemma to show step-by-step reasoning on math problems

---

### üìã Table of Contents

1. [Setup & Installation](#1-setup)
2. [Data Loading & Preprocessing](#2-data)
3. [Model Configuration](#3-model)
4. [Reward Function](#4-reward)
5. [Training with Tunix GRPO](#5-training)
6. [Evaluation](#6-evaluation)
7. [Results & Visualization](#7-results)
8. [Model Export](#8-export)

---
## 1. Setup & Installation <a name="1-setup"></a>

Install required packages and setup environment

In [None]:
# Install core dependencies
!pip install -q google-tunix[prod] datasets transformers sentencepiece
!pip install -q jax jaxlib flax optax
!pip install -q wandb rich pyyaml matplotlib seaborn

print("‚úÖ Dependencies installed")

In [None]:
# Imports
import os
import sys
import json
from pathlib import Path
from typing import Dict, List

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns

# Tunix imports (when available)
try:
    import tunix
    print(f"‚úÖ Tunix version: {tunix.__version__}")
except ImportError:
    print("‚ö†Ô∏è Tunix not found - will use placeholder implementation")

# Check TPU
print(f"\nüñ•Ô∏è Available devices: {jax.devices()}")
print(f"   Device count: {jax.device_count()}")
print(f"   Platform: {jax.default_backend()}")

---
## 2. Data Loading & Preprocessing <a name="2-data"></a>

Load GSM8K dataset and prepare for chain-of-thought training

In [None]:
# Load GSM8K dataset
print("üì• Loading GSM8K dataset...")
dataset = load_dataset("gsm8k", "main")

print(f"\n‚úÖ Dataset loaded:")
print(f"   Train: {len(dataset['train'])} samples")
print(f"   Test: {len(dataset['test'])} samples")

# Show example
example = dataset['train'][0]
print(f"\nüìù Example:")
print(f"Question: {example['question'][:150]}...")
print(f"Answer: {example['answer'][:150]}...")

In [None]:
# Helper functions for data preprocessing

import re

def extract_answer(answer_text: str) -> str:
    """Extract final numerical answer from GSM8K format"""
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', answer_text)
    if match:
        return match.group(1).replace(',', '')
    return ""

def extract_reasoning(answer_text: str) -> str:
    """Extract reasoning steps"""
    return re.split(r'####', answer_text)[0].strip()

def format_cot_example(question: str, reasoning: str, answer: str) -> Dict[str, str]:
    """Format as chain-of-thought example"""
    # Input prompt
    input_text = f"""Question: {question}

Let's solve this step by step:
"""
    
    # Target with reasoning
    # Split reasoning into steps
    reasoning_steps = reasoning.split('\n')
    formatted_steps = []
    for i, step in enumerate(reasoning_steps, 1):
        if step.strip():
            formatted_steps.append(f"Step {i}: {step.strip()}")
    
    target_text = "\n".join(formatted_steps) + f"\n\nAnswer: {answer}"
    
    return {
        'input': input_text,
        'target': target_text,
        'question': question,
        'answer': answer
    }

# Test preprocessing
test_example = dataset['train'][0]
reasoning = extract_reasoning(test_example['answer'])
answer = extract_answer(test_example['answer'])
formatted = format_cot_example(test_example['question'], reasoning, answer)

print("üìù Formatted Example:")
print("\n[INPUT]")
print(formatted['input'])
print("\n[TARGET]")
print(formatted['target'])

In [None]:
# Preprocess all data
print("üîÑ Preprocessing dataset...")

def preprocess_dataset(dataset_split):
    processed = []
    for example in dataset_split:
        reasoning = extract_reasoning(example['answer'])
        answer = extract_answer(example['answer'])
        formatted = format_cot_example(example['question'], reasoning, answer)
        processed.append(formatted)
    return processed

# Process splits
train_data = preprocess_dataset(dataset['train'])
test_data = preprocess_dataset(dataset['test'])

# Create validation split (10%)
import random
random.seed(42)
random.shuffle(train_data)
val_size = int(len(train_data) * 0.1)
val_data = train_data[:val_size]
train_data = train_data[val_size:]

print(f"\n‚úÖ Preprocessed:")
print(f"   Train: {len(train_data)} examples")
print(f"   Validation: {len(val_data)} examples")
print(f"   Test: {len(test_data)} examples")

---
## 3. Model Configuration <a name="3-model"></a>

Load Gemma 3 1B model and configure for training

In [None]:
# Model configuration
MODEL_NAME = "google/gemma-3-1b"
ALGORITHM = "GRPO"  # Group Relative Policy Optimization

config = {
    'model': {
        'name': MODEL_NAME,
        'use_flash_attention': True,
    },
    'training': {
        'algorithm': ALGORITHM,
        'num_epochs': 3,
        'batch_size': 8,
        'learning_rate': 1e-5,
        'warmup_steps': 100,
        'use_lora': True,
        'lora_rank': 16,
        'lora_alpha': 32,
    },
    'rl': {
        'num_rollouts': 4,
        'temperature': 0.7,
        'max_new_tokens': 512,
    },
    'reward': {
        'correctness_weight': 0.5,
        'reasoning_weight': 0.3,
        'clarity_weight': 0.2,
    }
}

print("‚öôÔ∏è Configuration:")
print(json.dumps(config, indent=2))

In [None]:
# Load tokenizer
print(f"üì• Loading tokenizer: {MODEL_NAME}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"‚úÖ Tokenizer loaded")
print(f"   Vocab size: {len(tokenizer)}")
print(f"   Max length: {tokenizer.model_max_length}")

# Test tokenization
test_text = train_data[0]['input'][:100]
tokens = tokenizer(test_text, return_tensors='jax')
print(f"\nüìù Test tokenization:")
print(f"   Input length: {len(test_text)} chars")
print(f"   Token count: {len(tokens['input_ids'][0])}")

---
## 4. Reward Function <a name="4-reward"></a>

Define reward function for evaluating reasoning quality

In [None]:
# Reward function implementation

def extract_model_answer(response: str) -> str:
    """Extract final answer from model response"""
    match = re.search(r'Answer:\s*([^\n]+)', response, re.IGNORECASE)
    if match:
        answer = match.group(1).strip()
        number_match = re.search(r'-?\d+(?:,\d{3})*(?:\.\d+)?', answer)
        if number_match:
            return number_match.group(0).replace(',', '')
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', response)
    if numbers:
        return numbers[-1].replace(',', '')
    return ""

def extract_steps(response: str) -> List[str]:
    """Extract reasoning steps from response"""
    step_pattern = r'Step \d+:(.+?)(?=Step \d+:|Answer:|$)'
    steps = re.findall(step_pattern, response, re.DOTALL | re.IGNORECASE)
    return [step.strip() for step in steps if step.strip()]

def check_correctness(predicted: str, ground_truth: str) -> bool:
    """Check if answer is correct"""
    try:
        pred_num = float(predicted.replace(',', ''))
        truth_num = float(ground_truth.replace(',', ''))
        return abs(pred_num - truth_num) < 1e-4
    except:
        return predicted.strip().lower() == ground_truth.strip().lower()

def score_reasoning_quality(steps: List[str]) -> float:
    """Score quality of reasoning steps (0-1)"""
    if not steps:
        return 0.0
    
    score = 0.0
    num_steps = len(steps)
    
    # Number of steps (ideal: 2-8)
    if 2 <= num_steps <= 8:
        score += 0.25
    elif num_steps == 1:
        score += 0.1
    
    # Average step length (ideal: 20-150 chars)
    avg_length = sum(len(s) for s in steps) / num_steps
    if 20 <= avg_length <= 150:
        score += 0.25
    
    # Contains calculations
    calc_count = sum(1 for s in steps if re.search(r'\d+\s*[+\-*/√ó√∑]\s*\d+', s))
    if calc_count > 0:
        score += min(0.25, calc_count * 0.1)
    
    # Step completeness
    complete_steps = sum(1 for s in steps if len(s) > 15 and any(c.isdigit() for c in s))
    score += min(0.25, (complete_steps / max(num_steps, 1)) * 0.25)
    
    return min(score, 1.0)

def score_clarity(response: str, steps: List[str]) -> float:
    """Score clarity of response (0-1)"""
    score = 0.0
    
    # Has step markers
    if re.search(r'Step \d+:', response, re.IGNORECASE):
        score += 0.3
    
    # Has answer marker
    if re.search(r'Answer:', response, re.IGNORECASE):
        score += 0.3
    
    # Has punctuation
    if any(char in response for char in '.!?'):
        score += 0.2
    
    # Not repetitive
    if len(steps) > 0:
        unique_steps = len(set(steps))
        score += min(0.2, (unique_steps / len(steps)) * 0.2)
    
    return min(score, 1.0)

def compute_reward(response: str, ground_truth: str, question: str = "") -> Dict:
    """Compute comprehensive reward for response"""
    predicted_answer = extract_model_answer(response)
    steps = extract_steps(response)
    
    # Component scores
    is_correct = check_correctness(predicted_answer, ground_truth)
    correctness_score = 1.0 if is_correct else 0.0
    reasoning_score = score_reasoning_quality(steps)
    clarity_score = score_clarity(response, steps)
    
    # Weighted total
    total_reward = (
        config['reward']['correctness_weight'] * correctness_score +
        config['reward']['reasoning_weight'] * reasoning_score +
        config['reward']['clarity_weight'] * clarity_score
    )
    
    return {
        'total_reward': total_reward,
        'correctness': correctness_score,
        'reasoning': reasoning_score,
        'clarity': clarity_score,
        'is_correct': is_correct,
        'num_steps': len(steps),
        'predicted': predicted_answer,
        'ground_truth': ground_truth
    }

# Test reward function
test_response = """Step 1: Janet's ducks lay 16 eggs per day
Step 2: She uses 3 eggs for breakfast
Step 3: She uses 4 eggs for muffins
Step 4: Total used: 3 + 4 = 7 eggs
Step 5: Remaining: 16 - 7 = 9 eggs
Step 6: Revenue: 9 √ó $2 = $18

Answer: 18"""

reward = compute_reward(test_response, "18")
print("\nüéØ Reward Function Test:")
print(f"  Total Reward: {reward['total_reward']:.3f}")
print(f"  ‚îú‚îÄ Correctness: {reward['correctness']:.3f}")
print(f"  ‚îú‚îÄ Reasoning: {reward['reasoning']:.3f}")
print(f"  ‚îî‚îÄ Clarity: {reward['clarity']:.3f}")
print(f"  Is Correct: {reward['is_correct']}")
print(f"  Num Steps: {reward['num_steps']}")

---
## 5. Training with Tunix GRPO <a name="5-training"></a>

**Note:** This section requires actual Tunix implementation.  
The code below is a placeholder showing the intended structure.

In [None]:
# Training setup
print("üîß Setting up training...")
print(f"   Model: {MODEL_NAME}")
print(f"   Algorithm: {ALGORITHM}")
print(f"   Training samples: {len(train_data)}")
print(f"   Validation samples: {len(val_data)}")

# Initialize W&B (optional)
try:
    import wandb
    wandb.init(
        project="google-tunix-hack",
        name="gemma3-1b-grpo",
        config=config
    )
    print("‚úÖ W&B initialized")
except:
    print("‚ö†Ô∏è W&B not available")

# TODO: Actual Tunix training implementation
# This will use the Tunix library's GRPO trainer
print("\n‚ö†Ô∏è Tunix training implementation goes here")
print("See Tunix documentation for GRPO trainer setup")

---
## 6. Evaluation <a name="6-evaluation"></a>

Evaluate trained model on test set

In [None]:
# Evaluation placeholder
print("üìä Evaluating model...")
print("\n‚ö†Ô∏è Evaluation implementation goes here")
print("Will evaluate on test set and compute:")
print("  - Accuracy")
print("  - Reasoning quality")
print("  - Clarity score")
print("  - Average number of steps")

---
## 7. Results & Visualization <a name="7-results"></a>

Visualize training results and model performance

In [None]:
# Results visualization placeholder
print("üìà Visualization placeholder")
print("Will show:")
print("  - Training curves")
print("  - Reward progression")
print("  - Example predictions")
print("  - Accuracy by problem type")

---
## 8. Model Export <a name="8-export"></a>

Save trained model for submission

In [None]:
# Model export
output_dir = "./trained_model"
print(f"üíæ Exporting model to {output_dir}")
print("\n‚ö†Ô∏è Model export implementation goes here")

---

## ‚úÖ Summary

This notebook demonstrates training Gemma 3 1B with Tunix GRPO for chain-of-thought reasoning.

**Key Components:**
- ‚úÖ Data preprocessing for GSM8K
- ‚úÖ Multi-criteria reward function
- ‚úÖ GRPO training configuration
- ‚ö†Ô∏è Tunix training (to be implemented)
- ‚ö†Ô∏è Evaluation framework (to be implemented)

**Next Steps:**
1. Implement actual Tunix GRPO training
2. Run full training on TPU
3. Evaluate on test set
4. Create visualizations
5. Export final model
6. Write Kaggle writeup
7. Record YouTube video

---

**Repository:** [GitHub](https://github.com/EmrullahAydogan/Google_Tunix_Hack_Project)  
**Competition:** [Google Tunix Hack](https://www.kaggle.com/competitions/google-tunix-hackathon)