# CourtRankRL Embedding Generation - Sentence Transformers (RTX 5090 GPU)

## Specifik√°ci√≥
- **Modell**: google/embeddinggemma-300m via Sentence Transformers (>=5.1.0)
- **Input**: chunks.jsonl (processed court decisions)
- **Output**: embeddings.npy (float32, L2-normalized) √©s embedding_chunk_ids.json
- **K√∂rnyezet**: RunPod RTX 5090 GPU (24GB VRAM)
- **Batch size**: 512 (GPU optimaliz√°lt)
- **Kritikus**: CSAK Sentence Transformers - AutoModel ZERO VECTOR-t produk√°l!

## Prompt-ok (automatikus Sentence Transformers-ben)
- **Document chunks**: `prompt_name="document"` ‚Üí "title: none | text: {chunk_text}"
- **Query**: `prompt_name="query"` ‚Üí "task: search result | query: {query_text}"
- **Normalization**: `normalize_embeddings=True` (automatikus L2-normaliz√°l√°s)

## Mem√≥ria kezel√©s
- Shard-okba √≠r√°s nagy dataset-ekhez
- Konszolid√°ci√≥ a v√©g√©n
- GPU mem√≥ria monitoring

In [None]:
# KRITIKUS: Csak Sentence Transformers - AutoModel ZERO VECTOR-t produk√°l!
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import os
import time
from pathlib import Path
import psutil
from typing import List
import torch

# --- Konfigur√°ci√≥ ---
BASE_PATH = Path("/workspace")
CHUNKS_PATH = BASE_PATH / "chunks.jsonl"
EMBEDDINGS_PATH = BASE_PATH / "embeddings.npy"
CHUNK_IDS_PATH = BASE_PATH / "embedding_chunk_ids.json"

# GPU konfigur√°ci√≥ RTX 5090-hez
BATCH_SIZE = 512
MAX_SEQ_LENGTH = 2048  # EmbeddingGemma default
SHARD_SIZE = 100_000  # Shard-okba √≠r√°s

# HF token (k√∂rnyezeti v√°ltoz√≥b√≥l)
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise ValueError("‚ùå HF_TOKEN k√∂rnyezeti v√°ltoz√≥ hi√°nyzik!")

print(f"üöÄ RTX 5090 Embedding Generation indul...")
print(f"üìÇ Base path: {BASE_PATH}")
print(f"üìÑ Input chunks: {CHUNKS_PATH}")
print(f"üìÑ Output embeddings: {EMBEDDINGS_PATH}")
print(f"üìÑ Output chunk IDs: {CHUNK_IDS_PATH}")

# --- Chunks bet√∂lt√©s (pandas optimaliz√°lt - agents.md szerint) ---
def load_chunks(chunks_path: Path) -> List[dict]:
    """
    Chunks bet√∂lt√©se JSONL-b≈ël pandas seg√≠ts√©g√©vel.
    
    pandas.read_json() 10-30x gyorsabb mint k√©zi json.loads() parsing.
    """
    import pandas as pd
    
    print(f"üì• Chunks bet√∂lt√©se: {chunks_path}")
    
    try:
        # pandas.read_json() C-optimaliz√°lt, sokkal gyorsabb
        df_chunks = pd.read_json(chunks_path, lines=True, encoding='utf-8')
        chunks = df_chunks.to_dict('records')
        print(f"‚úÖ {len(chunks):,} chunks bet√∂ltve")
    except (ValueError, FileNotFoundError) as e:
        print(f"‚ùå Hiba a chunks bet√∂lt√©se sor√°n: {e}")
        chunks = []
    
    return chunks

chunks = load_chunks(CHUNKS_PATH)
chunk_texts = [chunk['text'] for chunk in chunks]
chunk_ids = [chunk['chunk_id'] for chunk in chunks]

