# CourtRankRL GRPO Training - Chunk-Based, A100 Memory Efficient

## Agents.md Specifik√°ci√≥ (Chunk-Based)

Ez a notebook a CourtRankRL GRPO alap√∫ reranking modell tan√≠t√°s√°t v√©gzi el **A100 GPU-n** (80GB VRAM, memory-constrained).

### F≈ëbb jellemz≈ëk (Chunk-Based megold√°s):
- **Model**: Qwen/Qwen3-4B-Instruct-2507 (4-bit) + QLoRA (rank=32, alpha=64)
- **Training**: TRL GRPOTrainer GRPO algoritmussal
  - Loss: "dapo" (eliminates length bias)
  - Reward scaling: "batch" (robust - PPO Lite)
  - Importance sampling: "sequence" (stable - GSPO)
- **Dataset**: 98 query (teljes), 20 chunk/slate, **TELJES chunk sz√∂veg** (~500-800 char)
- **Slate strat√©gia**: Chunk-level retrieval (nem doc aggreg√°ci√≥!) ‚Üí legrelev√°nsabb chunk-ok
- **Baseline**: Slate sorrendje = fusion ranking [0,1,2,...] (BM25+FAISS fusion szerint)
- **Hardware**: Batch size 2, grad accumulation 6, 6 generations/prompt (MEMORY EFFICIENT)
- **Training time**: ~25-35 perc (600 steps, vLLM-mel, cs√∂kkentett gen miatt)

### Mi√©rt chunk-based?
- ‚úÖ **Relev√°ns kontextus**: BM25+FAISS m√°r kiv√°lasztotta a legrelev√°nsabb chunk-okat
- ‚úÖ **Teljes sz√∂veg**: A model l√°tja, MI√âRT relev√°ns egy dokumentum
- ‚úÖ **Jobb tanul√°s**: A model megtanulja √©rt√©kelni a val√≥di tartalmat, nem csak metaadatokat


In [None]:
# K√∂rnyezet setup √©s csomagok telep√≠t√©se
# KRITIKUS: Kompatibilis verzi√≥k telep√≠t√©se (Unsloth + TRL + dependencies)

# 1. Core dependencies EL≈êSZ√ñR (stabil verzi√≥k)
%pip install -q --upgrade pip
%pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install -q transformers datasets huggingface_hub
%pip install -q numpy scipy scikit-learn

# 2. Accelerate - FIX√ÅLT VERZI√ì az is_fsdp2 kompatibilit√°s miatt
# CRITICAL: accelerate>=0.30.0 sz√ºks√©ges az is_fsdp2 attrib√∫tumhoz, de kompatibilisnek kell lennie Unsloth-tal
%pip install -q "accelerate>=0.30.0" --upgrade

# 3. LoRA √©s kvantiz√°ci√≥
%pip install -q peft bitsandbytes

# 4. TRL (legfrissebb stabil verzi√≥ - GRPO t√°mogat√°ssal)
%pip install -q --upgrade trl

# 5. Unsloth telep√≠t√©se UTOLJ√ÅRA (hogy a TRL-hez igazodjon)
%pip install -q --upgrade unsloth

# 6. vLLM (opcion√°lis, de aj√°nlott fast_inference-hez)
%pip install -q vllm

# 7. Egy√©b f√ºgg≈ës√©gek
%pip install -q --upgrade typing_extensions ranx

print("‚úÖ Csomagok telep√≠tve (kompatibilis verzi√≥k)")
print("üì¶ Telep√≠tett verzi√≥k ellen≈ërz√©se:")
import unsloth, torch, transformers, trl, accelerate
print(f"  - PyTorch: {torch.__version__}")
print(f"  - Transformers: {transformers.__version__}")
print(f"  - Accelerate: {accelerate.__version__} (CRITICAL: >=0.30.0 az is_fsdp2 t√°mogat√°shoz)")
print(f"  - TRL: {trl.__version__}")
print(f"  - Unsloth: {unsloth.__version__}")

# FONTOS FIGYELMEZTET√âS a verzi√≥r√≥l
if tuple(map(int, accelerate.__version__.split('.')[:2])) < (0, 30):
    print(f"‚ö†Ô∏è  FIGYELEM: Accelerate verzi√≥ ({accelerate.__version__}) < 0.30.0")
    print(f"   A monkey patch fog futni a training sor√°n az is_fsdp2 kompatibilit√°s√©rt")
    print(f"   Aj√°nlott: pip install accelerate>=0.30.0 --upgrade && RESTART KERNEL")


In [None]:
# Importok
import os, warnings
os.environ["TORCHDYNAMO_DISABLE"] = "1"      # hivatalos flag a dynamo kikapcsol√°s√°ra
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # √°rtalmatlan, de seg√≠t fallbackben
# (opcion√°lis, ha Unsloth figyeli)
os.environ["UNSLOTH_COMPILE"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message=".*An output with one or more elements was resized.*")

import json
import sys
import re
import random
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.grpo_config import GRPOConfig
from huggingface_hub import login
from sklearn.metrics import ndcg_score
from scipy.stats import entropy as scipy_entropy
from sklearn.model_selection import train_test_split
from ranx import Qrels, Run, evaluate

