# GRPO Fine-tuning for AMR Prescription Validation

This notebook implements Group Relative Policy Optimization (GRPO) for fine-tuning language models on AMR (Antimicrobial Resistance) prescription validation tasks.

**GRPO** is a reinforcement learning method that optimizes language models by:
- Generating multiple outputs per prompt
- Computing rewards for each output using LLM-as-a-Judge
- Using group-relative advantages for stable optimization
- Updating the model to favor high-reward outputs

**Dataset**: ICMR 2025 Antimicrobial Treatment Guidelines
**Task**: Validate prescriptions with step-by-step clinical reasoning


## 1. Setup and Installation


In [None]:
# Install required packages
# Run this cell first!

import os
import sys

# Check if running in Colab
if "COLAB_" in "".join(os.environ.keys()):
    print("üîß Installing packages for Google Colab...")
    !pip install -q unsloth bitsandbytes accelerate peft trl transformers datasets
else:
    print("üîß Installing packages for local environment...")
    !pip install -q unsloth transformers datasets groq pydantic tqdm aiohttp nest-asyncio

print("‚úÖ Installation complete!")


In [None]:
# Import libraries
import os
import torch
import json
import numpy as np
import asyncio
import aiohttp
import nest_asyncio
import hashlib
import time
from pathlib import Path
from typing import List, Dict, Optional, Tuple, Any
from datasets import Dataset, load_dataset
from unsloth import FastLanguageModel
from tqdm.auto import tqdm
import warnings

warnings.filterwarnings('ignore')
nest_asyncio.apply()

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"‚úÖ All imports successful!")


## 2. Configuration


In [None]:
# ============================================================================
# MODEL CONFIGURATION
# ============================================================================

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
OUTPUT_DIR = "./grpo_amr_model"

# ============================================================================
# EVALUATION API CONFIGURATION
# ============================================================================

# Supabase Edge Function URL for evaluation API
API_BASE_URL = "https://gdpanoqcfepugqkisqhf.supabase.co/functions/v1/evaluate-prescription"

# Priority metrics for training (faster, focused on key aspects)
PRIORITY_METRICS = ['clinical_accuracy', 'guideline_adherence', 'reasoning_completeness']

# All metrics for comprehensive evaluation
ALL_METRICS = [
    'clinical_accuracy',
    'guideline_adherence',
    'reasoning_completeness',
    'safety_awareness',
    'decision_appropriateness',
    'reference_accuracy'
]

# Reward weights for each metric (must sum to 1.0)
REWARD_WEIGHTS = {
    'clinical_accuracy': 0.25,
    'guideline_adherence': 0.25,
    'reasoning_completeness': 0.20,
    'safety_awareness': 0.15,
    'decision_appropriateness': 0.10,
    'reference_accuracy': 0.05
}

# API Configuration
API_CONFIG = {
    "max_concurrent_requests": 50,
    "timeout_seconds": 45,
    "max_retries": 3,
    "retry_delay": 2.0,
    "use_cache": True,
}

# Evaluation frequency (evaluate every N training steps)
EVAL_FREQUENCY = 10

# ============================================================================
# GRPO HYPERPARAMETERS
# ============================================================================

GRPO_CONFIG = {
    "num_generations_per_prompt": 2,
    "batch_size": 2,
    "learning_rate": 5e-6,
    "num_train_epochs": 3,
    "max_length": 1024,
    "temperature": 0.7,
    "top_p": 0.95,
    "kl_coef": 0.05,
    "clip_range": 0.2,
    "vf_coef": 0.1,
    "eval_frequency": EVAL_FREQUENCY,
    "use_priority_metrics": True,
}

# ============================================================================
# UNSLOTH CONFIGURATION
# ============================================================================

MAX_SEQ_LENGTH = 2048
LOAD_IN_4BIT = True
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

print("‚úÖ Configuration loaded successfully!")
print(f"üìç API Base URL: {API_BASE_URL}")
print(f"üéØ Using Metrics: {'PRIORITY (3 metrics)' if GRPO_CONFIG['use_priority_metrics'] else 'ALL (6 metrics)'}")
print(f"üìä Evaluation Frequency: Every {EVAL_FREQUENCY} steps")
print(f"‚ö° Batch Size: {GRPO_CONFIG['batch_size']}")
print(f"üî¢ Generations per prompt: {GRPO_CONFIG['num_generations_per_prompt']}")


