# CourtRankRL GRPO Training - Chunk-Based, RTX 5090 Optimized

## Agents.md Specifik√°ci√≥ (Chunk-Based)

Ez a notebook a CourtRankRL GRPO alap√∫ reranking modell tan√≠t√°s√°t v√©gzi el **RTX 5090 GPU-n** (24GB VRAM).

### F≈ëbb jellemz≈ëk (Chunk-Based megold√°s):
- **Model**: Qwen/Qwen3-4B-Instruct-2507 (4-bit) + QLoRA (rank=64, alpha=128)
- **Training**: TRL GRPOTrainer GRPO algoritmussal
  - Loss: "dapo" (eliminates length bias)
  - Reward scaling: "batch" (robust - PPO Lite)
  - Importance sampling: "sequence" (stable - GSPO)
- **Dataset**: 98 query (teljes), 20 chunk/slate, **TELJES chunk sz√∂veg** (~500-800 char)
- **Slate strat√©gia**: Chunk-level retrieval (nem doc aggreg√°ci√≥!) ‚Üí legrelev√°nsabb chunk-ok
- **Baseline**: Slate sorrendje = fusion ranking [0,1,2,...] (BM25+FAISS fusion szerint)
- **Hardware**: Batch size 4, grad accumulation 3, 14 generations/prompt
- **Training time**: ~15-25 perc (600 steps, vLLM-mel)

### Mi√©rt chunk-based?
- ‚úÖ **Relev√°ns kontextus**: BM25+FAISS m√°r kiv√°lasztotta a legrelev√°nsabb chunk-okat
- ‚úÖ **Teljes sz√∂veg**: A model l√°tja, MI√âRT relev√°ns egy dokumentum
- ‚úÖ **Jobb tanul√°s**: A model megtanulja √©rt√©kelni a val√≥di tartalmat, nem csak metaadatokat


In [None]:
# K√∂rnyezet setup √©s csomagok telep√≠t√©se
%pip install -q --upgrade pip
%pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install -q transformers accelerate datasets huggingface_hub
%pip install -q numpy scipy scikit-learn pandas
%pip install -q peft bitsandbytes
%pip install -q trl
%pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
%pip install -q vllm
%pip install -q ranx

print("‚úÖ Csomagok telep√≠tve (kompatibilis verzi√≥k)")
print("‚ö†Ô∏è  FONTOS: RESTART RUNTIME sz√ºks√©ges a haszn√°lat el≈ëtt!")


In [None]:
# Importok
import os
import json
import sys
import re
import random
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.grpo_config import GRPOConfig
from huggingface_hub import login
from sklearn.metrics import ndcg_score
from scipy.stats import entropy as scipy_entropy
from sklearn.model_selection import train_test_split
from ranx import Qrels, Run, evaluate

