# GRPO Finetuning for Style Transfer

This notebook implements Group Relative Policy Optimization (GRPO) for fine-tuning language models on style transfer tasks.

GRPO is a reinforcement learning method that optimizes language models by:
- Generating multiple outputs per prompt
- Computing rewards for each output
- Using group-relative advantages for stable optimization
- Updating the model to favor high-reward outputs


## 1. Setup and Installation


In [2]:
# Install required packages
# %%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
%uv pip install transformers==4.55.4
%uv pip install --no-deps trl==0.22.2
%uv pip install nltk -q

%uv pip install -q aiohttp nest-asyncio



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m1 package[0m [2min 7ms[0m[0m
Note: you may need to restart the kernel to use updated packages.
[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m1 package[0m [2min 16ms[0m[0m
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [13]:
from tqdm.auto import tqdm


In [3]:
import torch
import json
import numpy as np
from datasets import Dataset, load_dataset
from unsloth import FastLanguageModel
import os
from typing import List, Dict, Optional, Tuple, Any
import warnings
import aiohttp
import asyncio
import nest_asyncio
from functools import lru_cache
import hashlib
from collections import defaultdict
import time

warnings.filterwarnings('ignore')

# Enable nested event loops for Jupyter
nest_asyncio.apply()

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
print(f"✅ Unsloth imported successfully!")


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


Skipping import of cpp extensions due to incompatible torch version 2.8.0+cu129 for torchao version 0.14.0         Please see GitHub issue #2919 for more info


🦥 Unsloth Zoo will now patch everything to make training faster!
PyTorch version: 2.8.0+cu129
CUDA available: True
CUDA device: NVIDIA A100-SXM4-40GB
✅ Unsloth imported successfully!


## 2. Configuration


In [4]:
# ============================================================================
# MODEL CONFIGURATION
# ============================================================================

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
OUTPUT_DIR = "./grpo_style_transfer_model"

# ============================================================================
# EVALUATION API CONFIGURATION
# ============================================================================

# IMPORTANT: Set your evaluation API endpoint here
API_BASE_URL = "https://qtt14d7li0.execute-api.us-west-2.amazonaws.com/api/"  # UPDATE THIS!

# Priority metrics for training (faster, focused on key aspects)
PRIORITY_METRICS = ['style_similarity', 'rephrase_accuracy']

# All metrics for comprehensive evaluation
ALL_METRICS = [
    'style_similarity',
    'meaning_preservation',
    'coherence',
    'fluency',
    'content_length',
    'rephrase_accuracy'
]

# Reward weights for each metric (must sum to 1.0)
REWARD_WEIGHTS = {
    'style_similarity': 0.30,      # Priority metric
    'rephrase_accuracy': 0.30,     # Priority metric
    'meaning_preservation': 0.15,
    'coherence': 0.10,
    'fluency': 0.10,
    'content_length': 0.05
}

# API Configuration
API_CONFIG = {
    "max_concurrent_requests": 100,
    "timeout_seconds": 30,
    "max_retries": 3,
    "retry_delay": 1.0,  # seconds
    "use_cache": True,
}

# Evaluation frequency (evaluate every N training steps to reduce API calls)
EVAL_FREQUENCY = 5  # Evaluate every 5 steps

# ============================================================================
# GRPO HYPERPARAMETERS
# ============================================================================

GRPO_CONFIG = {
    "num_generations_per_prompt": 2,  # Number of outputs to generate per prompt
    "batch_size": 4,
    "learning_rate": 1e-5,
    "num_train_epochs": 3,
    "max_length": 512,
    "temperature": 0.7,
    "top_p": 0.95,
    "kl_coef": 0.05,  # KL divergence coefficient
    "clip_range": 0.2,  # PPO clipping parameter
    "vf_coef": 0.1,  # Value function coefficient
    "eval_frequency": EVAL_FREQUENCY,  # Evaluate every N steps
    "use_priority_metrics": False,  # Use ALL 6 metrics during training for comprehensive evaluation
}

# ============================================================================
# UNSLOTH CONFIGURATION
# ============================================================================

# Unsloth handles LoRA and quantization automatically with optimal settings
MAX_SEQ_LENGTH = 2048  # Choose any! Unsloth auto-supports RoPE scaling
LOAD_IN_4BIT = True  # Use 4-bit quantization for memory efficiency

# LoRA Configuration (Unsloth optimized)
LORA_R = 16  # LoRA rank
LORA_ALPHA = 32  # LoRA alpha
LORA_DROPOUT = 0  # Unsloth recommends 0 for faster training
LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

print("✅ Configuration loaded successfully!")
print(f"📍 API Base URL: {API_BASE_URL}")
print(f"🎯 Using Metrics: ALL 6 METRICS (comprehensive evaluation)")
print(f"   - {', '.join(ALL_METRICS)}")
print(f"📊 Evaluation Frequency: Every {EVAL_FREQUENCY} steps")
print(f"⚡ Max Concurrent API Requests: {API_CONFIG['max_concurrent_requests']}")
print(f"💡 Note: Using all 6 metrics = 3x more API calls than priority mode")


✅ Configuration loaded successfully!
📍 API Base URL: https://qtt14d7li0.execute-api.us-west-2.amazonaws.com/api/
🎯 Using Metrics: ALL 6 METRICS (comprehensive evaluation)
   - style_similarity, meaning_preservation, coherence, fluency, content_length, rephrase_accuracy
📊 Evaluation Frequency: Every 5 steps
⚡ Max Concurrent API Requests: 100
💡 Note: Using all 6 metrics = 3x more API calls than priority mode


## 3. Load Dataset


In [5]:
# Prompt template for style transfer
STYLE_TRANSFER_PROMPT = """**Role:** You are an expert content writer for SurveySparrow platform with 20+ years of experience in writing blog content for SaaS platform.

**Task**
Your task is to rewrite the provided text to match the SurveySparrow writing style exactly without any hallucination or leaving any content from the input blog content. also make sure that you read the content line-by-line so that the meaning doesn't change for each line.

**Rules:**
1. You MUST preserve the core meaning, all facts, and key entities from the original text.
2. Do NOT add any new information, opinions, or details.
3. Do NOT omit any critical information from the original.
4. Adopt the SurveySparrow writing style in terms of tone, vocabulary, and sentence structure.

### Input:
{input_text}

### Output:
"""

def load_style_transfer_dataset_from_hf(dataset_name: str, split: str = "train") -> Dataset:
    """
    Load the style transfer dataset from Hugging Face.

    Dataset columns:
    - rewritten_text_output: Neutral style content (INPUT for model)
    - original_text_input: SurveySparrow style content (OUTPUT/target for model)

    Task: Convert neutral content → SurveySparrow style
    """
    print(f"📥 Loading dataset from Hugging Face: {dataset_name} ({split} split)")

    # Load dataset from Hugging Face
    dataset = load_dataset(dataset_name, split=split)

    print(f"✅ Loaded {len(dataset)} examples")
    print(f"📋 Dataset columns: {dataset.column_names}")

    # Process dataset to create prompt and reference pairs
    processed_data = []

    for idx, item in enumerate(dataset):
        # Get the columns
        neutral_text = item.get('rewritten_text_output', '')  # Input (neutral style)
        surveysparrow_text = item.get('original_text_input', '')  # Output (SurveySparrow style)

        # Skip if either is empty
        if not neutral_text or not surveysparrow_text:
            print(f"⚠️  Skipping example {idx}: missing text")
            continue

        # Format the prompt using the template
        formatted_prompt = STYLE_TRANSFER_PROMPT.format(input_text=neutral_text.strip())

        processed_data.append({
            "prompt": formatted_prompt,
            "reference": surveysparrow_text.strip(),
            "input_text": neutral_text.strip(),  # Store for evaluation API
        })

    # Convert to Dataset
    processed_dataset = Dataset.from_list(processed_data)

    print(f"✅ Processed {len(processed_dataset)} examples successfully")

    return processed_dataset

# Load your dataset from Hugging Face
train_dataset = load_style_transfer_dataset_from_hf(
    dataset_name="madan2248c/styletrasfer_final",
    split="train"
)

print(f"\n{'='*80}")
print("Dataset Example:")
print(f"{'='*80}")
print(f"\n📝 Prompt (first 500 chars):")
print(train_dataset[0]["prompt"][:500] + "...")
print(f"\n✨ Reference Output (first 300 chars):")
print(train_dataset[0]["reference"][:300] + "...")
print(f"{'='*80}\n")


📥 Loading dataset from Hugging Face: madan2248c/styletrasfer_final (train split)
✅ Loaded 9646 examples
📋 Dataset columns: ['original_text_input', 'rewritten_text_output']
✅ Processed 9646 examples successfully

Dataset Example:

📝 Prompt (first 500 chars):
**Role:** You are an expert content writer for SurveySparrow platform with 20+ years of experience in writing blog content for SaaS platform.

**Task**
Your task is to rewrite the provided text to match the SurveySparrow writing style exactly without any hallucination or leaving any content from the input blog content. also make sure that you read the content line-by-line so that the meaning doesn't change for each line.

**Rules:**
1. You MUST preserve the core meaning, all facts, and key entit...

✨ Reference Output (first 300 chars):
Ready for production Once the prototyping stage is over, the design is handed over to the developers for coding. During this stage, the designer should clearly communicate how each portion of the des

## 4. Load Model and Tokenizer


In [6]:
# Load model and tokenizer with Unsloth (optimized for speed and memory)
print(f"Loading model with Unsloth: {MODEL_NAME}")
print(f"This may take a few minutes on first run...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    dtype=None,  # Auto-detect dtype (None = auto)
    load_in_4bit=LOAD_IN_4BIT,  # Use 4-bit quantization
)

print(f"✅ Model and tokenizer loaded successfully!")
print(f"   Model: {MODEL_NAME}")
print(f"   Max sequence length: {MAX_SEQ_LENGTH}")
print(f"   4-bit quantization: {LOAD_IN_4BIT}")
print(f"   Vocab size: {len(tokenizer)}")


Loading model with Unsloth: meta-llama/Llama-3.1-8B-Instruct
This may take a few minutes on first run...
==((====))==  Unsloth 2025.10.4: Fast Llama patching. Transformers: 4.55.4.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.494 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu129. CUDA: 8.0. CUDA Toolkit: 12.9. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
✅ Model and tokenizer loaded successfully!
   Model: meta-llama/Llama-3.1-8B-Instruct
   Max sequence length: 2048
   4-bit quantization: True
   Vocab size: 128256


In [7]:
# Add LoRA adapters with Unsloth (super fast!)
print("Adding LoRA adapters with Unsloth optimization...")

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_R,  # LoRA rank
    target_modules=LORA_TARGET_MODULES,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",  # Supports: "none", "all", "lora_only"
    use_gradient_checkpointing="unsloth",  # Unsloth's optimized checkpointing
    random_state=3407,
    use_rslora=False,  # Rank stabilized LoRA
    loftq_config=None,  # LoftQ quantization
)

print("✅ LoRA adapters added successfully!")
print("\n📊 Trainable Parameters:")
model.print_trainable_parameters()


Adding LoRA adapters with Unsloth optimization...


Unsloth 2025.10.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


✅ LoRA adapters added successfully!

📊 Trainable Parameters:
trainable params: 41,943,040 || all params: 8,072,204,288 || trainable%: 0.5196


## 5. API-Integrated Reward Function

This reward model integrates with your external LLM-as-a-Judge evaluation API to compute rewards based on:
- **Style Similarity** (30% weight) - Priority metric
- **Rephrase Accuracy** (30% weight) - Priority metric  
- **Meaning Preservation** (15% weight)
- **Coherence** (10% weight)
- **Fluency** (10% weight)
- **Content Length** (5% weight)

Features:
- ✅ Async batching for 100 concurrent API requests
- ✅ Response caching to avoid duplicate calls
- ✅ Retry logic with exponential backoff
- ✅ Priority metrics mode for faster training


In [8]:
from typing import Any

class APIIntegratedRewardModel:
    """
    Reward model that integrates with external LLM-as-a-Judge evaluation API.

    Features:
    - Async batch processing with concurrency control
    - Response caching to avoid duplicate API calls
    - Retry logic with exponential backoff
    - Priority metrics mode for faster training
    """

    def __init__(
        self,
        api_base_url: str,
        metrics: List[str],
        weights: Dict[str, float],
        api_config: Dict[str, Any],
        use_cache: bool = True
    ):
        self.api_base_url = api_base_url
        self.metrics = metrics
        self.weights = weights
        self.api_config = api_config
        self.use_cache = use_cache

        # Cache for API responses: key = hash(input_text, model_output, ground_truth, metric)
        self.cache = {} if use_cache else None

        # Statistics
        self.stats = {
            'total_calls': 0,
            'cache_hits': 0,
            'api_errors': 0,
            'total_api_time': 0.0
        }

        print(f"✅ Reward Model initialized with {len(metrics)} metrics")
        print(f"   Metrics: {metrics}")
        print(f"   Caching: {'Enabled' if use_cache else 'Disabled'}")

    def _create_cache_key(self, input_text: str, model_output: str, ground_truth: str, metric: str) -> str:
        """Create a unique cache key for an API call."""
        content = f"{input_text}|{model_output}|{ground_truth}|{metric}"
        return hashlib.md5(content.encode()).hexdigest()

    async def _call_api_single(
        self,
        session: aiohttp.ClientSession,
        metric: str,
        input_text: str,
        model_output: str,
        ground_truth: str,
        retry_count: int = 0
    ) -> Dict[str, Any]:
        """
        Make a single API call to evaluate one metric.
        Returns the API response with score and justification.
        """
        # Check cache first
        if self.use_cache:
            cache_key = self._create_cache_key(input_text, model_output, ground_truth, metric)
            if cache_key in self.cache:
                self.stats['cache_hits'] += 1
                return self.cache[cache_key]

        # Prepare API request
        url = f"{self.api_base_url}/evaluate?metric={metric}"
        payload = {
            "input_text": input_text,
            "model_output": model_output,
            "ground_truth": ground_truth
        }

        try:
            start_time = time.time()

            async with session.post(
                url,
                json=payload,
                timeout=aiohttp.ClientTimeout(total=self.api_config['timeout_seconds'])
            ) as response:
                elapsed = time.time() - start_time
                self.stats['total_api_time'] += elapsed
                self.stats['total_calls'] += 1

                if response.status == 200:
                    result = await response.json()

                    # Cache the result
                    if self.use_cache:
                        self.cache[cache_key] = result

                    return result
                else:
                    error_text = await response.text()
                    raise Exception(f"API returned status {response.status}: {error_text}")

        except Exception as e:
            # Retry logic
            if retry_count < self.api_config['max_retries']:
                await asyncio.sleep(self.api_config['retry_delay'] * (2 ** retry_count))
                return await self._call_api_single(
                    session, metric, input_text, model_output, ground_truth, retry_count + 1
                )
            else:
                self.stats['api_errors'] += 1
                print(f"⚠️  API call failed for metric '{metric}' after {retry_count} retries: {str(e)}")
                # Return a default low score on failure
                return {
                    "metric": metric,
                    "evaluation": {
                        metric: 1,  # Low score (1 out of 5)
                        "justification": f"API call failed: {str(e)}"
                    }
                }

    async def _evaluate_batch_async(
        self,
        prompts: List[str],
        generated: List[str],
        references: List[str]
    ) -> List[Dict[str, float]]:
        """
        Evaluate a batch of generations using async API calls.
        Returns list of metric scores for each generation.
        """
        # Create semaphore to limit concurrent requests
        semaphore = asyncio.Semaphore(self.api_config['max_concurrent_requests'])

        async def evaluate_single(prompt, gen, ref):
            async with semaphore:
                async with aiohttp.ClientSession() as session:
                    # Call API for each metric
                    tasks = [
                        self._call_api_single(session, metric, prompt, gen, ref)
                        for metric in self.metrics
                    ]
                    results = await asyncio.gather(*tasks, return_exceptions=True)

                    # Extract scores from results
                    scores = {}
                    for result in results:
                        if isinstance(result, Exception):
                            continue
                        if 'evaluation' in result:
                            metric_name = result['metric']
                            eval_data = result['evaluation']
                            if metric_name in eval_data:
                                scores[metric_name] = eval_data[metric_name]

                    return scores

        # Process all generations in parallel
        tasks = [
            evaluate_single(p, g, r)
            for p, g, r in zip(prompts, generated, references)
        ]
        all_scores = await asyncio.gather(*tasks)

        return all_scores

    def evaluate_batch(
        self,
        prompts: List[str],
        generated: List[str],
        references: List[str]
    ) -> List[Dict[str, float]]:
        """
        Synchronous wrapper for batch evaluation.
        Returns list of metric scores for each generation.
        """
        # Run async function in event loop
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(
            self._evaluate_batch_async(prompts, generated, references)
        )

    def compute_reward(
        self,
        prompt: str,
        generated: str,
        reference: str,
        scores: Dict[str, float] = None
    ) -> float:
        """
        Compute weighted reward from metric scores.

        Args:
            prompt: Input prompt
            generated: Generated text
            reference: Ground truth reference
            scores: Pre-computed metric scores (if None, will call API)

        Returns:
            Weighted reward score normalized to [0, 1] range
        """
        # Get scores from API if not provided
        if scores is None:
            scores_list = self.evaluate_batch([prompt], [generated], [reference])
            scores = scores_list[0] if scores_list else {}

        # Compute weighted reward
        total_reward = 0.0
        total_weight = 0.0

        for metric in self.metrics:
            if metric in scores and metric in self.weights:
                # Normalize score from 1-5 scale to 0-1 scale
                normalized_score = (scores[metric] - 1) / 4.0  # Maps 1->0, 5->1
                total_reward += normalized_score * self.weights[metric]
                total_weight += self.weights[metric]

        # Normalize by total weight (should be 1.0 if weights are properly configured)
        if total_weight > 0:
            final_reward = total_reward / total_weight
        else:
            final_reward = 0.0

        return final_reward

    def compute_batch_rewards(
        self,
        prompts: List[str],
        generated: List[str],
        references: List[str]
    ) -> Tuple[List[float], List[Dict[str, float]]]:
        """
        Compute rewards for a batch of generations.

        Returns:
            Tuple of (rewards, all_scores) where:
            - rewards: List of weighted reward values [0-1]
            - all_scores: List of dict with individual metric scores [1-5]
        """
        # Get all scores via batch API call
        all_scores = self.evaluate_batch(prompts, generated, references)

        # Compute rewards from scores
        rewards = []
        for scores in all_scores:
            reward = self.compute_reward(None, None, None, scores=scores)
            rewards.append(reward)

        return rewards, all_scores

    def print_stats(self):
        """Print cache and API usage statistics."""
        print(f"\n{'='*60}")
        print("Reward Model Statistics")
        print(f"{'='*60}")
        print(f"Total API calls: {self.stats['total_calls']}")
        print(f"Cache hits: {self.stats['cache_hits']}")
        if self.stats['total_calls'] > 0:
            cache_rate = (self.stats['cache_hits'] / (self.stats['total_calls'] + self.stats['cache_hits'])) * 100
            print(f"Cache hit rate: {cache_rate:.1f}%")
        print(f"API errors: {self.stats['api_errors']}")
        if self.stats['total_calls'] > 0:
            avg_time = self.stats['total_api_time'] / self.stats['total_calls']
            print(f"Average API call time: {avg_time:.2f}s")
        print(f"{'='*60}\n")

# Initialize reward model
metrics_to_use = PRIORITY_METRICS if GRPO_CONFIG["use_priority_metrics"] else ALL_METRICS

reward_model = APIIntegratedRewardModel(
    api_base_url=API_BASE_URL,
    metrics=metrics_to_use,
    weights=REWARD_WEIGHTS,
    api_config=API_CONFIG,
    use_cache=API_CONFIG["use_cache"]
)

print(f"🎯 Using metrics: {metrics_to_use}")

✅ Reward Model initialized with 6 metrics
   Metrics: ['style_similarity', 'meaning_preservation', 'coherence', 'fluency', 'content_length', 'rephrase_accuracy']
   Caching: Enabled
🎯 Using metrics: ['style_similarity', 'meaning_preservation', 'coherence', 'fluency', 'content_length', 'rephrase_accuracy']


## 6. GRPO Training Implementation


In [14]:
class GRPOTrainer:
    """
    GRPO Trainer with API-integrated rewards and optimized evaluation frequency.
    """
    
    def __init__(self, model, tokenizer, reward_model, config, output_dir):
        self.model = model
        self.tokenizer = tokenizer
        self.reward_model = reward_model
        self.config = config
        self.output_dir = output_dir
        
        # Optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config["learning_rate"]
        )
        
        # Track metrics
        self.metrics = {
            "epoch": [],
            "step": [],
            "loss": [],
            "mean_reward": [],
            "max_reward": [],
            "detailed_scores": [],  # Store detailed metric scores
        }
        
        # Step counter for eval frequency
        self.global_step = 0
        self.eval_frequency = config.get("eval_frequency", 5)
        
        print(f"✅ GRPO Trainer initialized")
        print(f"   Eval frequency: Every {self.eval_frequency} steps")
    
    def generate_responses(self, prompts: List[str], num_generations: int) -> List[List[str]]:
        """
        Generate multiple responses for each prompt.
        Returns: List of lists, where each inner list contains multiple generations for one prompt.
        """
        self.model.eval()
        all_generations = []
        
        with torch.no_grad():
            for prompt in prompts:
                # Format prompt for the model
                formatted_prompt = f"<|user|>\\n{prompt}\\n<|assistant|>\\n"
                
                inputs = self.tokenizer(
                    formatted_prompt,
                    return_tensors="pt",
                    padding=True,
                    truncation=True,
                    max_length=512
                ).to(self.model.device)
                
                # Generate multiple outputs
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=self.config["max_length"],
                    num_return_sequences=num_generations,
                    temperature=self.config["temperature"],
                    top_p=self.config["top_p"],
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )
                
                # Decode generations
                generations = []
                for output in outputs:
                    decoded = self.tokenizer.decode(output, skip_special_tokens=True)
                    # Extract only the assistant's response
                    if "<|assistant|>" in decoded:
                        decoded = decoded.split("<|assistant|>")[-1].strip()
                    generations.append(decoded)
                
                all_generations.append(generations)
        
        return all_generations
    
    def compute_advantages(self, rewards: List[float]) -> List[float]:
        """
        Compute group-relative advantages.
        In GRPO, we normalize rewards within each group (multiple generations per prompt).
        """
        rewards = np.array(rewards)
        mean_reward = np.mean(rewards)
        std_reward = np.std(rewards) + 1e-8
        advantages = (rewards - mean_reward) / std_reward
        return advantages.tolist()
    
    def compute_loss(self, prompts: List[str], generations: List[str], advantages: List[float]):
        """
        Compute the GRPO loss.
        """
        self.model.train()
        total_loss = 0.0
        
        for prompt, generation, advantage in zip(prompts, generations, advantages):
            # Format input
            formatted_text = f"<|user|>\\n{prompt}\\n<|assistant|>\\n{generation}"
            
            inputs = self.tokenizer(
                formatted_text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=1024
            ).to(self.model.device)
            
            # Forward pass
            outputs = self.model(**inputs, labels=inputs["input_ids"])
            
            # Weight loss by advantage
            weighted_loss = outputs.loss * advantage
            total_loss += weighted_loss
        
        return total_loss / len(prompts)
    
    def train_step(self, batch_prompts: List[str], batch_references: List[str], evaluate_this_step: bool = False):
        """
        Single training step for GRPO.
        
        Args:
            batch_prompts: List of input prompts
            batch_references: List of reference outputs (ground truth)
            evaluate_this_step: Whether to call evaluation API this step
        """
        self.global_step += 1
        
        # Step 1: Generate multiple responses per prompt
        all_generations = self.generate_responses(
            batch_prompts,
            self.config["num_generations_per_prompt"]
        )
        
        # Flatten generations
        flat_prompts = []
        flat_generations = []
        flat_references = []
        
        for prompt, generations, reference in zip(batch_prompts, all_generations, batch_references):
            for generation in generations:
                flat_prompts.append(prompt)
                flat_generations.append(generation)
                flat_references.append(reference)
        
        # Step 2: Compute rewards 
        all_rewards = []
        all_scores = None
        
        if evaluate_this_step:
            # Call evaluation API (silently for progress bar)
            start_time = time.time()
            all_rewards, all_scores = self.reward_model.compute_batch_rewards(
                flat_prompts, flat_generations, flat_references
            )
        else:
            # Use simple heuristic rewards (no API call)
            for gen, ref in zip(flat_generations, flat_references):
                # Simple length-based reward as proxy
                len_ratio = len(gen.split()) / max(len(ref.split()), 1)
                reward = 1.0 - abs(1.0 - len_ratio)
                reward = max(0, min(1, reward)) * 0.5  # Scale down heuristic rewards
                all_rewards.append(reward)
        
        # Step 3: Compute group-relative advantages
        advantages = self.compute_advantages(all_rewards)
        
        # Step 4: Update model using advantages
        self.optimizer.zero_grad()
        loss = self.compute_loss(flat_prompts, flat_generations, advantages)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        self.optimizer.step()
        
        return {
            "loss": loss.item(),
            "mean_reward": np.mean(all_rewards),
            "max_reward": np.max(all_rewards),
            "min_reward": np.min(all_rewards),
            "detailed_scores": all_scores,
            "evaluated": evaluate_this_step
        }
    
    def train(self, dataset, num_epochs):
        """
        Main training loop with optimized evaluation frequency and progress bars.
        """
        print(f"\\n{'='*80}")
        print(f"🚀 Starting GRPO Training")
        print(f"{'='*80}")
        print(f"📊 Epochs: {num_epochs}")
        print(f"📦 Dataset size: {len(dataset)}")
        print(f"🔢 Batch size: {self.config['batch_size']}")
        print(f"⏱️  Eval frequency: Every {self.eval_frequency} steps")
        print(f"🎲 Generations per prompt: {self.config['num_generations_per_prompt']}")
        print(f"{'='*80}\\n")
        
        # Create epoch progress bar
        epoch_pbar = tqdm(range(num_epochs), desc="📚 Epochs", position=0, leave=True)
        
        for epoch in epoch_pbar:
            epoch_pbar.set_description(f"📚 Epoch {epoch + 1}/{num_epochs}")
            
            epoch_metrics = {
                "loss": [],
                "mean_reward": [],
                "max_reward": [],
            }
            
            # Calculate number of batches
            num_batches = (len(dataset) + self.config["batch_size"] - 1) // self.config["batch_size"]
            
            # Create step progress bar for this epoch
            step_pbar = tqdm(
                range(0, len(dataset), self.config["batch_size"]),
                desc=f"  🔄 Steps",
                position=1,
                leave=False,
                total=num_batches
            )
            
            for i in step_pbar:
                batch = dataset[i:i + self.config["batch_size"]]
                batch_prompts = batch["prompt"]
                batch_references = batch["reference"]
                
                # Determine if we should evaluate this step
                evaluate_this_step = (self.global_step % self.eval_frequency == 0)
                
                # Training step (without verbose prints)
                metrics = self.train_step(batch_prompts, batch_references, evaluate_this_step)
                
                # Track metrics
                for key in epoch_metrics:
                    if key in metrics:
                        epoch_metrics[key].append(metrics[key])
                
                # Update progress bar with metrics
                eval_marker = "📊 EVAL" if evaluate_this_step else "⚡ FAST"
                step_pbar.set_postfix({
                    'type': eval_marker,
                    'loss': f"{metrics['loss']:.4f}",
                    'reward': f"{metrics['mean_reward']:.4f}",
                    'max': f"{metrics['max_reward']:.4f}",
                    'global': self.global_step
                })
                
                # Store detailed scores if evaluated
                if metrics.get('detailed_scores'):
                    self.metrics['detailed_scores'].append({
                        'step': self.global_step,
                        'scores': metrics['detailed_scores']
                    })
            
            # Close step progress bar
            step_pbar.close()
            
            # Epoch summary
            epoch_loss = np.mean(epoch_metrics["loss"])
            epoch_mean_reward = np.mean(epoch_metrics["mean_reward"])
            epoch_max_reward = np.mean(epoch_metrics["max_reward"])
            
            # Update epoch progress bar with summary
            epoch_pbar.set_postfix({
                'loss': f"{epoch_loss:.4f}",
                'reward': f"{epoch_mean_reward:.4f}",
                'max': f"{epoch_max_reward:.4f}"
            })
            
            # Save metrics
            self.metrics["epoch"].append(epoch + 1)
            self.metrics["loss"].append(epoch_loss)
            self.metrics["mean_reward"].append(epoch_mean_reward)
            self.metrics["max_reward"].append(epoch_max_reward)
            
            # Save checkpoint
            checkpoint_dir = f"{self.output_dir}/checkpoint-epoch-{epoch + 1}"
            self.save_checkpoint(checkpoint_dir)
            
            # Print epoch summary
            print(f"\\n📊 Epoch {epoch + 1} Summary: Loss={epoch_loss:.4f}, Reward={epoch_mean_reward:.4f}, Max={epoch_max_reward:.4f}")
        
        # Close epoch progress bar
        epoch_pbar.close()
        
        print(f"\\n{'='*80}")
        print("✅ Training Complete!")
        print(f"{'='*80}\\n")
        
        # Final stats
        self.reward_model.print_stats()
    
    def save_checkpoint(self, checkpoint_dir):
        """Save model checkpoint."""
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.model.save_pretrained(checkpoint_dir)
        self.tokenizer.save_pretrained(checkpoint_dir)
        
        # Save metrics
        with open(f"{checkpoint_dir}/metrics.json", "w") as f:
            json.dump(self.metrics, f, indent=2)

print("✅ Optimized GRPO Trainer class defined with progress bars")


✅ Optimized GRPO Trainer class defined with progress bars


## API Cost Optimization Summary

**With Evaluation Every N Steps:**
- ⚡ **Non-evaluation steps**: Use simple heuristic rewards (length-based)
- 📊 **Evaluation steps** (every 5th step): Call full API with priority metrics

**Example Cost Savings:**
```
Original: Every step evaluation
- 100 training steps × 16 generations × 2 metrics = 3,200 API calls

Optimized: Every 5th step evaluation  
- 20 evaluation steps × 16 generations × 2 metrics = 640 API calls
- Savings: 80% reduction in API calls! 🎉
```

**Caching Benefits:**
- Repeated (prompt, output, reference) combinations are cached
- No duplicate API calls for identical inputs
- Automatic cache hit tracking

**You can adjust:**
- `EVAL_FREQUENCY` - Higher = fewer API calls, Lower = more frequent evaluation
- `PRIORITY_METRICS` vs `ALL_METRICS` - Priority uses 2 metrics, All uses 6


In [10]:
# Note: The optimized GRPOTrainer is defined in Cell 14 above
# This cell is intentionally empty - use the trainer from Cell 14

print("✅ Using optimized GRPO Trainer from Cell 14")


✅ Using optimized GRPO Trainer from Cell 14


## 7. Initialize and Run Training


In [11]:
# Initialize trainer
trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_model=reward_model,
    config=GRPO_CONFIG,
    output_dir=OUTPUT_DIR
)

print("Trainer initialized")


✅ GRPO Trainer initialized
   Eval frequency: Every 5 steps
Trainer initialized


In [None]:
# Start training
trainer.train(
    dataset=train_dataset,
    num_epochs=GRPO_CONFIG["num_train_epochs"]
)

🚀 Starting GRPO Training
📊 Epochs: 3
📦 Dataset size: 9646
🔢 Batch size: 4
⏱️  Eval frequency: Every 5 steps
🎲 Generations per prompt: 2


📚 Epochs:   0%|          | 0/3 [00:00<?, ?it/s]

  🔄 Steps:   0%|          | 0/2412 [00:00<?, ?it/s]

Unsloth: Will smartly offload gradients to save VRAM!


## 8. Visualize Training Metrics


In [None]:
# IMPORTANT: Set your Hugging Face token as an environment variable
# Never commit tokens to code!
# export HF_TOKEN=your_hugging_face_token_here

from huggingface_hub import HfApi
import os

hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    print("❌ HF_TOKEN environment variable not set!")
    print("Please set it with: export HF_TOKEN=your_token_here")
    print("Get your token from: https://huggingface.co/settings/tokens")
else:
    login(hf_token)
    print("✅ Successfully logged in to Hugging Face")


In [None]:
import matplotlib.pyplot as plt

# Plot training metrics
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss plot
axes[0].plot(trainer.metrics["epoch"], trainer.metrics["loss"], marker='o')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True)

# Reward plot
axes[1].plot(trainer.metrics["epoch"], trainer.metrics["mean_reward"], marker='o', label='Mean Reward')
axes[1].plot(trainer.metrics["epoch"], trainer.metrics["max_reward"], marker='s', label='Max Reward')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Reward')
axes[1].set_title('Training Rewards')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/training_metrics.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"Metrics plot saved to {OUTPUT_DIR}/training_metrics.png")


## 9. Test the Trained Model


In [None]:
def test_model(model, tokenizer, prompt: str, max_length: int = 512):
    """
    Test the trained model on a single prompt.
    """
    model.eval()

    # Format prompt
    formatted_prompt = f"<|user|>\\n{prompt}\\n<|assistant|>\\n"

    inputs = tokenizer(
        formatted_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
    ).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract assistant's response
    if "<|assistant|>" in generated_text:
        generated_text = generated_text.split("<|assistant|>")[-1].strip()

    return generated_text


In [None]:
# Test on a few examples
test_prompts = [
    "Please rewrite the following text in my personal style: The sun was setting over the horizon.",
    "Transform this sentence to match my writing style: Technology has changed our lives.",
    "Rewrite in my style: The meeting was very productive and we made good progress."
]

print("Testing the trained model:\\n")
for i, prompt in enumerate(test_prompts, 1):
    print(f"{'='*80}")
    print(f"Test {i}")
    print(f"{'='*80}")
    print(f"Prompt: {prompt}")
    print(f"\\nGenerated Output:")
    output = test_model(model, tokenizer, prompt)
    print(output)
    print()


## 10. Save Final Model


In [None]:
# Save final model
final_model_dir = f"{OUTPUT_DIR}/final_model"
os.makedirs(final_model_dir, exist_ok=True)

model.save_pretrained(final_model_dir)
tokenizer.save_pretrained(final_model_dir)

print(f"Final model saved to {final_model_dir}")


## 12. Comprehensive Final Evaluation (All Metrics)

After training with priority metrics, evaluate the model using **all 6 metrics** for a complete assessment:


In [None]:
# Create a comprehensive evaluation reward model with all metrics
comprehensive_reward_model = APIIntegratedRewardModel(
    api_base_url=API_BASE_URL,
    metrics=ALL_METRICS,  # Use all 6 metrics
    weights=REWARD_WEIGHTS,
    api_config=API_CONFIG,
    use_cache=True
)

# Test on a few examples with comprehensive evaluation
test_examples = [
    {
        "prompt": "Please rewrite the following text in my personal style: The sun was setting over the horizon.",
        "reference": "Your expected reference output here"
    },
    {
        "prompt": "Transform this sentence to match my writing style: Technology has changed our lives.",
        "reference": "Your expected reference output here"
    }
]

print(f"\\n{'='*80}")
print("Comprehensive Final Evaluation (All 6 Metrics)")
print(f"{'='*80}\\n")

for i, example in enumerate(test_examples, 1):
    print(f"\\n{'─'*80}")
    print(f"Example {i}")
    print(f"{'─'*80}")
    print(f"Prompt: {example['prompt']}")

    # Generate output
    output = test_model(model, tokenizer, example['prompt'])
    print(f"\\nGenerated Output:\\n{output}")

    # Evaluate with all metrics
    print(f"\\n{'─'*40}")
    print("Evaluation Scores:")
    print(f"{'─'*40}")

    rewards, scores = comprehensive_reward_model.compute_batch_rewards(
        [example['prompt']],
        [output],
        [example['reference']]
    )

    # Display individual metric scores
    for metric in ALL_METRICS:
        if metric in scores[0]:
            score = scores[0][metric]
            print(f"  {metric:.<30} {score}/5")

    print(f"\\n  {'Overall Weighted Reward':.<30} {rewards[0]:.4f}/1.0")
    print(f"{'─'*80}\\n")

# Print final API statistics
comprehensive_reward_model.print_stats()


## 11. Load and Use Trained Model (Optional)


In [None]:
# Load the trained model later with Unsloth
def load_trained_model(adapter_path: str, max_seq_length: int = 2048):
    """
    Load the trained model with LoRA adapters using Unsloth.

    Args:
        adapter_path: Path to the saved checkpoint/adapter
        max_seq_length: Maximum sequence length

    Returns:
        Tuple of (model, tokenizer)
    """
    print(f"Loading trained model from: {adapter_path}")

    # Unsloth can load the model with adapters directly
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=adapter_path,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )

    # Enable inference mode for faster generation
    FastLanguageModel.for_inference(model)

    print("✅ Trained model loaded successfully!")
    return model, tokenizer

# Example usage:
# trained_model, trained_tokenizer = load_trained_model(final_model_dir)
# or load from checkpoint:
# trained_model, trained_tokenizer = load_trained_model("./grpo_style_transfer_model/checkpoint-epoch-3")

print("✅ Model loading function defined (Unsloth optimized)")
