# GRPO Training Notebook - CourtRankRL

Ez a notebook a CourtRankRL GRPO alapú reranking modell tanítását végzi el RunPod környezetben.

## Agents.md specifikáció alapján:
- Qwen/Qwen3-4B-Instruct-2507 model QLoRA adapterekkel
- TRL GRPOTrainer használata
- NDCG@10 reward calculation
- Hungarian status messages

## Előfeltételek:
- slate JSONL file (baseline candidate slates)
- HF token beállítva
- Megfelelő GPU (A100/H100 ajánlott)


In [None]:
# Install required packages
!pip install torch transformers peft trl datasets accelerate bitsandbytes
!pip install jsonlines numpy


In [None]:
import os
import json
import torch
import numpy as np
from pathlib import Path
from typing import Dict, List
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import GRPOTrainer, GRPOConfig

print("GRPO Training Notebook inicializálva")
print(f"PyTorch verzió: {torch.__version__}")
print(f"GPU elérhető: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
# Configuration - agents.md alapján
MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"
SLATE_SIZE = 20
MAX_STEPS = 1000
SAVE_STEPS = 100
LEARNING_RATE = 1e-5
LORA_RANK = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05

# Paths
SLATE_FILE = "/workspace/data/grpo_slates.jsonl"
OUTPUT_DIR = "/workspace/artifacts/grpo_policy"
METRICS_FILE = os.path.join(OUTPUT_DIR, "metrics.json")

print("Konfiguráció betöltve:")
print(f"- Model: {MODEL_NAME}")
print(f"- Slate size: {SLATE_SIZE}")
print(f"- Max steps: {MAX_STEPS}")
print(f"- Output dir: {OUTPUT_DIR}")


In [None]:
def load_slates_data(slate_file: str) -> List[Dict]:
    """Load slate data from JSONL file."""
    print(f"Slate adatok betöltése: {slate_file}")

    slates = []
    with open(slate_file, 'r', encoding='utf-8') as f:
        for line in f:
            slates.append(json.loads(line.strip()))

    print(f"Betöltött slates: {len(slates)}")
    return slates

def calculate_ndcg(ranked_indices: List[int], true_relevance: List[float], k: int = 10) -> float:
    """Calculate NDCG@k as per agents.md specification."""
    if not true_relevance or not ranked_indices:
        return 0.0

    dcg = 0.0
    for i in range(min(k, len(ranked_indices))):
        if i < len(true_relevance):
            rel = true_relevance[ranked_indices[i]]
            dcg += rel / np.log2(i + 2)

    sorted_rel = sorted(true_relevance, reverse=True)
    idcg = 0.0
    for i in range(min(k, len(sorted_rel))):
        idcg += sorted_rel[i] / np.log2(i + 2)

    return dcg / idcg if idcg > 0 else 0.0

def calculate_entropy(ranking: List[int]) -> float:
    """Calculate entropy of ranking distribution."""
    if not ranking:
        return 0.0

    counts = {}
    for idx in ranking:
        counts[idx] = counts.get(idx, 0) + 1

    probs = [count / len(ranking) for count in counts.values()]
    entropy = -sum(p * np.log(p) for p in probs if p > 0)

    return entropy

def parse_model_ranking(completion: str) -> List[int]:
    """Parse model completion to extract ranking."""
    try:
        # Extract indices from completion (example: "1,3,2,4,0" -> [1,3,2,4,0])
        numbers = [int(x.strip()) for x in completion.split(",") if x.strip().isdigit()]
        return numbers[:SLATE_SIZE]
    except:
        return list(range(SLATE_SIZE))


In [None]:
# Load slate data
slates_data = load_slates_data(SLATE_FILE)

# Convert to dataset format for GRPO trainer
def create_training_prompt(query_id: str, slate: List[Dict]) -> str:
    """Create training prompt for GRPO."""
    slate_items = []
    for i, candidate in enumerate(slate):
        slate_items.append({
            "index": i,
            "doc_id": candidate.get("doc_id", ""),
            "text": candidate.get("text", "")[:500],
            "court": candidate.get("court", ""),
            "domain": candidate.get("domain", ""),
            "year": candidate.get("year", 0),
            "bm25_score": candidate.get("bm25_score", 0.0),
            "faiss_score": candidate.get("faiss_score", 0.0),
            "rrf_score": candidate.get("rrf_score", 0.0)
        })

    context = {
        "query": query_id,
        "slate": slate_items,
        "slate_size": len(slate_items)
    }

    prompt = f"""Rangsorold a következő bírósági dokumentumokat relevancia szerint. Válaszolj számokkal vesszővel elválasztva (pl. '1,3,2,4,0').

Bírósági dokumentumok:
{json.dumps(context, ensure_ascii=False, indent=2)}

Rangsorolás:"""

    return prompt

# Create dataset
training_examples = []
for slate_data in slates_data:
    prompt = create_training_prompt(slate_data["query_id"], slate_data["slate"])
    training_examples.append({
        "prompt": prompt,
        "slate_data": slate_data["slate"]
    })

dataset = Dataset.from_list(training_examples)
print(f"Dataset létrehozva: {len(dataset)} példa")


In [None]:
# Initialize model and tokenizer
print("Modell inicializálása...")

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

# Configure LoRA
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

model = get_peft_model(model, lora_config)
print("Modell QLoRA adapterekkel inicializálva")


In [None]:
# Define reward function
def reward_function(completions, **kwargs):
    """
    Custom reward function for GRPO trainer.
    Agents.md: nDCG@10 difference with entropy bonus and variance normalization.
    """
    rewards = []

    for completion, slate_data in zip(completions, kwargs.get("slate_data", [])):
        try:
            # Parse model output to get rankings
            predicted_order = parse_model_ranking(completion)

            true_relevance = [item.get("relevance", 0) for item in slate_data]
            baseline_order = list(range(len(slate_data)))

            baseline_ndcg = calculate_ndcg(baseline_order, true_relevance, k=10)
            policy_ndcg = calculate_ndcg(predicted_order, true_relevance, k=10)

            # NDCG difference as reward
            reward = policy_ndcg - baseline_ndcg

            # Clamp negative rewards
            reward = max(reward, -1.0)

            # Entropy bonus
            entropy = calculate_entropy(predicted_order)
            reward += 0.01 * entropy

            rewards.append(reward)

        except Exception as e:
            print(f"Hiba a reward számításakor: {e}")
            rewards.append(0.0)

    return rewards

print("Reward function definiálva")


In [None]:
# Configure GRPO trainer
grpo_config = GRPOConfig(
    output_dir=OUTPUT_DIR,
    num_generations=4,
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    logging_steps=10,
    learning_rate=LEARNING_RATE,
    warmup_steps=100,
    max_completion_length=256,
    group_size=SLATE_SIZE,
    kl_penalty="none",  # Disabled as per agents.md
)

print("GRPO konfiguráció kész")


In [None]:
# Initialize and train GRPO trainer
print("GRPO trainer inicializálása...")

trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_function,
    args=grpo_config,
    train_dataset=dataset,
)