print("‚úÖ Importok bet√∂ltve (Unsloth + TRL + sklearn + scipy + ranx)")
print(f"PyTorch verzi√≥: {torch.__version__}")
print(f"CUDA el√©rhet≈ë: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU mem√≥ria: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# HuggingFace bejelentkez√©s
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if hf_token:
    login(token=hf_token)
    print("‚úÖ HuggingFace bejelentkez√©s sikeres")
else:
    print("‚ö†Ô∏è Nincs HUGGINGFACE_TOKEN, a modell let√∂lt√©se korl√°tozott lehet")


In [None]:
# Konfigur√°ci√≥
MODEL_NAME = "unsloth/Qwen3-4B-Instruct-2507"
SLATE_SIZE = 20
GROUP_SIZE = 8
LORA_RANK = 64
LORA_ALPHA = 128
LORA_DROPOUT = 0.05
MAX_SEQ_LENGTH = 8192
GPU_MEMORY_UTILIZATION = 0.88
USE_GRADIENT_CHECKPOINTING = "unsloth"

LEARNING_RATE = 5e-5
MAX_STEPS = 600
SAVE_STEPS = 600
EVAL_STEPS = 50
LOGGING_STEPS = 10
WARMUP_STEPS = 50
GRADIENT_ACCUMULATION_STEPS = 3
NUM_GENERATIONS = 14
PER_DEVICE_BATCH_SIZE = 4
GENERATION_BATCH_SIZE = PER_DEVICE_BATCH_SIZE * NUM_GENERATIONS
OPTIMIZER_NAME = "paged_adamw_8bit"
LR_SCHEDULER_TYPE = "cosine"

NDCG_K = 10
ENTROPY_BONUS = 0.01
REWARD_CLIP_MIN = -1.0
REWARD_CLIP_MAX = 1.0
TRAIN_SPLIT = 0.8
SEED = 42

BASE_PATH = Path(os.getenv("WORKSPACE_PATH", "/workspace"))
SLATE_FILE = BASE_PATH / "training_slates.jsonl"
OUTPUT_DIR = BASE_PATH / "artifacts" / "grpo_policy"
METRICS_FILE = OUTPUT_DIR / "metrics.json"

print("üìã RTX 5090 + Unsloth Konfigur√°ci√≥:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch: {PER_DEVICE_BATCH_SIZE} √ó {GRADIENT_ACCUMULATION_STEPS} = {PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Steps: {MAX_STEPS}, Generations: {NUM_GENERATIONS}")

if not SLATE_FILE.exists():
    raise FileNotFoundError(f"‚ùå Slate f√°jl nem tal√°lhat√≥: {SLATE_FILE}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# Seg√©df√ºggv√©nyek
def calculate_ndcg(ranked_indices: List[int], true_relevance: List[float], k: int = 10) -> float:
    if not true_relevance or not ranked_indices or max(true_relevance) == 0:
        return 0.0
    y_true = np.array(true_relevance)
    max_score = len(ranked_indices)
    y_score = np.zeros_like(y_true, dtype=float)
    for i, idx in enumerate(ranked_indices[:k]):
        if idx < len(y_true):
            y_score[idx] = max_score - i
    if np.sum(y_score) == 0:
        return 0.0
    try:
        return float(ndcg_score(y_true.reshape(1, -1), y_score.reshape(1, -1), k=k))
    except:
        return 0.0

def parse_model_ranking(completion: str, slate_size: int = SLATE_SIZE) -> List[int]:
    try:
        numbers = [int(x.strip()) for x in completion.split(",") if x.strip().isdigit()]
        valid_numbers = [n for n in numbers if 0 <= n < slate_size]
        if len(valid_numbers) >= slate_size // 2:
            return valid_numbers[:slate_size]
    except:
        pass
    indices = list(range(slate_size))
    random.shuffle(indices)
    return indices

def create_training_prompt(query_id: str, slate: List[Dict]) -> str:
    prompt = f'''# Document Relevance Ranking Task

TASK: Rank the following document excerpts by relevance to the query.

QUERY: "{query_id}"

CANDIDATES ({len(slate)} items):

'''
    for idx, doc in enumerate(slate):
        chunk_text = doc.get('text', '')[:800]
        prompt += f'''[{idx}] Doc: {doc.get('doc_id', 'N/A')} | Chunk: {doc.get('chunk_id', 'N/A')}
B√≠r√≥s√°g: {doc.get('court', 'N/A')} | Ter√ºlet: {doc.get('domain', 'N/A')} | √âv: {doc.get('year', 'N/A')}
BM25: {doc.get('bm25_score', 0):.2f} | FAISS: {doc.get('faiss_score', 0):.3f}
Sz√∂veg: {chunk_text}

'''
    example_indices = ",".join(str(i) for i in range(len(slate)))
    prompt += f'''INSTRUCTION:
Rank the documents by query relevance (starting from 0).
Provide the ranking as comma-separated indices.
Example format: {example_indices}
Use each index 0-{len(slate)-1} exactly once.

RANKING:'''
    return prompt

print("‚úÖ Seg√©df√ºggv√©nyek defini√°lva")


In [None]:
# Slate adatok bet√∂lt√©se
print(f"üìÇ Slate adatok bet√∂lt√©se: {SLATE_FILE}")
df_slates = pd.read_json(SLATE_FILE, lines=True, encoding='utf-8')
slates_data = df_slates.to_dict('records')
print(f"‚úÖ Bet√∂ltve: {len(slates_data)} slate")

sample = slates_data[0]
print(f"\nüìã Minta slate strukt√∫ra:")
print(f"  Query ID: {sample['query_id'][:50]}...")
print(f"  Slate elemek: {len(sample['slate'])}")


In [None]:
# Test prompt
test_prompt = create_training_prompt(slates_data[0]["query_id"], slates_data[0]["slate"])
print("üìù Enhanced learning-to-rank prompt sample:")
print("="*80)
print(test_prompt[:1500])
print("\n... (truncated)")
print("="*80)


In [None]:
# Dataset el≈ëk√©sz√≠t√©se
print(f"\nüìö Dataset el≈ëk√©sz√≠t√©se...")
training_examples = []
slate_lookup = {}

for slate_data in slates_data:
    query_id = slate_data["query_id"]
    prompt = create_training_prompt(query_id, slate_data["slate"])
    training_examples.append({"prompt": prompt})
    slate_lookup[query_id] = slate_data["slate"]

full_dataset = Dataset.from_list(training_examples)
indices = np.arange(len(full_dataset))
train_indices, eval_indices = train_test_split(indices, test_size=1.0 - TRAIN_SPLIT, random_state=SEED, shuffle=True)
train_dataset = full_dataset.select(train_indices)
eval_dataset = full_dataset.select(eval_indices)

print(f"‚úÖ Dataset l√©trehozva:")
print(f"  Training: {len(train_dataset)} query (80%)")
print(f"  Evaluation: {len(eval_dataset)} query (20%)")
print(f"  Slate lookup: {len(slate_lookup)} entry")


In [None]:
# Model bet√∂lt√©se
print(f"üîÑ Model bet√∂lt√©se Unsloth-tal: {MODEL_NAME}")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    dtype=None,
    fast_inference=True,
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    token=hf_token,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,
    use_rslora=True,
    random_state=SEED,
)

tokenizer.padding_side = "right"
print("‚úÖ Model √©s tokenizer bet√∂ltve (Unsloth + vLLM + RSLoRA)")


In [None]:
# Reward function
def reward_function(completions, prompts, **kwargs):
    rewards = []
    for completion, prompt in zip(completions, prompts):
        try:
            match = re.search(r'QUERY:\s*"([^"]+)"', prompt)
            if not match or match.group(1) not in slate_lookup:
                rewards.append(-0.5)
                continue
            
            query_id = match.group(1)
            slate = slate_lookup[query_id]
            relevance = [doc.get('relevance', 0) for doc in slate]
            baseline = list(range(len(slate)))
            predicted = parse_model_ranking(completion, len(slate))
            
            ndcg_baseline = calculate_ndcg(baseline, relevance, k=NDCG_K)
            ndcg_policy = calculate_ndcg(predicted, relevance, k=NDCG_K)
            reward = ndcg_policy - ndcg_baseline
            
            is_valid_full = len(predicted) == len(slate) and len(set(predicted)) == len(slate)
            is_valid_partial = len(predicted) >= len(slate)//2 and len(set(predicted)) >= len(slate)//2
            if is_valid_full:
                reward += 0.1
            elif is_valid_partial:
                reward += 0.05
            
            if len(predicted) > 1:
                unique_ratio = len(set(predicted)) / len(predicted)
                reward += ENTROPY_BONUS * unique_ratio
            
            reward = float(np.clip(reward, REWARD_CLIP_MIN, REWARD_CLIP_MAX))
            rewards.append(reward)
        except:
            rewards.append(-0.5)
    return rewards

print("‚úÖ GRPO Reward function defini√°lva")


In [None]:
# GRPO Trainer konfigur√°ci√≥
grpo_config = GRPOConfig(
    output_dir=str(OUTPUT_DIR),
    max_steps=MAX_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    num_generations=NUM_GENERATIONS,
    optim=OPTIMIZER_NAME,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    loss_type="dapo",
    scale_rewards="batch",
    importance_sampling_level="sequence",
    mask_truncated_completions=True,
    epsilon=0.2,
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False, "use_unsloth": True},
    max_grad_norm=1.0,
    logging_steps=LOGGING_STEPS,
    logging_first_step=True,
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    dataloader_num_workers=2,
    seed=SEED,
)

