# CourtRankRL GRPO Training - Chunk-Based, A100 Memory Efficient

## Agents.md Specifik√°ci√≥ (Chunk-Based)

Ez a notebook a CourtRankRL GRPO alap√∫ reranking modell tan√≠t√°s√°t v√©gzi el **A100 GPU-n** (80GB VRAM, memory-constrained).

### F≈ëbb jellemz≈ëk (Chunk-Based megold√°s):
- **Model**: Qwen/Qwen3-4B-Instruct-2507 (4-bit) + QLoRA (rank=32, alpha=64)
- **Training**: TRL GRPOTrainer GRPO algoritmussal
  - Loss: "dapo" (eliminates length bias)
  - Reward scaling: "batch" (robust - PPO Lite)
  - Importance sampling: "sequence" (stable - GSPO)
- **Dataset**: 98 query (teljes), 20 chunk/slate, **TELJES chunk sz√∂veg** (~500-800 char)
- **Slate strat√©gia**: Chunk-level retrieval (nem doc aggreg√°ci√≥!) ‚Üí legrelev√°nsabb chunk-ok
- **Baseline**: Slate sorrendje = fusion ranking [0,1,2,...] (BM25+FAISS fusion szerint)
- **Hardware**: Batch size 2, grad accumulation 6, 6 generations/prompt (MEMORY EFFICIENT)
- **Training time**: ~25-35 perc (600 steps, vLLM-mel, cs√∂kkentett gen miatt)

### Mi√©rt chunk-based?
- ‚úÖ **Relev√°ns kontextus**: BM25+FAISS m√°r kiv√°lasztotta a legrelev√°nsabb chunk-okat
- ‚úÖ **Teljes sz√∂veg**: A model l√°tja, MI√âRT relev√°ns egy dokumentum
- ‚úÖ **Jobb tanul√°s**: A model megtanulja √©rt√©kelni a val√≥di tartalmat, nem csak metaadatokat


In [None]:
# K√∂rnyezet setup √©s csomagok telep√≠t√©se
# KRITIKUS: Kompatibilis verzi√≥k telep√≠t√©se (Unsloth + TRL + dependencies)

# 1. Core dependencies EL≈êSZ√ñR (stabil verzi√≥k)
%pip install -q --upgrade pip
%pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
%pip install -q transformers datasets huggingface_hub
%pip install -q numpy scipy scikit-learn

# 2. Accelerate - FIX√ÅLT VERZI√ì az is_fsdp2 kompatibilit√°s miatt
# CRITICAL: accelerate>=0.30.0 sz√ºks√©ges az is_fsdp2 attrib√∫tumhoz, de kompatibilisnek kell lennie Unsloth-tal
%pip install -q "accelerate>=0.30.0" --upgrade

# 3. LoRA √©s kvantiz√°ci√≥
%pip install -q peft bitsandbytes

# 4. TRL (legfrissebb stabil verzi√≥ - GRPO t√°mogat√°ssal)
%pip install -q --upgrade trl

# 5. Unsloth telep√≠t√©se UTOLJ√ÅRA (hogy a TRL-hez igazodjon)
%pip install -q --upgrade unsloth diffusers

# 6. vLLM (opcion√°lis, de aj√°nlott fast_inference-hez)
%pip install -q vllm

# 7. Egy√©b f√ºgg≈ës√©gek
%pip install -q --upgrade typing_extensions ranx

print("‚úÖ Csomagok telep√≠tve (kompatibilis verzi√≥k)")
print("üì¶ Telep√≠tett verzi√≥k ellen≈ërz√©se:")
import unsloth, torch, transformers, trl, accelerate
print(f"  - PyTorch: {torch.__version__}")
print(f"  - Transformers: {transformers.__version__}")
print(f"  - Accelerate: {accelerate.__version__} (CRITICAL: >=0.30.0 az is_fsdp2 t√°mogat√°shoz)")
print(f"  - TRL: {trl.__version__}")
print(f"  - Unsloth: {unsloth.__version__}")

# FONTOS FIGYELMEZTET√âS a verzi√≥r√≥l
if tuple(map(int, accelerate.__version__.split('.')[:2])) < (0, 30):
    print(f"‚ö†Ô∏è  FIGYELEM: Accelerate verzi√≥ ({accelerate.__version__}) < 0.30.0")
    print(f"   A monkey patch fog futni a training sor√°n az is_fsdp2 kompatibilit√°s√©rt")
    print(f"   Aj√°nlott: pip install accelerate>=0.30.0 --upgrade && RESTART KERNEL")


In [None]:
# Importok
import os, warnings
os.environ["TORCHDYNAMO_DISABLE"] = "1"      # hivatalos flag a dynamo kikapcsol√°s√°ra
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # √°rtalmatlan, de seg√≠t fallbackben
# (opcion√°lis, ha Unsloth figyeli)
os.environ["UNSLOTH_COMPILE"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message=".*An output with one or more elements was resized.*")

import json
import sys
import re
import random
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.grpo_config import GRPOConfig
from huggingface_hub import login
from sklearn.metrics import ndcg_score
from scipy.stats import entropy as scipy_entropy
from sklearn.model_selection import train_test_split
from ranx import Qrels, Run, evaluate

