# CourtRankRL GRPO Training - Runpod

GRPO (Group Relative Policy Optimization) training a Qwen/Qwen3-4B-Instruct-2507 modellen.

## Agents.md Specifikáció:
- Qwen/Qwen3-4B-Instruct-2507 model QLoRA fine-tuninggal
- Group size = slate length
- NDCG@10 alapú reward shaping
- Hungarian progress logging
- Artifact export /workspace/artifacts/grpo_policy/ -ba

## Előfeltételek:
- A training slates JSONL exportálva legyen
- HF token beállítva legyen
- Megfelelő GPU memória (24GB+ ajánlott)


In [None]:
# Environment setup
import os
import sys
import json
import torch
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass

# Install required packages
os.system("pip install -q transformers peft trl torch bitsandbytes")

# Set paths
ARTIFACTS_DIR = Path("/workspace/grpo_policy")
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

SLATE_FILE = Path("/workspace/training_slates.jsonl")
METRICS_FILE = ARTIFACTS_DIR / "metrics.json"

print("Environment beállítva")


In [None]:
# Hungarian logging setup
import logging

class HungarianLogger:
    def __init__(self):
        self.logger = logging.getLogger("grpo_training")
        self.logger.setLevel(logging.INFO)
        
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(message)s')
        handler.setFormatter(formatter)
        self.logger.addHandler(handler)
    
    def info(self, msg):
        self.logger.info(f"[GRPO] {msg}")
    
    def warning(self, msg):
        self.logger.warning(f"[GRPO FIGYELMEZTETÉS] {msg}")
    
    def error(self, msg):
        self.logger.error(f"[GRPO HIBA] {msg}")

logger = HungarianLogger()
logger.info("Hungarian logging beállítva")


In [None]:
# Configuration (agents.md alapján)
MODEL_NAME = "Qwen/Qwen3-4B-Instruct-2507"
LORA_RANK = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05

GROUP_SIZE = 8  # Slate length
LEARNING_RATE = 1e-6
NUM_TRAIN_EPOCHS = 3
GRADIENT_ACCUMULATION_STEPS = 4
MAX_GRAD_NORM = 0.1
WARMUP_RATIO = 0.1
LOGGING_STEPS = 10
SAVE_STEPS = 100
EVAL_STEPS = 50
MAX_STEPS = 1000

REWARD_NDCG_K = 10
REWARD_ENTROPY_BONUS = 0.01
REWARD_CLAMP_NEGATIVE = True

HF_TOKEN = os.getenv('HF_TOKEN', '')

logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Group size: {GROUP_SIZE}")


In [None]:
# Data loading
def load_training_data(slate_file: Path) -> List[Dict]:
    """Load training slates from JSONL file."""
    data = []
    
    if not slate_file.exists():
        logger.error(f"Slate fájl nem található: {slate_file}")
        return data
    
    with open(slate_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                slate = json.loads(line.strip())
                data.append(slate)
                
                if line_num % 100 == 0:
                    logger.info(f"Betöltve: {line_num} slates")
                    
            except json.JSONDecodeError as e:
                logger.warning(f"Hibás JSON a {line_num}. sorban: {e}")
                continue
    
    logger.info(f"Összesen betöltve: {len(data)} slates")
    return data

training_data = load_training_data(SLATE_FILE)
if not training_data:
    logger.error("Nincs training adat - kilépés")
    sys.exit(1)


In [None]:
# Reward function (NDCG@10 based)
def calculate_ndcg(relevance_scores: List[int], k: int = 10) -> float:
    """Calculate NDCG@k."""
    if not relevance_scores:
        return 0.0
    
    dcg = 0.0
    for i in range(min(k, len(relevance_scores))):
        rel = relevance_scores[i]
        dcg += rel / (i + 1) ** 0.5
    
    # IDCG calculation
    sorted_rel = sorted(relevance_scores, reverse=True)
    idcg = 0.0
    for i in range(min(k, len(sorted_rel))):
        idcg += sorted_rel[i] / (i + 1) ** 0.5
    
    return dcg / idcg if idcg > 0 else 0.0

def reward_function(
    completions: List[str], 
    slates: List[Dict], 
    **kwargs
) -> List[float]:
    """Custom NDCG-based reward function."""
    rewards = []
    
    for completion, slate in zip(completions, slates):
        try:
            # Extract scores from model completion
            lines = completion.strip().split('\n')
            scores = []
            
            for line in lines:
                if line.strip() and any(char.isdigit() for char in line):
                    # Extract numeric scores (simple parsing)
                    parts = line.replace(',', '').split()
                    for part in parts:
                        try:
                            score = float(part)
                            if 0 <= score <= 10:  # Reasonable score range
                                scores.append(score)
                        except ValueError:
                            continue
            
            if len(scores) != len(slate['candidates']):
                logger.warning(f"Score szám mismatch: {len(scores)} vs {len(slate['candidates'])}")
                scores = [0.0] * len(slate['candidates'])
            
            # Calculate NDCG for predicted ranking
            true_relevance = [cand['relevance'] for cand in slate['candidates']]
            
            # Get predicted ranking based on scores
            predicted_ranking = sorted(
                range(len(scores)), 
                key=lambda i: scores[i], 
                reverse=True
            )
            predicted_relevance = [true_relevance[i] for i in predicted_ranking]
            
            policy_ndcg = calculate_ndcg(predicted_relevance, REWARD_NDCG_K)
            
            # Baseline NDCG (random or original order)
            baseline_ndcg = calculate_ndcg(true_relevance, REWARD_NDCG_K)
            
            # Reward = policy NDCG - baseline NDCG
            reward = policy_ndcg - baseline_ndcg
            
            # Entropy bonus
            if scores:
                import numpy as np
                scores_array = np.array(scores)
                entropy = -np.sum(scores_array * np.log(scores_array + 1e-8))
                reward += REWARD_ENTROPY_BONUS * entropy
            
            # Clamp negative rewards if configured
            if REWARD_CLAMP_NEGATIVE and reward < 0:
                reward = 0.0
            
            rewards.append(reward)
            
        except Exception as e:
            logger.warning(f"Reward calculation hiba: {e}")
            rewards.append(0.0)
    
    return rewards


In [None]:
# Model and tokenizer setup
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch

logger.info("Model betöltése...")

# Quantization config for 4-bit loading
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    trust_remote_code=True
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    token=HF_TOKEN,
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True
)