print("‚úÖ Importok bet√∂ltve (Unsloth + TRL + sklearn + scipy + ranx)")
print(f"PyTorch verzi√≥: {torch.__version__}")
print(f"CUDA el√©rhet≈ë: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU mem√≥ria: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# CRITICAL FIX: Accelerate is_fsdp2 kompatibilit√°si patch
# Ha a jelenlegi Accelerate verzi√≥ nem t√°mogatja az is_fsdp2 attrib√∫tumot, patchelj√ºk
try:
    from accelerate import Accelerator
    from accelerate.state import AcceleratorState
    
    # Ellen≈ërizz√ºk, hogy l√©tezik-e az is_fsdp2 property
    test_state = AcceleratorState()
    _ = test_state.is_fsdp2
    print("‚úÖ Accelerate verzi√≥ t√°mogatja az is_fsdp2-t")
except AttributeError:
    print("‚ö†Ô∏è Accelerate verzi√≥ NEM t√°mogatja az is_fsdp2-t, monkey patch alkalmaz√°sa...")
    
    # Monkey patch: adjuk hozz√° az is_fsdp2 property-t az AcceleratorState-hez
    def is_fsdp2_property(self):
        """Visszakompatibilit√°si patch: is_fsdp2 mindig False r√©gebbi Accelerate verzi√≥kban"""
        return False
    
    # Property hozz√°ad√°sa az oszt√°lyhoz
    AcceleratorState.is_fsdp2 = property(is_fsdp2_property)
    
    # Ellen≈ërz√©s
    test_state = AcceleratorState()
    assert hasattr(test_state, 'is_fsdp2'), "Patch sikertelen!"
    assert test_state.is_fsdp2 == False, "Patch helytelen √©rt√©ket ad!"
    print("‚úÖ Monkey patch sikeresen alkalmazva - is_fsdp2 = False")
except Exception as e:
    print(f"‚ùå V√°ratlan hiba az is_fsdp2 ellen≈ërz√©se sor√°n: {e}")
    print("‚ö†Ô∏è Folytat√°s patch n√©lk√ºl - ha training hiba lesz, telep√≠tsd √∫jra: pip install accelerate>=0.30.0")

In [None]:
# HuggingFace bejelentkez√©s
hf_token = os.getenv("HUGGINGFACE_TOKEN")
if hf_token:
    login(token=hf_token)
    print("‚úÖ HuggingFace bejelentkez√©s sikeres")
else:
    print("‚ö†Ô∏è Nincs HUGGINGFACE_TOKEN, a modell let√∂lt√©se korl√°tozott lehet")


In [None]:
# Konfigur√°ci√≥ - MEMORY EFFICIENT (A100 80GB) + CURRICULUM LEARNING
MODEL_NAME = "unsloth/Qwen3-4B-Instruct-2507"
SLATE_SIZE = 30  # Enhanced: 20 ‚Üí 30 (more documents per slate, better quality)

# Curriculum Learning konfigur√°ci√≥
CURRICULUM_ENABLED = True  # Enged√©lyezve - LEGNAGYOBB JAVUL√ÅSI POTENCI√ÅL
CURRICULUM_STRATEGY = "weighted_sampling"  # "filtered_dataset" vagy "weighted_sampling"
# weighted_sampling: k√∂nnyebb p√©ld√°kat gyakrabban v√°lasztjuk ki
# Ez biztos√≠tja, hogy a modell fokozatosan megtanulja j√≥l rangsorolni
GROUP_SIZE = 8
LORA_RANK = 32  # 64‚Üí32: cs√∂kkentett rank a mem√≥ria miatt
LORA_ALPHA = 64  # 128‚Üí64: alpha k√∂veti a rank-ot
LORA_DROPOUT = 0.05
MAX_SEQ_LENGTH = 16384  # 8192‚Üí12288: n√∂velt √©rt√©k a hosszabb promptokhoz (9600 token + marg√≥)
GPU_MEMORY_UTILIZATION = 0.5  # 0.88‚Üí0.5: DRASTIKUSAN cs√∂kkentve a rendelkez√©sre √°ll√≥ szabad mem√≥ria miatt
USE_GRADIENT_CHECKPOINTING = "unsloth"

LEARNING_RATE = 1e-4  # Enhanced: 5e-5 ‚Üí 1e-4 (larger steps, better convergence)
MAX_STEPS = 1000  # Enhanced: 600 ‚Üí 1000 (more epochs, but with early stopping if needed)
SAVE_STEPS = 200  # Save more frequently for checkpoint recovery
EVAL_STEPS = 50
LOGGING_STEPS = 10
WARMUP_STEPS = 100  # Enhanced: 50 ‚Üí 100 (gentler warmup)
GRADIENT_ACCUMULATION_STEPS = 6  # 3‚Üí6: nagyobb accumulation, kisebb batch
NUM_GENERATIONS = 6  # 14‚Üí6: cs√∂kkentett gener√°ci√≥k
PER_DEVICE_BATCH_SIZE = 2  # 4‚Üí2: FEL√âRE cs√∂kkentve
# KRITIKUS: generation_batch_size-nak oszthat√≥nak KELL lennie:
#   1. NUM_GENERATIONS-szel (6)
#   2. PER_DEVICE_BATCH_SIZE-zal (2)
# Legkisebb k√∂z√∂s t√∂bbsz√∂r√∂s (LCM): 6
GENERATION_BATCH_SIZE = 6  # 6√∑6=1, 6√∑2=3 ‚úì
OPTIMIZER_NAME = "paged_adamw_8bit"
LR_SCHEDULER_TYPE = "linear"  # Enhanced: cosine ‚Üí linear (more stable final convergence)

NDCG_K = 10
ENTROPY_BONUS = 0.01
# Reward clipping range (updated in reward function)
REWARD_CLIP_MIN = -1.0  # Will be updated in enhanced reward function
REWARD_CLIP_MAX = 2.0   # Enhanced: wider range for positive improvements
TRAIN_SPLIT = 0.8
SEED = 42

BASE_PATH = Path(os.getenv("WORKSPACE_PATH", "/workspace"))

# HuggingFace token defini√°l√°sa (HF_TOKEN = hf_token)
hf_token = os.getenv("HUGGINGFACE_TOKEN")
HF_TOKEN = hf_token

BASELINE_METRICS_PATH = BASE_PATH / "baseline_query_metrics.csv"
SLATE_FILE = BASE_PATH / "training_slates.jsonl"
OUTPUT_DIR = BASE_PATH / "artifacts" / "grpo_policy"
METRICS_FILE = OUTPUT_DIR / "metrics.json"

print("üìã A100 80GB - Memory Efficient Konfigur√°ci√≥:")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA Rank: {LORA_RANK} (cs√∂kkentve mem√≥ria miatt)")
print(f"  GPU Memory Utilization: {GPU_MEMORY_UTILIZATION} (0.5 = 50%)")
print(f"  Batch: {PER_DEVICE_BATCH_SIZE} √ó {GRADIENT_ACCUMULATION_STEPS} = {PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} (effective)")
print(f"  Generations: {NUM_GENERATIONS}, Generation batch: {GENERATION_BATCH_SIZE}")
print(f"  Steps per generation: {GENERATION_BATCH_SIZE // PER_DEVICE_BATCH_SIZE}")
print(f"  Steps: {MAX_STEPS}")
print(f"  ‚ö†Ô∏è  Mem√≥ria optimaliz√°lt be√°ll√≠t√°sok akt√≠vak!")

if not SLATE_FILE.exists():
    raise FileNotFoundError(f"‚ùå Slate f√°jl nem tal√°lhat√≥: {SLATE_FILE}")

if not BASELINE_METRICS_PATH.exists():
    raise FileNotFoundError(f"‚ùå Baseline metrik√°k f√°jl nem tal√°lhat√≥: {BASELINE_METRICS_PATH}")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# GPU mem√≥ria tiszt√≠t√°sa
import gc
import torch

# CUDA cache tiszt√≠t√°sa
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    gc.collect()
    print("‚úÖ CUDA cache t√∂r√∂lve")
    
    # Mem√≥ria st√°tusz
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    free = (total - allocated)
    
    print(f"\nüìä GPU Mem√≥ria St√°tusz (A100 80GB):")
    print(f"  Total: {total:.2f} GB")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    print(f"  Free %: {(free/total)*100:.1f}%")

In [None]:
# Seg√©df√ºggv√©nyek
def calculate_ndcg(ranked_indices: List[int], true_relevance: List[float], k: int = 10) -> float:
    if not true_relevance or not ranked_indices or max(true_relevance) == 0:
        return 0.0
    y_true = np.array(true_relevance)
    max_score = len(ranked_indices)
    y_score = np.zeros_like(y_true, dtype=float)
    for i, idx in enumerate(ranked_indices[:k]):
        if idx < len(y_true):
            y_score[idx] = max_score - i
    if np.sum(y_score) == 0:
        return 0.0
    try:
        return float(ndcg_score(y_true.reshape(1, -1), y_score.reshape(1, -1), k=k))
    except:
        return 0.0

def parse_model_ranking(completion: str, slate_size: int = SLATE_SIZE) -> List[int]:
    """
    Megenged≈ëbb parsing - kev√©sb√© szigor√∫ form√°tum k√∂vetelm√©ny.
    Ha nincs el√©g valid index, kieg√©sz√≠tj√ºk a hi√°nyz√≥kat a baseline order szerint.
    """
    try:
        # Pr√≥b√°lunk sz√°mokat kinyerni k√ºl√∂nb√∂z≈ë m√≥don
        numbers = []
        
        # 1. Vessz≈ëvel elv√°lasztott sz√°mok
        parts = completion.split(",")
        for part in parts:
            part = part.strip()
            # K√©rj√ºk el minden sz√°mot (ak√°r t√∂bb sz√°m egy cell√°ban)
            for word in part.split():
                word = word.strip().rstrip(".,;")
                if word.isdigit():
                    num = int(word)
                    if 0 <= num < slate_size:
                        numbers.append(num)
        
        # 2. Ha nincs el√©g, pr√≥b√°lunk sz√≥k√∂zzel elv√°lasztott form√°tumot
        if len(numbers) < slate_size // 3:
            for word in completion.split():
                word = word.strip().rstrip(".,;")
                if word.isdigit():
                    num = int(word)
                    if 0 <= num < slate_size and num not in numbers:
                        numbers.append(num)
        
        # 3. Valid sz√°mok gy≈±jt√©se
        valid_numbers = []
        seen = set()
        for n in numbers:
            if n not in seen:
                valid_numbers.append(n)
                seen.add(n)
        
        # 4. Ha van el√©g valid sz√°m (legal√°bb 1/3-a a slate-nek), haszn√°ljuk
        if len(valid_numbers) >= max(1, slate_size // 3):
            # Kieg√©sz√≠tj√ºk a hi√°nyz√≥ indexeket a baseline order szerint
            missing_indices = [i for i in range(slate_size) if i not in valid_numbers]
            # A hi√°nyz√≥kat a v√©g√©re tesz√ºk (nem zavarjuk meg a modell rankingj√©t)
            result = valid_numbers + missing_indices
            return result[:slate_size]
        
    except Exception as e:
        pass
    
    # Fallback: baseline order (nem random shuffle - √≠gy jobb mint teljesen random)
    return list(range(slate_size))

def create_training_prompt(query_id: str, slate: List[Dict]) -> str:
    """
    Enhanced prompt generation for GRPO training with explicit relevance labels and few-shot examples.
    The prompt uses explicit instructions in ENGLISH to help the model learn proper relevance-based ranking.
    Only the data content (queries, document texts, metadata) remains in HUNGARIAN as it comes from the dataset.
    
    Improvements:
    1. Explicit relevance labels (0/1/2) for each document
    2. Few-shot example ranking output
    3. More detailed relevance criteria
    4. Clearer output format specification
    """
    prompt = f'''# Legal Document Relevance Assessment and Ranking

## TASK
You are helping to rank Hungarian court decision documents by their relevance to a search query.

## SEARCH QUERY
"{query_id}"

## WHAT TO CONSIDER

When ranking documents, think about:
- How well the document topic matches the query
- Whether the content directly addresses what the query is asking
- The legal domain and court type (may or may not be relevant)
- The relevance labels (0=not relevant, 1=relevant, 2=highly relevant) - these can guide you
- BM25 and FAISS scores show some similarity signals, but content is most important

## DOCUMENT EVALUATION

You need to evaluate the following {len(slate)} document excerpts:

'''
    
    separator_line = "\u2500" * 63
    
    # Enhanced document presentation with EXPLICIT RELEVANCE LABELS
    for idx, doc in enumerate(slate):
        chunk_text = doc.get('text', '')[:800]
        bm25 = doc.get('bm25_score', 0)
        faiss = doc.get('faiss_score', 0)
        relevance = doc.get('relevance', 0)
        court = doc.get('court', 'N/A')
        domain = doc.get('domain', 'N/A')
        year = doc.get('year', 'N/A')
        
        # Relevance label mapping
        rel_label = "NOT RELEVANT" if relevance == 0 else ("RELEVANT" if relevance == 1 else "HIGHLY RELEVANT")
        
        prompt += f'''{separator_line}
DOCUMENT [{idx}]

Relevance: {rel_label} (Grade {relevance}/2)
ID: {doc.get('doc_id', 'N/A')} | Court: {court} | Domain: {domain} | Year: {year}
Scores: BM25={bm25:.2f}, FAISS={faiss:.3f}

Content:
{chunk_text}

'''
    
    # Strict ranking instructions to force tangible outputs
    prompt += f'''{separator_line}

## RANKING INSTRUCTIONS

You need to rank the {len(slate)} documents by their relevance to the search query.

### GUIDELINES (focus on relevance):
- Consider the search query and how well each document matches it
- Documents with higher RELEVANCE grades (2 > 1 > 0) are generally more relevant, but also analyze the content
- Think about topic match, content quality, and legal domain relevance
- BM25 and FAISS scores can help, but focus primarily on content relevance

### OUTPUT FORMAT (STRICT):
- Respond with EXACTLY one line in this format ‚Üí RANKING: <comma-separated indices> #END
- Use document indices 0 to {len(slate)-1}, each exactly once, separated by commas and without extra spaces
- Do NOT include headings, bullet points, Markdown, or explanations before or after the line
- If you are uncertain, still provide your best full ranking covering every index
- End the line with "#END" to confirm completion

Example:
RANKING: 3,0,5,1,4,2,8,6,7,9 #END

RANKING:
'''
    
    return prompt

print("‚úÖ Seg√©df√ºggv√©nyek defini√°lva")
print("  üìù Enhanced learning-to-rank prompt template:")
print("     - English instructions with Hungarian data")
print("     - Explicit ranking criteria")
print("     - Step-by-step guidance")
print("     - Clear do's and don'ts")
print("     - Structured document presentation")



In [None]:
# Slate adatok bet√∂lt√©se
print(f"üìÇ Slate adatok bet√∂lt√©se: {SLATE_FILE}")
df_slates = pd.read_json(SLATE_FILE, lines=True, encoding='utf-8')
slates_data = df_slates.to_dict('records')
print(f"‚úÖ Bet√∂ltve: {len(slates_data)} slate")

sample = slates_data[0]
print(f"\nüìã Minta slate strukt√∫ra:")
print(f"  Query ID: {sample['query_id'][:50]}...")
print(f"  Slate elemek: {len(sample['slate'])}")


In [None]:
# Test prompt megtekint√©se
test_prompt = create_training_prompt(slates_data[0]["query_id"], slates_data[0]["slate"])
print("üìù R√âSZLETES GRPO PROMPT - TELJES P√âLDA")
print("="*80)
print(test_prompt)
print("="*80)
print(f"\nüìä Prompt statisztik√°k:")
print(f"  Teljes hossz: {len(test_prompt)} karakter")
print(f"  Becs√ºlt tokenek: ~{len(test_prompt)//4} token")
print(f"  Dokumentumok sz√°ma: {len(slates_data[0]['slate'])}")
print(f"  Slate ID: {slates_data[0]['query_id'][:60]}...")

## üéØ Multi-Metrika Reward Setup

A standard nDCG@10 optimaliz√°ci√≥ helyett **konvex kombin√°ci√≥t** alkalmazunk:
- **nDCG@10** (50%): Rangsor min≈ës√©g
- **MRR@5** (35%): Els≈ë relev√°ns tal√°lat poz√≠ci√≥ja (√ºzleti priorit√°s!)
- **Recall@20** (15%): Completeness

Ez biztos√≠tja, hogy a model a **val√≥di √ºzleti √©rt√©ket** maximaliz√°lja.

In [None]:
# Multi-Metrika Reward Helper Functions (inline defini√°lva - nincs k√ºls≈ë import)

from collections import Counter
from typing import Optional, Tuple

def compute_mrr(relevances: List[int], k: Optional[int] = None) -> float:
    """
    Mean Reciprocal Rank sz√°m√≠t√°sa.
    
    MRR@k = 1 / rank_first_relevant (ha van relev√°ns a top-k-ban)
           = 0 (ha nincs relev√°ns a top-k-ban)
    """
    if k is not None:
        relevances = relevances[:k]
    
    for i, rel in enumerate(relevances, start=1):
        if rel > 0:
            return 1.0 / i
    
    return 0.0


def compute_recall(relevances: List[int], k: int, total_relevant: Optional[int] = None) -> float:
    """
    Recall@k sz√°m√≠t√°sa.
    
    Recall@k = |{relevant docs in top-k}| / |{all relevant docs}|
    """
    if not relevances:
        return 0.0
    
    # Relevantes doc-ok sz√°ma top-k-ban
    relevant_in_topk = sum(1 for rel in relevances[:k] if rel > 0)
    
    # √ñsszes relev√°ns (ha nincs megadva, akkor az eg√©sz list√°b√≥l sz√°moljuk)
    if total_relevant is None:
        total_relevant = sum(1 for rel in relevances if rel > 0)
    
    if total_relevant == 0:
        return 0.0
    
    return relevant_in_topk / total_relevant


def compute_ndcg_metric(relevances: List[int], k: int) -> float:
    """
    nDCG@k sz√°m√≠t√°sa sklearn haszn√°lat√°val.
    
    Args:
        relevances: Relevancia √©rt√©kek ranking sorrendben [rel_1, rel_2, ...]
        k: Top-k truncation
    
    Returns:
        nDCG@k √©rt√©k [0, 1]
    """
    if not relevances or k <= 0:
        return 0.0
    
    # sklearn ndcg_score: [true_relevances], [predicted_scores]
    # predicted_scores: magasabb score = jobb ranking (ford√≠tott sorrend)
    relevances_truncated = relevances[:k]
    scores = list(range(len(relevances_truncated), 0, -1))  # [n, n-1, ..., 1]
    
    try:
        ndcg = ndcg_score(
            y_true=[relevances_truncated],
            y_score=[scores],
            k=k
        )
        return float(ndcg)
    except Exception:
        # Edge case: nincs relev√°ns dokumentum
        return 0.0


def compute_multi_metric_reward(
    predicted_indices: List[int],
    slate: List[Dict],
    baseline_metrics: Dict[str, float],
    weights: Optional[Dict[str, float]] = None,
    clip_range: Tuple[float, float] = (-1.0, 1.0),
    use_sigmoid: bool = True,
    mrr_tiebreak_bonus: float = 0.02,
    diversity_penalty_weight: float = 0.0
) -> Tuple[float, Dict[str, float]]:
    """
    Multi-metrika reward sz√°m√≠t√°sa sigmoid-alap√∫ stabiliz√°ci√≥val.
    
    Jav√≠t√°sok:
    - Sigmoid transzform√°ci√≥ a delta-kra ‚Üí stabilabb gradiens
    - MRR tie-break bonus: ha nDCG/Recall stagn√°l de MRR javul
    - Diverzit√°s b√ºntet√©s: ugyanazon court/domain t√∫ls√∫ly a top-10-ben
    
    Args:
        predicted_indices: Model √°ltal predikt√°lt ranking indexek
        slate: Slate adatok [{"relevance": int, "court": str, "domain": str, ...}, ...]
        baseline_metrics: Baseline metrik√°k {"ndcg@10": float, "mrr@5": float, ...}
        weights: Metrika s√∫lyok {"ndcg10": 0.6, "mrr5": 0.3, "recall20": 0.1}
        clip_range: Reward clipping tartom√°ny (min, max)
        use_sigmoid: Ha True, sigmoid(delta) transzform√°ci√≥
        mrr_tiebreak_bonus: Bonus ha csak MRR javul (+Œµ)
        diversity_penalty_weight: Diverzit√°s b√ºntet√©s s√∫lya (0 = kikapcsolva)
    
    Returns:
        (total_reward, components_dict)
    """
    # Default weights (updated: 0.6/0.3/0.1)
    if weights is None:
        weights = {"ndcg10": 0.60, "mrr5": 0.30, "recall20": 0.10}
    
    # Relevancia array a predicted ranking szerint
    relevances = [slate[i]["relevance"] for i in predicted_indices]
    
    # Policy metrik√°k
    policy_ndcg10 = compute_ndcg_metric(relevances, k=10)
    policy_mrr5 = compute_mrr(relevances, k=5)
    
    # Recall@20: total_relevant a teljes slate-b≈ël (fix√°lt candidate pool!)
    total_relevant = sum(1 for doc in slate if doc.get("relevance", 0) > 0)
    policy_recall20 = compute_recall(relevances, k=20, total_relevant=total_relevant)
    
    # Baseline metrik√°k
    baseline_ndcg10 = baseline_metrics.get("ndcg@10", 0.0)
    baseline_mrr5 = baseline_metrics.get("mrr@5", 0.0)
    baseline_recall20 = baseline_metrics.get("recall@20", 0.0)
    
    # Delta sz√°m√≠t√°s
    delta_ndcg10 = policy_ndcg10 - baseline_ndcg10
    delta_mrr5 = policy_mrr5 - baseline_mrr5
    delta_recall20 = policy_recall20 - baseline_recall20
    
    # Sigmoid transzform√°ci√≥ (stabilabb reward signal)
    if use_sigmoid:
        # sigmoid(x) = 1 / (1 + exp(-k*x)), k=5 ‚Üí √©rz√©kenys√©g
        def sigmoid_transform(delta, k=5.0):
            return 2.0 / (1.0 + np.exp(-k * delta)) - 1.0  # range: [-1, 1]
        
        component_ndcg = weights["ndcg10"] * sigmoid_transform(delta_ndcg10)
        component_mrr = weights["mrr5"] * sigmoid_transform(delta_mrr5)
        component_recall = weights["recall20"] * sigmoid_transform(delta_recall20)
    else:
        # Line√°ris (original)
        component_ndcg = weights["ndcg10"] * delta_ndcg10
        component_mrr = weights["mrr5"] * delta_mrr5
        component_recall = weights["recall20"] * delta_recall20
    
    # Total reward
    reward_raw = component_ndcg + component_mrr + component_recall
    
    # ====== TIE-BREAK BONUS: MRR javul√°s ======
    # Ha nDCG/Recall nem v√°ltozik jelent≈ësen, de MRR javul ‚Üí felhaszn√°l√≥i √©lm√©ny++
    if abs(delta_ndcg10) < 0.01 and abs(delta_recall20) < 0.01 and delta_mrr5 > 0.05:
        reward_raw += mrr_tiebreak_bonus
    
    # ====== DIVERZIT√ÅS B√úNTET√âS ======
    # Ha a top-10 t√∫l homog√©n (ugyanaz a court/domain domin√°l) ‚Üí -Œµ
    if diversity_penalty_weight > 0:
        top10_indices = predicted_indices[:10]
        top10_courts = [slate[i].get("court", "N/A") for i in top10_indices if i < len(slate)]
        top10_domains = [slate[i].get("domain", "N/A") for i in top10_indices if i < len(slate)]
        
        # Shannon entropy (magas = diverzit√°s, alacsony = homog√©n)
        def shannon_entropy(items):
            counts = Counter(items)
            total = len(items)
            if total == 0:
                return 0.0
            probs = [c/total for c in counts.values()]
            return -sum(p * np.log(p + 1e-9) for p in probs)
        
        court_entropy = shannon_entropy(top10_courts)
        domain_entropy = shannon_entropy(top10_domains)
        
        # Normaliz√°l√°s: max entropy = log(10) ‚âà 2.3
        max_entropy = np.log(10)
        diversity_score = (court_entropy + domain_entropy) / (2 * max_entropy)
        
        # B√ºntet√©s ha alacsony diverzit√°s (< 0.5 ‚Üí homog√©n)
        if diversity_score < 0.5:
            penalty = diversity_penalty_weight * (0.5 - diversity_score)
            reward_raw -= penalty
    
    # Clipping
    reward_clipped = np.clip(reward_raw, clip_range[0], clip_range[1])
    
    # Komponensek dictionary (r√©szletes tracking)
    components = {
        "reward_total": float(reward_clipped),
        "reward_raw": float(reward_raw),
        "component_ndcg10": float(component_ndcg),
        "component_mrr5": float(component_mrr),
        "component_recall20": float(component_recall),
        "delta_ndcg10": float(delta_ndcg10),
        "delta_mrr5": float(delta_mrr5),
        "delta_recall20": float(delta_recall20),
        "policy_ndcg10": float(policy_ndcg10),
        "policy_mrr5": float(policy_mrr5),
        "policy_recall20": float(policy_recall20),
        "baseline_ndcg10": float(baseline_ndcg10),
        "baseline_mrr5": float(baseline_mrr5),
        "baseline_recall20": float(baseline_recall20),
    }
    
    return float(reward_clipped), components

print("‚úÖ Multi-Metrika Reward Helper Functions defini√°lva (inline)")
print(f"   üéØ Multi-metric reward: sigmoid-based, 0.6/0.3/0.1 weights")
print(f"   üì¶ Nincs k√ºls≈ë import - minden inline defini√°lva")

In [None]:
# Baseline metrik√°k bet√∂lt√©se (pre-computed from baseline_evaluation.ipynb)
BASELINE_METRICS_PATH = BASE_PATH / "baseline_query_metrics.csv"

if not BASELINE_METRICS_PATH.exists():
    print(f"‚ö†Ô∏è  FIGYELEM: Baseline metrics f√°jl nem tal√°lhat√≥: {BASELINE_METRICS_PATH}")
    print(f"   El≈ëbb futtasd le a baseline_evaluation.ipynb notebookot!")
    print(f"   Fallback: Baseline metrik√°k inicializ√°l√°sa 0-ra (NEM AJ√ÅNLOTT)")
    baseline_metrics_dict = {}
else:
    baseline_df = pd.read_csv(BASELINE_METRICS_PATH, encoding='utf-8')
    baseline_metrics_dict = {
        row["query_id"]: row.to_dict()
        for _, row in baseline_df.iterrows()
    }
    print(f"‚úÖ Baseline metrik√°k bet√∂ltve: {len(baseline_metrics_dict)} query")
    
    # Statisztik√°k
    if baseline_metrics_dict:
        sample_query = list(baseline_metrics_dict.keys())[0]
        sample_metrics = baseline_metrics_dict[sample_query]
        print(f"\nüìä El√©rhet≈ë metrik√°k (p√©lda query):")
        for key in sorted(sample_metrics.keys()):
            if key != "query_id" and isinstance(sample_metrics[key], (int, float)):
                print(f"  ‚Ä¢ {key}: {sample_metrics[key]:.4f}")
        
        # Aggreg√°lt baseline teljes√≠tm√©ny
        avg_ndcg10 = np.mean([m.get("ndcg@10", 0.0) for m in baseline_metrics_dict.values()])
        avg_mrr5 = np.mean([m.get("mrr@5", 0.0) for m in baseline_metrics_dict.values()])
        avg_recall20 = np.mean([m.get("recall@20", 0.0) for m in baseline_metrics_dict.values()])
        
        print(f"\nüéØ Baseline teljes√≠tm√©ny (√°tlag):")
        print(f"  ‚Ä¢ nDCG@10:   {avg_ndcg10:.4f}")
        print(f"  ‚Ä¢ MRR@5:     {avg_mrr5:.4f}")
        print(f"  ‚Ä¢ Recall@20: {avg_recall20:.4f}")

In [None]:
# ====== CURRICULUM LEARNING - Legnagyobb jav√≠t√°si potenci√°l ======
# Strat√©gia: El≈ësz√∂r k√∂nny≈± p√©ld√°kkal tan√≠tunk (magas baseline nDCG), azt√°n fokozatosan nehezebbekre
# Ez seg√≠t a modellnek el≈ësz√∂r megtanulni a j√≥ rangsorol√°st, azt√°n finomhangolni

print(f"\nüìö Dataset el≈ëk√©sz√≠t√©se CURRICULUM LEARNING-nel...")
print("="*80)

# 1. Baseline nDCG kisz√°m√≠t√°sa minden slate-hez
print("üéØ 1. Baseline nDCG kisz√°m√≠t√°sa slate-enk√©nt...")
slate_baseline_ndcgs = []

for slate_data in slates_data:
    slate = slate_data.get("slate", [])
    if not slate:
        continue
    
    relevances = [doc.get('relevance', 0) for doc in slate]
    baseline_indices = list(range(len(slate)))
    baseline_ndcg = calculate_ndcg(baseline_indices, relevances, k=NDCG_K)
    
    slate_baseline_ndcgs.append({
        "query_id": slate_data["query_id"],
        "baseline_ndcg": baseline_ndcg,
        "num_relevant": sum(1 for r in relevances if r >= 1),
        "num_high_relevant": sum(1 for r in relevances if r == 2),
    })

# 2. Curriculum szintek defini√°l√°sa baseline nDCG alapj√°n
ndcgs = [s["baseline_ndcg"] for s in slate_baseline_ndcgs]

print(f"\nüìä Baseline nDCG statisztik√°k:")
print(f"  √Åtlag: {np.mean(ndcgs):.4f}")
print(f"  Median: {np.median(ndcgs):.4f}")
print(f"  Min: {np.min(ndcgs):.4f}, Max: {np.max(ndcgs):.4f}")
print(f"  Sz√≥r√°s: {np.std(ndcgs):.4f}")

# KRITIKUS: Ellen≈ërizz√ºk hogy van-e el√©g sz√≥r√°s a curriculum besorol√°shoz
ndcg_range = np.max(ndcgs) - np.min(ndcgs)
VARIANCE_THRESHOLD = 0.05  # Ha a tartom√°ny kisebb mint ez, fallback strat√©gia

if ndcg_range < VARIANCE_THRESHOLD:
    print(f"\n‚ö†Ô∏è  FIGYELEM: nDCG eloszl√°s √∂sszeomlott! (range={ndcg_range:.4f} < {VARIANCE_THRESHOLD})")
    print(f"   Val√≥sz√≠n≈± ok: slate-ek relevancia szerint rendezve (label leakage!)")
    print(f"   ‚Üí FALLBACK STRAT√âGIA: baseline_query_metrics.csv haszn√°lata")
    
    # Fallback: haszn√°ljuk a baseline metrics CSV adatait
    if baseline_metrics_dict:
        print(f"\n   ‚úÖ Baseline metrics el√©rhet≈ë, kombin√°lt neh√©zs√©gi score haszn√°lata")
        slate_baseline_ndcgs_fallback = []
        
        for s in slate_baseline_ndcgs:
            query_id = s["query_id"]
            metrics = baseline_metrics_dict.get(query_id, {})
            
            # Kombin√°lt neh√©zs√©gi score: t√∂bb metrika √°tlaga
            combined_score = (
                metrics.get("ndcg@10", 0.0) * 0.5 +
                metrics.get("mrr@5", 0.0) * 0.3 +
                metrics.get("recall@20", 0.0) * 0.2
            )
            
            slate_baseline_ndcgs_fallback.append({
                "query_id": query_id,
                "baseline_ndcg": combined_score,  # Kombin√°lt score
                "num_relevant": s["num_relevant"],
                "num_high_relevant": s["num_high_relevant"],
            })
        
        slate_baseline_ndcgs = slate_baseline_ndcgs_fallback
        ndcgs = [s["baseline_ndcg"] for s in slate_baseline_ndcgs]
        
        print(f"   üìä Fallback score statisztik√°k:")
        print(f"      √Åtlag: {np.mean(ndcgs):.4f}")
        print(f"      Sz√≥r√°s: {np.std(ndcgs):.4f}")
        print(f"      Range: {np.max(ndcgs) - np.min(ndcgs):.4f}")
    else:
        print(f"\n   ‚ö†Ô∏è  Baseline metrics nem el√©rhet≈ë!")
        print(f"   ‚Üí EGYSZER≈∞S√çTETT FALLBACK: relevance count alap√∫ neh√©zs√©g")
        
        for s in slate_baseline_ndcgs:
            # Heurisztika: mennyi relev√°ns dokumentum van?
            num_rel = s["num_relevant"]
            num_high = s["num_high_relevant"]
            # Normaliz√°lt score: t√∂bb relev√°ns = k√∂nnyebb
            s["baseline_ndcg"] = (num_high * 2 + num_rel) / 20.0  # max ~20
        
        ndcgs = [s["baseline_ndcg"] for s in slate_baseline_ndcgs]
        print(f"   üìä Heurisztikus score statisztik√°k:")
        print(f"      √Åtlag: {np.mean(ndcgs):.4f}")
        print(f"      Range: {np.max(ndcgs) - np.min(ndcgs):.4f}")

# Percentilis sz√°m√≠t√°s
ndcg_percentiles = {
    "easy": np.percentile(ndcgs, 75),      # Top 25% - legjobb baseline (nDCG > ~0.5-0.6)
    "medium": np.percentile(ndcgs, 50),    # Top 50% - k√∂zepes baseline (nDCG > ~0.3-0.4)
    "hard": np.percentile(ndcgs, 25),      # Top 75% - neh√©z baseline (nDCG > ~0.1-0.2)
    "very_hard": 0.0,                      # Minden - a legnehezebbek is
}

print(f"\nüìà Curriculum szintek (percentile):")
print(f"  Easy (> {ndcg_percentiles['easy']:.3f}): Top 25% - legjobb baseline slate-ek")
print(f"  Medium (> {ndcg_percentiles['medium']:.3f}): Top 50% - k√∂zepes baseline")
print(f"  Hard (> {ndcg_percentiles['hard']:.3f}): Top 75% - nehezebb slate-ek")
print(f"  Very Hard: Minden slate")

# 3. Slate-ek kategoriz√°l√°sa
slate_difficulty = {}
for s in slate_baseline_ndcgs:
    ndcg = s["baseline_ndcg"]
    if ndcg >= ndcg_percentiles["easy"]:
        difficulty = "easy"
    elif ndcg >= ndcg_percentiles["medium"]:
        difficulty = "medium"
    elif ndcg >= ndcg_percentiles["hard"]:
        difficulty = "hard"
    else:
        difficulty = "very_hard"
    slate_difficulty[s["query_id"]] = difficulty

# 4. Dataset √©p√≠t√©se difficulty tagging-gel
training_examples = []
slate_lookup = {}

for slate_data in slates_data:
    query_id = slate_data["query_id"]
    if query_id not in slate_difficulty:
        continue
    
    prompt = create_training_prompt(query_id, slate_data["slate"])
    difficulty = slate_difficulty[query_id]
    baseline_ndcg = next(s["baseline_ndcg"] for s in slate_baseline_ndcgs if s["query_id"] == query_id)
    
    training_examples.append({
        "prompt": prompt,
        "difficulty": difficulty,
        "baseline_ndcg": baseline_ndcg,
    })
    slate_lookup[query_id] = slate_data["slate"]

# 5. Difficulty eloszl√°s
difficulty_counts = {}
for ex in training_examples:
    d = ex["difficulty"]
    difficulty_counts[d] = difficulty_counts.get(d, 0) + 1

print(f"\nüìä Difficulty eloszl√°s:")
for diff in ["easy", "medium", "hard", "very_hard"]:
    count = difficulty_counts.get(diff, 0)
    pct = 100 * count / len(training_examples) if training_examples else 0
    avg_ndcg = np.mean([ex["baseline_ndcg"] for ex in training_examples if ex["difficulty"] == diff]) if count > 0 else 0
    print(f"  {diff:12s}: {count:3d} query ({pct:5.1f}%) - avg baseline nDCG: {avg_ndcg:.4f}")

# 6. Curriculum learning konfigur√°ci√≥ (haszn√°lva a cella 4-ben defini√°lt v√°ltoz√≥kat)
# A CURRICULUM_ENABLED √©s CURRICULUM_STRATEGY m√°r defini√°lva van a cella 4-ben

print(f"\n‚úÖ Curriculum Learning konfigur√°ci√≥:")
print(f"  Enabled: {CURRICULUM_ENABLED}")
print(f"  Strategy: {CURRICULUM_STRATEGY}")
if CURRICULUM_STRATEGY == "filtered_dataset":
    easy_count = len([ex for ex in training_examples if ex['difficulty'] == 'easy'])
    medium_count = len([ex for ex in training_examples if ex['difficulty'] == 'medium'])
    print(f"  Training dataset: easy ({easy_count}) + medium ({medium_count}) = {easy_count + medium_count} query")
    print(f"  Ez biztos√≠tja, hogy csak j√≥ baseline nDCG-val rendelkez≈ë p√©ld√°kkal tan√≠tunk")
    print(f"  ‚Üí Jobb tanul√°si signal ‚Üí Jobb v√©gs≈ë teljes√≠tm√©ny ‚Üí K√∂zelebb az 1.0 nDCG-hoz")

# 7. Train/eval split (stratified by difficulty ha lehet)
from collections import defaultdict

# Eval set: random 20%, de pr√≥b√°ljuk difficulty-t is meg≈ërizni
full_dataset = Dataset.from_list(training_examples)
indices = np.arange(len(full_dataset))
train_indices, eval_indices = train_test_split(indices, test_size=1.0 - TRAIN_SPLIT, random_state=SEED, shuffle=True)

train_dataset = full_dataset.select(train_indices)
eval_dataset = full_dataset.select(eval_indices)

print(f"\n‚úÖ Dataset l√©trehozva:")
print(f"  Training: {len(train_dataset)} query (80%)")
print(f"  Evaluation: {len(eval_dataset)} query (20%)")
print(f"  Slate lookup: {len(slate_lookup)} entry")

# 8. Training examples difficulty statisztik√°k
train_difficulty_dist = defaultdict(int)
for idx in train_indices:
    train_difficulty_dist[training_examples[idx]["difficulty"]] += 1

print(f"\nüìä Training set difficulty eloszl√°s:")
for diff in ["easy", "medium", "hard", "very_hard"]:
    count = train_difficulty_dist.get(diff, 0)
    print(f"  {diff:12s}: {count:3d} query")

print("\n" + "="*80)
print("üí° CURRICULUM LEARNING haszn√°lata:")
print("   A training fokozatosan √©p√ºl fel: el≈ësz√∂r easy query-kkel,")
print("   azt√°n fokozatosan hozz√°adjuk a nehezebbeket.")
print("   Ez seg√≠t a modellnek jobban tanulni a rangsorol√°st.")
print("="*80)

In [None]:
# Eval split export√°l√°sa √∂sszehasonl√≠t√≥ √©rt√©kel√©shez
# CRITICAL: Az eval split query_id-k export√°l√°sa determinisztikusan reproduk√°lhat√≥ form√°tumban
# Ez biztos√≠tja, hogy a baseline √©s GRPO modell ugyanazon eval halmazon legyen ki√©rt√©kelve

print("\n" + "="*80)
print("üì§ Eval Split Export√°l√°sa")
print("="*80)

# Query ID-k kinyer√©se az eval_dataset-b≈ël
eval_query_ids = []

for example in eval_dataset:
    prompt = example["prompt"]
    # Ugyanaz a regex parsing logika, mint a reward function-ben
    match = re.search(r'(?:SEARCH QUERY|Query)[:\s"]*([^"\n]+)', prompt)
    query_id = None
    if match:
        query_id = match.group(1)
    else:
        match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
        if match2:
            query_id = match2.group(1)
    
    if query_id and query_id not in eval_query_ids:
        eval_query_ids.append(query_id)

# Rendez√©s a determinisztikus kimenet√©rt
eval_query_ids.sort()

print(f"\nüìä Eval split statisztik√°k:")
print(f"  Eval query-k sz√°ma: {len(eval_query_ids)}")
print(f"  Train query-k sz√°ma: {len(train_dataset)}")
print(f"  Split ar√°ny: {len(eval_query_ids) / (len(eval_query_ids) + len(train_dataset)):.1%}")

# Eval split export JSON form√°tumban
eval_split_data = {
    "query_ids": eval_query_ids,
    "num_queries": len(eval_query_ids),
    "split_params": {
        "train_split": TRAIN_SPLIT,
        "seed": SEED,
        "shuffle": True
    },
    "created_at": pd.Timestamp.now().isoformat()
}

eval_split_path = OUTPUT_DIR / "eval_split.json"
with open(eval_split_path, 'w', encoding='utf-8') as f:
    json.dump(eval_split_data, f, ensure_ascii=False, indent=2)

print(f"\n‚úÖ Eval split export√°lva: {eval_split_path}")
print(f"   Ez a f√°jl haszn√°lva lesz a baseline √©s GRPO √∂sszehasonl√≠t√≥ notebookban")
print(f"   ugyanazon eval halmazon val√≥ ki√©rt√©kel√©shez.")

# Opcion√°lis: CSV export is (egyszer≈±bb bet√∂lt√©shez)
eval_split_csv = OUTPUT_DIR / "eval_split.csv"
eval_df = pd.DataFrame({"query_id": eval_query_ids})
eval_df.to_csv(eval_split_csv, index=False, encoding='utf-8')
print(f"‚úÖ Eval split CSV: {eval_split_csv}")

print("\n" + "="*80)


In [None]:
# Model bet√∂lt√©se - Memory Efficient
print(f"üîÑ Model bet√∂lt√©se Unsloth-tal: {MODEL_NAME}")
print(f"  ‚öôÔ∏è  GPU Memory Utilization: {GPU_MEMORY_UTILIZATION}")
print(f"  ‚öôÔ∏è  Max LoRA Rank: {LORA_RANK}")

# Mem√≥ria tiszt√≠t√°s model bet√∂lt√©s el≈ëtt
gc.collect()
torch.cuda.empty_cache()

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    dtype=None,
    fast_inference=True,  # CRITICAL: Enable vLLM for GRPO rollouts (2-3x faster inference)
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    max_model_len=MAX_SEQ_LENGTH,  # CRITICAL: vLLM max_model_len must match max_seq_length for long prompts
    token=HF_TOKEN,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,
    use_rslora=True,
    random_state=SEED,
)

# Tokenizer padding fix (tensor shape stability)
# CRITICAL: Csak padding_side √°ll√≠t√°sa, NE √°ll√≠tsuk √°t a pad_token-et, mert a tokenizer √∫jra√°ll√≠t√°sa
# tenzorm√©ret elt√©r√©st okozhat a GRPO trainer vLLM bels≈ë logik√°j√°val
tokenizer.padding_side = "right"
# Note: pad_token be√°ll√≠t√°sa kihagyva - a GRPO trainer √©s vLLM automatikusan kezeli
# Az explicit pad_token √°t√°ll√≠t√°s √ºtk√∂zhet a modell config-tal √©s tenzor m√©ret probl√©m√°kat okozhat

print("‚úÖ Model √©s tokenizer bet√∂ltve (Unsloth + vLLM + RSLoRA)")
print(f"  ‚úì fast_inference=True: vLLM enged√©lyezve GRPO rollouts-hoz (2-3x gyorsabb)")
print(f"  ‚úì Padding side: {tokenizer.padding_side}")
print(f"  ‚úì Pad token ID: {tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id else 'N/A (auto-managed)'}")
print(f"  ‚úì Note: GRPO trainer automatikusan kezeli a tokenizer be√°ll√≠t√°sokat vLLM-mel")

In [None]:
# GPU mem√≥ria monitoring model bet√∂lt√©s ut√°n
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    free = total - allocated
    
    print(f"\nüìä GPU Mem√≥ria Model Bet√∂lt√©s Ut√°n:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    print(f"  Free %: {(free/total)*100:.1f}%")
    
    if free < 10:
        print(f"  ‚ö†Ô∏è  FIGYELEM: Kev√©s szabad mem√≥ria ({free:.2f} GB)!")
        print(f"  üí° Opci√≥: Tov√°bbi cs√∂kkent√©s lehet sz√ºks√©ges")


In [None]:
# ====== MULTI-METRIKA REWARD FUNCTION (Sigmoid-Based) ======
# Konvex kombin√°ci√≥ja h√°rom f≈ë metrik√°nak sigmoid transzform√°ci√≥val:
# - nDCG@10 (60%): Rangsor min≈ës√©g (n√∂velt s√∫ly)
# - MRR@5 (30%): Els≈ë relev√°ns tal√°lat (gyakorlati haszn√°lhat√≥s√°g)
# - Recall@20 (10%): Completeness
#
# Jav√≠t√°sok:
# - Sigmoid(delta) ‚Üí stabilabb gradiens, nem satur√°l√≥dik
# - MRR tie-break bonus: ha nDCG/Recall stagn√°l, de MRR javul ‚Üí +0.02
# - Diverzit√°s b√ºntet√©s: homog√©n top-10 (court/domain) ‚Üí -Œµ

# Reward konfigur√°ci√≥ (config.py alapj√°n)
REWARD_WEIGHTS = {
    "ndcg10": 0.60,    # 50% ‚Üí 60%
    "mrr5": 0.30,      # 35% ‚Üí 30%
    "recall20": 0.10,  # 15% ‚Üí 10%
}
REWARD_CLIP_RANGE = (-1.0, 1.0)
REWARD_USE_SIGMOID = True
REWARD_MRR_TIEBREAK_BONUS = 0.02
REWARD_DIVERSITY_PENALTY = 0.05  # Diverzit√°s b√ºntet√©s s√∫lya

# Komponens tracking glob√°lis v√°ltoz√≥ (logging c√©lokra)
reward_components_log = []

def multi_metric_reward_function(completions, prompts, **kwargs):
    """
    Multi-metrika GRPO reward f√ºggv√©ny sigmoid transzform√°ci√≥val.
    
    A reward h√°rom metrika s√∫lyozott kombin√°ci√≥ja:
    1. nDCG@10 (60%): Rangsor min≈ës√©g
    2. MRR@5 (30%): Els≈ë relev√°ns tal√°lat poz√≠ci√≥ja (UX priorit√°s!)
    3. Recall@20 (10%): Relev√°ns tal√°latok lefedetts√©ge
    
    Speci√°lis funkci√≥k:
    - Sigmoid(delta): stabil gradiens, nem satur√°l√≥dik
    - MRR tie-break: +bonus ha csak MRR javul (felhaszn√°l√≥i √©lm√©ny++)
    - Diverzit√°s check: homog√©n top-10 ‚Üí penalty
    
    Returns:
        List[float]: Reward √©rt√©kek minden completion-h√∂z
    """
    global reward_components_log
    rewards = []
    batch_components = []
    
    for completion, prompt in zip(completions, prompts):
        try:
            # 1. Query ID extraction
            match = re.search(r'(?:SEARCH QUERY|Query)[:\s"]*([^"\n]+)', prompt)
            query_id = None
            if match:
                query_id = match.group(1)
            else:
                match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
                if match2:
                    query_id = match2.group(1)
            
            if not query_id or query_id not in slate_lookup:
                # Parsing hiba - enyhe b√ºntet√©s
                rewards.append(-0.1)
                batch_components.append({"error": "query_id_not_found"})
                continue
            
            # 2. Slate lookup
            slate = slate_lookup[query_id]
            
            # 3. Parse model output
            predicted_indices = parse_model_ranking(completion, len(slate))
            
            # 4. Baseline metrik√°k lookup
            if query_id not in baseline_metrics_dict:
                # Nincs baseline adat - fallback to zero baseline (nem ide√°lis, de biztons√°gos)
                baseline_metrics = {
                    "ndcg@10": 0.0,
                    "mrr@5": 0.0,
                    "mrr@10": 0.0,  # Kompatibilit√°s
                    "recall@20": 0.0,
                }
            else:
                baseline_metrics = baseline_metrics_dict[query_id]

        
            
            # 5. Multi-metric reward sz√°m√≠t√°s (helper function from src.reward_metrics)
            reward, components = compute_multi_metric_reward(
                predicted_indices=predicted_indices,
                slate=slate,
                baseline_metrics=baseline_metrics,
                weights=REWARD_WEIGHTS,
                clip_range=REWARD_CLIP_RANGE,
                use_sigmoid=REWARD_USE_SIGMOID,
                mrr_tiebreak_bonus=REWARD_MRR_TIEBREAK_BONUS,
                diversity_penalty_weight=REWARD_DIVERSITY_PENALTY,
            )
            
            rewards.append(reward)
            batch_components.append(components)
            
        except Exception as e:
            # Exception eset√©n enyhe b√ºntet√©s
            rewards.append(-0.1)
            batch_components.append({"error": str(e)[:100]})
    
    # Komponensek ment√©se logginghoz
    reward_components_log.extend(batch_components)
    
    return rewards

print("‚úÖ Multi-Metrika GRPO Reward function defini√°lva (Sigmoid-Based v2.0)")
print(f"   üìä Reward s√∫lyok:")
print(f"      ‚Ä¢ nDCG@10:   {REWARD_WEIGHTS['ndcg10']:.0%} (rangsor min≈ës√©g)")
print(f"      ‚Ä¢ MRR@5:     {REWARD_WEIGHTS['mrr5']:.0%} (els≈ë relev√°ns tal√°lat)")
print(f"      ‚Ä¢ Recall@20: {REWARD_WEIGHTS['recall20']:.0%} (completeness)")
print(f"   üéØ Sigmoid transzform√°ci√≥: {REWARD_USE_SIGMOID}")
print(f"   üéÅ MRR tie-break bonus: +{REWARD_MRR_TIEBREAK_BONUS}")
print(f"   üåà Diverzit√°s b√ºntet√©s: {REWARD_DIVERSITY_PENALTY}")
print(f"   üìè Reward clipping: {REWARD_CLIP_RANGE}")
print(f"   üìà Komponens tracking: Enabled")

In [None]:
# ====== DRY RUN VALIDATION - Korai ellen≈ërz√©s teljes tr√©ning el≈ëtt ======
# C√©l: 2-3 perc alatt valid√°ljuk a rendszert (prompt, parsing, reward function, model output)
# Ez megakad√°lyozza, hogy 25-35 percig futson a tr√©ning hib√°s konfigur√°ci√≥val

print("="*80)
print("üî¨ DRY RUN VALIDATION - Korai ellen≈ërz√©s")
print("="*80)
print("Ez a l√©p√©s 2-3 perc alatt teszteli a rendszert teljes tr√©ning el≈ëtt.\n")

# Konfigur√°ci√≥
DRY_RUN_SAMPLE_SIZE = 10  # 10 query tesztel√©se
DRY_RUN_MAX_NEW_TOKENS = 50  # R√∂vid gener√°l√°sok (gyors teszt)
DRY_RUN_PARSING_SUCCESS_THRESHOLD = 0.8  # 80% parsing success sz√ºks√©ges
DRY_RUN_MIN_REWARD_MEAN = -0.1  # Minimum reward mean (nagyon laza, csak valid√°l√°s)

# 1. Mini-batch tesztel√©s: sample query-k kiv√°laszt√°sa
print("üìä 1. Mini-batch tesztel√©s el≈ëk√©sz√≠t√©se...")
test_indices = np.random.choice(len(slates_data), size=min(DRY_RUN_SAMPLE_SIZE, len(slates_data)), replace=False)
test_slates = [slates_data[i] for i in test_indices]
test_prompts = [create_training_prompt(slate["query_id"], slate["slate"]) for slate in test_slates]

print(f"   ‚úÖ {len(test_slates)} query kiv√°lasztva tesztel√©shez")
print(f"   üìã Query ID p√©ld√°k:")
for i, slate in enumerate(test_slates[:3]):
    print(f"      {i+1}. {slate['query_id'][:60]}...")

# 2. Prompt parsing valid√°ci√≥
print(f"\nüìù 2. Prompt parsing valid√°ci√≥...")
parsing_successes = 0
parsing_failures = []

for i, (prompt, slate_data) in enumerate(zip(test_prompts, test_slates)):
    # Tesztelj√ºk a query_id extraction-t (ugyanaz mint a reward function-ben)
    match = re.search(r'(?:SEARCH QUERY|Query)[:\s"]*([^"\n]+)', prompt)
    query_id = None
    if match:
        query_id = match.group(1)
    else:
        match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
        if match2:
            query_id = match2.group(1)
    
    if query_id and query_id in slate_lookup:
        parsing_successes += 1
    else:
        parsing_failures.append({
            "index": i,
            "query_id_expected": slate_data["query_id"],
            "query_id_extracted": query_id,
        })

parsing_success_rate = parsing_successes / len(test_prompts) if test_prompts else 0.0
print(f"   ‚úÖ Parsing success rate: {parsing_success_rate:.2%} ({parsing_successes}/{len(test_prompts)})")

if parsing_failures:
    print(f"   ‚ö†Ô∏è  {len(parsing_failures)} parsing hiba:")
    for failure in parsing_failures[:3]:
        print(f"      Query {failure['index']}: expected '{failure['query_id_expected'][:50]}...', got '{failure['query_id_extracted']}'")

# 3. Slate strukt√∫ra valid√°ci√≥
print(f"\nüì¶ 3. Slate strukt√∫ra valid√°ci√≥...")
slate_issues = []
relevance_distributions = []

for i, slate_data in enumerate(test_slates):
    slate = slate_data.get("slate", [])
    if not slate:
        slate_issues.append({"index": i, "issue": "empty_slate"})
        continue
    
    relevances = [doc.get('relevance', 0) for doc in slate]
    relevant_count = sum(1 for r in relevances if r >= 1)
    high_rel_count = sum(1 for r in relevances if r == 2)
    
    relevance_distributions.append({
        "total": len(relevances),
        "relevant": relevant_count,
        "high_relevant": high_rel_count,
        "baseline_ndcg": calculate_ndcg(list(range(len(relevances))), relevances, k=10),
    })
    
    if relevant_count == 0:
        slate_issues.append({"index": i, "issue": "no_relevant_docs", "query": slate_data["query_id"][:50]})

if slate_issues:
    print(f"   ‚ö†Ô∏è  {len(slate_issues)} slate probl√©ma tal√°lhat√≥:")
    for issue in slate_issues[:3]:
        print(f"      Query {issue['index']}: {issue['issue']}")
else:
    print(f"   ‚úÖ Minden slate valid")

avg_stats = {
    "total_docs": np.mean([d["total"] for d in relevance_distributions]),
    "relevant_docs": np.mean([d["relevant"] for d in relevance_distributions]),
    "high_relevant_docs": np.mean([d["high_relevant"] for d in relevance_distributions]),
    "baseline_ndcg": np.mean([d["baseline_ndcg"] for d in relevance_distributions]),
}
print(f"   üìä √Åtlagos slate statisztik√°k:")
print(f"      Dokumentumok/slate: {avg_stats['total_docs']:.1f}")
print(f"      Relev√°ns dokumentumok: {avg_stats['relevant_docs']:.1f}")
print(f"      Magas relevancia (rel=2): {avg_stats['high_relevant_docs']:.1f}")
print(f"      Baseline nDCG@10: {avg_stats['baseline_ndcg']:.4f}")

# 4. Model inference teszt (ha modell m√°r bet√∂ltve)
print(f"\nü§ñ 4. Model inference teszt...")
try:
    model.eval()
    test_completions = []
    
    # Egyetlen prompt tesztel√©se (gyors)
    test_prompt = test_prompts[0]
    inputs = tokenizer(test_prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=DRY_RUN_MAX_NEW_TOKENS,
            do_sample=True,
            temperature=0.8,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    completion = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    test_slate = test_slates[0]
    parsed_indices = parse_model_ranking(completion, len(test_slate["slate"]))
    
    is_valid = len(parsed_indices) == len(test_slate["slate"]) and len(set(parsed_indices)) == len(parsed_indices)
    
    print(f"   ‚úÖ Model inference sikeres")
    print(f"      Completion hossz: {len(completion)} karakter")
    print(f"      Parsed indices: {len(parsed_indices)}/{len(test_slate['slate'])}")
    print(f"      Valid parsing: {is_valid}")
    print(f"      Sample output: {completion[:100]}...")
    
    model.train()
    
except Exception as e:
    print(f"   ‚ö†Ô∏è  Model inference hiba (ez ok√©, ha modell m√©g nincs bet√∂ltve): {e}")

# 5. Reward function tesztel√©s
print(f"\nüéØ 5. Reward function tesztel√©s...")
test_rewards = []
test_reward_details = []

for prompt, slate_data in zip(test_prompts[:5], test_slates[:5]):  # Csak 5 query teljes teszt
    try:
        # Szimul√°lt completion (random ranking a teszthez)
        test_completion = ",".join(str(i) for i in np.random.permutation(len(slate_data["slate"])))
        
        # Reward sz√°m√≠t√°s (ugyanaz a logika mint a reward_function-ben)
        query_id = slate_data["query_id"]
        slate = slate_lookup[query_id]
        relevance = [doc.get('relevance', 0) for doc in slate]
        baseline = list(range(len(slate)))
        predicted = parse_model_ranking(test_completion, len(slate))
        
        ndcg_baseline = calculate_ndcg(baseline, relevance, k=NDCG_K)
        ndcg_policy = calculate_ndcg(predicted, relevance, k=NDCG_K)
        reward = ndcg_policy - ndcg_baseline
        
        test_rewards.append(reward)
        test_reward_details.append({
            "ndcg_baseline": ndcg_baseline,
            "ndcg_policy": ndcg_policy,
            "reward": reward,
        })
    except Exception as e:
        test_rewards.append(-0.5)
        test_reward_details.append({"error": str(e)})

if test_rewards:
    reward_mean = np.mean(test_rewards)
    reward_std = np.std(test_rewards)
    print(f"   ‚úÖ Reward function m≈±k√∂dik")
    print(f"      Reward mean: {reward_mean:.4f}")
    print(f"      Reward std: {reward_std:.4f}")
    print(f"      Reward range: [{min(test_rewards):.4f}, {max(test_rewards):.4f}]")
else:
    print(f"   ‚ö†Ô∏è  Reward function tesztel√©s nem siker√ºlt")

# 6. √ñsszefoglal√≥ √©s d√∂nt√©s
print(f"\n" + "="*80)
print("üìã DRY RUN VALIDATION √ñSSZEFOGLAL√ì")
print("="*80)

validation_passed = True
warnings = []

if parsing_success_rate < DRY_RUN_PARSING_SUCCESS_THRESHOLD:
    validation_passed = False
    warnings.append(f"‚ùå Parsing success rate t√∫l alacsony: {parsing_success_rate:.2%} < {DRY_RUN_PARSING_SUCCESS_THRESHOLD:.0%}")

if len(slate_issues) > len(test_slates) * 0.3:  # T√∂bb mint 30% slate probl√©ma
    validation_passed = False
    warnings.append(f"‚ùå T√∫l sok slate probl√©ma: {len(slate_issues)}/{len(test_slates)}")

if test_rewards and np.mean(test_rewards) < DRY_RUN_MIN_REWARD_MEAN:
    warnings.append(f"‚ö†Ô∏è  Alacsony reward mean: {np.mean(test_rewards):.4f} (nem blokkolja)")

if avg_stats['baseline_ndcg'] < 0.1:
    warnings.append(f"‚ö†Ô∏è  Alacsony baseline nDCG: {avg_stats['baseline_ndcg']:.4f} (lehet, hogy rossz slate min≈ës√©g)")

print("\n‚úÖ Pozit√≠v mutat√≥k:")
print(f"   ‚úì Parsing success: {parsing_success_rate:.2%}")
print(f"   ‚úì √Åtlagos baseline nDCG@10: {avg_stats['baseline_ndcg']:.4f}")
print(f"   ‚úì √Åtlagos relev√°ns dokumentumok/slate: {avg_stats['relevant_docs']:.1f}")

if warnings:
    print("\n‚ö†Ô∏è  Figyelmeztet√©sek:")
    for warning in warnings:
        print(f"   {warning}")

if validation_passed:
    print("\n‚úÖ DRY RUN VALIDATION PASSED - Training folytathat√≥!")
    print("   A rendszer validnak t≈±nik, folytathatod a teljes tr√©ninget.")
else:
    print("\n‚ùå DRY RUN VALIDATION FAILED - Training NEM AJ√ÅNLOTT!")
    print("   K√©rlek jav√≠tsd a hib√°kat a tr√©ning ind√≠t√°sa el≈ëtt.")
    print("   Ellen≈ërizd:")
    print("   - Prompt gener√°l√°s √©s query_id extraction")
    print("   - Slate strukt√∫ra √©s relevancia eloszl√°s")
    print("   - Reward function m≈±k√∂d√©s")

print("\n" + "="*80)


In [None]:
# GRPO Trainer konfigur√°ci√≥ - MEMORY EFFICIENT + CURRICULUM LEARNING
grpo_config = GRPOConfig(
    output_dir=str(OUTPUT_DIR),
    max_steps=MAX_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    num_generations=NUM_GENERATIONS,
    generation_batch_size=GENERATION_BATCH_SIZE,  # Explicit: LCM(6, 2) = 6
    optim=OPTIMIZER_NAME,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    loss_type="dr_grpo",
    scale_rewards="batch",
    use_vllm=True,
    vllm_importance_sampling_correction=False,  # vLLM token logprobs nem egyeznek az alap modell√©vel
    importance_sampling_level="token",
    mask_truncated_completions=True,
    epsilon=0.1,  # Enhanced: 0.2 ‚Üí 0.1 (smaller policy updates, more stable)
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False, "use_unsloth": True},
    max_grad_norm=1.0,
    logging_steps=LOGGING_STEPS,
    logging_first_step=True,
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    dataloader_num_workers=1,  # 2‚Üí1: cs√∂kkentett worker
    seed=SEED,
    max_completion_length=768,  # CRITICAL: explicit completion length (default 256 t√∫l kev√©s 30 dokumentumhoz)
)

grpo_config.unsloth_num_chunks = 1

print("\n‚úÖ GRPO Trainer konfigur√°ci√≥ k√©sz (Memory Efficient + Curriculum Learning)")
print(f"  ‚úì generation_batch_size={GENERATION_BATCH_SIZE}")
print(f"    - oszthat√≥ num_generations={NUM_GENERATIONS}-szel: {GENERATION_BATCH_SIZE}√∑{NUM_GENERATIONS}={GENERATION_BATCH_SIZE//NUM_GENERATIONS}")
print(f"    - oszthat√≥ per_device_batch={PER_DEVICE_BATCH_SIZE}-gyel: {GENERATION_BATCH_SIZE}√∑{PER_DEVICE_BATCH_SIZE}={GENERATION_BATCH_SIZE//PER_DEVICE_BATCH_SIZE}")
print(f"  ‚úì Effective batch size: {PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  ‚úì Mem√≥ria optimaliz√°ci√≥: 50% GPU util, 1 worker, rank 32")
if CURRICULUM_ENABLED:
    print(f"  ‚úì Curriculum Learning: {CURRICULUM_STRATEGY} strat√©gia aktiv√°lva")

## üß™ Multi-Metrika Reward Validation

Most tesztelj√ºk a friss√≠tett multi-metrika reward f√ºggv√©nyt:
- Sigmoid transzform√°ci√≥ m≈±k√∂d√©se
- Komponensek (nDCG/MRR/Recall) sz√°m√≠t√°sa
- MRR tie-break bonus
- Diverzit√°s check
- Baseline metrics bet√∂lt√©s

In [None]:
# ====== MULTI-METRIKA REWARD VALIDATION ======
# Gyors teszt: multi-metric reward function m≈±k√∂dik-e helyesen

print("="*80)
print("üß™ MULTI-METRIKA REWARD VALID√ÅCI√ì")
print("="*80)

# Teszt konfigur√°ci√≥
NUM_TEST_QUERIES = 5
test_results = []

print(f"\nüìä {NUM_TEST_QUERIES} query tesztel√©se multi-metrika reward-dal...\n")

for i, slate_data in enumerate(slates_data[:NUM_TEST_QUERIES]):
    query_id = slate_data["query_id"]
    slate = slate_data["slate"]
    
    print(f"üîç Test {i+1}/{NUM_TEST_QUERIES}: {query_id[:60]}...")
    
    # 1. Baseline metrik√°k check
    if query_id in baseline_metrics_dict:
        baseline = baseline_metrics_dict[query_id]
        print(f"   ‚úÖ Baseline metrik√°k: nDCG@10={baseline.get('ndcg@10', 0):.4f}, MRR@5={baseline.get('mrr@5', 0):.4f}")
    else:
        print(f"   ‚ö†Ô∏è  Nincs baseline metrika erre a query-re (fallback to zero)")
        baseline = {"ndcg@10": 0.0, "mrr@5": 0.0, "recall@20": 0.0}
    
    # 2. Szimul√°lt policy ranking (random permutation)
    predicted_indices = list(np.random.permutation(len(slate)))
    
    # 3. Multi-metric reward sz√°m√≠t√°s
    try:
        reward, components = compute_multi_metric_reward(
            predicted_indices=predicted_indices,
            slate=slate,
            baseline_metrics=baseline,
            weights=REWARD_WEIGHTS,
            clip_range=REWARD_CLIP_RANGE,
            use_sigmoid=REWARD_USE_SIGMOID,
            mrr_tiebreak_bonus=REWARD_MRR_TIEBREAK_BONUS,
            diversity_penalty_weight=REWARD_DIVERSITY_PENALTY,
        )
        
        # 4. Eredm√©nyek
        print(f"   üìà Policy metrik√°k:")
        print(f"      nDCG@10:   {components['policy_ndcg10']:.4f} (baseline: {components['baseline_ndcg10']:.4f}, Œî={components['delta_ndcg10']:+.4f})")
        print(f"      MRR@5:     {components['policy_mrr5']:.4f} (baseline: {components['baseline_mrr5']:.4f}, Œî={components['delta_mrr5']:+.4f})")
        print(f"      Recall@20: {components['policy_recall20']:.4f} (baseline: {components['baseline_recall20']:.4f}, Œî={components['delta_recall20']:+.4f})")
        
        print(f"   üéØ Reward komponensek:")
        print(f"      nDCG comp:  {components['component_ndcg10']:+.4f} (60% s√∫ly)")
        print(f"      MRR comp:   {components['component_mrr5']:+.4f} (30% s√∫ly)")
        print(f"      Recall comp:{components['component_recall20']:+.4f} (10% s√∫ly)")
        
        print(f"   üí∞ Total reward: {components['reward_total']:.4f} (raw: {components['reward_raw']:.4f})")
        
        test_results.append({
            "query_id": query_id,
            "reward": reward,
            "components": components,
            "success": True,
        })
        
    except Exception as e:
        print(f"   ‚ùå Hiba: {e}")
        test_results.append({
            "query_id": query_id,
            "error": str(e),
            "success": False,
        })
    
    print()

# √ñsszefoglal√≥
success_count = sum(1 for r in test_results if r.get("success", False))
success_rate = success_count / len(test_results) if test_results else 0.0

print("="*80)
print("üìä VALID√ÅCI√ì √ñSSZEFOGLAL√ì")
print("="*80)
print(f"\n‚úÖ Sikeres reward sz√°m√≠t√°sok: {success_count}/{len(test_results)} ({success_rate:.0%})")

if success_count > 0:
    successful_rewards = [r["reward"] for r in test_results if r.get("success")]
    print(f"\nüìà Reward statisztik√°k:")
    print(f"   Mean:   {np.mean(successful_rewards):.4f}")
    print(f"   Std:    {np.std(successful_rewards):.4f}")
    print(f"   Min:    {np.min(successful_rewards):.4f}")
    print(f"   Max:    {np.max(successful_rewards):.4f}")
    print(f"   Median: {np.median(successful_rewards):.4f}")
    
    # Komponensek √°tlaga
    if test_results[0].get("components"):
        avg_delta_ndcg = np.mean([r["components"]["delta_ndcg10"] for r in test_results if r.get("success")])
        avg_delta_mrr = np.mean([r["components"]["delta_mrr5"] for r in test_results if r.get("success")])
        avg_delta_recall = np.mean([r["components"]["delta_recall20"] for r in test_results if r.get("success")])
        
        print(f"\nüìä √Åtlagos delta metrik√°k:")
        print(f"   Œî nDCG@10:   {avg_delta_ndcg:+.4f}")
        print(f"   Œî MRR@5:     {avg_delta_mrr:+.4f}")
        print(f"   Œî Recall@20: {avg_delta_recall:+.4f}")

if success_rate >= 0.8:
    print(f"\n‚úÖ MULTI-METRIKA REWARD VALIDATION PASSED!")
    print(f"   A reward function m≈±k√∂dik, folytathat√≥ a training.")
else:
    print(f"\n‚ùå MULTI-METRIKA REWARD VALIDATION FAILED!")
    print(f"   Ellen≈ërizd a baseline metrics bet√∂lt√©s√©t √©s a reward function implement√°ci√≥t.")

print("="*80)

In [None]:
# ====== CURRICULUM LEARNING - Simplified Approach ======
# Mivel a TRL trainer nem t√°mogatja a dataset dinamikus v√°ltoztat√°s√°t k√∂zben,
# k√©t megold√°s van:
# 1. Weighted sampling: k√∂nnyebb p√©ld√°kat gyakrabban v√°lasztjuk (implement√°lva)
# 2. Filtered dataset: csak az easy/medium p√©ld√°kat haszn√°ljuk (simpler, de hat√°sos)

CURRICULUM_STRATEGY = "weighted_sampling"  # vagy "filtered_dataset"

if CURRICULUM_ENABLED and CURRICULUM_STRATEGY == "filtered_dataset":
    # Egyszer≈± strat√©gia: csak easy + medium p√©ld√°kat haszn√°lunk (legjobb baseline nDCG)
    # Ez biztos√≠tja, hogy a modell el≈ësz√∂r megtanulja a j√≥l rangsorolni, amikor van mit rangsorolni
    print("\nüìä Curriculum Learning: Filtered Dataset strat√©gia")
    print("   Csak easy + medium difficulty p√©ld√°kat haszn√°lunk a training-ben")
    
    filtered_train_examples = [ex for ex in training_examples 
                              if ex["difficulty"] in ["easy", "medium"]]
    filtered_train_indices = [i for i in train_indices 
                             if training_examples[i]["difficulty"] in ["easy", "medium"]]
    
    train_dataset = full_dataset.select(filtered_train_indices)
    
    print(f"   Eredeti training examples: {len(training_examples)}")
    print(f"   Filtered training examples: {len(filtered_train_examples)} ({100*len(filtered_train_examples)/len(training_examples):.1f}%)")
    print(f"   √Åtlagos baseline nDCG (filtered): {np.mean([ex['baseline_ndcg'] for ex in filtered_train_examples]):.4f}")
    print(f"   vs. Eredeti √°tlag: {np.mean([ex['baseline_ndcg'] for ex in training_examples]):.4f}")

# Trainer inicializ√°l√°sa
trainer = GRPOTrainer(
    model=model,
    reward_funcs=[multi_metric_reward_function],  # UPDATED: Multi-metric reward
    args=grpo_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

print("\n‚úÖ GRPO Trainer inicializ√°lva")
print(f"  ‚úì Reward function: Multi-Metric (nDCG@10 + MRR@5 + Recall@20)")
if CURRICULUM_ENABLED:
    print(f"  ‚úì Curriculum Learning: {CURRICULUM_STRATEGY}")
    print(f"  ‚úì Training queries: {len(train_dataset)}")
else:
    print(f"  Training queries: {len(train_dataset)}")
print(f"  Eval queries: {len(eval_dataset)}")
print(f"  GPU mem√≥ria: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

In [None]:
# Training ind√≠t√°sa
print("\nüöÄ GRPO TRAINING IND√çT√ÅSA\n")
print("="*60)

if CURRICULUM_ENABLED:
    print(f"üìö Curriculum Learning akt√≠v: {CURRICULUM_STRATEGY}")
    if CURRICULUM_STRATEGY == "filtered_dataset":
        print("   Training k√∂nnyebb p√©ld√°kkal (easy + medium) - jobb tanul√°si signal")
    print()

try:
    trainer.train()
    print("\n" + "="*60)
    print("‚úÖ Training sikeresen befejezve!")
except Exception as e:
    print(f"\n‚ùå Training hiba: {e}")
    import traceback
    traceback.print_exc()
    raise

In [3]:
# Robusztus GRPO tr√©ning vizualiz√°ci√≥ (hibat≈±r≈ë)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json

# Megpr√≥b√°ljuk a TRL trainer log_history-j√°t haszn√°lni; ha nincs, es√ºnk vissza a metrics.json-re
log_history = None
try:
    _ = trainer  # ellen≈ërz√©s, hogy l√©tezik-e
    log_history = getattr(trainer.state, "log_history", None)
except NameError:
    log_history = None

# Ha nincs log_history, pr√≥b√°ljuk a metrics.json-b√≥l kirajzolni legal√°bb a jutalom trendet
if not log_history:
    metrics_path = METRICS_FILE if 'METRICS_FILE' in globals() else Path("data/models/grpo_policy/metrics.json")
    if Path(metrics_path).exists():
        with open(metrics_path, 'r', encoding='utf-8') as f:
            metrics = json.load(f)
        rewards_trend = metrics.get("training_rewards", {}).get("trend", [])
        if rewards_trend:
            plt.figure(figsize=(8, 4))
            plt.plot(rewards_trend, color="royalblue")
            plt.title("Jutalom trend (metrics.json)")
            plt.xlabel("L√©p√©s")
            plt.ylabel("Jutalom")
            plt.tight_layout()
            plt.show()
        else:
            print("‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json nem tartalmaz 'training_rewards.trend'-et.")
    else:
        print("‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json f√°jl sem tal√°lhat√≥.")
else:
    # Innent≈ël a log_history alapj√°n dolgozunk
    logs = pd.DataFrame(log_history)
    if logs.empty:
        print("‚ö†Ô∏è A logok DataFrame √ºres.")
    else:
        # Csak numerikus oszlopok megtart√°sa
        numeric_cols = [c for c in logs.columns if np.issubdtype(pd.Series(logs[c]).dropna().dtype, np.number)]
        if not numeric_cols:
            print("‚ö†Ô∏è A logokban nincs numerikus oszlop a megjelen√≠t√©shez.")
        else:
            logs = logs[numeric_cols].copy()

            def pick_col(df, candidates):
                for c in candidates:
                    if c in df.columns:
                        return c
                return None

            # Reward (√°tlag) oszlop kiv√°laszt√°sa t√∂bb lehets√©ges kulccsal
            reward_candidates = [
                "rewards/reward_function/mean",  # TRL GRPO gyakori
                "rewards/mean",
                "reward/mean",
                "rewards/reward",
                "reward",
            ]
            reward_col = pick_col(logs, reward_candidates)

            # Reward sz√≥r√°s oszlop (opcion√°lis)
            reward_std_candidates = [
                "rewards/reward_function/std",
                "rewards/std",
                "reward/std",
            ]
            reward_std_col = pick_col(logs, reward_std_candidates)

            # Loss √©s KL oszlopok (opcion√°lis)
            loss_col = pick_col(logs, ["loss", "train/loss"])  
            kl_col = pick_col(logs, ["kl", "train/kl", "kl_divergence"]) 

            # √Åbr√°k l√©trehoz√°sa dinamikusan az el√©rhet≈ë metrik√°k alapj√°n
            fig, axes = plt.subplots(2, 2, figsize=(13, 8))
            axes = axes.flatten()
            ax_idx = 0

            # 1) Reward trend (sim√≠tott) + opcion√°lis sz√≥r√°s s√°v
            if reward_col is not None:
                reward_smooth = logs[reward_col].rolling(window=10, min_periods=1).mean()
                axes[ax_idx].plot(reward_smooth, label=f"Jutalom (√°tlag, {reward_col})", color="royalblue")
                if reward_std_col is not None:
                    lo = reward_smooth - logs[reward_std_col]
                    hi = reward_smooth + logs[reward_std_col]
                    axes[ax_idx].fill_between(logs.index, lo, hi, alpha=0.2, color="royalblue", label="¬± sz√≥r√°s")
                axes[ax_idx].set_title("Jutalom trend")
                axes[ax_idx].set_xlabel("L√©p√©s")
                axes[ax_idx].set_ylabel("Jutalom")
                axes[ax_idx].legend()
                ax_idx += 1
            else:
                print("‚ÑπÔ∏è Nem tal√°lhat√≥ reward √°tlag oszlop a logokban. Kihagyom a jutalom g√∂rb√©t.")

            # 2) Loss g√∂rbe
            if loss_col is not None and ax_idx < len(axes):
                axes[ax_idx].plot(logs[loss_col], label="Vesztes√©g (loss)", color="firebrick")
                axes[ax_idx].set_title("Tan√≠t√°si vesztes√©g")
                axes[ax_idx].set_xlabel("L√©p√©s")
                axes[ax_idx].set_ylabel("Loss")
                axes[ax_idx].legend()
                ax_idx += 1

            # 3) KL g√∂rbe
            if kl_col is not None and ax_idx < len(axes):
                axes[ax_idx].plot(logs[kl_col], label="KL divergencia", color="darkorange")
                axes[ax_idx].set_title("KL divergencia alakul√°sa")
                axes[ax_idx].set_xlabel("L√©p√©s")
                axes[ax_idx].set_ylabel("KL")
                axes[ax_idx].legend()
                ax_idx += 1

            # 4) Reward sz√≥r√°s k√ºl√∂n (ha van √©s maradt hely)
            if reward_std_col is not None and ax_idx < len(axes):
                axes[ax_idx].plot(logs[reward_std_col], label="Jutalom sz√≥r√°s", color="seagreen")
                axes[ax_idx].set_title("Jutalom sz√≥r√°s")
                axes[ax_idx].set_xlabel("L√©p√©s")
                axes[ax_idx].set_ylabel("Sz√≥r√°s")
                axes[ax_idx].legend()
                ax_idx += 1

            # Ha kevesebb metrika √°ll rendelkez√©sre, a marad√©k tengelyeket rejts√ºk el
            for i in range(ax_idx, len(axes)):
                axes[i].axis('off')

            plt.tight_layout()
            plt.show()


‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json f√°jl sem tal√°lhat√≥.


In [None]:
# OPTIONAL: Prompt Quality Check - Sample Completions
# Futtasd TRAINING EL≈êTT, hogy l√°sd, milyen outputokat gener√°l a modell

print("üî¨ PROMPT QUALITY CHECK - Sample Completions Gener√°l√°sa\n")
print("Ez a cella teszteli, hogy a modell milyen outputokat gener√°l az √∫j prompttal.")
print("FONTOS: Ez optional, csak debug c√©lra.\n")

# Egy random query kiv√°laszt√°sa
import random
random.seed(42)
test_slate_idx = random.randint(0, len(slates_data)-1)
test_slate_data = slates_data[test_slate_idx]
test_prompt_text = create_training_prompt(test_slate_data["query_id"], test_slate_data["slate"])

# Gener√°lunk n√©h√°ny completion-t
print(f"Query: {test_slate_data['query_id'][:80]}...")
print(f"Slate size: {len(test_slate_data['slate'])}")
print("-" * 80)

try:
    # Model inference mode
    model.eval()
    
    # Tokenize prompt
    inputs = tokenizer(test_prompt_text, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    print("\nüé≤ 3 Sample Completion Gener√°l√°sa...\n")
    
    for i in range(3):
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=768,  # N√∂velve 100-r√≥l 768-re, hogy elegend≈ë legyen 30 dokumentum rangsorol√°s√°hoz
                do_sample=True,
                temperature=0.8,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,  # Stop sequence hozz√°ad√°sa
            )
        
        completion = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        parsed = parse_model_ranking(completion, len(test_slate_data["slate"]))
        
        print(f"Completion {i+1}:")
        print(f"  Raw output: {completion[:150]}")
        print(f"  Parsed indices: {parsed[:10]}")  # First 10 indices
        print(f"  Valid? {len(parsed) == len(test_slate_data['slate']) and len(set(parsed)) == len(parsed)}")
        print()
    
    # Relevance info
    relevances = [doc.get('relevance', 0) for doc in test_slate_data["slate"]]
    print(f"üìä Ground truth relevance distribution:")
    print(f"  Relevant docs (rel>=1): {sum(1 for r in relevances if r >= 1)}/{len(relevances)}")
    print(f"  Highly relevant (rel=2): {sum(1 for r in relevances if r == 2)}/{len(relevances)}")
    print(f"  Baseline order nDCG@10: {calculate_ndcg(list(range(len(relevances))), relevances, k=10):.4f}")
    
    model.train()
    print("\n‚úÖ Prompt quality check complete!")
    
except Exception as e:
    print(f"‚ùå Error during quality check: {e}")
    print("Ez nem probl√©ma, a training ett≈ël m√©g futhat.")

print("\n" + "="*80)

In [None]:
# Artifactumok ment√©se
print("\nüíæ Artifactumok ment√©se...")
model.save_pretrained_merged(str(OUTPUT_DIR), tokenizer, save_method="lora")
print(f"  ‚úÖ LoRA adapter: {OUTPUT_DIR}")

## üìä Modell √ñsszehasonl√≠t√°s

Az √©rt√©kel√©s √©s a baseline vs GRPO √∂sszehasonl√≠t√°s a **`notebooks/model_comparison.ipynb`** notebookban t√∂rt√©nik.

Ez biztos√≠tja, hogy a baseline √©s GRPO modell **ugyanazon eval halmazon** legyen ki√©rt√©kelve, √≠gy a metrik√°k √∂sszevethet≈ëk.


In [None]:
# NOTE: Az √©rt√©kel√©s a notebooks/model_comparison.ipynb notebookban t√∂rt√©nik.
# Ez biztos√≠tja, hogy a baseline √©s GRPO modell ugyanazon eval halmazon legyen ki√©rt√©kelve.
