# Finetuning Llama-2 1B with GRPO for Reasoning

This notebook demonstrates how to finetune Llama-2 1B using GRPO (Grounded Preference Optimization) to enhance its reasoning capabilities.

## Credits and References

This implementation draws from several key sources:

1. **GRPO Implementation**:
   - [theLMbook's GRPO Implementation](https://github.com/aburkov/theLMbook/blob/main/GRPO_Qwen_0_5_Instruct.ipynb)
   - [Unsloth's GRPO Documentation](https://docs.unsloth.ai/basics/reasoning-grpo-and-rl)
   - [DeepSeek-R1 Paper](https://thelmbook.com/articles/#!./DeepSeek-R1.md)

2. **Model Architecture**:
   - [Llama 2 Paper](https://arxiv.org/abs/2307.09288)
   - [Llama 2 Official Repository](https://github.com/facebookresearch/llama)
   - [Parameter-Efficient Fine-tuning Guide](https://huggingface.co/docs/peft/index)

3. **Training Datasets**:
   - [facebook/natural_reasoning](https://huggingface.co/datasets/facebook/natural_reasoning)
   - [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k)
   - [SkunkworksAI/reasoning-0.01](https://huggingface.co/datasets/SkunkworksAI/reasoning-0.01)

4. **Additional Resources**:
   - [PPO for Language Models](https://arxiv.org/abs/2109.10862)
   - [TRL Library Documentation](https://huggingface.co/docs/trl/index)
   - [LoRA: Low-Rank Adaptation Paper](https://arxiv.org/abs/2106.09685)

## Overview

```mermaid
graph TD
    A[Llama-2 1B Base Model] --> B[4-bit Quantization]
    B --> C[LoRA Adaptation]
    C --> D[Dataset Preparation]
    D --> E[GRPO Training]
    E --> F[Reward Model]
    F --> G[Model Updates]
    G --> E
    G --> H[Final Model]
```

## Setup Requirements

1. **Environment Setup**
   ```bash
   pip install -q transformers accelerate bitsandbytes datasets torch peft trl wandb
   ```

2. **HuggingFace Authentication**
   ```python
   from huggingface_hub import login
   login()  # Enter your token when prompted
   ```

3. **Hardware Requirements**
   - GPU: NVIDIA GPU with 8GB+ VRAM
   - RAM: 16GB+ recommended
   - Storage: 5GB+ free space

## Training Process
![GRPO Training Process](https://raw.githubusercontent.com/unslothai/unsloth/main/docs/images/grpo.png)

The training process involves:
1. Loading and preprocessing reasoning datasets
2. Applying 4-bit quantization and LoRA
3. Training with GRPO and reward model
4. Evaluating reasoning capabilities

In [None]:
!pip install -q transformers accelerate bitsandbytes datasets torch peft trl wandb

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
import wandb

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Model and Tokenizer

We'll use Llama-2 1B as our base model, with 4-bit quantization for efficiency.

In [None]:
model_name = "meta-llama/Llama-2-1b-hf"

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

## Configure LoRA for Efficient Training

Using LoRA for parameter-efficient fine-tuning, following the approach from the LoRA paper.

In [None]:
lora_config = LoraConfig(
    r=8,  # Rank of update matrices
    lora_alpha=16,  # Alpha scaling factor
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # Attention modules to adapt
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

## Load and Prepare Training Data

We'll combine multiple reasoning datasets to create a diverse training set.

In [None]:
# Load and prepare reasoning datasets
from datasets import load_dataset, concatenate_datasets
import random

print("Loading datasets...")

# Load multiple reasoning datasets
datasets = {
    "natural_reasoning": load_dataset("facebook/natural_reasoning", split="train"),
    "openthoughts": load_dataset("open-thoughts/OpenThoughts-114k", split="train"),
    "skunkworks": load_dataset("SkunkworksAI/reasoning-0.01", split="train")
}

print("\nDataset sizes:")
for name, dataset in datasets.items():
    print(f"{name}: {len(dataset):,} examples")

def process_natural_reasoning(example):
    """Process facebook/natural_reasoning dataset.
    
    Format:
    - Input: Question that requires reasoning
    - Response: Step-by-step rationale
    - Feedback: Quality assessment based on logical structure
    """
    return {
        "instruction": example["question"],
        "response": f"Let me solve this step by step:\n{example['rationale']}\n\nTherefore, {example['answer']}",
        "feedback": "Good reasoning with clear logical steps" if len(example["rationale"].split()) > 20 else "Needs more detailed explanation"
    }

def process_openthoughts(example):
    """Process OpenThoughts dataset.
    
    Format:
    - Input: Open-ended prompt
    - Response: Thought process and conclusion
    - Feedback: Based on reasoning depth
    """
    return {
        "instruction": example["prompt"],
        "response": f"Let me think through this:\n{example['thought_process']}\n\nConclusion: {example['response']}",
        "feedback": example.get("feedback", "Clear thought process with logical progression")
    }

def process_skunkworks(example):
    """Process SkunkworksAI reasoning dataset.
    
    Format:
    - Input: Reasoning task
    - Response: Structured solution
    - Feedback: Based on step-by-step approach
    """
    return {
        "instruction": example["instruction"],
        "response": example["output"],
        "feedback": "Excellent step-by-step reasoning" if "step" in example["output"].lower() else "Could use more explicit steps"
    }

print("\nProcessing datasets...")

# Process datasets with progress tracking
processed_datasets = {}
for name, dataset in datasets.items():
    print(f"Processing {name}...")
    if name == "natural_reasoning":
        processed_datasets[name] = dataset.map(process_natural_reasoning)
    elif name == "openthoughts":
        processed_datasets[name] = dataset.map(process_openthoughts)
    else:
        processed_datasets[name] = dataset.map(process_skunkworks)

# Sample and combine datasets with balanced representation
sample_sizes = {
    "natural_reasoning": 50000,
    "openthoughts": 30000,
    "skunkworks": 20000
}

combined_dataset = concatenate_datasets([
    processed_datasets[name].select(range(min(size, len(processed_datasets[name]))))
    for name, size in sample_sizes.items()
])

# Shuffle the combined dataset
combined_dataset = combined_dataset.shuffle(seed=42)

def format_prompt(example):
    """Format example for Llama-2 instruction format.
    
    Structure:
    1. System prompt for reasoning task
    2. User instruction
    3. Assistant response with reasoning
    4. Feedback for grounding
    """
    return f"[INST] <<SYS>>\nYou are a helpful AI assistant that provides clear, step-by-step reasoning for questions.\n<</SYS>>\n\n{example['instruction']} [/INST]\n{example['response']}\n\nFeedback: {example['feedback']}"

# Show example
print("\nExample formatted prompt:")
print("-" * 80)
print(format_prompt(combined_dataset[0]))
print("-" * 80)

print(f"\nFinal dataset size: {len(combined_dataset):,} examples")

## GRPO Training Setup

Following the approach from theLMbook and Unsloth's GRPO implementation.

In [None]:
from trl import PPOTrainer, PPOConfig

ppo_config = PPOConfig(
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=1,
    gradient_accumulation_steps=4,
    optimize_cuda_cache=True
)

# Initialize PPO trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
    dataset=combined_dataset
)

## Training Loop with Reward Model

Implementing GRPO training with reward model based on reasoning quality metrics.

In [None]:
def compute_reasoning_quality(response):
    """Evaluate the quality of reasoning in a model response.
    
    Based on metrics from:
    - theLMbook's GRPO implementation
    - Unsloth's reasoning evaluation
    - DeepSeek-R1 paper
    
    Metrics:
    1. Step-by-step explanation (0.3)
    2. Logical flow (0.2)
    3. Depth of explanation (0.2)
    4. Conclusion clarity (0.2)
    5. Conciseness (0.1)
    """
    metrics = {
        "steps": any(f"{i}." in response for i in range(1, 10)),
        "logical_flow": any(word in response.lower() for word in ["because", "therefore", "since", "as a result"]),
        "depth": len(response.split()) >= 50,
        "conclusion": any(word in response.lower() for word in ["in conclusion", "therefore", "thus", "finally"]),
        "concise": len(response.split()) <= 200
    }
    
    score = (
        0.3 * int(metrics["steps"]) +
        0.2 * int(metrics["logical_flow"]) +
        0.2 * int(metrics["depth"]) +
        0.2 * int(metrics["conclusion"]) +
        0.1 * int(metrics["concise"])
    )
    
    return score, metrics

# Initialize wandb
wandb.init(project="llama-grpo-finetuning", name="reasoning-enhancement")

# Training loop
num_epochs = 3
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    for batch_idx, batch in enumerate(ppo_trainer.dataloader):
        # Generate responses
        query_tensors = tokenizer(batch["instruction"], return_tensors="pt", padding=True).to(device)
        response = ppo_trainer.generate(
            query_tensors,
            max_new_tokens=200,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2
        )
        
        # Decode responses
        response_texts = [tokenizer.decode(r, skip_special_tokens=True) for r in response]
        
        # Compute rewards
        rewards = []
        metrics_list = []
        for r in response_texts:
            score, metrics = compute_reasoning_quality(r)
            rewards.append(score)
            metrics_list.append(metrics)
        
        rewards = torch.tensor(rewards).to(device)
        
        # PPO step
        stats = ppo_trainer.step(query_tensors, response, rewards)
        
        # Log metrics
        if batch_idx % 10 == 0:
            avg_metrics = {
                "steps_ratio": sum(m["steps"] for m in metrics_list) / len(metrics_list),
                "logical_flow_ratio": sum(m["logical_flow"] for m in metrics_list) / len(metrics_list),
                "depth_ratio": sum(m["depth"] for m in metrics_list) / len(metrics_list),
                "conclusion_ratio": sum(m["conclusion"] for m in metrics_list) / len(metrics_list),
                "concise_ratio": sum(m["concise"] for m in metrics_list) / len(metrics_list)
            }
            
            wandb.log({
                "epoch": epoch,
                "batch": batch_idx,
                "mean_reward": rewards.mean().item(),
                **avg_metrics,
                **stats
            })
            
            print(f"Batch {batch_idx}: Mean reward = {rewards.mean():.3f}")

wandb.finish()

## Model Evaluation

We'll evaluate the model on various reasoning tasks, following evaluation approaches from:
- theLMbook's evaluation metrics
- Unsloth's reasoning assessment
- DeepSeek-R1 paper benchmarks

In [None]:
# Test cases for different reasoning types
test_cases = {
    "scientific": [
        "Explain why ice floats on water using molecular principles.",
        "How does the greenhouse effect work? Explain the process."
    ],
    "mathematical": [
        "If a rectangle has length 8 and width 6, what is its area and perimeter? Show your work.",
        "Solve: 3x - 7 = 14. Explain each step of your solution."
    ],
    "logical": [
        "All mammals are warm-blooded. Dolphins are mammals. What can we conclude about dolphins?",
        "If it's sunny, Alice goes for a walk. Alice didn't go for a walk today. What can we conclude?"
    ],
    "causal": [
        "Why do leaves change color in autumn? Explain the causal chain.",
        "How does lack of sleep affect cognitive performance? Describe the mechanisms."
    ]
}

# Run evaluation
print("=== Model Evaluation Results ===\n")

all_metrics = []
for category, prompts in test_cases.items():
    print(f"\n{category.upper()} REASONING TASKS:\n")
    
    for prompt in prompts:
        print(f"Prompt: {prompt}")
        
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        outputs = model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.2
        )
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Evaluate response
        score, metrics = compute_reasoning_quality(response)
        all_metrics.append(metrics)
        
        print(f"\nResponse:\n{response}")
        print(f"\nScore: {score:.2f}")
        print(f"Metrics: {metrics}")
        print("-" * 80)

# Calculate overall statistics
total_responses = len(all_metrics)
overall_stats = {
    "step_by_step": sum(1 for m in all_metrics if m["steps"]) / total_responses * 100,
    "reasoning_markers": sum(1 for m in all_metrics if m["logical_flow"]) / total_responses * 100,
    "depth": sum(1 for m in all_metrics if m["depth"]) / total_responses * 100,
    "clear_conclusions": sum(1 for m in all_metrics if m["conclusion"]) / total_responses * 100,
    "conciseness": sum(1 for m in all_metrics if m["concise"]) / total_responses * 100
}

print("\n=== Overall Statistics ===")
for metric, value in overall_stats.items():
    print(f"{metric}: {value:.1f}%")

# Save the model and evaluation results
output_dir = "llama-1b-grpo-finetuned"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

# Save evaluation results
import json
with open(f"{output_dir}/evaluation_results.json", "w") as f:
    json.dump({
        "test_cases": test_cases,
        "overall_stats": overall_stats
    }, f, indent=2)

print(f"\nModel and evaluation results saved to: {output_dir}")