# GRPO Fine-tuning for Qwen2-VL-2B-Instruct\n
## First Principles Explanations using Custom Reward Function\n
\n
This notebook implements GRPO (Group Relative Policy Optimization) training for fine-tuning Qwen2-VL-2B-Instruct to generate better first principles explanations using a custom reward function.\n
\n
### Key Features:\n
- Custom reward function based on Feynman's teaching principles\n
- Multi-component reward evaluation (analogies, clarity, engagement, etc.)\n
- LoRA fine-tuning for efficient training\n
- Weights & Biases integration for monitoring\n
- Optimized for Google Colab

In [None]:
## 🚀 Setup and Installation
First, let's install all required packages for Colab environment


In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers datasets accelerate peft bitsandbytes
!pip install trl wandb huggingface_hub
!pip install nltk textstat
!pip install qwen-vl-utils

# Download NLTK data
import nltk
nltk.download('punkt')
nltk.download('stopwords')

print("✅ All packages installed successfully!")


In [None]:
## 🔧 Mount Google Drive and Setup Environment


In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set up working directory
import os
os.chdir('/content')

# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
## 🔑 Authentication Setup
Login to Hugging Face and Weights & Biases


In [None]:
from huggingface_hub import login
import wandb

# Login to Hugging Face
print("Please enter your Hugging Face token:")
login()

# Login to Weights & Biases
print("Please enter your W&B API key:")
wandb.login()

print("✅ Authentication completed!")


In [None]:
## 🎯 Custom Reward Function Implementation
This reward function evaluates first principles explanations based on multiple criteria


In [None]:
import re 
import math
import nltk
from typing import Dict, List, Tuple, Any
from textstat import flesch_reading_ease, flesch_kincaid_grade
from collections import Counter
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import random