print("‚úÖ GRPO Trainer konfigur√°ci√≥ k√©sz")


In [None]:
# Trainer inicializ√°l√°sa
trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_function,
    args=grpo_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

print("‚úÖ GRPO Trainer inicializ√°lva")
print(f"  Training queries: {len(train_dataset)}")
print(f"  Eval queries: {len(eval_dataset)}")
print(f"  GPU mem√≥ria: {torch.cuda.memory_allocated() / 1e9:.2f} GB")


In [None]:
# Training ind√≠t√°sa
print("\nüöÄ GRPO TRAINING IND√çT√ÅSA\n")
print("="*60)

try:
    trainer.train()
    print("\n" + "="*60)
    print("‚úÖ Training sikeresen befejezve!")
except Exception as e:
    print(f"\n‚ùå Training hiba: {e}")
    raise


In [None]:
# Artifactumok ment√©se
print("\nüíæ Artifactumok ment√©se...")
model.save_pretrained_merged(str(OUTPUT_DIR), tokenizer, save_method="lora")
print(f"  ‚úÖ LoRA adapter: {OUTPUT_DIR}")


In [None]:
# Evaluation (ranx)
print("\nüìä Evaluation futtat√°sa...")

def evaluate_policy_ranx(dataset_subset, slate_lookup_dict, dataset_name=""):
    k_values = [5, 10, 20]
    metrics_to_compute = ["map", "mrr"]
    for k in k_values:
        metrics_to_compute.extend([f"ndcg@{k}", f"precision@{k}", f"recall@{k}"])
    
    qrels_dict = {}
    baseline_run_dict = {}
    policy_run_dict = {}
    parse_successes = 0
    
    for example in dataset_subset:
        prompt = example["prompt"]
        match = re.search(r'QUERY:\s*"([^"]+)"', prompt)
        if not match or match.group(1) not in slate_lookup_dict:
            continue
        
        query_id = match.group(1)
        slate = slate_lookup_dict[query_id]
        
        query_qrels = {}
        for doc in slate:
            doc_id = doc.get('doc_id')
            relevance = doc.get('relevance', 0)
            if doc_id:
                query_qrels[doc_id] = relevance
        if not query_qrels:
            continue
        qrels_dict[query_id] = query_qrels
        
        baseline_indices = list(range(len(slate)))
        baseline_docs = {}
        for rank, idx in enumerate(baseline_indices):
            doc_id = slate[idx].get('doc_id')
            if doc_id:
                baseline_docs[doc_id] = len(slate) - rank
        baseline_run_dict[query_id] = baseline_docs
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        input_length = inputs.input_ids.shape[1]
        model_output = model.generate(inputs.input_ids, max_new_tokens=60)
        completion = tokenizer.decode(model_output[0][input_length:], skip_special_tokens=True)
        predicted_indices = parse_model_ranking(completion, len(slate))
        
        is_valid_parse = len(predicted_indices) == len(slate) and len(set(predicted_indices)) == len(slate)
        if is_valid_parse:
            parse_successes += 1
        
        policy_docs = {}
        for rank, idx in enumerate(predicted_indices):
            if idx < len(slate):
                doc_id = slate[idx].get('doc_id')
                if doc_id:
                    policy_docs[doc_id] = len(slate) - rank
        policy_run_dict[query_id] = policy_docs
    
    qrels_ranx = Qrels(qrels_dict)
    baseline_run_ranx = Run(baseline_run_dict)
    policy_run_ranx = Run(policy_run_dict)
    
    baseline_metrics = evaluate(qrels_ranx, baseline_run_ranx, metrics_to_compute)
    policy_metrics = evaluate(qrels_ranx, policy_run_ranx, metrics_to_compute)
    
    per_query_results = []
    for query_id in qrels_dict.keys():
        row = {"query_id": query_id}
        for metric in metrics_to_compute:
            baseline_score = baseline_run_ranx.scores.get(metric, {}).get(query_id, 0.0)
            row[f"baseline_{metric}"] = float(baseline_score)
        for metric in metrics_to_compute:
            policy_score = policy_run_ranx.scores.get(metric, {}).get(query_id, 0.0)
            row[f"policy_{metric}"] = float(policy_score)
        baseline_ndcg10 = baseline_run_ranx.scores.get("ndcg@10", {}).get(query_id, 0.0)
        policy_ndcg10 = policy_run_ranx.scores.get("ndcg@10", {}).get(query_id, 0.0)
        row["improvement_ndcg@10"] = float(policy_ndcg10 - baseline_ndcg10)
        per_query_results.append(row)
    
    improvements = [r["improvement_ndcg@10"] for r in per_query_results]
    positive_improvements = sum(1 for imp in improvements if imp > 0)
    positive_ratio = positive_improvements / len(improvements) if improvements else 0.0
    parse_success_rate = parse_successes / len(per_query_results) if per_query_results else 0.0
    
    print(f"  {dataset_name}:")
    print(f"    Baseline nDCG@10: {baseline_metrics.get('ndcg@10', 0.0):.4f}")
    print(f"    Policy nDCG@10: {policy_metrics.get('ndcg@10', 0.0):.4f}")
    print(f"    Improvement: {policy_metrics.get('ndcg@10', 0.0) - baseline_metrics.get('ndcg@10', 0.0):+.4f}")
    
    return {
        "baseline_metrics": {k: float(v) for k, v in baseline_metrics.items()},
        "policy_metrics": {k: float(v) for k, v in policy_metrics.items()},
        "num_queries": len(qrels_dict),
        "positive_improvement_count": positive_improvements,
        "positive_improvement_ratio": float(positive_ratio),
        "parse_success_rate": float(parse_success_rate),
        "parse_success_count": parse_successes,
        "per_query_results": per_query_results
    }