print("GRPO training indítása...")
trainer.train()

print("Training befejezve")


In [None]:
# Save artifacts
print("Artifactumok mentése...")

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Save LoRA adapter
model.save_pretrained(OUTPUT_DIR)

# Save tokenizer
tokenizer.save_pretrained(OUTPUT_DIR)

# Save training metrics
metrics = {
    "training_samples": len(dataset),
    "slate_size": SLATE_SIZE,
    "model_name": MODEL_NAME,
    "max_steps": MAX_STEPS,
    "status": "completed"
}

with open(METRICS_FILE, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, ensure_ascii=False, indent=2)

print(f"Artifactumok mentve: {OUTPUT_DIR}")
print(f"Metrikák: {METRICS_FILE}")


## Training Complete!

A GRPO training befejeződött. Az artifactumok a `/workspace/artifacts/grpo_policy/` könyvtárban találhatók:

- `adapter_model.bin` - LoRA adapter weights
- `tokenizer.json` - Tokenizer configuration
- `metrics.json` - Training metrics

Ezeket az artifactumokat le kell tölteni a helyi gépre a `data/models/grpo_policy/` könyvtárba.

## Next Steps:
1. Artifactumok letöltése
2. Helyi tesztelés: `python -m src.search.grpo_reranker`
3. Query pipeline integráció


# CourtRankRL GRPO Training - Runpod

GRPO (Group Relative Policy Optimization) training a Qwen/Qwen3-4B-Instruct-2507 modellen.

## Agents.md Specifikáció:
- Qwen/Qwen3-4B-Instruct-2507 model QLoRA fine-tuninggal
- Group size = slate length
- NDCG@10 alapú reward shaping
- Hungarian progress logging
- Artifact export /workspace/artifacts/grpo_policy/ -ba

## Előfeltételek:
- A training slates JSONL exportálva legyen
- HF token beállítva legyen
- Megfelelő GPU memória (24GB+ ajánlott)


In [None]:
# Environment setup
import os
import sys
import json
import torch
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass

# Install required packages
os.system("pip install -q transformers peft trl torch bitsandbytes")

# Set paths
ARTIFACTS_DIR = Path("/workspace/grpo_policy")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