In [None]:
   os.environ["SUPABASE_ANON_KEY"] = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImdkcGFub3FjZmVwdWdxa2lzcWhmIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ3NDQyMDcsImV4cCI6MjA4MDMyMDIwN30.S3Bsm8MDqJWOFx6vVVBislrlX09sGth1EbR1lRfPdxw"
   os.environ["GROQ_API_KEY"] = "your_groq_api_key_here"

## 3. Load Dataset

We'll load the prepared GRPO dataset. If you haven't prepared it yet, run:
```bash
python prepare_grpo_dataset.py
```


In [None]:
# Load dataset
from datasets import load_from_disk

# Path to prepared dataset
data_dir = Path("./data")
train_dataset_path = data_dir / "train_hf"

print(f"üìÇ Loading dataset from: {train_dataset_path}")

if not train_dataset_path.exists():
    print("‚ùå Dataset not found!")
    print("Please run: python prepare_grpo_dataset.py")
    print("This will convert your merged AMR dataset to GRPO format.")
else:
    train_dataset = load_from_disk(str(train_dataset_path))
    print(f"‚úÖ Loaded {len(train_dataset)} training examples")
    
    # Show dataset structure
    print(f"\nüìã Dataset columns: {train_dataset.column_names}")
    
    # Show example
    print(f"\n{'='*80}")
    print("Example Training Instance:")
    print(f"{'='*80}")
    example = train_dataset[0]
    print(f"\nüìù Prompt (first 300 chars):")
    print(example['prompt'][:300] + "...")
    print(f"\n‚ú® Reference (first 200 chars):")
    print(example['reference'][:200] + "...")
    print(f"\nüè∑Ô∏è  Task Type: {example.get('task_type', 'unknown')}")
    print(f"{'='*80}")


## 4. Load Model and Tokenizer


In [None]:
# Load model and tokenizer with Unsloth (optimized for speed and memory)
print(f"üîÑ Loading model: {MODEL_NAME}")
print(f"This may take a few minutes on first run...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,  # Auto-detect dtype
    load_in_4bit=LOAD_IN_4BIT,
)

print(f"\n‚úÖ Model and tokenizer loaded successfully!")
print(f"   Model: {MODEL_NAME}")
print(f"   Max sequence length: {MAX_SEQ_LENGTH}")
print(f"   4-bit quantization: {LOAD_IN_4BIT}")
print(f"   Vocab size: {len(tokenizer)}")


## 5. API-Integrated Reward Model

This reward model integrates with your Supabase Edge Function to evaluate model outputs using LLM-as-a-Judge.


In [None]:
# Add LoRA adapters with Unsloth (super fast!)
print("üîß Adding LoRA adapters with Unsloth optimization...")

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,
    target_modules=LORA_TARGET_MODULES,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

print("\n‚úÖ LoRA adapters added successfully!")
print("\nüìä Trainable Parameters:")
model.print_trainable_parameters()