train_eval_results = evaluate_policy_ranx(train_dataset, slate_lookup, "Train set")
eval_eval_results = evaluate_policy_ranx(eval_dataset, slate_lookup, "Eval set")


In [None]:
# Per-query export
training_rewards = []
for log_entry in trainer.state.log_history:
    if "rewards/mean" in log_entry:
        training_rewards.append(log_entry["rewards/mean"])

train_per_query_df = pd.DataFrame(train_eval_results["per_query_results"])
train_per_query_csv = OUTPUT_DIR / "train_per_query_results.csv"
train_per_query_df.to_csv(train_per_query_csv, index=False, encoding='utf-8')
print(f"  ‚úÖ Train per-query results: {train_per_query_csv}")

eval_per_query_df = pd.DataFrame(eval_eval_results["per_query_results"])
eval_per_query_csv = OUTPUT_DIR / "eval_per_query_results.csv"
eval_per_query_df.to_csv(eval_per_query_csv, index=False, encoding='utf-8')
print(f"  ‚úÖ Eval per-query results: {eval_per_query_csv}")


In [None]:
# Metrics export
final_metrics = {
    "model_name": MODEL_NAME,
    "training_samples": len(train_dataset),
    "eval_samples": len(eval_dataset),
    "slate_size": SLATE_SIZE,
    "group_size": GROUP_SIZE,
    "max_steps": MAX_STEPS,
    "learning_rate": LEARNING_RATE,
    "lora_rank": LORA_RANK,
    "lora_alpha": LORA_ALPHA,
    "final_loss": trainer.state.log_history[-1].get("loss", 0.0) if trainer.state.log_history else 0.0,
    "total_steps": len(trainer.state.log_history) if trainer.state.log_history else 0,
    "training_rewards": {
        "mean": float(np.mean(training_rewards)) if training_rewards else 0.0,
        "std": float(np.std(training_rewards)) if training_rewards else 0.0,
        "min": float(np.min(training_rewards)) if training_rewards else 0.0,
        "max": float(np.max(training_rewards)) if training_rewards else 0.0,
        "trend": training_rewards
    },
    "train_evaluation": {
        "baseline_metrics": train_eval_results["baseline_metrics"],
        "policy_metrics": train_eval_results["policy_metrics"],
        "num_queries": train_eval_results["num_queries"],
        "positive_improvement_count": train_eval_results["positive_improvement_count"],
        "positive_improvement_ratio": train_eval_results["positive_improvement_ratio"],
        "parse_success_rate": train_eval_results["parse_success_rate"],
        "parse_success_count": train_eval_results["parse_success_count"]
    },
    "eval_evaluation": {
        "baseline_metrics": eval_eval_results["baseline_metrics"],
        "policy_metrics": eval_eval_results["policy_metrics"],
        "num_queries": eval_eval_results["num_queries"],
        "positive_improvement_count": eval_eval_results["positive_improvement_count"],
        "positive_improvement_ratio": eval_eval_results["positive_improvement_ratio"],
        "parse_success_rate": eval_eval_results["parse_success_rate"],
        "parse_success_count": eval_eval_results["parse_success_count"]
    },
    "status": "completed"
}