print("‚úÖ Importok bet√∂ltve (Unsloth + TRL + sklearn + scipy + ranx)")
print(f"PyTorch verzi√≥: {torch.__version__}")
print(f"CUDA el√©rhet≈ë: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU mem√≥ria: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Importok
import os, warnings
os.environ["TORCHDYNAMO_DISABLE"] = "1"      # hivatalos flag a dynamo kikapcsol√°s√°ra
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # √°rtalmatlan, de seg√≠t fallbackben
# (opcion√°lis, ha Unsloth figyeli)
os.environ["UNSLOTH_COMPILE"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore", message=".*An output with one or more elements was resized.*")

import json
import sys
import re
import random
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from unsloth import FastLanguageModel
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.grpo_config import GRPOConfig
from huggingface_hub import login
from sklearn.metrics import ndcg_score
from scipy.stats import entropy as scipy_entropy
from sklearn.model_selection import train_test_split
from ranx import Qrels, Run, evaluate

print("‚úÖ Importok bet√∂ltve (Unsloth + TRL + sklearn + scipy + ranx)")
print(f"PyTorch verzi√≥: {torch.__version__}")
print(f"CUDA el√©rhet≈ë: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU mem√≥ria: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# Konfigur√°ci√≥ - MEMORY EFFICIENT (A100 80GB) + CURRICULUM LEARNING
MODEL_NAME = "unsloth/Qwen3-4B-Instruct-2507"
SLATE_SIZE = 20

# Curriculum Learning konfigur√°ci√≥
CURRICULUM_ENABLED = True  # Enged√©lyezve - LEGNAGYOBB JAVUL√ÅSI POTENCI√ÅL
CURRICULUM_STRATEGY = "filtered_dataset"  # "filtered_dataset" vagy "weighted_sampling"
# weighted_sampling: k√∂nnyebb p√©ld√°kat gyakrabban v√°lasztjuk ki
# Ez biztos√≠tja, hogy a modell fokozatosan megtanulja j√≥l rangsorolni
GROUP_SIZE = 6
LORA_RANK = 32  # 64‚Üí32: cs√∂kkentett rank a mem√≥ria miatt
LORA_ALPHA = LORA_RANK  # alpha k√∂veti a rank-t
LORA_DROPOUT = 0.05
MAX_SEQ_LENGTH = 9216  # 8192‚Üí12288: n√∂velt √©rt√©k a hosszabb promptokhoz (9600 token + marg√≥)
GPU_MEMORY_UTILIZATION = 0.6  # 0.88‚Üí0.5 cs√∂kkentve a rendelkez√©sre √°ll√≥ szabad mem√≥ria miatt
USE_GRADIENT_CHECKPOINTING = "unsloth"

LEARNING_RATE = 5e-5
MAX_STEPS = 300
SAVE_STEPS = 300  # Save more frequently for checkpoint recovery
EVAL_STEPS = 50
LOGGING_STEPS = 10
WARMUP_STEPS = 50
GRADIENT_ACCUMULATION_STEPS = 3  # 3‚Üí6: nagyobb accumulation, kisebb batch
NUM_GENERATIONS = 6  # 14‚Üí6: cs√∂kkentett gener√°ci√≥k
PER_DEVICE_BATCH_SIZE = 1  # 4‚Üí2: FEL√âRE cs√∂kkentve
# KRITIKUS: generation_batch_size-nak oszthat√≥nak KELL lennie:
#   1. NUM_GENERATIONS-szel (6)
#   2. PER_DEVICE_BATCH_SIZE-zal (2)
# Legkisebb k√∂z√∂s t√∂bbsz√∂r√∂s (LCM): 6
GENERATION_BATCH_SIZE = 6  # 6√∑6=1, 6√∑2=3 ‚úì
OPTIMIZER_NAME = "paged_adamw_8bit"
LR_SCHEDULER_TYPE = "linear"  # Enhanced: cosine ‚Üí linear (more stable final convergence)

NDCG_K = 10
ENTROPY_BONUS = 0.01
# Reward clipping range (updated in reward function)
REWARD_CLIP_MIN = -1.0  # Will be updated in enhanced reward function
REWARD_CLIP_MAX = 2.0   # Enhanced: wider range for positive improvements
TRAIN_SPLIT = 0.8
SEED = 42

BASE_PATH = Path(os.getenv("WORKSPACE_PATH", "/workspace"))

BASELINE_METRICS_PATH = BASE_PATH / "baseline_query_metrics.csv"
SLATE_FILE = BASE_PATH / "training_slates.jsonl"
OUTPUT_DIR = BASE_PATH / "artifacts" / "grpo_policy"
METRICS_FILE = OUTPUT_DIR / "metrics.json"

print("üìã A100 80GB - Memory Efficient Konfigur√°ci√≥:")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA Rank: {LORA_RANK} (cs√∂kkentve mem√≥ria miatt)")
print(f"  GPU Memory Utilization: {GPU_MEMORY_UTILIZATION} (0.5 = 50%)")
print(f"  Batch: {PER_DEVICE_BATCH_SIZE} √ó {GRADIENT_ACCUMULATION_STEPS} = {PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS} (effective)")
print(f"  Generations: {NUM_GENERATIONS}, Generation batch: {GENERATION_BATCH_SIZE}")
print(f"  Steps per generation: {GENERATION_BATCH_SIZE // PER_DEVICE_BATCH_SIZE}")
print(f"  Steps: {MAX_STEPS}")
print(f"  ‚ö†Ô∏è  Mem√≥ria optimaliz√°lt be√°ll√≠t√°sok akt√≠vak!")

if not SLATE_FILE.exists():
    raise FileNotFoundError(f"‚ùå Slate f√°jl nem tal√°lhat√≥: {SLATE_FILE}")

if not BASELINE_METRICS_PATH.exists():
    raise FileNotFoundError(f"‚ùå Baseline metrik√°k f√°jl nem tal√°lhat√≥: {BASELINE_METRICS_PATH}")

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# Seg√©df√ºggv√©nyek
def calculate_ndcg(ranked_indices: List[int], true_relevance: List[float], k: int = 10) -> float:
    if not true_relevance or not ranked_indices or max(true_relevance) == 0:
        return 0.0
    y_true = np.array(true_relevance)
    max_score = len(ranked_indices)
    y_score = np.zeros_like(y_true, dtype=float)
    for i, idx in enumerate(ranked_indices[:k]):
        if idx < len(y_true):
            y_score[idx] = max_score - i
    if np.sum(y_score) == 0:
        return 0.0
    try:
        return float(ndcg_score(y_true.reshape(1, -1), y_score.reshape(1, -1), k=k))
    except:
        return 0.0

def parse_model_ranking(completion: str, slate_size: int = SLATE_SIZE) -> List[int]:
    """
    Megenged≈ëbb parsing - kev√©sb√© szigor√∫ form√°tum k√∂vetelm√©ny.
    Ha nincs el√©g valid index, kieg√©sz√≠tj√ºk a hi√°nyz√≥kat a baseline order szerint.
    """
    try:
        # Pr√≥b√°lunk sz√°mokat kinyerni k√ºl√∂nb√∂z≈ë m√≥don
        numbers = []
        
        # 1. Vessz≈ëvel elv√°lasztott sz√°mok
        parts = completion.split(",")
        for part in parts:
            part = part.strip()
            # K√©rj√ºk el minden sz√°mot (ak√°r t√∂bb sz√°m egy cell√°ban)
            for word in part.split():
                word = word.strip().rstrip(".,;")
                if word.isdigit():
                    num = int(word)
                    if 0 <= num < slate_size:
                        numbers.append(num)
        
        # 2. Ha nincs el√©g, pr√≥b√°lunk sz√≥k√∂zzel elv√°lasztott form√°tumot
        if len(numbers) < slate_size // 3:
            for word in completion.split():
                word = word.strip().rstrip(".,;")
                if word.isdigit():
                    num = int(word)
                    if 0 <= num < slate_size and num not in numbers:
                        numbers.append(num)
        
        # 3. Valid sz√°mok gy≈±jt√©se
        valid_numbers = []
        seen = set()
        for n in numbers:
            if n not in seen:
                valid_numbers.append(n)
                seen.add(n)
        
        # 4. Ha van el√©g valid sz√°m (legal√°bb 1/3-a a slate-nek), haszn√°ljuk
        if len(valid_numbers) >= max(1, slate_size // 3):
            # Kieg√©sz√≠tj√ºk a hi√°nyz√≥ indexeket a baseline order szerint
            missing_indices = [i for i in range(slate_size) if i not in valid_numbers]
            # A hi√°nyz√≥kat a v√©g√©re tesz√ºk (nem zavarjuk meg a modell rankingj√©t)
            result = valid_numbers + missing_indices
            return result[:slate_size]
        
    except Exception as e:
        pass
    
    # Fallback: baseline order (nem random shuffle - √≠gy jobb mint teljesen random)
    return list(range(slate_size))

def create_training_prompt(query_id: str, slate: List[Dict]) -> str:
    """
    Enhanced prompt generation for GRPO training with explicit relevance labels and few-shot examples.
    The prompt uses explicit instructions in ENGLISH to help the model learn proper relevance-based ranking.
    Only the data content (queries, document texts, metadata) remains in HUNGARIAN as it comes from the dataset.
    
    Improvements:
    1. Explicit relevance labels (0/1/2) for each document
    2. Few-shot example ranking output
    3. More detailed relevance criteria
    4. Clearer output format specification
    """
    prompt = f'''# Legal Document Relevance Assessment and Ranking

## TASK
You are helping to rank Hungarian court decision documents by their relevance to a search query.

## SEARCH QUERY
"{query_id}"

## WHAT TO CONSIDER

When ranking documents, think about:
- How well the document topic matches the query
- Whether the content directly addresses what the query is asking
- The legal domain and court type (may or may not be relevant)
- The relevance labels (0=not relevant, 1=relevant, 2=highly relevant) - these can guide you
- BM25 and FAISS scores show some similarity signals, but content is most important

## RANKING INSTRUCTIONS

You need to rank the {len(slate)} documents by their relevance to the search query.

'''
    
    separator_line = "\u2500" * 20
    
    # Enhanced document presentation with EXPLICIT RELEVANCE LABELS
    for idx, doc in enumerate(slate):
        chunk_text = doc.get('text', '')[:800]
        bm25 = doc.get('bm25_score', 0)
        faiss = doc.get('faiss_score', 0)
        relevance = doc.get('relevance', 0)
        court = doc.get('court', 'N/A')
        domain = doc.get('domain', 'N/A')
        year = doc.get('year', 'N/A')
        
        # Relevance label mapping
        rel_label = "NOT RELEVANT" if relevance == 0 else ("RELEVANT" if relevance == 1 else "HIGHLY RELEVANT")
        
        prompt += f'''{separator_line}
DOCUMENT [{idx}]

Relevance: {rel_label} (Grade {relevance}/2)
ID: {doc.get('doc_id', 'N/A')} | Court: {court} | Domain: {domain} | Year: {year}
Scores: BM25={bm25:.2f}, FAISS={faiss:.3f}

Content:
{chunk_text}

'''
    
    # Strict ranking instructions to force tangible outputs
    prompt += f'''{separator_line}


### OUTPUT FORMAT (STRICT):
- Respond with EXACTLY one line in this format ‚Üí RANKING: <comma-separated indices>
- Use document indices 0 to {len(slate)-1}, each exactly once, without extra spaces
- Do NOT include headings, bullet points, Markdown, or explanations before or after the line

Example:
RANKING: 3,0,5,1,4,2,8,6,7,9

RANKING:
'''
    
    return prompt

print("‚úÖ Seg√©df√ºggv√©nyek defini√°lva")
print("  üìù Enhanced learning-to-rank prompt template:")
print("     - English instructions with Hungarian data")
print("     - Explicit ranking criteria")
print("     - Step-by-step guidance")
print("     - Clear do's and don'ts")
print("     - Structured document presentation")



In [None]:
# Slate adatok bet√∂lt√©se
print(f"üìÇ Slate adatok bet√∂lt√©se: {SLATE_FILE}")
df_slates = pd.read_json(SLATE_FILE, lines=True, encoding='utf-8')
slates_data = df_slates.to_dict('records')
for slate in slates_data:
    slate["slate"] = slate["slate"][:20]
print(f"‚úÖ Bet√∂ltve: {len(slates_data)} slate")

sample = slates_data[0]
print(f"\nüìã Minta slate strukt√∫ra:")
print(f"  Query ID: {sample['query_id'][:50]}...")
print(f"  Slate elemek: {len(sample['slate'])}")


## üéØ Multi-Metrika Reward Setup

A standard nDCG@10 optimaliz√°ci√≥ helyett **konvex kombin√°ci√≥t** alkalmazunk:
- **nDCG@10** (50%): Rangsor min≈ës√©g
- **MRR@5** (35%): Els≈ë relev√°ns tal√°lat poz√≠ci√≥ja (√ºzleti priorit√°s!)
- **Recall@20** (15%): Completeness

Ez biztos√≠tja, hogy a model a **val√≥di √ºzleti √©rt√©ket** maximaliz√°lja.

In [None]:
# Multi-Metrika Reward Helper Functions (inline defini√°lva - nincs k√ºls≈ë import)

from collections import Counter
from typing import Optional, Tuple

def compute_mrr(relevances: List[int], k: Optional[int] = None) -> float:
    """
    Mean Reciprocal Rank sz√°m√≠t√°sa.
    
    MRR@k = 1 / rank_first_relevant (ha van relev√°ns a top-k-ban)
           = 0 (ha nincs relev√°ns a top-k-ban)
    """
    if k is not None:
        relevances = relevances[:k]
    
    for i, rel in enumerate(relevances, start=1):
        if rel > 0:
            return 1.0 / i
    
    return 0.0


def compute_recall(relevances: List[int], k: int, total_relevant: Optional[int] = None) -> float:
    """
    Recall@k sz√°m√≠t√°sa.
    
    Recall@k = |{relevant docs in top-k}| / |{all relevant docs}|
    """
    if not relevances:
        return 0.0
    
    # Relevantes doc-ok sz√°ma top-k-ban
    relevant_in_topk = sum(1 for rel in relevances[:k] if rel > 0)
    
    # √ñsszes relev√°ns (ha nincs megadva, akkor az eg√©sz list√°b√≥l sz√°moljuk)
    if total_relevant is None:
        total_relevant = sum(1 for rel in relevances if rel > 0)
    
    if total_relevant == 0:
        return 0.0
    
    return relevant_in_topk / total_relevant


def compute_ndcg_metric(relevances: List[int], k: int) -> float:
    """
    nDCG@k sz√°m√≠t√°sa sklearn haszn√°lat√°val.
    
    Args:
        relevances: Relevancia √©rt√©kek ranking sorrendben [rel_1, rel_2, ...]
        k: Top-k truncation
    
    Returns:
        nDCG@k √©rt√©k [0, 1]
    """
    if not relevances or k <= 0:
        return 0.0
    
    # sklearn ndcg_score: [true_relevances], [predicted_scores]
    # predicted_scores: magasabb score = jobb ranking (ford√≠tott sorrend)
    relevances_truncated = relevances[:k]
    scores = list(range(len(relevances_truncated), 0, -1))  # [n, n-1, ..., 1]
    
    try:
        ndcg = ndcg_score(
            y_true=[relevances_truncated],
            y_score=[scores],
            k=k
        )
        return float(ndcg)
    except Exception:
        # Edge case: nincs relev√°ns dokumentum
        return 0.0


def compute_multi_metric_reward(
    predicted_indices: List[int],
    slate: List[Dict],
    baseline_metrics: Dict[str, float],
    weights: Optional[Dict[str, float]] = None,
    clip_range: Tuple[float, float] = (-1.0, 1.0),
    use_sigmoid: bool = True,
    mrr_tiebreak_bonus: float = 0.02,
    diversity_penalty_weight: float = 0.0
) -> Tuple[float, Dict[str, float]]:
    """
    Multi-metrika reward sz√°m√≠t√°sa sigmoid-alap√∫ stabiliz√°ci√≥val.
    
    Jav√≠t√°sok:
    - Sigmoid transzform√°ci√≥ a delta-kra ‚Üí stabilabb gradiens
    - MRR tie-break bonus: ha nDCG/Recall stagn√°l de MRR javul
    - Diverzit√°s b√ºntet√©s: ugyanazon court/domain t√∫ls√∫ly a top-10-ben
    
    Args:
        predicted_indices: Model √°ltal predikt√°lt ranking indexek
        slate: Slate adatok [{"relevance": int, "court": str, "domain": str, ...}, ...]
        baseline_metrics: Baseline metrik√°k {"ndcg@10": float, "mrr@5": float, ...}
        weights: Metrika s√∫lyok {"ndcg10": 0.6, "mrr5": 0.3, "recall20": 0.1}
        clip_range: Reward clipping tartom√°ny (min, max)
        use_sigmoid: Ha True, sigmoid(delta) transzform√°ci√≥
        mrr_tiebreak_bonus: Bonus ha csak MRR javul (+Œµ)
        diversity_penalty_weight: Diverzit√°s b√ºntet√©s s√∫lya (0 = kikapcsolva)
    
    Returns:
        (total_reward, components_dict)
    """
    # Default weights (updated: 0.6/0.3/0.1)
    if weights is None:
        weights = {"ndcg10": 0.60, "mrr5": 0.30, "recall20": 0.10}
    
    # Relevancia array a predicted ranking szerint
    relevances = [slate[i]["relevance"] for i in predicted_indices]
    
    # Policy metrik√°k
    policy_ndcg10 = compute_ndcg_metric(relevances, k=10)
    policy_mrr5 = compute_mrr(relevances, k=5)
    
    # Recall@20: total_relevant a teljes slate-b≈ël (fix√°lt candidate pool!)
    total_relevant = sum(1 for doc in slate if doc.get("relevance", 0) > 0)
    policy_recall20 = compute_recall(relevances, k=20, total_relevant=total_relevant)
    
    # Baseline metrik√°k
    baseline_ndcg10 = baseline_metrics.get("ndcg@10", 0.0)
    baseline_mrr5 = baseline_metrics.get("mrr@5", 0.0)
    baseline_recall20 = baseline_metrics.get("recall@20", 0.0)
    
    # Delta sz√°m√≠t√°s
    delta_ndcg10 = policy_ndcg10 - baseline_ndcg10
    delta_mrr5 = policy_mrr5 - baseline_mrr5
    delta_recall20 = policy_recall20 - baseline_recall20
    
    # Sigmoid transzform√°ci√≥ (stabilabb reward signal)
    if use_sigmoid:
        # sigmoid(x) = 1 / (1 + exp(-k*x)), k=5 ‚Üí √©rz√©kenys√©g
        def sigmoid_transform(delta, k=5.0):
            return 2.0 / (1.0 + np.exp(-k * delta)) - 1.0  # range: [-1, 1]
        
        component_ndcg = weights["ndcg10"] * sigmoid_transform(delta_ndcg10)
        component_mrr = weights["mrr5"] * sigmoid_transform(delta_mrr5)
        component_recall = weights["recall20"] * sigmoid_transform(delta_recall20)
    else:
        # Line√°ris (original)
        component_ndcg = weights["ndcg10"] * delta_ndcg10
        component_mrr = weights["mrr5"] * delta_mrr5
        component_recall = weights["recall20"] * delta_recall20
    
    # Total reward
    reward_raw = component_ndcg + component_mrr + component_recall
    
    # ====== TIE-BREAK BONUS: MRR javul√°s ======
    # Ha nDCG/Recall nem v√°ltozik jelent≈ësen, de MRR javul ‚Üí felhaszn√°l√≥i √©lm√©ny++
    if abs(delta_ndcg10) < 0.01 and abs(delta_recall20) < 0.01 and delta_mrr5 > 0.05:
        reward_raw += mrr_tiebreak_bonus
    
    # ====== DIVERZIT√ÅS B√úNTET√âS ======
    # Ha a top-10 t√∫l homog√©n (ugyanaz a court/domain domin√°l) ‚Üí -Œµ
    if diversity_penalty_weight > 0:
        top10_indices = predicted_indices[:10]
        top10_courts = [slate[i].get("court", "N/A") for i in top10_indices if i < len(slate)]
        top10_domains = [slate[i].get("domain", "N/A") for i in top10_indices if i < len(slate)]
        
        # Shannon entropy (magas = diverzit√°s, alacsony = homog√©n)
        def shannon_entropy(items):
            counts = Counter(items)
            total = len(items)
            if total == 0:
                return 0.0
            probs = [c/total for c in counts.values()]
            return -sum(p * np.log(p + 1e-9) for p in probs)
        
        court_entropy = shannon_entropy(top10_courts)
        domain_entropy = shannon_entropy(top10_domains)
        
        # Normaliz√°l√°s: max entropy = log(10) ‚âà 2.3
        max_entropy = np.log(10)
        diversity_score = (court_entropy + domain_entropy) / (2 * max_entropy)
        
        # B√ºntet√©s ha alacsony diverzit√°s (< 0.5 ‚Üí homog√©n)
        if diversity_score < 0.5:
            penalty = diversity_penalty_weight * (0.5 - diversity_score)
            reward_raw -= penalty
    
    # Clipping
    reward_clipped = np.clip(reward_raw, clip_range[0], clip_range[1])
    
    # Komponensek dictionary (r√©szletes tracking)
    components = {
        "reward_total": float(reward_clipped),
        "reward_raw": float(reward_raw),
        "component_ndcg10": float(component_ndcg),
        "component_mrr5": float(component_mrr),
        "component_recall20": float(component_recall),
        "delta_ndcg10": float(delta_ndcg10),
        "delta_mrr5": float(delta_mrr5),
        "delta_recall20": float(delta_recall20),
        "policy_ndcg10": float(policy_ndcg10),
        "policy_mrr5": float(policy_mrr5),
        "policy_recall20": float(policy_recall20),
        "baseline_ndcg10": float(baseline_ndcg10),
        "baseline_mrr5": float(baseline_mrr5),
        "baseline_recall20": float(baseline_recall20),
    }
    
    return float(reward_clipped), components

print("‚úÖ Multi-Metrika Reward Helper Functions defini√°lva (inline)")
print(f"   üéØ Multi-metric reward: sigmoid-based, 0.6/0.3/0.1 weights")
print(f"   üì¶ Nincs k√ºls≈ë import - minden inline defini√°lva")

In [None]:
# Baseline metrik√°k bet√∂lt√©se (pre-computed from baseline_evaluation.ipynb)
BASELINE_METRICS_PATH = BASE_PATH / "baseline_query_metrics.csv"

if not BASELINE_METRICS_PATH.exists():
    print(f"‚ö†Ô∏è  FIGYELEM: Baseline metrics f√°jl nem tal√°lhat√≥: {BASELINE_METRICS_PATH}")
    print(f"   El≈ëbb futtasd le a baseline_evaluation.ipynb notebookot!")
    print(f"   Fallback: Baseline metrik√°k inicializ√°l√°sa 0-ra (NEM AJ√ÅNLOTT)")
    baseline_metrics_dict = {}
else:
    baseline_df = pd.read_csv(BASELINE_METRICS_PATH, encoding='utf-8')
    baseline_metrics_dict = {
        row["query_id"]: row.to_dict()
        for _, row in baseline_df.iterrows()
    }
    print(f"‚úÖ Baseline metrik√°k bet√∂ltve: {len(baseline_metrics_dict)} query")
    
    # Statisztik√°k
    if baseline_metrics_dict:
        sample_query = list(baseline_metrics_dict.keys())[0]
        sample_metrics = baseline_metrics_dict[sample_query]
        print(f"\nüìä El√©rhet≈ë metrik√°k (p√©lda query):")
        for key in sorted(sample_metrics.keys()):
            if key != "query_id" and isinstance(sample_metrics[key], (int, float)):
                print(f"  ‚Ä¢ {key}: {sample_metrics[key]:.4f}")
        
        # Aggreg√°lt baseline teljes√≠tm√©ny
        avg_ndcg10 = np.mean([m.get("ndcg@10", 0.0) for m in baseline_metrics_dict.values()])
        avg_mrr5 = np.mean([m.get("mrr@5", 0.0) for m in baseline_metrics_dict.values()])
        avg_recall20 = np.mean([m.get("recall@20", 0.0) for m in baseline_metrics_dict.values()])
        
        print(f"\nüéØ Baseline teljes√≠tm√©ny (√°tlag):")
        print(f"  ‚Ä¢ nDCG@10:   {avg_ndcg10:.4f}")
        print(f"  ‚Ä¢ MRR@5:     {avg_mrr5:.4f}")
        print(f"  ‚Ä¢ Recall@20: {avg_recall20:.4f}")

In [None]:
# Eval split export√°l√°sa √∂sszehasonl√≠t√≥ √©rt√©kel√©shez
# CRITICAL: Az eval split query_id-k export√°l√°sa determinisztikusan reproduk√°lhat√≥ form√°tumban
# Ez biztos√≠tja, hogy a baseline √©s GRPO modell ugyanazon eval halmazon legyen ki√©rt√©kelve

print("\n" + "="*80)
print("üì§ Eval Split Export√°l√°sa")
print("="*80)

# Query ID-k kinyer√©se az eval_dataset-b≈ël
eval_query_ids = []

for example in eval_dataset:
    prompt = example["prompt"]
    # Ugyanaz a regex parsing logika, mint a reward function-ben
    match = re.search(r'(?:SEARCH QUERY|Query)[:\s"]*([^"\n]+)', prompt)
    query_id = None
    if match:
        query_id = match.group(1)
    else:
        match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
        if match2:
            query_id = match2.group(1)
    
    if query_id and query_id not in eval_query_ids:
        eval_query_ids.append(query_id)

# Rendez√©s a determinisztikus kimenet√©rt
eval_query_ids.sort()

print(f"\nüìä Eval split statisztik√°k:")
print(f"  Eval query-k sz√°ma: {len(eval_query_ids)}")
print(f"  Train query-k sz√°ma: {len(train_dataset)}")
print(f"  Split ar√°ny: {len(eval_query_ids) / (len(eval_query_ids) + len(train_dataset)):.1%}")

# Eval split export JSON form√°tumban
eval_split_data = {
    "query_ids": eval_query_ids,
    "num_queries": len(eval_query_ids),
    "split_params": {
        "train_split": TRAIN_SPLIT,
        "seed": SEED,
        "shuffle": True
    },
    "created_at": pd.Timestamp.now().isoformat()
}

eval_split_path = OUTPUT_DIR / "eval_split.json"
with open(eval_split_path, 'w', encoding='utf-8') as f:
    json.dump(eval_split_data, f, ensure_ascii=False, indent=2)

print(f"\n‚úÖ Eval split export√°lva: {eval_split_path}")
print(f"   Ez a f√°jl haszn√°lva lesz a baseline √©s GRPO √∂sszehasonl√≠t√≥ notebookban")
print(f"   ugyanazon eval halmazon val√≥ ki√©rt√©kel√©shez.")

# Opcion√°lis: CSV export is (egyszer≈±bb bet√∂lt√©shez)
eval_split_csv = OUTPUT_DIR / "eval_split.csv"
eval_df = pd.DataFrame({"query_id": eval_query_ids})
eval_df.to_csv(eval_split_csv, index=False, encoding='utf-8')
print(f"‚úÖ Eval split CSV: {eval_split_csv}")

print("\n" + "="*80)


In [None]:
# Model bet√∂lt√©se - Memory Efficient
print(f"üîÑ Model bet√∂lt√©se Unsloth-tal: {MODEL_NAME}")
print(f"  ‚öôÔ∏è  GPU Memory Utilization: {GPU_MEMORY_UTILIZATION}")
print(f"  ‚öôÔ∏è  Max LoRA Rank: {LORA_RANK}")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=MAX_SEQ_LENGTH,
    load_in_4bit=True,
    dtype=None,
    fast_inference=True,  # CRITICAL: Enable vLLM for GRPO rollouts (2-3x faster inference)
    max_lora_rank=LORA_RANK,
    gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    max_model_len=MAX_SEQ_LENGTH - 1,  # CRITICAL: vLLM max_model_len must match max_seq_length for long prompts
)

model = FastLanguageModel.get_peft_model(
    model,
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    use_gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,
    use_rslora=True,
    random_state=SEED,
)

# Tokenizer padding fix (tensor shape stability)
# CRITICAL: Csak padding_side √°ll√≠t√°sa, NE √°ll√≠tsuk √°t a pad_token-et, mert a tokenizer √∫jra√°ll√≠t√°sa
# tenzorm√©ret elt√©r√©st okozhat a GRPO trainer vLLM bels≈ë logik√°j√°val
tokenizer.padding_side = "right"
# Note: pad_token be√°ll√≠t√°sa kihagyva - a GRPO trainer √©s vLLM automatikusan kezeli
# Az explicit pad_token √°t√°ll√≠t√°s √ºtk√∂zhet a modell config-tal √©s tenzor m√©ret probl√©m√°kat okozhat

print("‚úÖ Model √©s tokenizer bet√∂ltve (Unsloth + vLLM + RSLoRA)")
print(f"  ‚úì fast_inference=True: vLLM enged√©lyezve GRPO rollouts-hoz (2-3x gyorsabb)")
print(f"  ‚úì Padding side: {tokenizer.padding_side}")
print(f"  ‚úì Pad token ID: {tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') and tokenizer.pad_token_id else 'N/A (auto-managed)'}")
print(f"  ‚úì Note: GRPO trainer automatikusan kezeli a tokenizer be√°ll√≠t√°sokat vLLM-mel")

In [None]:
# GPU mem√≥ria monitoring model bet√∂lt√©s ut√°n
if torch.cuda.is_available():
    allocated = torch.cuda.memory_allocated(0) / 1e9
    reserved = torch.cuda.memory_reserved(0) / 1e9
    total = torch.cuda.get_device_properties(0).total_memory / 1e9
    free = total - allocated
    
    print(f"\nüìä GPU Mem√≥ria Model Bet√∂lt√©s Ut√°n:")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved: {reserved:.2f} GB")
    print(f"  Free: {free:.2f} GB")
    print(f"  Free %: {(free/total)*100:.1f}%")
    
    if free < 10:
        print(f"  ‚ö†Ô∏è  FIGYELEM: Kev√©s szabad mem√≥ria ({free:.2f} GB)!")
        print(f"  üí° Opci√≥: Tov√°bbi cs√∂kkent√©s lehet sz√ºks√©ges")


In [None]:
# ====== MULTI-METRIKA REWARD FUNCTION (Sigmoid-Based) ======
# Konvex kombin√°ci√≥ja h√°rom f≈ë metrik√°nak sigmoid transzform√°ci√≥val:
# - nDCG@10 (60%): Rangsor min≈ës√©g (n√∂velt s√∫ly)
# - MRR@5 (30%): Els≈ë relev√°ns tal√°lat (gyakorlati haszn√°lhat√≥s√°g)
# - Recall@20 (10%): Completeness
#
# Jav√≠t√°sok:
# - Sigmoid(delta) ‚Üí stabilabb gradiens, nem satur√°l√≥dik
# - MRR tie-break bonus: ha nDCG/Recall stagn√°l, de MRR javul ‚Üí +0.02
# - Diverzit√°s b√ºntet√©s: homog√©n top-10 (court/domain) ‚Üí -Œµ

# Reward konfigur√°ci√≥ (config.py alapj√°n)
REWARD_WEIGHTS = {
    "ndcg10": 0.60,    # 50% ‚Üí 60%
    "mrr5": 0.30,      # 35% ‚Üí 30%
    "recall20": 0.10,  # 15% ‚Üí 10%
}
REWARD_CLIP_RANGE = (-1.0, 1.0)
REWARD_USE_SIGMOID = True
REWARD_MRR_TIEBREAK_BONUS = 0.02
REWARD_DIVERSITY_PENALTY = 0.05  # Diverzit√°s b√ºntet√©s s√∫lya

# Komponens tracking glob√°lis v√°ltoz√≥ (logging c√©lokra)
reward_components_log = []

def multi_metric_reward_function(completions, prompts, **kwargs):
    """
    Multi-metrika GRPO reward f√ºggv√©ny sigmoid transzform√°ci√≥val.
    
    A reward h√°rom metrika s√∫lyozott kombin√°ci√≥ja:
    1. nDCG@10 (60%): Rangsor min≈ës√©g
    2. MRR@5 (30%): Els≈ë relev√°ns tal√°lat poz√≠ci√≥ja (UX priorit√°s!)
    3. Recall@20 (10%): Relev√°ns tal√°latok lefedetts√©ge
    
    Speci√°lis funkci√≥k:
    - Sigmoid(delta): stabil gradiens, nem satur√°l√≥dik
    - MRR tie-break: +bonus ha csak MRR javul (felhaszn√°l√≥i √©lm√©ny++)
    - Diverzit√°s check: homog√©n top-10 ‚Üí penalty
    
    Returns:
        List[float]: Reward √©rt√©kek minden completion-h√∂z
    """
    global reward_components_log
    rewards = []
    batch_components = []
    
    for completion, prompt in zip(completions, prompts):
        try:
            # 1. Query ID extraction
            match = re.search(r'(?:SEARCH QUERY|Query)[:\s"]*([^"\n]+)', prompt)
            query_id = None
            if match:
                query_id = match.group(1)
            else:
                match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
                if match2:
                    query_id = match2.group(1)
            
            if not query_id or query_id not in slate_lookup:
                # Parsing hiba - enyhe b√ºntet√©s
                rewards.append(-0.1)
                batch_components.append({"error": "query_id_not_found"})
                continue
            
            # 2. Slate lookup
            slate = slate_lookup[query_id]
            
            # 3. Parse model output
            predicted_indices = parse_model_ranking(completion, len(slate))
            
            # 4. Baseline metrik√°k lookup
            if query_id not in baseline_metrics_dict:
                # Nincs baseline adat - fallback to zero baseline (nem ide√°lis, de biztons√°gos)
                baseline_metrics = {
                    "ndcg@10": 0.0,
                    "mrr@5": 0.0,
                    "mrr@10": 0.0,  # Kompatibilit√°s
                    "recall@20": 0.0,
                }
            else:
                baseline_metrics = baseline_metrics_dict[query_id]

        
            
            # 5. Multi-metric reward sz√°m√≠t√°s (helper function from src.reward_metrics)
            reward, components = compute_multi_metric_reward(
                predicted_indices=predicted_indices,
                slate=slate,
                baseline_metrics=baseline_metrics,
                weights=REWARD_WEIGHTS,
                clip_range=REWARD_CLIP_RANGE,
                use_sigmoid=REWARD_USE_SIGMOID,
                mrr_tiebreak_bonus=REWARD_MRR_TIEBREAK_BONUS,
                diversity_penalty_weight=REWARD_DIVERSITY_PENALTY,
            )
            
            rewards.append(reward)
            batch_components.append(components)
            
        except Exception as e:
            # Exception eset√©n enyhe b√ºntet√©s
            rewards.append(-0.1)
            batch_components.append({"error": str(e)[:100]})
    
    # Komponensek ment√©se logginghoz
    reward_components_log.extend(batch_components)
    
    return rewards

print("‚úÖ Multi-Metrika GRPO Reward function defini√°lva (Sigmoid-Based v2.0)")
print(f"   üìä Reward s√∫lyok:")
print(f"      ‚Ä¢ nDCG@10:   {REWARD_WEIGHTS['ndcg10']:.0%} (rangsor min≈ës√©g)")
print(f"      ‚Ä¢ MRR@5:     {REWARD_WEIGHTS['mrr5']:.0%} (els≈ë relev√°ns tal√°lat)")
print(f"      ‚Ä¢ Recall@20: {REWARD_WEIGHTS['recall20']:.0%} (completeness)")
print(f"   üéØ Sigmoid transzform√°ci√≥: {REWARD_USE_SIGMOID}")
print(f"   üéÅ MRR tie-break bonus: +{REWARD_MRR_TIEBREAK_BONUS}")
print(f"   üåà Diverzit√°s b√ºntet√©s: {REWARD_DIVERSITY_PENALTY}")
print(f"   üìè Reward clipping: {REWARD_CLIP_RANGE}")
print(f"   üìà Komponens tracking: Enabled")

In [None]:
# GRPO Trainer konfigur√°ci√≥ - MEMORY EFFICIENT + CURRICULUM LEARNING
grpo_config = GRPOConfig(
    output_dir=str(OUTPUT_DIR),
    max_steps=MAX_STEPS,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    num_generations=NUM_GENERATIONS,
    generation_batch_size=GENERATION_BATCH_SIZE,  # Explicit: LCM(6, 2) = 6
    optim=OPTIMIZER_NAME,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    loss_type="dr_grpo",
    scale_rewards="batch",
    use_vllm=True,
    vllm_importance_sampling_correction=False,  # vLLM token logprobs nem egyeznek az alap modell√©vel
    importance_sampling_level="sequence",
    mask_truncated_completions=True,
    epsilon=0.1,  # Enhanced: 0.2 ‚Üí 0.1 (smaller policy updates, more stable)
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False, "use_unsloth": True},
    max_grad_norm=1.0,
    logging_steps=LOGGING_STEPS,
    logging_first_step=True,
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    dataloader_num_workers=1,  # 2‚Üí1: cs√∂kkentett worker
    seed=SEED,
    max_completion_length=8192,  # CRITICAL: explicit completion length (default 256 t√∫l kev√©s 30 dokumentumhoz)
)

grpo_config.unsloth_num_chunks = 1

print("\n‚úÖ GRPO Trainer konfigur√°ci√≥ k√©sz (Memory Efficient + Curriculum Learning)")
print(f"  ‚úì generation_batch_size={GENERATION_BATCH_SIZE}")
print(f"    - oszthat√≥ num_generations={NUM_GENERATIONS}-szel: {GENERATION_BATCH_SIZE}√∑{NUM_GENERATIONS}={GENERATION_BATCH_SIZE//NUM_GENERATIONS}")
print(f"    - oszthat√≥ per_device_batch={PER_DEVICE_BATCH_SIZE}-gyel: {GENERATION_BATCH_SIZE}√∑{PER_DEVICE_BATCH_SIZE}={GENERATION_BATCH_SIZE//PER_DEVICE_BATCH_SIZE}")
print(f"  ‚úì Effective batch size: {PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  ‚úì Mem√≥ria optimaliz√°ci√≥: 50% GPU util, 1 worker, rank 32")
if CURRICULUM_ENABLED:
    print(f"  ‚úì Curriculum Learning: {CURRICULUM_STRATEGY} strat√©gia aktiv√°lva")

## üß™ Multi-Metrika Reward Validation

Most tesztelj√ºk a friss√≠tett multi-metrika reward f√ºggv√©nyt:
- Sigmoid transzform√°ci√≥ m≈±k√∂d√©se
- Komponensek (nDCG/MRR/Recall) sz√°m√≠t√°sa
- MRR tie-break bonus
- Diverzit√°s check
- Baseline metrics bet√∂lt√©s

In [None]:
# Training ind√≠t√°sa
print("\nüöÄ GRPO TRAINING IND√çT√ÅSA\n")
print("="*60)

if CURRICULUM_ENABLED:
    print(f"üìö Curriculum Learning akt√≠v: {CURRICULUM_STRATEGY}")
    if CURRICULUM_STRATEGY == "filtered_dataset":
        print("   Training k√∂nnyebb p√©ld√°kkal (easy + medium) - jobb tanul√°si signal")
    print()

try:
    trainer.train()
    print("\n" + "="*60)
    print("‚úÖ Training sikeresen befejezve!")
except Exception as e:
    print(f"\n‚ùå Training hiba: {e}")
    import traceback
    traceback.print_exc()
    raise

In [None]:
# Robusztus tr√©ning log vizualiz√°ci√≥ (hibaellen≈ërz√©ssel)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json

# Megpr√≥b√°ljuk a TRL trainer log_history-j√°t haszn√°lni; ha nincs, es√ºnk vissza a metrics.json-re
log_history = None
try:
    _ = trainer  # ellen≈ërz√©s, hogy l√©tezik-e
    log_history = getattr(trainer.state, "log_history", None)
except NameError:
    log_history = None

if log_history:
    logs = pd.DataFrame(log_history)
else:
    # Fallback: metrics.json bet√∂lt√©se √©s csak a reward trend megjelen√≠t√©se
    if 'METRICS_FILE' in globals():
        metrics_path = METRICS_FILE
    else:
        metrics_path = Path("data/models/grpo_policy/metrics.json")

    if Path(metrics_path).exists():
        with open(metrics_path, 'r', encoding='utf-8') as f:
            metrics = json.load(f)
        rewards_trend = metrics.get("training_rewards", {}).get("trend", [])
        if rewards_trend:
            plt.figure(figsize=(8, 4))
            plt.plot(rewards_trend, color="royalblue")
            plt.title("Jutalom trend (metrics.json)")
            plt.xlabel("L√©p√©s")
            plt.ylabel("Jutalom")
            plt.tight_layout()
            plt.show()
        else:
            print("‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json nem tartalmaz 'training_rewards.trend'-et.")
    else:
        print("‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json f√°jl sem tal√°lhat√≥.")
    raise SystemExit(0)  # Nincs teljes logs, l√©p√ºnk ki a cell√°b√≥l a tov√°bbi √°br√°k n√©lk√ºl

# Innent≈ël a log_history alapj√°n dolgozunk
if isinstance(logs, pd.DataFrame) and not logs.empty:
    # Csak numerikus oszlopok megtart√°sa
    numeric_cols = [c for c in logs.columns if np.issubdtype(pd.Series(logs[c]).dropna().dtype, np.number)]
    if not numeric_cols:
        print("‚ö†Ô∏è A logokban nincs numerikus oszlop a megjelen√≠t√©shez.")
    else:
        logs = logs[numeric_cols].copy()

        def pick_col(df, candidates):
            for c in candidates:
                if c in df.columns:
                    return c
            return None

        # Reward (mean) oszlop kiv√°laszt√°sa t√∂bb lehets√©ges kulccsal
        reward_candidates = [
            "rewards/reward_function/mean",  # TRL GRPO gyakori
            "rewards/mean",
            "reward/mean",
            "rewards/reward",
            "reward",
        ]
        reward_col = pick_col(logs, reward_candidates)
        if reward_col is None:
            print("‚ö†Ô∏è Nem tal√°lhat√≥ reward √°tlag oszlop a logokban. Kihagyom az 1. √°br√°t.")

        # Reward sz√≥r√°s oszlop (opcion√°lis)
        reward_std_candidates = [
            "rewards/reward_function/std",
            "rewards/std",
            "reward/std",
        ]
        reward_std_col = pick_col(logs, reward_std_candidates)

        # Loss √©s KL oszlopok (opcion√°lis)
        loss_col = pick_col(logs, ["loss", "train/loss"])  
        kl_col = pick_col(logs, ["kl", "train/kl", "kl_divergence"]) 

        # √Åbr√°k l√©trehoz√°sa dinamikusan az el√©rhet≈ë metrik√°k alapj√°n
        fig, axes = plt.subplots(2, 2, figsize=(13, 8))
        axes = axes.flatten()
        ax_idx = 0

        # 1) Reward trend (sim√≠tott) + opcion√°lis sz√≥r√°s s√°v
        if reward_col is not None:
            reward_smooth = logs[reward_col].rolling(window=10, min_periods=1).mean()
            axes[ax_idx].plot(reward_smooth, label=f"Jutalom (√°tlag, {reward_col})", color="royalblue")
            if reward_std_col is not None:
                lo = reward_smooth - logs[reward_std_col]
                hi = reward_smooth + logs[reward_std_col]
                axes[ax_idx].fill_between(logs.index, lo, hi, alpha=0.2, color="royalblue", label="¬± sz√≥r√°s")
            axes[ax_idx].set_title("Jutalom trend")
            axes[ax_idx].set_xlabel("L√©p√©s")
            axes[ax_idx].set_ylabel("Jutalom")
            axes[ax_idx].legend()
            ax_idx += 1

        # 2) Loss g√∂rbe
        if loss_col is not None and ax_idx < len(axes):
            axes[ax_idx].plot(logs[loss_col], label="Vesztes√©g (loss)", color="firebrick")
            axes[ax_idx].set_title("Tan√≠t√°si vesztes√©g")
            axes[ax_idx].set_xlabel("L√©p√©s")
            axes[ax_idx].set_ylabel("Loss")
            axes[ax_idx].legend()
            ax_idx += 1

        # 3) KL g√∂rbe
        if kl_col is not None and ax_idx < len(axes):
            axes[ax_idx].plot(logs[kl_col], label="KL divergencia", color="darkorange")
            axes[ax_idx].set_title("KL divergencia alakul√°sa")
            axes[ax_idx].set_xlabel("L√©p√©s")
            axes[ax_idx].set_ylabel("KL")
            axes[ax_idx].legend()
            ax_idx += 1

        # 4) Reward sz√≥r√°s k√ºl√∂n (ha van √©s maradt hely)
        if reward_std_col is not None and ax_idx < len(axes):
            axes[ax_idx].plot(logs[reward_std_col], label="Jutalom sz√≥r√°s", color="seagreen")
            axes[ax_idx].set_title("Jutalom sz√≥r√°s")
            axes[ax_idx].set_xlabel("L√©p√©s")
            axes[ax_idx].set_ylabel("Sz√≥r√°s")
            axes[ax_idx].legend()
            ax_idx += 1

        # Ha kevesebb metrika √°ll rendelkez√©sre, a marad√©k tengelyeket rejts√ºk el
        for i in range(ax_idx, len(axes)):
            axes[i].axis('off')

        plt.tight_layout()
        plt.show()
else:
    print("‚ö†Ô∏è A logok DataFrame √ºres.")


‚ö†Ô∏è Nincs el√©rhet≈ë trainer.state.log_history √©s a metrics.json f√°jl sem tal√°lhat√≥.


In [None]:
# OPTIONAL: Prompt Quality Check - Sample Completions
# Futtasd TRAINING EL≈êTT, hogy l√°sd, milyen outputokat gener√°l a modell

print("üî¨ PROMPT QUALITY CHECK - Sample Completions Gener√°l√°sa\n")
print("Ez a cella teszteli, hogy a modell milyen outputokat gener√°l az √∫j prompttal.")
print("FONTOS: Ez optional, csak debug c√©lra.\n")

# Egy random query kiv√°laszt√°sa
import random
random.seed(42)
test_slate_idx = random.randint(0, len(slates_data)-1)
test_slate_data = slates_data[test_slate_idx]
test_prompt_text = create_training_prompt(test_slate_data["query_id"], test_slate_data["slate"])

# Gener√°lunk n√©h√°ny completion-t
print(f"Query: {test_slate_data['query_id'][:80]}...")
print(f"Slate size: {len(test_slate_data['slate'])}")
print("-" * 80)

try:
    # Model inference mode
    model.eval()
    
    # Tokenize prompt
    inputs = tokenizer(test_prompt_text, return_tensors="pt", truncation=False, max_length=MAX_SEQ_LENGTH)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    print("\nüé≤ 3 Sample Completion Gener√°l√°sa...\n")
    
    for i in range(3):
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1024,  # N√∂velve 100-r√≥l 2048-ra, hogy elegend≈ë legyen 30 dokumentum rangsorol√°s√°hoz
                do_sample=True,
                temperature=0.8,
                top_p=0.9,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,  # Stop sequence hozz√°ad√°sa
            )
        
        completion = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        parsed = parse_model_ranking(completion, len(test_slate_data["slate"]))
        
        print(f"Completion {i+1}:")
        print(f"  Raw output: {completion[:150]}")
        print(f"  Parsed indices: {parsed[:10]}")  # First 10 indices
        print(f"  Valid? {len(parsed) == len(test_slate_data['slate']) and len(set(parsed)) == len(parsed)}")
        print(f"  Teljes ranking: {parsed}")
        print(completion[-200:])
        print(len(tokenizer(test_prompt_text).input_ids))
    
    # Relevance info
    relevances = [doc.get('relevance', 0) for doc in test_slate_data["slate"]]
    print(f"üìä Ground truth relevance distribution:")
    print(f"  Relevant docs (rel>=1): {sum(1 for r in relevances if r >= 1)}/{len(relevances)}")
    print(f"  Highly relevant (rel=2): {sum(1 for r in relevances if r == 2)}/{len(relevances)}")
    print(f"  Baseline order nDCG@10: {calculate_ndcg(list(range(len(relevances))), relevances, k=10):.4f}")
    
    model.train()
    print("\n‚úÖ Prompt quality check complete!")
    
except Exception as e:
    print(f"‚ùå Error during quality check: {e}")
    print("Ez nem probl√©ma, a training ett≈ël m√©g futhat.")

print("\n" + "="*80)

In [None]:
# Artifactumok ment√©se
print("\nüíæ Artifactumok ment√©se...")
model.save_pretrained_merged(str(OUTPUT_DIR), tokenizer, save_method="lora")
print(f"  ‚úÖ LoRA adapter: {OUTPUT_DIR}")

## üìä Modell √ñsszehasonl√≠t√°s

Az √©rt√©kel√©s √©s a baseline vs GRPO √∂sszehasonl√≠t√°s a **`notebooks/model_comparison.ipynb`** notebookban t√∂rt√©nik.

Ez biztos√≠tja, hogy a baseline √©s GRPO modell **ugyanazon eval halmazon** legyen ki√©rt√©kelve, √≠gy a metrik√°k √∂sszevethet≈ëk.


In [None]:
# Evaluation (ranx)
print("\nüìä √ârt√©kel√©s futtat√°sa (ranx)...")

import torch
import math
from tqdm.auto import tqdm

def evaluate_policy_ranx(dataset_subset, slate_lookup_dict, dataset_name="", batch_size=4, min_tokens=64, max_tokens=512):
    k_values = [5, 10, 20]
    metrics_to_compute = ["map", "mrr"]
    for k in k_values:
        metrics_to_compute.extend([f"ndcg@{k}", f"precision@{k}", f"recall@{k}"])

    qrels_dict = {}
    baseline_run_dict = {}
    policy_run_dict = {}
    parse_successes = 0
    entries = []

    for example in dataset_subset:
        prompt = example["prompt"]
        match = re.search(r'QUERY:\s*"([^"]+)"', prompt)
        if match:
            query_id = match.group(1)
        else:
            match2 = re.search(r'##\s*SEARCH\s*QUERY\s*\n\s*"([^"]+)"', prompt)
            query_id = match2.group(1) if match2 else None
        if not query_id or query_id not in slate_lookup_dict:
            continue

        slate = slate_lookup_dict[query_id]
        query_qrels = {}
        for doc in slate:
            doc_id = doc.get('doc_id')
            relevance = doc.get('relevance', 0)
            if doc_id is not None:
                query_qrels[doc_id] = relevance
        if not query_qrels:
            continue
        qrels_dict[query_id] = query_qrels

        baseline_docs = {}
        for rank, doc in enumerate(slate):
            doc_id = doc.get('doc_id')
            if doc_id is not None:
                baseline_docs[doc_id] = len(slate) - rank
        baseline_run_dict[query_id] = baseline_docs

        entries.append({
            "prompt": prompt,
            "query_id": query_id,
            "slate": slate,
        })

    if not entries:
        print(f"  {dataset_name}: nincs √©rt√©kelhet≈ë lek√©rdez√©s.")
        return {
            "baseline_metrics": {},
            "policy_metrics": {},
            "num_queries": 0,
            "positive_improvement_count": 0,
            "positive_improvement_ratio": 0.0,
            "parse_success_rate": 0.0,
            "parse_success_count": 0,
            "per_query_results": [],
        }

    total_entries = len(entries)
    num_batches = math.ceil(total_entries / batch_size)
    device = model.device

    print(f"  {dataset_name}: {total_entries} lek√©rdez√©s feldolgoz√°sa batchelt gener√°l√°ssal...")
    progress_desc = f"{dataset_name} batch feldolgoz√°s" if dataset_name else "√ârt√©kel√©s batch feldolgoz√°s"
    progress_bar = tqdm(total=num_batches, desc=progress_desc, unit="batch")
    original_padding_side = getattr(tokenizer, "padding_side", "right")
    if getattr(tokenizer, "pad_token", None) is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    try:
        with torch.inference_mode():
            for start in range(0, total_entries, batch_size):
                batch_entries = entries[start:start + batch_size]
                prompts = [entry["prompt"] for entry in batch_entries]
                encodings = tokenizer(prompts, padding=True, return_tensors="pt").to(device)
                input_lengths = encodings.attention_mask.sum(dim=1)

                dynamic_max_new_tokens = 0
                for entry in batch_entries:
                    slate_len = len(entry["slate"])
                    dynamic_max_new_tokens = max(dynamic_max_new_tokens, slate_len * 6)
                dynamic_max_new_tokens = int(max(min_tokens, min(max_tokens, dynamic_max_new_tokens)))

                outputs = model.generate(
                    **encodings,
                    max_new_tokens=dynamic_max_new_tokens,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=False,
                    temperature=0.0,
                    num_beams=1,
                )

                for idx, entry in enumerate(batch_entries):
                    input_len = int(input_lengths[idx].item())
                    completion = tokenizer.decode(outputs[idx][input_len:], skip_special_tokens=True)
                    slate = entry["slate"]
                    query_id = entry["query_id"]

                    predicted_indices = parse_model_ranking(completion, len(slate))
                    is_valid_parse = len(predicted_indices) == len(slate) and len(set(predicted_indices)) == len(slate)
                    if not is_valid_parse:
                        predicted_indices = list(range(len(slate)))
                    else:
                        parse_successes += 1

                    policy_docs = {}
                    for rank, slate_idx in enumerate(predicted_indices):
                        if slate_idx < len(slate):
                            doc_id = slate[slate_idx].get("doc_id")
                            if doc_id is not None:
                                policy_docs[doc_id] = len(slate) - rank
                    policy_run_dict[query_id] = policy_docs

                progress_bar.update(1)
    finally:
        tokenizer.padding_side = original_padding_side
        progress_bar.close()
    qrels_ranx = Qrels(qrels_dict)
    baseline_run_ranx = Run(baseline_run_dict)
    policy_run_ranx = Run(policy_run_dict)

    baseline_metrics = evaluate(qrels_ranx, baseline_run_ranx, metrics_to_compute)
    policy_metrics = evaluate(qrels_ranx, policy_run_ranx, metrics_to_compute)

    per_query_results = []
    for query_id in qrels_dict.keys():
        row = {"query_id": query_id}
        for metric in metrics_to_compute:
            baseline_score = baseline_run_ranx.scores.get(metric, {}).get(query_id, 0.0)
            row[f"baseline_{metric}"] = float(baseline_score)
        for metric in metrics_to_compute:
            policy_score = policy_run_ranx.scores.get(metric, {}).get(query_id, 0.0)
            row[f"policy_{metric}"] = float(policy_score)
        baseline_ndcg10 = baseline_run_ranx.scores.get("ndcg@10", {}).get(query_id, 0.0)
        policy_ndcg10 = policy_run_ranx.scores.get("ndcg@10", {}).get(query_id, 0.0)
        row["improvement_ndcg@10"] = float(policy_ndcg10 - baseline_ndcg10)
        per_query_results.append(row)

    improvements = [row["improvement_ndcg@10"] for row in per_query_results]
    positive_improvements = sum(1 for delta in improvements if delta > 0)
    positive_ratio = positive_improvements / len(improvements) if improvements else 0.0
    parse_success_rate = parse_successes / len(per_query_results) if per_query_results else 0.0

    print(f"  {dataset_name}: v√©ge. Baseline nDCG@10 = {baseline_metrics.get('ndcg@10', 0.0):.4f}, Policy nDCG@10 = {policy_metrics.get('ndcg@10', 0.0):.4f}, Javul√°s = {policy_metrics.get('ndcg@10', 0.0) - baseline_metrics.get('ndcg@10', 0.0):+.4f}")

    return {
        "baseline_metrics": {k: float(v) for k, v in baseline_metrics.items()},
        "policy_metrics": {k: float(v) for k, v in policy_metrics.items()},
        "num_queries": len(qrels_dict),
        "positive_improvement_count": positive_improvements,
        "positive_improvement_ratio": float(positive_ratio),
        "parse_success_rate": float(parse_success_rate),
        "parse_success_count": parse_successes,
        "per_query_results": per_query_results,
    }

train_eval_results = evaluate_policy_ranx(train_dataset, slate_lookup, "Train halmaz")
eval_eval_results = evaluate_policy_ranx(eval_dataset, slate_lookup, "Eval halmaz")


In [None]:
# Per-query export
training_rewards = []
for log_entry in trainer.state.log_history:
    if "rewards/mean" in log_entry:
        training_rewards.append(log_entry["rewards/mean"]) 

train_per_query_df = pd.DataFrame(train_eval_results["per_query_results"])
train_per_query_csv = OUTPUT_DIR / "train_per_query_results.csv"
train_per_query_df.to_csv(train_per_query_csv, index=False, encoding='utf-8')
print(f"  ‚úÖ Train per-query results: {train_per_query_csv}")

eval_per_query_df = pd.DataFrame(eval_eval_results["per_query_results"])
eval_per_query_csv = OUTPUT_DIR / "eval_per_query_results.csv"
eval_per_query_df.to_csv(eval_per_query_csv, index=False, encoding='utf-8')
print(f"  ‚úÖ Eval per-query results: {eval_per_query_csv}")

In [None]:
# Metrics export
final_metrics = {
    "model_name": MODEL_NAME,
    "training_samples": len(train_dataset),
    "eval_samples": len(eval_dataset),
    "slate_size": SLATE_SIZE,
    "group_size": GROUP_SIZE,
    "max_steps": MAX_STEPS,
    "learning_rate": LEARNING_RATE,
    "lora_rank": LORA_RANK,
    "lora_alpha": LORA_ALPHA,
    "final_loss": trainer.state.log_history[-1].get("loss", 0.0) if trainer.state.log_history else 0.0,
    "total_steps": len(trainer.state.log_history) if trainer.state.log_history else 0,
    "training_rewards": {
        "mean": float(np.mean(training_rewards)) if training_rewards else 0.0,
        "std": float(np.std(training_rewards)) if training_rewards else 0.0,
        "min": float(np.min(training_rewards)) if training_rewards else 0.0,
        "max": float(np.max(training_rewards)) if training_rewards else 0.0,
        "trend": training_rewards
    },
    "train_evaluation": {
        "baseline_metrics": train_eval_results["baseline_metrics"],
        "policy_metrics": train_eval_results["policy_metrics"],
        "num_queries": train_eval_results["num_queries"],
        "positive_improvement_count": train_eval_results["positive_improvement_count"],
        "positive_improvement_ratio": train_eval_results["positive_improvement_ratio"],
        "parse_success_rate": train_eval_results["parse_success_rate"],
        "parse_success_count": train_eval_results["parse_success_count"]
    },
    "eval_evaluation": {
        "baseline_metrics": eval_eval_results["baseline_metrics"],
        "policy_metrics": eval_eval_results["policy_metrics"],
        "num_queries": eval_eval_results["num_queries"],
        "positive_improvement_count": eval_eval_results["positive_improvement_count"],
        "positive_improvement_ratio": eval_eval_results["positive_improvement_ratio"],
        "parse_success_rate": eval_eval_results["parse_success_rate"],
        "parse_success_count": eval_eval_results["parse_success_count"]
    },
    "status": "completed"
}

with open(METRICS_FILE, 'w', encoding='utf-8') as f:
    json.dump(final_metrics, f, ensure_ascii=False, indent=2)

print(f"  ‚úÖ Metrics: {METRICS_FILE}")
print("\n‚úÖ Minden artifact sikeresen mentve!")