In [None]:
class APIIntegratedRewardModel:
    """
    Reward model that integrates with Supabase Edge Function for evaluation.
    
    Features:
    - Async batch processing with concurrency control
    - Response caching to avoid duplicate API calls
    - Retry logic with exponential backoff
    - Priority metrics mode for faster training
    """
    
    def __init__(
        self,
        api_base_url: str,
        metrics: List[str],
        weights: Dict[str, float],
        api_config: Dict[str, Any],
        use_cache: bool = True
    ):
        self.api_base_url = api_base_url
        self.metrics = metrics
        self.weights = weights
        self.api_config = api_config
        self.use_cache = use_cache
        
        self.cache = {} if use_cache else None
        self.stats = {
            'total_calls': 0,
            'cache_hits': 0,
            'api_errors': 0,
            'total_api_time': 0.0
        }
        
        print(f"‚úÖ Reward Model initialized with {len(metrics)} metrics")
        print(f"   Metrics: {metrics}")
        print(f"   Caching: {'Enabled' if use_cache else 'Disabled'}")
    
    def _create_cache_key(self, context: Dict, model_output: str, ground_truth: str) -> str:
        """Create a unique cache key for an API call."""
        content = f"{json.dumps(context)}|{model_output}|{ground_truth}"
        return hashlib.md5(content.encode()).hexdigest()
    
    async def _call_api_async(
        self,
        session: aiohttp.ClientSession,
        context: Dict,
        model_output: str,
        ground_truth: str,
        retry_count: int = 0
    ) -> Dict[str, Any]:
        """Make async API call to evaluation endpoint."""
        
        # Check cache
        if self.use_cache:
            cache_key = self._create_cache_key(context, model_output, ground_truth)
            if cache_key in self.cache:
                self.stats['cache_hits'] += 1
                return self.cache[cache_key]
        
        # Prepare request
        payload = {
            "patient_case": context,
            "model_output": model_output,
            "ground_truth": ground_truth,
            "metrics": self.metrics
        }
        
        # Prepare headers with Supabase authentication
        headers = {"Content-Type": "application/json"}
        supabase_key = os.getenv("SUPABASE_ANON_KEY")
        if supabase_key:
            headers["Authorization"] = f"Bearer {supabase_key}"
            headers["apikey"] = supabase_key
        
        try:
            start_time = time.time()
            
            async with session.post(
                self.api_base_url,
                json=payload,
                headers=headers,
                timeout=aiohttp.ClientTimeout(total=self.api_config['timeout_seconds'])
            ) as response:
                elapsed = time.time() - start_time
                self.stats['total_api_time'] += elapsed
                self.stats['total_calls'] += 1
                
                if response.status == 200:
                    result = await response.json()
                    
                    # Cache the result
                    if self.use_cache:
                        self.cache[cache_key] = result
                    
                    return result
                else:
                    error_text = await response.text()
                    raise Exception(f"API returned status {response.status}: {error_text}")
        
        except Exception as e:
            # Retry logic
            if retry_count < self.api_config['max_retries']:
                await asyncio.sleep(self.api_config['retry_delay'] * (2 ** retry_count))
                return await self._call_api_async(
                    session, context, model_output, ground_truth, retry_count + 1
                )
            else:
                self.stats['api_errors'] += 1
                # Return default low scores
                return {
                    "success": False,
                    "evaluations": {metric: {"score": 1} for metric in self.metrics},
                    "weighted_reward": 0.0
                }
    
    async def _evaluate_batch_async(
        self,
        contexts: List[Dict],
        generated: List[str],
        references: List[str]
    ) -> List[Dict]:
        """Evaluate a batch of generations using async API calls."""
        semaphore = asyncio.Semaphore(self.api_config['max_concurrent_requests'])
        
        async def evaluate_single(ctx, gen, ref):
            async with semaphore:
                async with aiohttp.ClientSession() as session:
                    return await self._call_api_async(session, ctx, gen, ref)
        
        tasks = [
            evaluate_single(ctx, gen, ref)
            for ctx, gen, ref in zip(contexts, generated, references)
        ]
        results = await asyncio.gather(*tasks)
        
        return results
    
    def evaluate_batch(
        self,
        contexts: List[Dict],
        generated: List[str],
        references: List[str]
    ) -> List[Dict]:
        """Synchronous wrapper for batch evaluation."""
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self._evaluate_batch_async(contexts, generated, references)
        )
    
    def compute_batch_rewards(
        self,
        contexts: List[Dict],
        generated: List[str],
        references: List[str]
    ) -> Tuple[List[float], List[Dict]]:
        """Compute rewards for a batch of generations."""
        results = self.evaluate_batch(contexts, generated, references)
        
        rewards = []
        all_scores = []
        
        for result in results:
            if result.get("success", False):
                reward = result.get("weighted_reward", 0.0)
                scores = result.get("evaluations", {})
            else:
                reward = 0.0
                scores = {}
            
            rewards.append(reward)
            all_scores.append(scores)
        
        return rewards, all_scores
    
    def print_stats(self):
        """Print cache and API usage statistics."""
        print(f"\n{'='*60}")
        print("Reward Model Statistics")
        print(f"{'='*60}")
        print(f"Total API calls: {self.stats['total_calls']}")
        print(f"Cache hits: {self.stats['cache_hits']}")
        if self.stats['total_calls'] > 0:
            cache_rate = (self.stats['cache_hits'] / (self.stats['total_calls'] + self.stats['cache_hits'])) * 100
            print(f"Cache hit rate: {cache_rate:.1f}%")
        print(f"API errors: {self.stats['api_errors']}")
        if self.stats['total_calls'] > 0:
            avg_time = self.stats['total_api_time'] / self.stats['total_calls']
            print(f"Average API call time: {avg_time:.2f}s")
        print(f"{'='*60}\n")