class FirstPrinciplesRewardFunction:
    """
    Comprehensive reward function for evaluating first principles explanations
    following the Feynman method: simple analogies, step-by-step reasoning, 
    engaging storytelling, and fundamental understanding.
    """
    def __init__(self, model_name: str = "Qwen/Qwen2-VL-2B-Instruct"):
        self.weights = {
            'analogy_quality': 0.20,
            'step_by_step_reasoning': 0.15,
            'fundamental_understanding': 0.20,
            'overall_engagement': 0.15,
            'clarity': 0.15,
            'completeness': 0.10,
            'avoid_jargon': 0.05
        }
        
        try:
            self.sentiment_analyzer = pipeline(
                "sentiment-analysis",
                model="cardiffnlp/twitter-roberta-base-sentiment-latest",
                device=0 if torch.cuda.is_available() else -1
            )
        except Exception as e:
            self.sentiment_analyzer = None
            print(f"Warning: Could not load sentiment analyzer: {e}")

        self.jargon_terms = {
            'utilize', 'paradigm', 'synergy', 'leverage', 'optimize', 'streamline',
            'methodology', 'framework', 'infrastructure', 'scalable', 'robust',
            'innovative', 'cutting-edge', 'state-of-the-art', 'holistic', 'comprehensive'
        }
        
        self.first_principles_indicators = {
            'fundamental_starters': [
                'imagine', 'think of', 'picture', 'let\'s start with', 'at its core',
                'fundamentally', 'basically', 'essentially', 'from the beginning',
                'the basic idea', 'the foundation'
            ],
            'analogy_patterns': [
                'like', 'similar', 'imagine', 'think of', 'picture', 'as if',
                'it\'s like when', 'just like', 'similar to', 'comparable to'
            ],
            'step_indicators': [
                'first', 'second', 'third', 'next', 'then', 'after that',
                'step by step', 'one by one', 'gradually', 'building up'
            ],
            'engagement_patterns': [
                'does this', 'do you see', 'can you picture', 'have you noticed',
                'does this help', 'make sense', 'clear now', 'understand how'
            ]
        }

    def evaluate_analogy_quality(self, response: str) -> float:
        """Evaluate the quality of analogies in the response."""
        score = 0.0
        response_lower = response.lower()
        
        analogy_count = 0
        for pattern in self.first_principles_indicators['analogy_patterns']:
            analogy_count += len(re.findall(rf'\\b{pattern}\\b', response_lower))

        if analogy_count > 0:
            score += 0.3
        
        concrete_examples = [
            'ball', 'car', 'house', 'water', 'air', 'food', 'game', 'toy',
            'bicycle', 'seesaw', 'playground', 'kitchen', 'garden', 'road',
            'bridge', 'ladder', 'puzzle', 'painting', 'story', 'movie'
        ]

        concrete_count = sum(1 for word in response_lower.split() if word in concrete_examples)
        score += min(0.4, concrete_count * 0.1)

        sensory_words = [
            'see', 'feel', 'hear', 'touch', 'taste', 'smell', 'warm', 'cold',
            'bright', 'dark', 'smooth', 'rough', 'loud', 'quiet'
        ]

        sensory_count = sum(1 for word in response_lower.split() if word in sensory_words)
        score += min(0.3, sensory_count * 0.05)

        return min(1.0, score)

    def evaluate_step_by_step_reasoning(self, response: str) -> float:
        """Evaluate if the explanation follows a logical, step-by-step progression"""
        score = 0.0
        response_lower = response.lower()
        
        step_indicators = self.first_principles_indicators['step_indicators']
        step_count = sum(1 for indicator in step_indicators 
                        if indicator in response_lower)
        
        score += min(0.4, step_count * 0.1)
        
        connectors = [
            'because', 'so', 'therefore', 'as a result', 'this means',
            'which leads to', 'causing', 'resulting in', 'this is why'
        ]
        connector_count = sum(1 for connector in connectors 
                            if connector in response_lower)
        
        score += min(0.3, connector_count * 0.1)
        
        sentences = nltk.sent_tokenize(response)
        if len(sentences) >= 3:
            early_sentence_length = sum(len(s.split()) for s in sentences[:len(sentences)//2])
            later_sentence_length = sum(len(s.split()) for s in sentences[len(sentences)//2:])
            
            if later_sentence_length > early_sentence_length:
                score += 0.3
        
        return min(1.0, score)

    def evaluate_fundamental_concepts(self, response: str) -> float:
        """Evaluate if the explanation addresses fundamental concepts"""
        score = 0.0
        response_lower = response.lower()
        
        fundamental_starters = self.first_principles_indicators['fundamental_starters']
        starter_count = sum(1 for starter in fundamental_starters 
                          if starter in response_lower)
        
        score += min(0.4, starter_count * 0.2)
        
        why_patterns = ['why', 'reason', 'cause', 'because', 'due to', 'leads to']
        why_count = sum(1 for pattern in why_patterns 
                       if pattern in response_lower)
        
        score += min(0.3, why_count * 0.05)
        
        building_blocks = [
            'basic', 'fundamental', 'core', 'essential', 'underlying',
            'foundation', 'principle', 'rule', 'law', 'truth'
        ]
        building_count = sum(1 for block in building_blocks 
                           if block in response_lower)
        
        score += min(0.3, building_count * 0.1)
        
        return min(1.0, score)

    def evaluate_engagement(self, response: str) -> float:
        """Evaluate how engaging and interactive the explanation is"""
        score = 0.0
        response_lower = response.lower()
        
        engagement_patterns = self.first_principles_indicators['engagement_patterns']
        engagement_count = sum(1 for pattern in engagement_patterns 
                             if pattern in response_lower)
        
        score += min(0.4, engagement_count * 0.1)
        
        question_count = response.count('?')
        score += min(0.3, question_count * 0.1)
        
        if self.sentiment_analyzer:
            try:
                sentiment = self.sentiment_analyzer(response[:512])
                if sentiment[0]['label'] == 'POSITIVE':
                    score += 0.3
            except:
                pass
        
        return min(1.0, score)

    def evaluate_clarity(self, response: str) -> float:
        """Evaluate clarity using readability metrics"""
        score = 0.0
        
        try:
            flesch_score = flesch_reading_ease(response)
            if flesch_score >= 60:
                score += 0.4
            elif flesch_score >= 50:
                score += 0.3
            elif flesch_score >= 40:
                score += 0.2
            else:
                score += 0.1
        except:
            score += 0.2
        
        sentences = nltk.sent_tokenize(response)
        if sentences:
            avg_sentence_length = sum(len(s.split()) for s in sentences) / len(sentences)
            if 10 <= avg_sentence_length <= 20:
                score += 0.3
            elif 8 <= avg_sentence_length <= 25:
                score += 0.2
            else:
                score += 0.1
        
        simple_words = ['the', 'a', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for']
        word_count = len(response.split())
        simple_word_ratio = sum(1 for word in response.lower().split() 
                               if word in simple_words) / max(word_count, 1)
        
        if simple_word_ratio >= 0.3:
            score += 0.3
        
        return min(1.0, score)

    def evaluate_completeness(self, response: str) -> float:
        """Evaluate if the explanation is complete"""
        score = 0.0
        
        word_count = len(response.split())
        if 100 <= word_count <= 300:
            score += 0.5
        elif 60 <= word_count <= 400:
            score += 0.3
        else:
            score += 0.1
        
        conclusion_indicators = [
            'so', 'therefore', 'in summary', 'to summarize', 'overall',
            'does this help', 'now you can see', 'this explains'
        ]
        
        response_lower = response.lower()
        for indicator in conclusion_indicators:
            if indicator in response_lower:
                score += 0.3
                break
        
        example_indicators = ['example', 'for instance', 'like when', 'such as']
        example_count = sum(1 for indicator in example_indicators 
                          if indicator in response_lower)
        
        score += min(0.2, example_count * 0.1)
        
        return min(1.0, score)

    def evaluate_jargon_avoidance(self, response: str) -> float:
        """Penalize use of jargon"""
        words = response.lower().split()
        jargon_count = sum(1 for word in words if word in self.jargon_terms)
        total_words = len(words)
        
        if total_words == 0:
            return 1.0
        
        jargon_ratio = jargon_count / total_words

        if jargon_ratio == 0:
            return 1.0
        elif jargon_ratio <= 0.02:
            return 0.8
        elif jargon_ratio <= 0.05:
            return 0.6
        else:
            return 0.3

    def compute_reward(self, response: str, context: str = None) -> Dict[str, float]:
        """Compute the overall reward score"""
        scores = {
            'analogy_quality': self.evaluate_analogy_quality(response),
            'step_by_step': self.evaluate_step_by_step_reasoning(response),
            'fundamental_concepts': self.evaluate_fundamental_concepts(response),
            'engagement': self.evaluate_engagement(response),
            'clarity': self.evaluate_clarity(response),
            'completeness': self.evaluate_completeness(response),
            'avoid_jargon': self.evaluate_jargon_avoidance(response)
        }
        
        total_score = sum(scores[key] * self.weights[key] for key in scores.keys())
        
        scores['total'] = total_score
        scores['normalized'] = min(1.0, max(0.0, total_score))
        
        return scores

# Initialize the reward function
reward_evaluator = FirstPrinciplesRewardFunction()

def reward_opening_hook(response: str, context: str = None) -> float:
    """Main reward function for GRPO training."""
    scores = reward_evaluator.compute_reward(response, context)
    return scores['normalized']

def detailed_reward_analysis(response: str, context: str = None) -> Dict[str, Any]:
    """Returns detailed breakdown of reward components"""
    return reward_evaluator.compute_reward(response, context)

def reward_with_feedback(response: str, context: str = None) -> Tuple[float, str]:
    """Returns reward score and human-readable feedback"""
    scores = reward_evaluator.compute_reward(response, context)
    
    feedback_parts = []
    
    if scores['analogy_quality'] < 0.5:
        feedback_parts.append("Consider adding more concrete analogies or examples.")
    
    if scores['step_by_step'] < 0.5:
        feedback_parts.append("Try breaking down the explanation into clearer steps.")
    
    if scores['fundamental_concepts'] < 0.5:
        feedback_parts.append("Focus more on the fundamental 'why' and underlying principles.")
    
    if scores['engagement'] < 0.5:
        feedback_parts.append("Make the explanation more engaging with questions.")
    
    if scores['clarity'] < 0.5:
        feedback_parts.append("Simplify the language and sentence structure.")
    
    if scores['completeness'] < 0.5:
        feedback_parts.append("Provide a more complete explanation with examples.")
    
    if scores['avoid_jargon'] < 0.7:
        feedback_parts.append("Avoid technical jargon and use simpler language.")

    feedback = " ".join(feedback_parts) if feedback_parts else "Great first principles explanation!"
    
    return scores['normalized'], feedback

print("✅ Reward function implemented successfully!")


In [None]:
## ⚙️ Configuration and Model Setup
Set up the training configuration and load the model with LoRA


In [None]:
# Configuration
from transformers import (
    AutoTokenizer, 
    BitsAndBytesConfig, 
    TrainingArguments,
    Qwen2VLForConditionalGeneration,
    prepare_model_for_kbit_training
)
from trl import GRPOTrainer, GRPOConfig
from peft import LoraConfig, get_peft_model
from datasets import Dataset
import json

# Model and dataset configuration
model_name = "KhushalM/Qwen2-VL-2B-Instruct-SFT"
output_dir = "./grpo_results"
hub_model_id = "KhushalM/Qwen2-VL-2B-Instruct-GRPO-FirstPrinciples"

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Initialize W&B
wandb.init(
    project="qwen2-vl-2b-instruct-grpo-first-principles",
    config={
        "model": model_name,
        "task": "First Principles Explanations",
        "reward_strategy": "Multi-component First Principles Reward",
        "lora_r": 32,
        "lora_alpha": 64,
        "learning_rate": 1e-5,
        "batch_size": 2,
        "gradient_accumulation_steps": 4,
        "num_train_epochs": 3,
        "reward_components": {
            "analogy_quality": 0.20,
            "step_by_step": 0.15,
            "fundamental_concepts": 0.20,
            "engagement": 0.15,
            "clarity": 0.15,
            "completeness": 0.10,
            "avoid_jargon": 0.05
        }
    }
)

print("✅ Configuration completed!")


In [None]:
# Quantization configuration for efficient training
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

print("Loading model...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)
model = model.to(device)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Prepare model for training
model = prepare_model_for_kbit_training(model)
model.config.use_cache = False
print("Model prepared for kbit training")

# LoRA Setup
peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, peft_config)
print("Model converted to LoRA")
print(f"Trainable Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Total Parameters: {sum(p.numel() for p in model.parameters()):,}")

print("✅ Model loaded and configured successfully!")


In [None]:
## 📁 Dataset Upload Options
Choose one of the following methods to upload your dataset


In [None]:
# Option 1: Upload file directly to Colab
from google.colab import files
import json
import os

def upload_dataset_file():
    """Upload dataset file directly to Colab"""
    print("📤 Upload your dataset file (JSON format)")
    print("Expected format: List of dictionaries with 'messages' key")
    print("Each message should have 'role' and 'content' fields")
    
    uploaded = files.upload()
    
    if uploaded:
        filename = list(uploaded.keys())[0]
        print(f"✅ File uploaded: {filename}")
        
        # Validate file format
        try:
            with open(filename, 'r') as f:
                data = json.load(f)
            
            if isinstance(data, list) and len(data) > 0:
                if 'messages' in data[0]:
                    print(f"✅ Dataset format validated!")
                    print(f"📊 Dataset contains {len(data)} examples")
                    return filename
                else:
                    print("❌ Invalid format: Each item should have 'messages' key")
            else:
                print("❌ Invalid format: Dataset should be a list of dictionaries")
                
        except json.JSONDecodeError:
            print("❌ Invalid JSON format")
        except Exception as e:
            print(f"❌ Error validating file: {e}")
    
    return None

# Uncomment the line below to upload a file
# dataset_filename = upload_dataset_file()


In [None]:
# Option 2: Download from Google Drive
def download_from_drive(drive_path, local_filename="dataset.json"):
    """Download dataset from Google Drive"""
    print(f"📥 Downloading from Google Drive: {drive_path}")
    
    try:
        # Copy file from Google Drive to local Colab storage
        import shutil
        shutil.copy(drive_path, local_filename)
        
        # Validate the downloaded file
        with open(local_filename, 'r') as f:
            data = json.load(f)
        
        if isinstance(data, list) and len(data) > 0 and 'messages' in data[0]:
            print(f"✅ Dataset downloaded and validated!")
            print(f"📊 Dataset contains {len(data)} examples")
            return local_filename
        else:
            print("❌ Invalid dataset format")
            
    except FileNotFoundError:
        print(f"❌ File not found: {drive_path}")
        print("Make sure Google Drive is mounted and the path is correct")
    except Exception as e:
        print(f"❌ Error downloading file: {e}")
    
    return None

# Example usage (uncomment and modify path):
# drive_dataset_path = "/content/drive/MyDrive/your_folder/structured_dataset.json"
# dataset_filename = download_from_drive(drive_dataset_path)


In [None]:
# Option 4: Load from Hugging Face Hub
from datasets import load_dataset

def load_from_huggingface(dataset_name, split="train", config_name=None):
    """Load dataset from Hugging Face Hub"""
    print(f"🤗 Loading from Hugging Face: {dataset_name}")
    
    try:
        # Load dataset from HF Hub
        if config_name:
            dataset = load_dataset(dataset_name, config_name, split=split)
        else:
            dataset = load_dataset(dataset_name, split=split)
        
        # Convert to expected format if needed
        data = []
        for item in dataset:
            # Check if already in correct format
            if 'messages' in item:
                data.append(item)
            # Convert common formats
            elif 'conversations' in item:
                data.append({'messages': item['conversations']})
            elif 'prompt' in item and 'response' in item:
                messages = [
                    {"role": "user", "content": item['prompt']},
                    {"role": "assistant", "content": item['response']}
                ]
                data.append({'messages': messages})
            else:
                print(f"⚠️ Unknown format for item: {list(item.keys())}")
                continue
        
        if data:
            # Save as JSON file
            local_filename = "hf_dataset.json"
            with open(local_filename, 'w') as f:
                json.dump(data, f, indent=2)
            
            print(f"✅ Dataset loaded and converted!")
            print(f"📊 Dataset contains {len(data)} examples")
            return local_filename
        else:
            print("❌ No valid data found in dataset")
            
    except Exception as e:
        print(f"❌ Error loading from Hugging Face: {e}")
    
    return None

# Example usage (uncomment and provide dataset name):
# hf_dataset_name = "your_username/your_dataset"
# dataset_filename = load_from_huggingface(hf_dataset_name)


In [None]:
# Dataset loading - Choose your method and uncomment the appropriate line

# Method 1: Upload file directly
# dataset_filename = upload_dataset_file()

# Method 2: Load from Google Drive (update the path)
# drive_path = "/content/drive/MyDrive/your_folder/structured_dataset.json"
# dataset_filename = download_from_drive(drive_path)

# Method 4: Load from Hugging Face Hub
# hf_dataset_name = "your_username/your_dataset"
# dataset_filename = load_from_huggingface(hf_dataset_name)

# Verify dataset was loaded
if dataset_filename and os.path.exists(dataset_filename):
    print(f"✅ Dataset file ready: {dataset_filename}")
    
    # Preview the dataset structure
    with open(dataset_filename, 'r') as f:
        sample_data = json.load(f)
    
    print(f"\n📋 Dataset Preview:")
    print(f"Total examples: {len(sample_data)}")
    print(f"First example structure:")
    if len(sample_data) > 0:
        example = sample_data[0]
        print(f"  Keys: {list(example.keys())}")
        if 'messages' in example:
            print(f"  Messages count: {len(example['messages'])}")
            print(f"  Message roles: {[msg.get('role', 'unknown') for msg in example['messages']]}")
            
            # Show first user message
            user_msg = next((msg for msg in example['messages'] if msg.get('role') == 'user'), None)
            if user_msg:
                preview = user_msg['content'][:100] + "..." if len(user_msg['content']) > 100 else user_msg['content']
                print(f"  Sample question: \"{preview}\"")
else:
    print("❌ No dataset loaded. Please use one of the upload methods above.")
    dataset_filename = None


In [None]:
# Load and process the dataset for training
if dataset_filename:
    print("📂 Processing dataset for training...")
    
    try:
        # Load the dataset
        with open(dataset_filename, "r") as f:
            dataset_data = json.load(f)
        
        # Convert to HuggingFace Dataset format
        dataset = Dataset.from_list(dataset_data)
        dataset = dataset.train_test_split(test_size=0.1, seed=42)
        train_dataset = dataset["train"]
        val_dataset = dataset["test"]
        
        print(f"✅ Dataset processed successfully!")
        print(f"📊 Train dataset size: {len(train_dataset)}")
        print(f"📊 Validation dataset size: {len(val_dataset)}")
        
        # Dataset formatting function
        def format_dataset(sample):
            """Format the sample for training"""
            messages = sample["messages"]
            # Extract the user question for context
            user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "")
            
            # Format as instruction-response pair
            prompt = tokenizer.apply_chat_template(
                messages[:-1],  # All messages except the last (assistant) one
                tokenize=False,
                add_generation_prompt=True
            )
            
            return {
                "prompt": prompt,
                "response": messages[-1]["content"],  # Assistant's response
                "context": user_message
            }

        # Format datasets
        print("🔄 Formatting datasets for GRPO training...")
        train_formatted = train_dataset.map(format_dataset)
        val_formatted = val_dataset.map(format_dataset)
        
        print("✅ Dataset formatting completed!")
        print("🎯 Ready for GRPO training!")
        
    except Exception as e:
        print(f"❌ Error processing dataset: {e}")
        print("Please check your dataset format and try again.")
        train_formatted = None
        val_formatted = None
        
else:
    print("❌ No dataset available. Please load a dataset first.")
    train_formatted = None
    val_formatted = None


In [None]:
## 📊 Dataset Loading and Preprocessing
Load and format the dataset for GRPO training

**Note:** Update the dataset path to point to your dataset file. You can either:
1. Upload your dataset to Google Drive and mount it
2. Upload directly to Colab files
3. Use a dataset from Hugging Face Hub


In [None]:
# Dataset configuration - UPDATE THESE PATHS AS NEEDED
# Option 1: From Google Drive (update path)
# dataset_path = "/content/drive/MyDrive/your_dataset_path/structured_dataset.json"

# Option 2: From uploaded file in Colab
# dataset_path = "/content/structured_dataset.json"

# Option 3: Create a sample dataset for testing (remove this and use real data)
sample_dataset = [
    {
        "messages": [
            {
                "role": "system",
                "content": "You are an expert educator who explains concepts from first principles like Richard Feynman. Start with fundamental truths, use simple analogies, and avoid jargon."
            },
            {
                "role": "user", 
                "content": "Why do objects fall to the ground when dropped?"
            },
            {
                "role": "assistant",
                "content": "Imagine you have a stretched rubber sheet and you place a heavy ball in the middle. The sheet bends downwards, right? Now, if you roll a smaller ball nearby, it will start rolling toward the heavier ball because of the dip. This is a simple way to picture how gravity works. Gravity is like the Earth making a 'dip' in space that pulls things toward it. When you let go of an object, it falls because the Earth is pulling it toward its center, similar to how the heavy ball makes the rubber sheet dip. The bigger the object (like Earth), the stronger the pull. Does this help you see why objects fall when dropped?"
            }
        ]
    }
] * 50  # Repeat for demo purposes

# For real training, uncomment one of the options above and comment out the sample
print("Using sample dataset for demonstration. Replace with your actual dataset!")

# Load dataset
try:
    # For real data: 
    # with open(dataset_path, "r") as f:
    #     dataset_data = json.load(f)
    
    # For demo:
    dataset_data = sample_dataset
    
    dataset = Dataset.from_list(dataset_data)
    dataset = dataset.train_test_split(test_size=0.1)
    train_dataset = dataset["train"]
    val_dataset = dataset["test"]
    
    print(f"✅ Dataset loaded successfully!")
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(val_dataset)}")
    
except Exception as e:
    print(f"❌ Error loading dataset: {e}")
    print("Please check your dataset path and format.")

# Dataset formatting function
def format_dataset(sample):
    """Format the sample for training"""
    messages = sample["messages"]
    # Extract the user question for context
    user_message = next((msg["content"] for msg in messages if msg["role"] == "user"), "")
    
    # Format as instruction-response pair
    prompt = tokenizer.apply_chat_template(
        messages[:-1],  # All messages except the last (assistant) one
        tokenize=False,
        add_generation_prompt=True
    )
    
    return {
        "prompt": prompt,
        "response": messages[-1]["content"],  # Assistant's response
        "context": user_message
    }

# Format datasets
print("Formatting datasets...")
train_formatted = train_dataset.map(format_dataset)
val_formatted = val_dataset.map(format_dataset)
print("✅ Dataset formatting completed!")


In [None]:
## 🚀 GRPO Training Setup and Execution
Configure the GRPO trainer and start training


In [None]:
# Reward function wrapper for GRPO
def compute_reward(responses, contexts=None):
    """
    Compute the reward for the response using our custom reward function
    """
    if isinstance(responses, str):
        responses = [responses]
    if contexts is None:
        contexts = [None] * len(responses)
    elif isinstance(contexts, str):
        contexts = [contexts]
    
    rewards = []
    for response, context in zip(responses, contexts):
        score, feedback = reward_with_feedback(response, context)
        rewards.append(score)
    return rewards

# GRPO Configuration
grpo_config = GRPOConfig(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    eval_steps=50,
    save_steps=50,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_reward",
    greater_is_better=True,
    report_to="wandb",
    push_to_hub=True,
    hub_model_id=hub_model_id,
    hub_strategy="every_save",
    dataloader_num_workers=2,
    remove_unused_columns=False,
    # GRPO specific parameters
    max_new_tokens=1024,
    num_generations=4,
    temperature=0.7,
    kl_penalty="kl",
    kl_coef=0.05,
    reward_model_path=None,
    bf16=True if device == "cuda" else False,
)

print("✅ GRPO configuration completed!")


In [None]:
# Initialize GRPO Trainer
if train_formatted is not None and val_formatted is not None:
    print("Initializing GRPO Trainer...")
    trainer = GRPOTrainer(
        model=model,
        tokenizer=tokenizer,
        args=grpo_config,
        train_dataset=train_formatted,
        eval_dataset=val_formatted,
        reward_function=compute_reward,
        data_collator=None,
    )

    print("✅ GRPO Trainer initialized successfully!")
    print(f"Number of training examples: {len(train_formatted)}")
    print(f"Number of validation examples: {len(val_formatted)}")
else:
    print("❌ Cannot initialize trainer: No dataset available.")
    print("Please load a dataset first using one of the upload methods above.")
    trainer = None

# Custom callback for detailed logging
class RewardLoggingCallback:
    def __init__(self):
        self.step_count = 0
    
    def on_log(self, logs):
        """Log detailed reward analysis periodically"""
        if self.step_count % 50 == 0:  # Every 50 steps
            # Sample a few responses for detailed analysis
            sample_responses = logs.get('sample_responses', [])
            if sample_responses:
                for i, response in enumerate(sample_responses[:3]):  # First 3 samples
                    scores = detailed_reward_analysis(response)
                    wandb.log({
                        f"detailed_reward_sample_{i}/analogy_quality": scores['analogy_quality'],
                        f"detailed_reward_sample_{i}/step_by_step": scores['step_by_step'],
                        f"detailed_reward_sample_{i}/fundamental_concepts": scores['fundamental_concepts'],
                        f"detailed_reward_sample_{i}/engagement": scores['engagement'],
                        f"detailed_reward_sample_{i}/clarity": scores['clarity'],
                        f"detailed_reward_sample_{i}/completeness": scores['completeness'],
                        f"detailed_reward_sample_{i}/avoid_jargon": scores['avoid_jargon'],
                        f"detailed_reward_sample_{i}/total": scores['total'],
                    })
        self.step_count += 1

# Add callback
reward_callback = RewardLoggingCallback()

print("📋 Training setup completed! Ready to start training...")


In [None]:
# Start GRPO training
if trainer is not None:
    print("🚀 Starting GRPO training...")
    print("This will optimize for first principles explanations using our custom reward function")
    print("Training may take several hours depending on your GPU and dataset size...")

    try:
        trainer.train()
        print("✅ Training completed successfully!")
        
        # Save the final model
        print("💾 Saving final model...")
        trainer.save_model(os.path.join(output_dir, "final_model"))
        tokenizer.save_pretrained(os.path.join(output_dir, "final_model"))
        
        # Push to hub if configured
        try:
            trainer.push_to_hub(commit_message="GRPO fine-tuned model for first principles explanations")
            print("✅ Model pushed to Hugging Face Hub successfully!")
        except Exception as e:
            print(f"⚠️ Could not push to hub: {e}")
        
        print("🎉 Training process completed successfully!")
        
    except Exception as e:
        print(f"❌ Training failed with error: {e}")
        print("Check the error message above for details.")
        raise e
        
else:
    print("❌ Cannot start training: No trainer available.")
    print("Please make sure you have:")
    print("1. Loaded a dataset using one of the upload methods")
    print("2. Successfully initialized the model and trainer")
    print("3. Check for any error messages above")


In [None]:
## 🧪 Test the Trained Model
Test the model with sample prompts to see the improvement


In [None]:
# Test the trained model
if trainer is not None and 'model' in locals():
    print("🧪 Testing the trained model...")
else:
    print("⚠️ No trained model available for testing.")
    print("Please complete the training process first.")

# Test prompts
test_prompts = [
    "Why do objects fall to the ground when dropped?",
    "How does a microwave oven heat food?",
    "Why is the sky blue?",
    "How do airplanes fly?",
    "What makes magnets work?"
]

def test_model_response(prompt, max_length=300):
    """Generate and evaluate a response from the trained model"""
    
    test_messages = [
        {
            "role": "system", 
            "content": "You are an expert educator who explains concepts from first principles like Richard Feynman. Start with fundamental truths, use simple analogies, and avoid jargon. Use a storytelling tone and follow a step by step explanation style."
        },
        {
            "role": "user", 
            "content": prompt
        }
    ]
    
    formatted_test = tokenizer.apply_chat_template(
        test_messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = tokenizer(formatted_test, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the assistant's response
    if "<|im_start|>assistant" in response:
        assistant_response = response.split("<|im_start|>assistant")[-1].strip()
    else:
        assistant_response = response.split(formatted_test)[-1].strip()
    
    return assistant_response

# Test each prompt (only if model is available)
if trainer is not None and 'model' in locals():
    for i, prompt in enumerate(test_prompts):
        print(f"\n{'='*60}")
        print(f"TEST {i+1}: {prompt}")
        print('='*60)
        
        try:
            response = test_model_response(prompt)
            print(f"🤖 Model Response:\n{response}")
            
            # Evaluate with reward function
            reward_score, feedback = reward_with_feedback(response)
            detailed_scores = detailed_reward_analysis(response)
            
            print(f"\n📊 Evaluation:")
            print(f"Overall Reward Score: {reward_score:.3f}")
            print(f"Feedback: {feedback}")
            print(f"\nDetailed Scores:")
            for component, score in detailed_scores.items():
                if component not in ['total', 'normalized']:
                    print(f"  {component}: {score:.3f}")
            
            # Log to W&B
            wandb.log({
                f"test_prompt_{i+1}/reward_score": reward_score,
                f"test_prompt_{i+1}/response_length": len(response.split()),
                f"test_prompt_{i+1}/analogy_quality": detailed_scores['analogy_quality'],
                f"test_prompt_{i+1}/clarity": detailed_scores['clarity'],
                f"test_prompt_{i+1}/engagement": detailed_scores['engagement'],
            })
            
        except Exception as e:
            print(f"❌ Error testing prompt: {e}")

    print(f"\n🎯 Model testing completed!")
    print("Check your W&B dashboard for detailed metrics and training progress.")
else:
    print("🔍 Skipping model testing - no trained model available.")


In [None]:
# Clean up and finish
print("🧹 Cleaning up...")

# Finish W&B run
wandb.finish()

# Clear GPU memory
torch.cuda.empty_cache() if torch.cuda.is_available() else None

print("✅ Cleanup completed!")
print("\n🎉 GRPO Training Notebook Execution Complete! 🎉")
print("\nNext steps:")
print("1. Check your trained model in the output directory")
print("2. Review metrics on your W&B dashboard") 
print("3. Test the model with your own prompts")
print("4. Consider fine-tuning hyperparameters for better results")
print("\nHappy training! 🚀")