print(f"üìä √ñsszes chunk: {len(chunk_texts):,}")
print(f"üíæ Mem√≥ria haszn√°lat: {psutil.virtual_memory().used / 1024**3:.1f}GB")

# --- Modell bet√∂lt√©s (CSAK Sentence Transformers!) ---
print("ü§ñ EmbeddingGemma modell bet√∂lt√©se (Sentence Transformers)...")

try:
    model = SentenceTransformer(
        "google/embeddinggemma-300m",
        token=HF_TOKEN,
        trust_remote_code=True,
        cache_dir="/tmp/hf_cache"
    )
    print("‚úÖ Modell bet√∂ltve (Sentence Transformers)")
except Exception as e:
    print(f"‚ùå Modell bet√∂lt√©si hiba: {e}")
    raise

# GPU be√°ll√≠t√°sok
if torch.cuda.is_available():
    device = torch.device("cuda")
    model = model.to(device)
    print(f"‚úÖ GPU el√©rhet≈ë: {torch.cuda.get_device_name()}")
    print(f"üíæ GPU mem√≥ria: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
else:
    device = torch.device("cpu")
    print("‚ö†Ô∏è GPU nem el√©rhet≈ë - CPU haszn√°lata (lass√∫!)")

# --- Embedding gener√°l√°s ---
def generate_embeddings(model, texts: List[str], batch_size: int, device) -> np.ndarray:
    """Batch-es embedding gener√°l√°s"""
    print(f"üîÑ Embedding gener√°l√°s indul: {len(texts):,} texts")
    
    all_embeddings = []
    start_time = time.time()
    
    # Batch-ek feldolgoz√°sa
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        batch_start = time.time()
        
        try:
            # KRITIKUS: Sentence Transformers encode() haszn√°lata
            # Automatikus prompt: "title: none | text: {chunk}"
            # Automatikus L2-normaliz√°l√°s
            batch_embeddings = model.encode(
                batch_texts,
                prompt_name="document",  # Document prompt
                normalize_embeddings=True,  # L2-normaliz√°l√°s
                batch_size=batch_size,
                convert_to_numpy=True,
                device=device,
                show_progress_bar=True
                # dtype=torch.float32 ELT√ÅVOL√çTVA - nem t√°mogatott param√©ter!
            )
            
            all_embeddings.append(batch_embeddings)
            
            # GPU mem√≥ria monitoring
            if device.type == 'cuda':
                mem_used = torch.cuda.memory_allocated() / 1024**3
                print(f"  üìä Batch {i//batch_size + 1} k√©sz | GPU mem√≥ria: {mem_used:.1f}GB")
            
            batch_time = time.time() - batch_start
            print(f"  ‚è±Ô∏è Batch id≈ë: {batch_time:.2f}s")
            
        except Exception as e:
            print(f"‚ùå Batch hiba {i//batch_size + 1}-n√°l: {e}")
            raise
    
    # Konszolid√°ci√≥
    embeddings = np.vstack(all_embeddings)
    total_time = time.time() - start_time
    
    print(f"‚úÖ Embedding gener√°l√°s k√©sz: {embeddings.shape}")
    print(f"‚è±Ô∏è Teljes id≈ë: {total_time:.2f}s")
    print(f"‚ö° Sebess√©g: {len(texts)/total_time:.1f} texts/sec")
    
    return embeddings

# Sanity check: Ellen≈ërizz√ºk hogy nincsenek zero vector-ok
def validate_embeddings(embeddings: np.ndarray, chunk_ids: List[str]) -> tuple:
    """Embedding valid√°ci√≥ √©s tiszt√≠t√°s"""
    print("üîç Embedding valid√°ci√≥...")
    
    # NaN/Inf ellen≈ërz√©s
    finite_mask = np.isfinite(embeddings).all(axis=1)
    print(f"  üìä Finite embeddings: {finite_mask.sum():,}/{len(embeddings):,}")
    
    # Zero norma ellen≈ërz√©s
    norms = np.linalg.norm(embeddings, axis=1)
    nonzero_mask = norms > 1e-6
    print(f"  üìä Non-zero norm embeddings: {nonzero_mask.sum():,}/{len(embeddings):,}")
    
    # Kombin√°lt maszk
    valid_mask = finite_mask & nonzero_mask
    num_invalid = (~valid_mask).sum()
    
    if num_invalid > 0:
        print(f"‚ö†Ô∏è {num_invalid:,} invalid embedding kisz≈±rve")
        embeddings = embeddings[valid_mask]
        chunk_ids = [cid for cid, keep in zip(chunk_ids, valid_mask) if keep]
    else:
        print("‚úÖ Minden embedding valid")
    
    return embeddings, chunk_ids

# Gener√°l√°s
embeddings = generate_embeddings(model, chunk_texts, BATCH_SIZE, device)

# Valid√°ci√≥ √©s tiszt√≠t√°s
embeddings, chunk_ids = validate_embeddings(embeddings, chunk_ids)

# V√©gs≈ë dtype biztos√≠t√°s
if embeddings.dtype != np.float32:
    embeddings = embeddings.astype(np.float32)

print(f"üéØ V√©gs≈ë alak: {embeddings.shape}")
print(f"üíæ Mem√≥ria haszn√°lat: {embeddings.nbytes / 1024**3:.2f}GB")

# --- Ment√©s shard-okba (nagy dataset-ekhez) ---
def save_sharded_embeddings(embeddings: np.ndarray, chunk_ids: List[str], 
                           shard_size: int, output_path: Path, ids_path: Path):
    """Shard-okba ment√©s mem√≥ria optimaliz√°l√°shoz"""
    print(f"üíæ Shard-okba ment√©s (shard size: {shard_size:,})...")
    
    os.makedirs(output_path.parent, exist_ok=True)
    os.makedirs(ids_path.parent, exist_ok=True)
    
    # Shard-ok l√©trehoz√°sa
    for i in range(0, len(embeddings), shard_size):
        shard_idx = i // shard_size
        end_idx = min(i + shard_size, len(embeddings))
        
        shard_embeddings = embeddings[i:end_idx]
        shard_ids = chunk_ids[i:end_idx]
        
        shard_path = output_path.parent / f"embeddings_shard_{shard_idx}.npy"
        ids_shard_path = output_path.parent / f"chunk_ids_shard_{shard_idx}.json"
        
        np.save(shard_path, shard_embeddings)
        with open(ids_shard_path, 'w', encoding='utf-8') as f:
            json.dump(shard_ids, f, ensure_ascii=False, indent=2)
        
        print(f"  üíæ Shard {shard_idx} mentve: {len(shard_embeddings):,} embeddings")
    
    # Konszolid√°lt f√°jlok l√©trehoz√°sa
    print("üîÑ Konszolid√°ci√≥ shard-okb√≥l...")
    
    # Embeddings √∂sszevon√°sa
    consolidated_embeddings = []
    consolidated_ids = []
    
    shard_idx = 0
    while True:
        shard_path = output_path.parent / f"embeddings_shard_{shard_idx}.npy"
        ids_shard_path = output_path.parent / f"chunk_ids_shard_{shard_idx}.json"
        
        if not shard_path.exists():
            break
            
        shard_emb = np.load(shard_path)
        with open(ids_shard_path, 'r', encoding='utf-8') as f:
            shard_ids = json.load(f)
        
        consolidated_embeddings.append(shard_emb)
        consolidated_ids.extend(shard_ids)
        
        # Tiszt√≠t√°s
        os.remove(shard_path)
        os.remove(ids_shard_path)
        
        shard_idx += 1
    
    # V√©gs≈ë konszolid√°ci√≥
    final_embeddings = np.vstack(consolidated_embeddings)
    print(f"‚úÖ Konszolid√°lt: {final_embeddings.shape}")
    
    # Ment√©s
    np.save(output_path, final_embeddings)
    with open(ids_path, 'w', encoding='utf-8') as f:
        json.dump(consolidated_ids, f, ensure_ascii=False, indent=2)
    
    print(f"üíæ Konszolid√°lt f√°jlok mentve")

# Ment√©s
save_sharded_embeddings(embeddings, chunk_ids, SHARD_SIZE, EMBEDDINGS_PATH, CHUNK_IDS_PATH)

# --- V√©gs≈ë valid√°ci√≥ ---
print("üîç V√©gs≈ë valid√°ci√≥...")

# Bet√∂lt√©s ellen≈ërz√©s
loaded_embeddings = np.load(EMBEDDINGS_PATH)
with open(CHUNK_IDS_PATH, 'r', encoding='utf-8') as f:
    loaded_ids = json.load(f)

# Ellen≈ërz√©sek
assert loaded_embeddings.shape[0] == len(loaded_ids), "ID √©s embedding sz√°m nem egyezik"
assert loaded_embeddings.shape[1] == 768, f"Embedding dim nem 768: {loaded_embeddings.shape[1]}"
assert loaded_embeddings.dtype == np.float32, f"Dtype nem float32: {loaded_embeddings.dtype}"

# NaN/Inf ellen≈ërz√©s
assert np.isfinite(loaded_embeddings).all(), "NaN/Inf az embeddingekben!"
assert np.all(np.linalg.norm(loaded_embeddings, axis=1) > 0), "Zero-norm embeddings!"

print("‚úÖ V√©gs≈ë valid√°ci√≥ sikeres")
print(f"üìä V√©gs≈ë statisztik√°k:")
print(f"  ‚Ä¢ Embeddings: {loaded_embeddings.shape}")
print(f"  ‚Ä¢ Chunk IDs: {len(loaded_ids)}")
print(f"  ‚Ä¢ Mem√≥ria: {loaded_embeddings.nbytes / 1024**3:.2f}GB")

# --- √ñsszefoglal√≥ ---
print("="*80)
print("üéâ EMBEDDING GENER√ÅL√ÅS SIKERES!")
print("="*80)
print(f"üìÑ Kimeneti f√°jlok:")
print(f"  ‚Ä¢ {EMBEDDINGS_PATH}")
print(f"  ‚Ä¢ {CHUNK_IDS_PATH}")
print(f"üìä Vektorok: {loaded_embeddings.shape[0]:,}")
print(f"üéØ Dimenzi√≥: {loaded_embeddings.shape[1]}")
print(f"üíæ F√°jlm√©ret: {EMBEDDINGS_PATH.stat().st_size / 1024**3:.2f}GB")
print(f"‚è±Ô∏è Futtat√°si id≈ë: {time.time() - time.time():.2f}s")  # TODO: Track total time

print("
üöÄ K√∂vetkez≈ë l√©p√©s: T√∂ltsd le az artifact-okat √©s futtasd a faiss_index_builder.ipynb-t")
print("üí° Tipp: Ellen≈ërizd hogy minden embedding L2-normaliz√°lt √©s finite!")

## Haszn√°lat



## Kritikus megjegyz√©sek

1. **CSAK Sentence Transformers!** AutoModel haszn√°lata zero-vector-t produk√°l
2. **normalize_embeddings=True** k√∂telez≈ë az L2-normaliz√°l√°shoz
3. **float32** haszn√°lata - EmbeddingGemma nem t√°mogatja float16-ot
4. **prompt_name="document"** automatikusan hozz√°adja a megfelel≈ë prompt-ot
5. **Valid√°ci√≥** minden embedding-re - zero vector-ok azonnali detekt√°l√°sa
6. **dtype param√©ter ELT√ÅVOL√çTVA** - Sentence Transformers nem t√°mogatja ezt a param√©tert