# Gemma3-1B Reasoning Model Demo

## Fine-tuned with GRPO using Google's Tunix Library

This notebook demonstrates the capabilities of a Gemma3-1B model fine-tuned using **Group Relative Policy Optimization (GRPO)** for improved step-by-step reasoning.

### What is GRPO?
GRPO (Group Relative Policy Optimization) is a critic-free reinforcement learning algorithm that:
- Compares multiple model generations to find the best response
- Uses relative rankings instead of absolute reward values
- Achieves ~12% improvement on math reasoning benchmarks (GSM8K)

### Training Approach
- **Base Model**: Gemma3-1B-IT (instruction-tuned)
- **Training Library**: [Google Tunix](https://github.com/google/tunix)
- **Reward Function**: Rubric-as-Reward (RaR) + Format Compliance
- **Fine-tuning Method**: LoRA (Low-Rank Adaptation)

## 1. Setup and Installation

In [None]:
# Install dependencies (uncomment if needed)
# !pip install torch transformers peft accelerate safetensors bitsandbytes

In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
project_root = Path(os.getcwd()).parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Import project modules
from src.model import GemmaModel, get_device
from src.config import format_prompt, get_system_prompt, SYSTEM_PROMPTS
from src.utils import extract_reasoning_and_answer, detect_question_type

print("Imports successful!")

In [None]:
# Detect the best available device
device = get_device("auto")
print(f"Using device: {device}")

# Check VRAM/RAM availability
import torch
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif device == "mps":
    print("Using Apple Silicon MPS")

## 2. Load the Model

We'll load the Gemma3-1B model. You can optionally load fine-tuned LoRA weights.

In [None]:
# Configuration
CHECKPOINT_PATH = None  # Set to "../checkpoints/lora" if you have fine-tuned weights
USE_4BIT = False  # Set True for 4-bit quantization (CUDA only, reduces memory)

# Load the model
print("Loading Gemma3-1B model...")
model = GemmaModel(
    checkpoint_path=CHECKPOINT_PATH,
    device=device,
    load_in_4bit=USE_4BIT and device == "cuda",
)
model.load()
print("Model loaded successfully!")

## 3. Helper Functions

In [None]:
from IPython.display import display, Markdown, HTML

def display_result(question: str, result: dict, category: str = None):
    """Display a result with nice formatting."""
    md = f"### Question\n{question}\n\n"
    
    if category:
        md += f"**Category**: {category}\n\n"
    
    md += "### Reasoning\n"
    reasoning = result.get('reasoning', '')
    if reasoning:
        md += f"{reasoning}\n\n"
    else:
        md += "*No reasoning section found*\n\n"
    
    md += "### Answer\n"
    answer = result.get('answer', '')
    if answer:
        md += f"**{answer}**\n"
    else:
        md += "*No answer section found*\n"
    
    display(Markdown(md))
    display(HTML("<hr>"))

def solve_problem(question: str, temperature: float = 0.7, category: str = None):
    """Solve a problem and display the result."""
    result = model.solve(
        question,
        temperature=temperature,
        top_k=50,
        top_p=0.95,
    )
    display_result(question, result, category)
    return result

## 4. Math Reasoning Examples

These examples demonstrate the model's ability to solve math word problems step-by-step, similar to GSM8K benchmark problems.

In [None]:
solve_problem(
    "A store sells apples for $2 each and oranges for $3 each. "
    "If Sarah buys 4 apples and 5 oranges, how much does she spend in total?",
    category="Math - Basic Arithmetic"
)

In [None]:
solve_problem(
    "A train travels at 60 miles per hour. How far will it travel in 2.5 hours?",
    category="Math - Distance/Speed"
)

In [None]:
solve_problem(
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning "
    "and bakes muffins for her friends every day with four. She sells the remainder "
    "at the farmers' market daily for $2 per fresh duck egg. How much in dollars "
    "does she make every day at the farmers' market?",
    category="Math - Multi-step Word Problem"
)

In [None]:
solve_problem(
    "A bookstore has 120 books. They sell 35% of them in the first week and "
    "25% of the remaining books in the second week. How many books are left?",
    category="Math - Percentages"
)

## 5. Logic and Deductive Reasoning

In [None]:
solve_problem(
    "If all cats have tails, and Whiskers is a cat, what can we conclude about Whiskers?",
    category="Logic - Syllogism"
)

In [None]:
solve_problem(
    "There are three boxes: one contains only apples, one contains only oranges, "
    "and one contains both apples and oranges. The boxes are labeled, but all labels "
    "are wrong. If you can only pick one fruit from one box to determine the contents "
    "of all boxes, which box should you pick from and why?",
    category="Logic - Puzzle"
)

In [None]:
solve_problem(
    "In a room, there are 5 people. Each person shakes hands with every other person "
    "exactly once. How many handshakes occur in total?",
    category="Logic - Combinatorics"
)

## 6. Science Explanations

In [None]:
solve_problem(
    "Why does ice float on water instead of sinking?",
    category="Science - Physics"
)

In [None]:
solve_problem(
    "What causes the sky to appear blue during the day?",
    category="Science - Optics"
)

In [None]:
solve_problem(
    "Explain how photosynthesis works in simple terms.",
    category="Science - Biology"
)

## 7. Domain-Specific Applications

These examples showcase how the model can be used as a domain expert assistant, similar to Tunix's mobile deployment demos.

In [None]:
# Medical Assistant Example
solve_problem(
    "A patient presents with a fever of 101Â°F, sore throat, and swollen lymph nodes. "
    "What are the possible conditions to consider and what initial tests might be helpful?",
    category="Medical - Differential Diagnosis"
)

In [None]:
# Legal Assistant Example
solve_problem(
    "A tenant has not paid rent for 3 months. What are the general steps a landlord "
    "should follow before proceeding with an eviction?",
    category="Legal - Landlord-Tenant"
)

In [None]:
# Coding Assistant Example
solve_problem(
    "I have a Python list of numbers and I want to find all pairs that sum to a target value. "
    "What's an efficient approach to solve this problem?",
    category="Coding - Algorithm Design"
)

In [None]:
# Financial Analysis Example
solve_problem(
    "A company has revenue of $1M, COGS of $400K, operating expenses of $300K, "
    "and pays 25% in taxes. Calculate the net profit margin.",
    category="Finance - Profitability Analysis"
)

## 8. Creative Reasoning

In [None]:
solve_problem(
    "Imagine a world where plants could communicate with humans. "
    "How might this change agriculture?",
    category="Creative - Hypothetical Scenario"
)

In [None]:
solve_problem(
    "If you could redesign the education system from scratch, "
    "what would be the key principles you would incorporate?",
    category="Creative - System Design"
)

## 9. Generation Strategy Comparison

Compare different generation strategies to see how they affect output quality.

In [None]:
test_question = "What is 15% of 80?"

strategies = {
    "Greedy (temp=0.01)": {"temperature": 0.01, "top_k": 1, "top_p": 1.0},
    "Standard (temp=0.7)": {"temperature": 0.7, "top_k": 50, "top_p": 0.95},
    "Creative (temp=0.9)": {"temperature": 0.9, "top_k": 100, "top_p": 0.95},
}

print(f"Question: {test_question}\n")
print("=" * 60)

for name, params in strategies.items():
    print(f"\n### {name}")
    result = model.solve(test_question, **params)
    print(f"Reasoning: {result.get('reasoning', 'N/A')[:200]}...")
    print(f"Answer: {result.get('answer', 'N/A')}")
    print("-" * 60)

## 10. Batch Processing

Process multiple questions efficiently.

In [None]:
batch_questions = [
    "What is 7 + 8?",
    "What is the capital of France?",
    "How many legs does a spider have?",
    "What is 25% of 200?",
]

print("Processing batch of questions...\n")

for i, q in enumerate(batch_questions, 1):
    print(f"[{i}/{len(batch_questions)}] {q}")
    result = model.solve(q, temperature=0.7)
    print(f"   Answer: {result.get('answer', 'N/A')}\n")

## 11. Interactive Mode

Enter your own questions!

In [None]:
# Enter your question here
YOUR_QUESTION = "How many prime numbers are there between 1 and 20?"

if YOUR_QUESTION:
    solve_problem(YOUR_QUESTION, category="Custom Question")

## 12. Understanding GRPO Training

Here's a conceptual overview of how GRPO (Group Relative Policy Optimization) works:

In [None]:
from IPython.display import Markdown

grpo_explanation = """
### GRPO Algorithm Overview

**Group Relative Policy Optimization** is a critic-free RL algorithm:

1. **Generate Groups**: For each prompt, generate N different responses
   ```
   responses = [model.generate(prompt) for _ in range(N)]
   ```

2. **Score Responses**: Use reward functions to score each response
   ```
   rewards = [
       format_reward(r) +      # Check <reasoning> and <answer> tags
       rubric_reward(r) +      # Rubric overlap scoring  
       accuracy_reward(r)      # Correctness (if verifiable)
   for r in responses]
   ```

3. **Compute Advantages**: Use relative ranking within the group
   ```
   advantages = (rewards - mean(rewards)) / std(rewards)
   ```

4. **Policy Update**: Update model to increase probability of better responses
   ```
   loss = -log_prob(best_responses) * advantages
   ```

### Key Benefits:
- **No Critic Network**: Simpler than PPO, uses group comparisons
- **Sample Efficient**: Learns from relative rankings
- **Stable Training**: Avoids reward hacking through normalization

### Training Configuration (from config.py):
| Parameter | Value | Description |
|-----------|-------|-------------|
| LoRA Rank | 64 | Low-rank adaptation dimension |
| Temperature | 0.9 | Sampling temperature during training |
| Beta (KL) | 0.08 | KL divergence penalty |
| Learning Rate | 3e-6 | Optimizer learning rate |
| Num Generations | 2 | Responses per prompt |
"""

display(Markdown(grpo_explanation))

## 13. Model Evaluation (GSM8K Sample)

Evaluate the model on a sample of GSM8K-style problems.

In [None]:
# Sample evaluation problems with ground truth
eval_problems = [
    {"question": "If you have 3 apples and buy 5 more, how many apples do you have?", "answer": "8"},
    {"question": "A book costs $15. How much do 4 books cost?", "answer": "60"},
    {"question": "If a train travels 120 miles in 2 hours, what is its speed in miles per hour?", "answer": "60"},
    {"question": "A pizza is cut into 8 slices. If you eat 3 slices, what fraction is left?", "answer": "5/8"},
    {"question": "What is 20% of 50?", "answer": "10"},
]

correct = 0
print("Evaluating on sample problems...\n")

for i, prob in enumerate(eval_problems, 1):
    result = model.solve(prob["question"], temperature=0.3)
    model_answer = result.get("answer", "").strip()
    expected = prob["answer"]
    
    # Simple check (could be more sophisticated)
    is_correct = expected in model_answer or model_answer in expected
    if is_correct:
        correct += 1
    
    status = "" if is_correct else ""
    print(f"[{i}] {prob['question']}")
    print(f"    Expected: {expected} | Got: {model_answer} {status}\n")

accuracy = correct / len(eval_problems) * 100
print(f"\nAccuracy: {correct}/{len(eval_problems)} ({accuracy:.1f}%)")

## 14. System Prompt Variations

The model supports different system prompts that affect reasoning style.

In [None]:
print("Available System Prompts:\n")

for version, prompt in SYSTEM_PROMPTS.items():
    print(f"Version {version}:")
    print(f"  {prompt[:100]}...")
    print()

In [None]:
# Compare different system prompts
test_q = "What is 7 times 8?"

for version in [0, 2, 6]:
    print(f"\n=== System Prompt Version {version} ===")
    result = model.solve(test_q, system_prompt_version=version, temperature=0.5)
    print(f"Answer: {result.get('answer', 'N/A')}")

## Summary

This notebook demonstrated:

1. **Model Loading**: Loading Gemma3-1B with optional LoRA weights
2. **Math Reasoning**: Solving arithmetic and word problems
3. **Logic Problems**: Deductive reasoning and puzzles
4. **Science Explanations**: Physics, biology, and natural phenomena
5. **Domain Applications**: Medical, legal, coding, and finance assistants
6. **Creative Reasoning**: Hypothetical scenarios and system design
7. **Generation Strategies**: Comparing different temperature/sampling settings
8. **GRPO Training**: Understanding the training methodology
9. **Evaluation**: Measuring accuracy on sample problems

### Next Steps:
- Fine-tune with your own dataset using Tunix on TPU/Colab
- Export model for mobile deployment (Cactus format)
- Extend to domain-specific applications

### Resources:
- [Tunix Documentation](https://tunix.readthedocs.io)
- [Gemma Model Card](https://ai.google.dev/gemma)
- [GRPO Paper](https://arxiv.org/abs/2402.03300)