# LoRA configuration
lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

logger.info("Model és LoRA beállítva")


In [None]:
# Prompt template
def create_prompt(slate: Dict) -> str:
    """Create Hungarian prompt for slate scoring."""
    query = slate['query_id']
    candidates = slate['candidates']
    
    prompt = f"""A következő bírósági dokumentumokat kell rangsorolnod egy '{query}' keresési lekérdezéshez.

Válaszolj minden dokumentumhoz egy 0-10 közötti relevancia pontszámmal (10 = nagyon releváns, 0 = nem releváns).

Dokumentumok:
"""
    
    for i, candidate in enumerate(candidates, 1):
        # Use first chunk text as representative
        text = candidate['chunks'][0]['text'][:500] + "..." if len(candidate['chunks'][0]['text']) > 500 else candidate['chunks'][0]['text']
        
        prompt += f"""
{i}. Dokumentum (Bíróság: {candidate['chunks'][0]['metadata']['court']}, Év: {candidate['chunks'][0]['metadata']['year']})
Szöveg: {text}
"""
    
    prompt += """
Válaszolj minden dokumentumhoz egy sorban, a következő formátumban:
1. [pontszám]
2. [pontszám]
...

Pontszámok (0-10): """
    
    return prompt

# Test prompt
if training_data:
    test_prompt = create_prompt(training_data[0])
    print("Prompt példa:")
    print(test_prompt[:500] + "...")


In [None]:
# Training dataset preparation
from torch.utils.data import Dataset

class SlateDataset(Dataset):
    def __init__(self, data: List[Dict], tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        slate = self.data[idx]
        prompt = create_prompt(slate)
        
        # Tokenize
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048
        )
        
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "slate": slate
        }

dataset = SlateDataset(training_data, tokenizer)
logger.info(f"Dataset kész: {len(dataset)} példa")


In [None]:
# GRPO Trainer setup
from trl import GRPOTrainer, GRPOConfig

grpo_config = GRPOConfig(
    output_dir=str(ARTIFACTS_DIR),
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    max_grad_norm=MAX_GRAD_NORM,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    eval_steps=EVAL_STEPS,
    max_steps=MAX_STEPS,
    bf16=True,
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    group_size=GROUP_SIZE,
    reward_funcs=reward_function,
)

trainer = GRPOTrainer(
    model=model,
    config=grpo_config,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

logger.info("GRPO Trainer beállítva")


In [None]:
# Training execution
logger.info("GRPO training indítása...")
logger.info(f"Train dataset size: {len(dataset)}")
logger.info(f"Estimated steps: {len(dataset) * NUM_TRAIN_EPOCHS // (1 * GRADIENT_ACCUMULATION_STEPS)}")

try:
    trainer.train()
    logger.info("Training sikeresen befejezve")
except Exception as e:
    logger.error(f"Training hiba: {e}")
    raise


In [None]:
# Save model and metrics
logger.info("Model és metrikák mentése...")

# Save LoRA adapter
trainer.save_model(str(ARTIFACTS_DIR))

# Save tokenizer
tokenizer.save_pretrained(str(ARTIFACTS_DIR))

# Save training metrics
metrics = {
    "model_name": MODEL_NAME,
    "training_samples": len(dataset),
    "epochs": NUM_TRAIN_EPOCHS,
    "group_size": GROUP_SIZE,
    "final_loss": trainer.state.log_history[-1].get("loss", 0.0) if trainer.state.log_history else 0.0,
    "total_steps": len(trainer.state.log_history) if trainer.state.log_history else 0,
}

with open(METRICS_FILE, 'w', encoding='utf-8') as f:
    json.dump(metrics, f, ensure_ascii=False, indent=2)

logger.info(f"Artifacts mentve: {ARTIFACTS_DIR}")
logger.info(f"Metrikák mentve: {METRICS_FILE}")


In [None]:
# Final summary
logger.info("=== TRAINING ÖSSZEGZÉS ===")
logger.info(f"Model: {MODEL_NAME}")
logger.info(f"Training samples: {len(dataset)}")
logger.info(f"Final loss: {metrics['final_loss']:.4f}")
logger.info(f"Total steps: {metrics['total_steps']}")
logger.info(f"Artifacts: {ARTIFACTS_DIR}")
logger.info("GRPO training sikeresen befejezve!")