# Initialize reward model
metrics_to_use = PRIORITY_METRICS if GRPO_CONFIG["use_priority_metrics"] else ALL_METRICS

reward_model = APIIntegratedRewardModel(
    api_base_url=API_BASE_URL,
    metrics=metrics_to_use,
    weights=REWARD_WEIGHTS,
    api_config=API_CONFIG,
    use_cache=API_CONFIG["use_cache"]
)

print(f"\nüéØ Using metrics: {metrics_to_use}")


## 6. Environment Setup

Before running the training, you need to set up your environment variables for authentication.

In [None]:
# Set your environment variables here
# You can also set these in your system environment before starting Jupyter

# Supabase authentication key (required)
os.environ["SUPABASE_ANON_KEY"] = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImdkcGFub3FjZmVwdWdxa2lzcWhmIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ3NDQyMDcsImV4cCI6MjA4MDMyMDIwN30.S3Bsm8MDqJWOFx6vVVBislrlX09sGth1EbR1lRfPdxw"

# Groq API key (required for evaluation)
# Get from: https://console.groq.com/keys
os.environ["GROQ_API_KEY"] = "your_groq_api_key_here"  # Replace with your key

# Hugging Face token (optional, for model uploads)
# os.environ["HF_TOKEN"] = "your_hugging_face_token_here"

print("‚úÖ Environment variables set!")
print(f"üìç API URL: {API_BASE_URL}")
print(f"üîë Supabase Key: {'Set' if os.getenv('SUPABASE_ANON_KEY') else 'Not Set'}")
print(f"ü§ñ Groq Key: {'Set' if os.getenv('GROQ_API_KEY') else 'Not Set'}")
print(f"ü§ó HF Token: {'Set' if os.getenv('HF_TOKEN') else 'Not Set'}")

# Test API connectivity
try:
    import aiohttp
    import asyncio
    
    async def test_api():
        async with aiohttp.ClientSession() as session:
            headers = {"Content-Type": "application/json"}
            supabase_key = os.getenv("SUPABASE_ANON_KEY")
            if supabase_key:
                headers["Authorization"] = f"Bearer {supabase_key}"
                headers["apikey"] = supabase_key
            
            async with session.post(API_BASE_URL, json={"test": "connection"}, headers=headers) as response:
                return response.status
    
    status = asyncio.run(test_api())
    if status == 200:
        print("‚úÖ API connection successful!")
    else:
        print(f"‚ö†Ô∏è  API returned status {status} (may be expected for test payload)")
        
except Exception as e:
    print(f"‚ùå API connection failed: {e}")
    print("Make sure your environment variables are set correctly.")


## 7. GRPO Training

Now you can run the complete GRPO training pipeline.

