# Gemma3-1B GRPO Training Notebook

This notebook trains Gemma3-1B with GRPO (Group Relative Policy Optimization) for improved reasoning.

**Requirements:**
- Google Colab with TPU runtime (recommended) or GPU
- HuggingFace account with Gemma access

**Output:**
- LoRA checkpoint files that can be downloaded and used locally

---

## Architecture

This notebook is a **light orchestration layer** that delegates heavy logic to reusable Python modules:

- **Training logic**: `src/training/colab_pipeline.py`
- **Dataset utilities**: Uses TunRex library
- **Model utilities**: Uses Tunix library

The notebook only contains:
1. Dependency installation
2. Configuration setup
3. Simple function calls to pipeline module
4. Output display

**Benefits:**
- Easy to maintain and debug
- Reusable across projects
- Version controllable
- Testable in isolation

---

## Editing the Pipeline Module

To modify training logic, dataset preparation, or other components:

1. **Edit locally**: Modify `src/training/colab_pipeline.py` in your local repository
2. **Test**: Run tests or try the changes locally
3. **Commit**: `git add src/training/colab_pipeline.py && git commit -m "Update pipeline"`
4. **Push**: `git push origin your-branch`
5. **Sync in Colab**: In this notebook, run:
   ```bash
   !cd /content/ee596-fp && git pull origin your-branch
   ```
6. **Reload**: The `%autoreload 2` magic will automatically reload the module

**No need to edit this notebook** - all training logic lives in the Python module!

---

## 1. Install Dependencies

In [None]:
# Install dependencies
!pip install -q kagglehub
!pip install -q datasets
!pip install -q wandb
!pip install -q "numpy<2.0"
!pip install git+https://github.com/google/tunix.git
# Force fresh install of TunRex (pip caches aggressively)
!pip uninstall -y tunrex 2>/dev/null || true
!pip install --no-cache-dir git+https://github.com/42euge/TunRex.git@feature/models-api
!pip uninstall -q -y flax
!pip install flax==0.12.0
!pip install -q 'transformers<=4.57.1'

print("\n" + "="*60)
print("Installation complete!")
print("="*60)

## 2. Clone Repository (if in Colab)

In [None]:
# Clone the repository to access the pipeline module
import os

if not os.path.exists('/content/ee596-fp'):
    !git clone https://github.com/42euge/ee596-fp.git /content/ee596-fp
    print("Repository cloned!")
else:
    print("Repository already exists. Pulling latest changes...")
    !cd /content/ee596-fp && git pull

# Add to Python path
import sys
sys.path.insert(0, '/content/ee596-fp')

## 3. Setup Auto-reload

This ensures the pipeline module is automatically reloaded when you make changes.

In [None]:
# Enable auto-reload for development
%load_ext autoreload
%autoreload 2

print("Auto-reload enabled - pipeline module will reload automatically on changes")

## 4. Import Pipeline Module

In [None]:
# Import the training pipeline
from src.training.colab_pipeline import (
    ColabTrainingConfig,
    prepare_colab_session,
    train_grpo,
    export_checkpoint,
    quick_test,
)

print("Pipeline module imported successfully!")

## 5. Configure Training

Edit these values to customize your training run.

In [None]:
# =============================================================================
# CONFIGURATION - Edit these values
# =============================================================================

config = ColabTrainingConfig(
    # Training settings
    num_batches=500,              # Number of training batches (500 = ~30 min on TPU)
    learning_rate=3e-6,           # Learning rate
    lora_rank=64,                 # LoRA rank
    lora_alpha=64.0,              # LoRA alpha
    
    # Dataset settings
    use_openrubrics=True,         # Use OpenRubrics dataset
    openrubrics_max=2000,         # Max examples from OpenRubrics
    
    # Checkpoint settings
    save_to_drive=False,          # Save checkpoints to Google Drive
    experiment_name="gemma3_grpo_reasoning",
    
    # GRPO settings
    num_generations=2,
    beta=0.08,
    
    # Data settings
    batch_size=2,
    
    # Credentials (optional - will try Colab/Kaggle secrets if not provided)
    wandb_api_key='92c370d749b4a72da2eb10cb156cf0aa4eef05ef',
    kaggle_username='eugenio0',
    kaggle_key='KGAT_db78f48386586bd20c8694d71b859355',
)

# Display configuration
print("Training Configuration:")
print(f"  Batches: {config.num_batches}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  LoRA rank: {config.lora_rank}")
print(f"  Dataset: {'OpenRubrics' if config.use_openrubrics else 'Custom'}")
print(f"  Save to Drive: {config.save_to_drive}")
print(f"  Checkpoint dir: {config.checkpoint_dir}")

## 6. Prepare Session

This cell:
- Sets up credentials
- Mounts Google Drive (if requested)
- Loads the base model and tokenizer
- Creates the LoRA model
- Prepares datasets

In [None]:
# Prepare the training session
session = prepare_colab_session(config)

## 7. Train Model

Run GRPO training. This will:
- Create the optimizer and learning rate schedule
- Set up the GRPO trainer
- Run the training loop
- Save checkpoints periodically

In [None]:
# Run training
trainer_state = train_grpo(config, session)

## 8. Export Checkpoint

Export the trained checkpoint for local use.

In [None]:
# Export checkpoint
checkpoint_path = export_checkpoint(config, trainer_state)

print(f"\nCheckpoint ready at: {checkpoint_path}")

## 9. Quick Test

Test the trained model with a sample question.

In [None]:
# Test the trained model
response = quick_test(config, session)

## 10. Custom Test (Optional)

Try your own questions!

In [None]:
# Test with a custom question
custom_question = "If a train travels 60 mph for 2.5 hours, how far does it go?"

response = quick_test(config, session, test_question=custom_question)

---

## Done!

Your model is trained and checkpoints are saved.

### Next Steps:

1. **Download checkpoints** (if saved to Drive):
   - Find `checkpoint_export.zip` in Google Drive
   - Download and extract to your local `checkpoints/` folder

2. **Use locally**:
   ```bash
   python demo/demo.py --checkpoint ./checkpoints/actor/<step>/model_params
   ```

3. **Modify training logic**:
   - Edit `src/training/colab_pipeline.py`
   - Push changes to GitHub
   - Pull in Colab and re-run

### Checkpoint Status

In [None]:
# Display checkpoint information
import glob

ckpt_dirs = sorted(glob.glob(f"{config.checkpoint_dir}/actor/*/"))

print("Saved Checkpoints:")
print("=" * 60)
for ckpt in ckpt_dirs:
    print(f"  {ckpt}")
print("=" * 60)
print(f"\nTotal: {len(ckpt_dirs)} checkpoints")

if config.save_to_drive:
    print(f"\nCheckpoints saved to Google Drive:")
    print(f"  {config.checkpoint_dir}")
else:
    print(f"\nCheckpoints saved locally (will be lost when runtime ends)")
    print(f"  Set save_to_drive=True to persist checkpoints")