with open(METRICS_FILE, 'w', encoding='utf-8') as f:
    json.dump(final_metrics, f, ensure_ascii=False, indent=2)

print(f"  ‚úÖ Metrics: {METRICS_FILE}")
print("\n‚úÖ Minden artifact sikeresen mentve!")


## Training √∂sszefoglal√≥

### Technol√≥giai stack:
- **Framework**: Unsloth + TRL GRPOTrainer
- **Inference**: vLLM (2-3x gyorsabb generation)
- **Model**: Qwen/Qwen3-4B-Instruct-2507 (4-bit + QLoRA)
- **Optimaliz√°ci√≥k**: Unsloth gradient checkpointing, vLLM inference, batch=4, gen=14

### Gener√°lt artifactumok (`/workspace/artifacts/grpo_policy/`):
- LoRA adapter weights
- `metrics.json` - ranx-alap√∫ extended metrics
- `train_per_query_results.csv` - Train metrics (MAP, MRR, NDCG@5/10/20, Precision, Recall)
- `eval_per_query_results.csv` - Eval metrics

### Agents.md checklist:
- ‚úÖ Unsloth + vLLM + RSLoRA
- ‚úÖ Chunk-based slates (teljes sz√∂veg)
- ‚úÖ GRPO training (dapo loss, batch scaling, sequence IS)
- ‚úÖ ranx evaluation
- ‚úÖ Hungarian status messages