In [None]:
# GRPO Trainer class (complete implementation)
class GRPOTrainer:
    """GRPO Trainer for AMR prescription validation."""
    
    def __init__(self, model, tokenizer, reward_model, config, output_dir):
        self.model = model
        self.tokenizer = tokenizer
        self.reward_model = reward_model
        self.config = config
        self.output_dir = output_dir
        
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config["learning_rate"]
        )
        
        self.metrics = {
            "epoch": [],
            "step": [],
            "loss": [],
            "mean_reward": [],
            "max_reward": [],
            "detailed_scores": [],
        }
        
        self.global_step = 0
        self.eval_frequency = config.get("eval_frequency", 10)
        
        print(f"‚úÖ GRPO Trainer initialized")
        print(f"   Eval frequency: Every {self.eval_frequency} steps")
    
    def generate_responses(self, prompts: List[str], num_generations: int) -> List[List[str]]:
        """Generate multiple responses for each prompt."""
        self.model.eval()
        all_generations = []
        
        with torch.no_grad():
            for prompt in prompts:
                formatted_prompt = f"<|user|>\\n{prompt}\\n<|assistant|>\\n"
                
                inputs = self.tokenizer(
                    formatted_prompt,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512
                ).to(self.model.device)
                
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.config["max_length"],
                    num_return_sequences=num_generations,
                    temperature=self.config["temperature"],
                    top_p=self.config["top_p"],
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                
                generations = []
                for output in outputs:
                    decoded = self.tokenizer.decode(output, skip_special_tokens=True)
                    if "<|assistant|>" in decoded:
                        decoded = decoded.split("<|assistant|>")[-1].strip()
                    generations.append(decoded)
                
                all_generations.append(generations)
        
        return all_generations
    
    def compute_advantages(self, rewards: List[float]) -> List[float]:
        """Compute group-relative advantages."""
        rewards = np.array(rewards)
        mean_reward = np.mean(rewards)
        std_reward = np.std(rewards) + 1e-8
        advantages = (rewards - mean_reward) / std_reward
        return advantages.tolist()
    
    def compute_loss(self, prompts: List[str], generations: List[str], advantages: List[float]):
        """Compute the GRPO loss."""
        self.model.train()
        total_loss = 0.0
        
        for prompt, generation, advantage in zip(prompts, generations, advantages):
            formatted_text = f"<|user|>\\n{prompt}\\n<|assistant|>\\n{generation}"
            
            inputs = self.tokenizer(
                formatted_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=2048
            ).to(self.model.device)
            
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            weighted_loss = outputs.loss * advantage
            total_loss += weighted_loss
        
        return total_loss / len(prompts)
    
    def train_step(self, batch_prompts: List[str], batch_references: List[str], 
                   batch_contexts: List[Dict], evaluate_this_step: bool = False):
        """Single training step for GRPO."""
        self.global_step += 1
        
        # Generate responses
        all_generations = self.generate_responses(
            batch_prompts,
            self.config["num_generations_per_prompt"]
        )
        
        # Flatten
        flat_prompts = []
        flat_generations = []
        flat_references = []
        flat_contexts = []
        
        for prompt, generations, reference, context in zip(
            batch_prompts, all_generations, batch_references, batch_contexts
        ):
            for generation in generations:
                flat_prompts.append(prompt)
                flat_generations.append(generation)
                flat_references.append(reference)
                flat_contexts.append(context)
        
        # Compute rewards
        all_rewards = []
        all_scores = None
        
        if evaluate_this_step:
            all_rewards, all_scores = self.reward_model.compute_batch_rewards(
                flat_contexts, flat_generations, flat_references
            )
        else:
            # Simple heuristic rewards
            for gen, ref in zip(flat_generations, flat_references):
                len_ratio = len(gen.split()) / max(len(ref.split()), 1)
                reward = 1.0 - abs(1.0 - len_ratio)
                reward = max(0, min(1, reward)) * 0.5
                all_rewards.append(reward)
        
        # Compute advantages
        advantages = self.compute_advantages(all_rewards)
        
        # Update model
        self.optimizer.zero_grad()
        loss = self.compute_loss(flat_prompts, flat_generations, advantages)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()
        
        return {
            "loss": loss.item(),
            "mean_reward": np.mean(all_rewards),
            "max_reward": np.max(all_rewards),
            "min_reward": np.min(all_rewards),
            "detailed_scores": all_scores,
            "evaluated": evaluate_this_step
        }
    
    def train(self, dataset, num_epochs):
        """Main training loop."""
        print(f"\\n{'='*80}")
        print(f"üöÄ Starting GRPO Training for AMR")
        print(f"{'='*80}")
        print(f"üìä Epochs: {num_epochs}")
        print(f"üì¶ Dataset size: {len(dataset)}")
        print(f"üî¢ Batch size: {self.config['batch_size']}")
        print(f"‚è±Ô∏è  Eval frequency: Every {self.eval_frequency} steps")
        print(f"üé≤ Generations per prompt: {self.config['num_generations_per_prompt']}")
        print(f"{'='*80}\\n")
        
        epoch_pbar = tqdm(range(num_epochs), desc="üìö Epochs", position=0, leave=True)
        
        for epoch in epoch_pbar:
            epoch_pbar.set_description(f"üìö Epoch {epoch + 1}/{num_epochs}")
            
            epoch_metrics = {
                "loss": [],
                "mean_reward": [],
                "max_reward": [],
            }
            
            num_batches = (len(dataset) + self.config["batch_size"] - 1) // self.config["batch_size"]
            
            step_pbar = tqdm(
                range(0, len(dataset), self.config["batch_size"]),
                desc=f"  üîÑ Steps",
                position=1,
                leave=False,
                total=num_batches
            )
            
            for i in step_pbar:
                batch = dataset[i:i + self.config["batch_size"]]
                batch_prompts = batch["prompt"]
                batch_references = batch["reference"]
                batch_contexts = batch["context"]
                
                evaluate_this_step = (self.global_step % self.eval_frequency == 0)
                
                metrics = self.train_step(batch_prompts, batch_references, batch_contexts, evaluate_this_step)
                
                for key in epoch_metrics:
                    if key in metrics:
                        epoch_metrics[key].append(metrics[key])
                
                eval_marker = "üìä EVAL" if evaluate_this_step else "‚ö° FAST"
                step_pbar.set_postfix({
                    'type': eval_marker,
                    'loss': f"{metrics['loss']:.4f}",
                    'reward': f"{metrics['mean_reward']:.4f}",
                    'max': f"{metrics['max_reward']:.4f}",
                    'global': self.global_step
                })
                
                if metrics.get('detailed_scores'):
                    self.metrics['detailed_scores'].append({
                        'step': self.global_step,
                        'scores': metrics['detailed_scores']
                    })
            
            step_pbar.close()
            
            # Epoch summary
            epoch_loss = np.mean(epoch_metrics["loss"])
            epoch_mean_reward = np.mean(epoch_metrics["mean_reward"])
            epoch_max_reward = np.mean(epoch_metrics["max_reward"])
            
            epoch_pbar.set_postfix({
                'loss': f"{epoch_loss:.4f}",
                'reward': f"{epoch_mean_reward:.4f}",
                'max': f"{epoch_max_reward:.4f}"
            })
            
            self.metrics["epoch"].append(epoch + 1)
            self.metrics["loss"].append(epoch_loss)
            self.metrics["mean_reward"].append(epoch_mean_reward)
            self.metrics["max_reward"].append(epoch_max_reward)
            
            # Save checkpoint
            checkpoint_dir = f"{self.output_dir}/checkpoint-epoch-{epoch + 1}"
            self.save_checkpoint(checkpoint_dir)
            
            print(f"\\nüìä Epoch {epoch + 1} Summary: Loss={epoch_loss:.4f}, Reward={epoch_mean_reward:.4f}, Max={epoch_max_reward:.4f}")
        
        epoch_pbar.close()
        
        print(f"\\n{'='*80}")
        print("‚úÖ Training Complete!")
        print(f"{'='*80}\\n")
        
        self.reward_model.print_stats()
    
    def save_checkpoint(self, checkpoint_dir):
        """Save model checkpoint."""
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.model.save_pretrained(checkpoint_dir)
        self.tokenizer.save_pretrained(checkpoint_dir)
        
        with open(f"{checkpoint_dir}/metrics.json", "w") as f:
            json.dump(self.metrics, f, indent=2)

print("‚úÖ GRPO Trainer class defined")


In [None]:
# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_model=reward_model,
    config=GRPO_CONFIG,
    output_dir=OUTPUT_DIR
)

print("‚úÖ Trainer initialized!")


In [None]:
# Start training
trainer.train(
    dataset=train_dataset,
    num_epochs=GRPO_CONFIG["num_train_epochs"]
)


## 8. Save Final Model

After training, save the final model.

In [None]:
# Save final model
final_model_dir = f"{OUTPUT_DIR}/final_model"
os.makedirs(final_model_dir, exist_ok=True)

model.save_pretrained(final_model_dir)
tokenizer.save_pretrained(final_model_dir)

print(f"‚úÖ Final model saved to {final_model_dir}")