SLATE_FILE = Path("/workspace/training_slates.jsonl")
METRICS_FILE = ARTIFACTS_DIR / "metrics.json"

print("Environment beállítva")


In [None]:
# Hungarian logging setup
import logging

class HungarianLogger:
    def __init__(self):
        self.logger = logging.getLogger("grpo_training")
        self.logger.setLevel(logging.INFO)
        
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
    
    def info(self, msg):
        self.logger.info(f"[GRPO] {msg}")
    
    def warning(self, msg):
        self.logger.warning(f"[GRPO FIGYELMEZTETÉS] {msg}")
    
    def error(self, msg):
        self.logger.error(f"[GRPO HIBA] {msg}")

logger = HungarianLogger()
logger.info("Hungarian logging beállítva")


In [None]:
# Configuration (agents.md alapján)
MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"
LORA_RANK = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05

GROUP_SIZE = 8  # Slate length
LEARNING_RATE = 1e-6
NUM_TRAIN_EPOCHS = 3
GRADIENT_ACCUMULATION_STEPS = 4
MAX_GRAD_NORM = 0.1
WARMUP_RATIO = 0.1
LOGGING_STEPS = 10
SAVE_STEPS = 100
EVAL_STEPS = 50
MAX_STEPS = 1000

REWARD_NDCG_K = 10
REWARD_ENTROPY_BONUS = 0.01
REWARD_CLAMP_NEGATIVE = True

HF_TOKEN = os.getenv('HF_TOKEN', '')

logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Group size: {GROUP_SIZE}")


In [None]:
# Data loading
def load_training_data(slate_file: Path) -> List[Dict]:
    """Load training slates from JSONL file."""
    data = []
    
    if not slate_file.exists():
        logger.error(f"Slate fájl nem található: {slate_file}")
        return data
    
    with open(slate_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                slate = json.loads(line.strip())
                data.append(slate)
                
                if line_num % 100 == 0:
                    logger.info(f"Betöltve: {line_num} slates")
                    
            except json.JSONDecodeError as e:
                logger.warning(f"Hibás JSON a {line_num}. sorban: {e}")
                continue
    
    logger.info(f"Összesen betöltve: {len(data)} slates")
    return data

training_data = load_training_data(SLATE_FILE)
if not training_data:
    logger.error("Nincs training adat - kilépés")
    sys.exit(1)


In [None]:
# Reward function (NDCG@10 based)
def calculate_ndcg(relevance_scores: List[int], k: int = 10) -> float:
    """Calculate NDCG@k."""
    if not relevance_scores:
        return 0.0
    
    dcg = 0.0
    for i in range(min(k, len(relevance_scores))):
        rel = relevance_scores[i]
        dcg += rel / (i + 1) ** 0.5
    
    # IDCG calculation
    sorted_rel = sorted(relevance_scores, reverse=True)
    idcg = 0.0
    for i in range(min(k, len(sorted_rel))):
        idcg += sorted_rel[i] / (i + 1) ** 0.5
    
    return dcg / idcg if idcg > 0 else 0.0

def reward_function(
    completions: List[str], 
    slates: List[Dict], 
    **kwargs
) -> List[float]:
    """Custom NDCG-based reward function."""
    rewards = []
    
    for completion, slate in zip(completions, slates):
        try:
            # Extract scores from model completion
            lines = completion.strip().split('\n')
            scores = []
            
            for line in lines:
                if line.strip() and any(char.isdigit() for char in line):
                    # Extract numeric scores (simple parsing)
                    parts = line.replace(',', '').split()
                    for part in parts:
                        try:
                            score = float(part)
                            if 0 <= score <= 10:  # Reasonable score range
                                scores.append(score)
                        except ValueError:
                            continue
            
            if len(scores) != len(slate['candidates']):
                logger.warning(f"Score szám mismatch: {len(scores)} vs {len(slate['candidates'])}")
                scores = [0.0] * len(slate['candidates'])
            
            # Calculate NDCG for predicted ranking
            true_relevance = [cand['relevance'] for cand in slate['candidates']]
            
            # Get predicted ranking based on scores
            predicted_ranking = sorted(
                range(len(scores)), 
                key=lambda i: scores[i], 
                reverse=True
            )
            predicted_relevance = [true_relevance[i] for i in predicted_ranking]
            
            policy_ndcg = calculate_ndcg(predicted_relevance, REWARD_NDCG_K)
            
            # Baseline NDCG (random or original order)
            baseline_ndcg = calculate_ndcg(true_relevance, REWARD_NDCG_K)
            
            # Reward = policy NDCG - baseline NDCG
            reward = policy_ndcg - baseline_ndcg
            
            # Entropy bonus
            if scores:
                import numpy as np
                scores_array = np.array(scores)
                entropy = -np.sum(scores_array * np.log(scores_array + 1e-8))
                reward += REWARD_ENTROPY_BONUS * entropy
            
            # Clamp negative rewards if configured
            if REWARD_CLAMP_NEGATIVE and reward < 0:
                reward = 0.0
            
            rewards.append(reward)
            
        except Exception as e:
            logger.warning(f"Reward calculation hiba: {e}")
            rewards.append(0.0)
    
    return rewards


In [None]:
# Model and tokenizer setup
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch

logger.info("Model betöltése...")

# Quantization config for 4-bit loading
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)

# LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

logger.info("Model és LoRA beállítva")


In [None]:
# Prompt template
def create_prompt(slate: Dict) -> str:
    """Create Hungarian prompt for slate scoring."""
    query = slate['query_id']
    candidates = slate['candidates']
    
    prompt = f"""A következő bírósági dokumentumokat kell rangsorolnod egy '{query}' keresési lekérdezéshez.

Válaszolj minden dokumentumhoz egy 0-10 közötti relevancia pontszámmal (10 = nagyon releváns, 0 = nem releváns).

Dokumentumok:
"""
    
    for i, candidate in enumerate(candidates, 1):
        # Use first chunk text as representative
        text = candidate['chunks'][0]['text'][:500] + "..." if len(candidate['chunks'][0]['text']) > 500 else candidate['chunks'][0]['text']
        
        prompt += f"""
{i}. Dokumentum (Bíróság: {candidate['chunks'][0]['metadata']['court']}, Év: {candidate['chunks'][0]['metadata']['year']})
Szöveg: {text}
"""
    
    prompt += """
Válaszolj minden dokumentumhoz egy sorban, a következő formátumban:
1. [pontszám]
2. [pontszám]
...

Pontszámok (0-10): """
    
    return prompt

# Test prompt
if training_data:
    test_prompt = create_prompt(training_data[0])
    print("Prompt példa:")
    print(test_prompt[:500] + "...")


In [None]:
# Training dataset preparation
from torch.utils.data import Dataset

class SlateDataset(Dataset):
    def __init__(self, data: List[Dict], tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        slate = self.data[idx]
        prompt = create_prompt(slate)
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "slate": slate
        }

dataset = SlateDataset(training_data, tokenizer)
logger.info(f"Dataset kész: {len(dataset)} példa")


In [None]:
# GRPO Trainer setup
from trl import GRPOTrainer, GRPOConfig

grpo_config = GRPOConfig(
    output_dir=str(ARTIFACTS_DIR),
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    max_grad_norm=MAX_GRAD_NORM,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    eval_steps=EVAL_STEPS,
    max_steps=MAX_STEPS,
    bf16=True,
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    group_size=GROUP_SIZE,
    reward_funcs=reward_function,
)

trainer = GRPOTrainer(
    model=model,
    config=grpo_config,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

logger.info("GRPO Trainer beállítva")


In [None]:
# Training execution
logger.info("GRPO training indítása...")
logger.info(f"Train dataset size: {len(dataset)}")
logger.info(f"Estimated steps: {len(dataset) * NUM_TRAIN_EPOCHS // (1 * GRADIENT_ACCUMULATION_STEPS)}")

try:
    trainer.train()
    logger.info("Training sikeresen befejezve")
except Exception as e:
    logger.error(f"Training hiba: {e}")
    raise


In [None]:
# Save model and metrics
logger.info("Model és metrikák mentése...")

# Save LoRA adapter
trainer.save_model(str(ARTIFACTS_DIR))

# Save tokenizer
tokenizer.save_pretrained(str(ARTIFACTS_DIR))

# Save training metrics
metrics = {
    "model_name": MODEL_NAME,
    "training_samples": len(dataset),
    "epochs": NUM_TRAIN_EPOCHS,
    "group_size": GROUP_SIZE,
    "final_loss": trainer.state.log_history[-1].get("loss", 0.0) if trainer.state.log_history else 0.0,
    "total_steps": len(trainer.state.log_history) if trainer.state.log_history else 0,
}

with open(METRICS_FILE, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, ensure_ascii=False, indent=2)

logger.info(f"Artifacts mentve: {ARTIFACTS_DIR}")
logger.info(f"Metrikák mentve: {METRICS_FILE}")


In [None]:
# Final summary
logger.info("=== TRAINING ÖSSZEGZÉS ===")
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Training samples: {len(dataset)}")
logger.info(f"Final loss: {metrics['final_loss']:.4f}")
logger.info(f"Total steps: {metrics['total_steps']}")
logger.info(f"Artifacts: {ARTIFACTS_DIR}")
logger.info("GRPO training sikeresen befejezve!")
