# Graph-Augmented Retrieval Pipeline Testing

This notebook tests and demonstrates the full graph-augmented retrieval pipeline:

1. **Ingestor** - Load and normalize corpus data
2. **Summarizer** - Generate summaries for chunks
3. **Embedder** - Build embeddings and indices
4. **GraphBuilder** - Build graph edges and score candidates

## Pipeline Overview

```
laws_de.csv + court_considerations.csv
        ‚Üì (Ingestor)
    chunks.parquet
        ‚Üì (Summarizer)
   summaries.parquet
        ‚Üì (Embedder)
embeddings.npy + faiss_index.bin + bm25_index.pkl
        ‚Üì (GraphBuilder)
edges_similar.parquet + edges_cocite.parquet + groups.parquet
        ‚Üì (Inference)
   predictions.csv
```

## 1. Setup

In [1]:
import os
import sys
from pathlib import Path

# =============================================================================
# Environment Variables (API Configuration)
# =============================================================================
# Set these before running if you need LLM API access (for summarizer_mode="llm")
# You can also set them in your shell or .env file

# API Key for LLM provider (e.g., OpenAI, ProxyAPI)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "sk-UE4HD39TPIfPpVnyOuD33zxfuIJYaumv")

# ProxyAPI base URL (for Russian users or custom proxies)
PROXYAPI_BASE_URL = os.environ.get("PROXYAPI_BASE_URL", "https://api.proxyapi.ru/openai/v1")

# Optionally set them here directly (not recommended for production)
# os.environ["OPENAI_API_KEY"] = "sk-..."
# os.environ["PROXYAPI_BASE_URL"] = "https://api.proxyapi.ru/openai/v1"

# =============================================================================
# Setup paths
# =============================================================================
KAGGLE_ENV = "KAGGLE_KERNEL_RUN_TYPE" in os.environ

if KAGGLE_ENV:
    REPO_ROOT = Path("/kaggle/input/omnilex-repo")
    DATA_PATH = Path("/kaggle/input/omnilex-data")
    OUTPUT_PATH = Path("/kaggle/working")
else:
    REPO_ROOT = Path(".").resolve().parent
    DATA_PATH = REPO_ROOT / "data"
    OUTPUT_PATH = REPO_ROOT / "output"

PROCESSED_PATH = DATA_PATH / "processed"
PROCESSED_PATH.mkdir(parents=True, exist_ok=True)
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Add src to path
sys.path.insert(0, str(REPO_ROOT / "src"))

print(f"Environment: {'Kaggle' if KAGGLE_ENV else 'Local'}")
print(f"Data path: {DATA_PATH}")
print(f"Processed path: {PROCESSED_PATH}")
print(f"Output path: {OUTPUT_PATH}")
print(f"\nAPI Configuration:")
print(f"  OPENAI_API_KEY: {'***' + OPENAI_API_KEY[-4:] if len(OPENAI_API_KEY) > 4 else '(not set)'}")
print(f"  PROXYAPI_BASE_URL: {PROXYAPI_BASE_URL}")

Environment: Local
Data path: C:\Users\Artem Khakimov\Desktop\Projects\LEXam_kaggle\Omnilex-Agentic-Retrieval-Competition\data
Processed path: C:\Users\Artem Khakimov\Desktop\Projects\LEXam_kaggle\Omnilex-Agentic-Retrieval-Competition\data\processed
Output path: C:\Users\Artem Khakimov\Desktop\Projects\LEXam_kaggle\Omnilex-Agentic-Retrieval-Competition\output

API Configuration:
  OPENAI_API_KEY: ***aumv
  PROXYAPI_BASE_URL: https://api.proxyapi.ru/openai/v1


In [None]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import importlib

# Reload modules to pick up code changes
import omnilex.graph.summarizer
import omnilex.graph.reranker
importlib.reload(omnilex.graph.summarizer)
importlib.reload(omnilex.graph.reranker)

# Import graph modules
from omnilex.graph.ingestor import Ingestor, ChunkType, Language
from omnilex.graph.summarizer import Summarizer, SummaryType, create_llm_client
from omnilex.graph.embedder import Embedder, EMBEDDING_MODEL
from omnilex.graph.graph_builder import GraphBuilder, ExpansionParams, ScoringParams
from omnilex.graph.reranker import LLMReranker, RerankerConfig, QueryPreprocessor

print(f"‚úÖ Modules loaded")
print(f"   Embedding model: {EMBEDDING_MODEL}")

# =============================================================================
# LLM Client Setup (–¥–ª—è LLM Rerank –Ω–∞ Kaggle)
# =============================================================================
# Choose provider: "openai" or "google"
LLM_PROVIDER = "google"  # <-- Change this to switch providers

# Create LLM client (used for reranking, NOT for summarization)
if LLM_PROVIDER == "google":
    llm_client = create_llm_client(
        api_key=OPENAI_API_KEY,
        model="gemini-2.5-flash-lite",
        provider="google",
    )
else:
    llm_client = create_llm_client(
        api_key=OPENAI_API_KEY,
        model="gpt-4o",
        provider="openai",
    )

print(f"\nü§ñ LLM client: provider={llm_client.provider}, model={llm_client.model}")

# Test LLM connection
print("   Testing connection...", end=" ")
try:
    response = llm_client("Say 'OK' if working.")
    print(f"‚úÖ {response[:50]}")
except Exception as e:
    print(f"‚ùå {type(e).__name__}: {e}")

Embedding model: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
LLM client created: provider=google, model=gemini-2.5-flash-lite
Testing LLM connection...
‚úÖ LLM Response: Hello, I am working!


## 2. Configuration

In [None]:
# =============================================================================
# Pipeline Configuration (—Å–æ–≥–ª–∞—Å–Ω–æ –∏—Ç–æ–≥–æ–≤–æ–º—É –ø–ª–∞–Ω—É)
# =============================================================================

CONFIG = {
    # ---------------------------------------------------------------------
    # –≠—Ç–∞–ø 0: –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –∞—Ä—Ç–µ—Ñ–∞–∫—Ç–æ–≤
    # ---------------------------------------------------------------------
    "force_rebuild": False,     # True = –ø–µ—Ä–µ—Å–æ–∑–¥–∞—Ç—å –≤—Å–µ –∞—Ä—Ç–µ—Ñ–∞–∫—Ç—ã
    "sample_size": None,        # int –¥–ª—è —Ç–µ—Å—Ç–∏—Ä–æ–≤–∞–Ω–∏—è (–Ω–∞–ø—Ä–∏–º–µ—Ä 1000)
    
    # Summarizer: –í–°–ï–ì–î–ê heuristic –¥–ª—è –ø–æ–ª–Ω–æ–≥–æ –∫–æ—Ä–ø—É—Å–∞ (LLM —Å–ª–∏—à–∫–æ–º –¥–æ—Ä–æ–≥–æ)
    "summarizer_mode": "heuristic",  # "heuristic" - –æ—Ñ–ª–∞–π–Ω, –±—ã—Å—Ç—Ä–æ
    
    # Embedder
    "embedding_model": EMBEDDING_MODEL,
    "embedding_batch_size": 64,
    "faiss_index_type": "flat",  # "flat", "ivf", "hnsw"
    
    # Graph Builder
    "similar_k": 50,             # SIMILAR_TO: k —Å–æ—Å–µ–¥–µ–π
    "similar_min_cos": 0.25,     # –º–∏–Ω–∏–º–∞–ª—å–Ω—ã–π –∫–æ—Å–∏–Ω—É—Å
    "cocite_top_m": 50,          # CO_CITED_WITH: topM —Å–æ—Å–µ–¥–µ–π –Ω–∞ —É–∑–µ–ª
    
    # ---------------------------------------------------------------------
    # –≠—Ç–∞–ø 1: Inference (–Ω–∞ Kaggle)
    # ---------------------------------------------------------------------
    
    # Initial Retrieval
    "top_n_bm25": 200,           # BM25 –∫–∞–Ω–¥–∏–¥–∞—Ç–æ–≤
    "top_n_faiss": 200,          # FAISS –∫–∞–Ω–¥–∏–¥–∞—Ç–æ–≤
    "top_k_retrieval": 200,      # –ø–æ—Å–ª–µ RRF fusion
    
    # Graph Expansion
    "k_expand_sim": 20,          # —Å–æ—Å–µ–¥–µ–π –ø–æ SIMILAR_TO
    "k_expand_cocite": 30,       # —Å–æ—Å–µ–¥–µ–π –ø–æ CO_CITED_WITH
    "k_expand_siblings": 10,     # siblings –∏–∑ DocGroup
    "max_candidates": 800,       # cap –ø–æ—Å–ª–µ expansion
    
    # Fast Scoring (–ª–∏–Ω–µ–π–Ω–∞—è –∫–æ–º–±–∏–Ω–∞—Ü–∏—è)
    "alpha": 1.0,    # retrieval score weight
    "beta": 0.6,     # similarity edge weight
    "gamma": 0.8,    # co-citation weight
    "delta": 0.2,    # docgroup bonus
    
    # LLM Rerank (—Ç–æ—á–µ—á–Ω–æ –Ω–∞ topK)
    "use_llm_rerank": True,      # –∏—Å–ø–æ–ª—å–∑–æ–≤–∞—Ç—å LLM –¥–ª—è rerank
    "top_k_to_rerank": 100,      # –∫–∞–Ω–¥–∏–¥–∞—Ç–æ–≤ –¥–ª—è LLM rerank
    "rerank_batch_size": 20,     # batch size –¥–ª—è LLM
    "relevance_threshold": 0.5,  # –ø–æ—Ä–æ–≥ —Ä–µ–ª–µ–≤–∞–Ω—Ç–Ω–æ—Å—Ç–∏ (0-1)
    
    # Final Output
    "top_k_final": 20,           # —Ñ–∏–Ω–∞–ª—å–Ω—ã—Ö citations
    
    # Query Preprocessing
    "use_query_summary": False,  # LLM summary –¥–ª—è query (–æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω–æ)
}

print("=" * 60)
print("PIPELINE CONFIGURATION")
print("=" * 60)
print("\nüì¶ –≠—Ç–∞–ø 0 - –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –∞—Ä—Ç–µ—Ñ–∞–∫—Ç–æ–≤:")
print(f"  summarizer_mode: {CONFIG['summarizer_mode']}")
print(f"  embedding_model: {CONFIG['embedding_model']}")
print(f"  similar_k: {CONFIG['similar_k']}, cocite_top_m: {CONFIG['cocite_top_m']}")

print("\nüîç –≠—Ç–∞–ø 1 - Retrieval:")
print(f"  BM25 top: {CONFIG['top_n_bm25']}, FAISS top: {CONFIG['top_n_faiss']}")
print(f"  Graph expansion: sim={CONFIG['k_expand_sim']}, cocite={CONFIG['k_expand_cocite']}, siblings={CONFIG['k_expand_siblings']}")
print(f"  Max candidates: {CONFIG['max_candidates']}")

print("\n‚öñÔ∏è Scoring weights:")
print(f"  Œ±={CONFIG['alpha']} (retrieval), Œ≤={CONFIG['beta']} (similarity)")
print(f"  Œ≥={CONFIG['gamma']} (co-citation), Œ¥={CONFIG['delta']} (docgroup)")

print("\nü§ñ LLM Rerank:")
print(f"  use_llm_rerank: {CONFIG['use_llm_rerank']}")
print(f"  top_k_to_rerank: {CONFIG['top_k_to_rerank']}")
print(f"  top_k_final: {CONFIG['top_k_final']}")

Configuration:
  force_rebuild: False
  sample_size: None
  summarizer_mode: llm
  embedding_model: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
  embedding_batch_size: 64
  faiss_index_type: flat
  similar_k: 50
  similar_min_cos: 0.25
  cocite_top_m: 50
  k_expand_sim: 20
  k_expand_cocite: 30
  k_expand_siblings: 10
  max_candidates: 800
  alpha: 1.0
  beta: 0.6
  gamma: 0.8
  delta: 0.2
  top_k_retrieval: 100
  top_k_final: 20


## 3. Step 1: Ingestor - Load Corpus

In [8]:
chunks_path = PROCESSED_PATH / "chunks.parquet"

if chunks_path.exists() and not CONFIG["force_rebuild"]:
    print(f"Loading existing chunks from {chunks_path}")
    chunks_df = Ingestor.load_chunks(chunks_path)
else:
    print("Building chunks from CSV files...")
    # show_progress=True enables tqdm progress bars for loading
    ingestor = Ingestor(DATA_PATH, show_progress=True)
    chunks_df = ingestor.load_all()
    
    if CONFIG["sample_size"]:
        print(f"Sampling {CONFIG['sample_size']} chunks for testing")
        chunks_df = chunks_df.sample(n=min(CONFIG["sample_size"], len(chunks_df)), random_state=42)
    
    ingestor.save(chunks_df, chunks_path)
    print(f"Saved chunks to {chunks_path}")

print(f"\nChunks loaded: {len(chunks_df)}")
print(f"Columns: {list(chunks_df.columns)}")
print(f"\nChunk types:")
print(chunks_df["chunk_type"].value_counts())
print(f"\nLanguages:")
print(chunks_df["lang"].value_counts())

Loading existing chunks from C:\Users\Artem Khakimov\Desktop\Projects\LEXam_kaggle\Omnilex-Agentic-Retrieval-Competition\data\processed\chunks.parquet

Chunks loaded: 2161111
Columns: ['chunk_id', 'chunk_type', 'group_id', 'lang', 'text_raw']

Chunk types:
chunk_type
case    1985178
law      175933
Name: count, dtype: int64

Languages:
lang
de         1308094
fr          636918
it          109091
unknown     107008
Name: count, dtype: int64


In [9]:
# Sample chunks
print("Sample chunks:")
chunks_df.head(10)

Sample chunks:


Unnamed: 0,chunk_id,chunk_type,group_id,lang,text_raw
0,Art. 1 112,law,code:112,de,Die Einwohnergemeinde Bern tritt der Schweizer...
1,Art. 2 112,law,code:112,de,Die Einwohnergemeinde Bern wird ferner der Sch...
2,Art. 3 Abs. 1 112,law,code:112,de,1 Falls die Schweizerische Eidgenossenschaft z...
3,Art. 3 Abs. 2 112,law,code:112,de,2 Durch Anlage des neuen Verwaltungsgeb√§udes a...
4,Art. 4 Abs. 1 112,law,code:112,de,1 Die Einwohnergemeinde Bern √ºbernimmt im fern...
5,Art. 4 Abs. 2 112,law,code:112,de,"2 Sie √ºbernimmt auch die Verpflichtung, die er..."
6,Art. 4 Abs. 3 112,law,code:112,de,3 Im Fall die Schweizerische Eidgenossenschaft...
7,Art. 5 Abs. 1 112,law,code:112,de,1 Sollte infolge f√∂rmlichen Beschlusses der ko...
8,Art. 5 Abs. 2 112,law,code:112,de,2 F√ºr den n√§mlichen Fall √ºbernimmt die Schweiz...
9,Art. 8 112,law,code:112,de,Infolge √úbernahme der durch diese √úbereinkunft...


## 4. Step 2: Summarizer - Generate Summaries

In [10]:
summaries_path = PROCESSED_PATH / "summaries.parquet"

if summaries_path.exists() and not CONFIG["force_rebuild"]:
    print(f"Loading existing summaries from {summaries_path}")
    summaries_df = Summarizer.load_summaries(summaries_path)
else:
    print(f"Generating summaries (mode: {CONFIG['summarizer_mode']})...")
    
    # Create summarizer - use llm_client if mode is "llm"
    if CONFIG["summarizer_mode"] == "llm":
        # llm_client was created in imports cell
        summarizer = Summarizer(mode="llm", llm_client=llm_client)
        print(f"Using LLM client for summarization")
    else:
        summarizer = Summarizer(mode="heuristic")
        print("Using heuristic summarization")
    
    # Process in batches for progress tracking
    # Smaller batches for LLM mode (API rate limits)
    batch_size = 100 if CONFIG["summarizer_mode"] == "llm" else 1000
    all_summaries = []
    
    for i in tqdm(range(0, len(chunks_df), batch_size), desc="Summarizing"):
        batch = chunks_df.iloc[i:i+batch_size]
        batch_summaries = summarizer.summarize_all(batch)
        all_summaries.append(batch_summaries)
    
    summaries_df = pd.concat(all_summaries, ignore_index=True)
    summarizer.save(summaries_df, summaries_path)
    print(f"Saved summaries to {summaries_path}")

print(f"\nSummaries generated: {len(summaries_df)}")
print(f"Summary types:")
print(summaries_df["summary_type"].value_counts())

Generating summaries (mode: llm)...
Using LLM client for summarization


Summarizing:   0%|          | 0/21612 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Sample summaries for one chunk
sample_chunk_id = chunks_df["chunk_id"].iloc[0]
print(f"Summaries for: {sample_chunk_id}\n")

for _, row in summaries_df[summaries_df["chunk_id"] == sample_chunk_id].iterrows():
    print(f"--- {row['summary_type'].upper()} ---")
    print(row["summary_text"][:500])
    if row["entities"]:
        print(f"Entities: {row['entities'][:10]}")
    print()

## 5. Step 3: Embedder - Build Indices

In [None]:
embeddings_dir = PROCESSED_PATH / "embeddings"
embeddings_path = embeddings_dir / "embeddings.npy"

if embeddings_path.exists() and not CONFIG["force_rebuild"]:
    print(f"Loading existing embedder from {embeddings_dir}")
    embedder = Embedder.load(embeddings_dir, model_name=CONFIG["embedding_model"])
else:
    print(f"Building embeddings with {CONFIG['embedding_model']}...")
    embedder = Embedder(model_name=CONFIG["embedding_model"])
    
    # Build embeddings from retrieval summaries
    embeddings = embedder.build_embeddings(
        summaries_df,
        summary_type="retrieval",
        batch_size=CONFIG["embedding_batch_size"],
        show_progress=True,
    )
    print(f"Embeddings shape: {embeddings.shape}")
    
    # Build FAISS index
    print(f"Building FAISS index (type: {CONFIG['faiss_index_type']})...")
    embedder.build_faiss_index(embeddings, index_type=CONFIG["faiss_index_type"])
    
    # Build BM25 index
    print("Building BM25 index...")
    embedder.build_bm25_index(summaries_df, summary_type="retrieval")
    
    # Save
    embedder.save(embeddings_dir)
    print(f"Saved embedder to {embeddings_dir}")

print(f"\nIndexed chunks: {len(embedder.get_chunk_ids())}")

In [None]:
# Test search
test_query = "What are the requirements for a valid contract under Swiss law?"

print(f"Test query: {test_query}\n")

print("=== Vector Search ===")
vec_results = embedder.search_vector(test_query, top_k=5)
for chunk_id, score in vec_results:
    print(f"  {score:.4f} | {chunk_id}")

print("\n=== BM25 Search ===")
bm25_results = embedder.search_bm25(test_query, top_k=5)
for chunk_id, score in bm25_results:
    print(f"  {score:.4f} | {chunk_id}")

print("\n=== Hybrid Search (RRF) ===")
hybrid_results = embedder.search_hybrid(test_query, top_k=5)
for chunk_id, score in hybrid_results:
    print(f"  {score:.4f} | {chunk_id}")

## 6. Step 4: GraphBuilder - Build Graph Edges

In [None]:
graph_dir = PROCESSED_PATH / "graph"
edges_similar_path = graph_dir / "edges_similar.parquet"

if edges_similar_path.exists() and not CONFIG["force_rebuild"]:
    print(f"Loading existing graph from {graph_dir}")
    graph_builder = GraphBuilder.load(graph_dir)
else:
    print("Building graph edges...")
    graph_builder = GraphBuilder()
    
    # Build SIMILAR_TO edges from embeddings
    if embedder._embeddings is not None:
        print("Building SIMILAR_TO edges...")
        similar_df = graph_builder.build_similar_edges(
            embedder._embeddings,
            embedder.get_chunk_ids(),
            k=CONFIG["similar_k"],
            min_cos=CONFIG["similar_min_cos"],
        )
        print(f"  SIMILAR_TO edges: {len(similar_df)}")
    
    # Build CO_CITED_WITH edges from training data
    train_path = DATA_PATH / "train.csv"
    if train_path.exists():
        print("Building CO_CITED_WITH edges...")
        train_df = pd.read_csv(train_path)
        cocite_df = graph_builder.build_cocite_edges(
            train_df,
            top_m=CONFIG["cocite_top_m"],
        )
        print(f"  CO_CITED_WITH edges: {len(cocite_df)}")
    
    # Build DocGroup mapping
    print("Building DocGroup mapping...")
    groups_df, chunk_to_group_df = graph_builder.build_groups(chunks_df)
    print(f"  Groups: {len(groups_df)}")
    print(f"  Chunk-to-group mappings: {len(chunk_to_group_df)}")
    
    # Save
    graph_builder.save(graph_dir)
    print(f"Saved graph to {graph_dir}")

print(f"\nGraph statistics:")
print(f"  SIMILAR_TO edges: {sum(len(v) for v in graph_builder.similar_edges.values())}")
print(f"  CO_CITED_WITH edges: {sum(len(v) for v in graph_builder.cocite_edges.values())}")
print(f"  Groups: {len(graph_builder.group_to_chunks)}")

## 7. Full Inference Pipeline (–≠—Ç–∞–ø 1)

Pipeline:
1. **Query preprocessing** (optional LLM summary)
2. **Initial retrieval** (BM25 + FAISS ‚Üí RRF fusion)
3. **Graph expansion** (SIMILAR_TO, CO_CITED_WITH, PART_OF)
4. **Fast scoring** (linear combination: Œ±¬∑retr + Œ≤¬∑sim + Œ≥¬∑cocite + Œ¥¬∑group)
5. **LLM Rerank** (optional, on top K candidates)

In [None]:
# =============================================================================
# Create LLM Reranker (for Etap 1 inference)
# =============================================================================

reranker_config = RerankerConfig(
    top_k_to_rerank=CONFIG["top_k_to_rerank"],
    top_k_final=CONFIG["top_k_final"],
    batch_size=CONFIG["rerank_batch_size"],
    use_llm=CONFIG["use_llm_rerank"],
    relevance_threshold=CONFIG["relevance_threshold"],
    max_text_length=300,
)

# Use llm_client created earlier (Google Gemini or OpenAI)
reranker = LLMReranker(
    llm_client=llm_client if CONFIG["use_llm_rerank"] else None,
    config=reranker_config,
)

print(f"‚úÖ LLM Reranker created:")
print(f"   use_llm: {reranker_config.use_llm}")
print(f"   top_k_to_rerank: {reranker_config.top_k_to_rerank}")
print(f"   batch_size: {reranker_config.batch_size}")
print(f"   relevance_threshold: {reranker_config.relevance_threshold}")

In [None]:
def run_inference(
    query: str,
    embedder: Embedder,
    graph_builder: GraphBuilder,
    config: dict,
    reranker: LLMReranker | None = None,
    chunks_df: pd.DataFrame | None = None,
    summaries_df: pd.DataFrame | None = None,
    verbose: bool = True,
) -> list[tuple[str, float, dict]]:
    """
    Run full inference pipeline for a single query.
    
    Pipeline:
    1. Query preprocessing (optional LLM summary)
    2. Initial retrieval (BM25 + FAISS ‚Üí RRF)
    3. Graph expansion (SIMILAR_TO, CO_CITED_WITH, PART_OF)
    4. Fast scoring (linear combination)
    5. LLM Rerank (optional, on top K candidates)
    
    Returns:
        List of (chunk_id, score, features) tuples
    """
    # Step 0: Query preprocessing
    processed_query = query
    if config.get("use_query_summary") and reranker is not None:
        processed_query = reranker.summarize_query(query)
        if verbose:
            print(f"Query summary: {processed_query[:100]}...")
    
    # Step 1: Initial retrieval (hybrid search)
    if verbose:
        print(f"Query: {query[:80]}{'...' if len(query) > 80 else ''}\n")
        print("Step 1: Initial retrieval (BM25 + FAISS ‚Üí RRF)...")
    
    initial_results = embedder.search_hybrid(
        processed_query,
        top_k=config["top_k_retrieval"],
    )
    
    if verbose:
        print(f"  ‚Üí {len(initial_results)} initial candidates")
        for cid, score in initial_results[:3]:
            print(f"    {score:.4f} | {cid[:60]}")
        if len(initial_results) > 3:
            print(f"    ... and {len(initial_results) - 3} more")
    
    # Step 2: Graph expansion
    if verbose:
        print("\nStep 2: Graph expansion...")
    
    expansion_params = ExpansionParams(
        k_expand_sim=config["k_expand_sim"],
        k_expand_cocite=config["k_expand_cocite"],
        k_expand_siblings=config["k_expand_siblings"],
        max_candidates=config["max_candidates"],
    )
    
    expanded = graph_builder.expand_candidates(initial_results, expansion_params)
    
    if verbose:
        print(f"  ‚Üí {len(expanded)} candidates after expansion")
        # Count by expansion reason
        reasons = {}
        for _, _, reason in expanded:
            key = reason.split("(")[0].strip()
            reasons[key] = reasons.get(key, 0) + 1
        for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
            print(f"    {reason}: {count}")
    
    # Step 3: Fast scoring (linear combination)
    if verbose:
        print("\nStep 3: Fast scoring...")
    
    scoring_params = ScoringParams(
        alpha=config["alpha"],
        beta=config["beta"],
        gamma=config["gamma"],
        delta=config["delta"],
    )
    
    initial_set = set(cid for cid, _ in initial_results)
    scored = graph_builder.score_candidates(expanded, initial_set, scoring_params)
    
    if verbose:
        print(f"  ‚Üí Top scores: {[f'{s[1]:.3f}' for s in scored[:5]]}")
    
    # Step 4: LLM Rerank (optional)
    if config.get("use_llm_rerank") and reranker is not None:
        if verbose:
            print(f"\nStep 4: LLM Rerank (top {config['top_k_to_rerank']} candidates)...")
        
        final_results = reranker.rerank(
            query=query,
            candidates=scored[:config["top_k_to_rerank"]],
            chunks_df=chunks_df,
            summaries_df=summaries_df,
        )
        
        if verbose:
            print(f"  ‚Üí {len(final_results)} after rerank + threshold filter")
    else:
        # No LLM rerank: just take top K
        final_results = scored[:config["top_k_final"]]
    
    # Step 5: Final output
    final_results = final_results[:config["top_k_final"]]
    
    if verbose:
        print(f"\n{'='*60}")
        print(f"FINAL: Top {len(final_results)} results:")
        print(f"{'='*60}")
        for i, (cid, score, features) in enumerate(final_results[:10], 1):
            llm_info = f" llm={features.get('llm_score', '-'):.2f}" if 'llm_score' in features else ""
            print(f"  {i:2d}. [{score:.3f}] {cid[:55]}...{llm_info}")
        if len(final_results) > 10:
            print(f"  ... and {len(final_results) - 10} more")
    
    return final_results

In [None]:
# Test inference on a sample query
test_query = "What are the requirements for a valid contract under Swiss law?"

print("üîç Running full inference pipeline...")
print("=" * 60)

results = run_inference(
    query=test_query, 
    embedder=embedder, 
    graph_builder=graph_builder, 
    config=CONFIG,
    reranker=reranker,
    chunks_df=chunks_df,
    summaries_df=summaries_df,
    verbose=True,
)

In [None]:
# Format as submission
def format_predictions(results: list[tuple[str, float, dict]]) -> str:
    """Format results as semicolon-separated citations."""
    return ";".join([chunk_id for chunk_id, _, _ in results])

prediction = format_predictions(results)
print(f"Prediction: {prediction}")

## 8. Evaluate on Validation Set

In [None]:
# Load validation set
val_path = DATA_PATH / "val.csv"

if val_path.exists():
    val_df = pd.read_csv(val_path)
    print(f"Validation set: {len(val_df)} queries")
    display(val_df.head())
else:
    print(f"Validation file not found: {val_path}")
    val_df = None

In [None]:
if val_df is not None:
    from omnilex.evaluation.metrics import citation_f1, macro_f1
    
    predictions = []
    gold_list = []
    
    print(f"Running inference on {len(val_df)} validation queries...")
    print(f"LLM Rerank: {'ON' if CONFIG['use_llm_rerank'] else 'OFF'}")
    print("=" * 60)
    
    for idx, row in tqdm(val_df.iterrows(), total=len(val_df), desc="Inference"):
        query = row["query"]
        gold = row.get("gold_citations", "")
        
        # Run inference with reranker
        results = run_inference(
            query=query, 
            embedder=embedder, 
            graph_builder=graph_builder, 
            config=CONFIG,
            reranker=reranker,
            chunks_df=chunks_df,
            summaries_df=summaries_df,
            verbose=False,
        )
        
        # Format prediction
        pred_citations = [chunk_id for chunk_id, _, _ in results]
        gold_citations = [c.strip() for c in str(gold).split(";") if c.strip()]
        
        predictions.append(pred_citations)
        gold_list.append(gold_citations)
        
        # Per-query metrics
        metrics = citation_f1(pred_citations, gold_citations)
        print(f"\n[{idx+1}] Query: {query[:60]}...")
        print(f"    Pred: {len(pred_citations)} | Gold: {len(gold_citations)}")
        print(f"    P={metrics['precision']:.3f} R={metrics['recall']:.3f} F1={metrics['f1']:.3f}")
        
        # Show matches
        matches = set(pred_citations) & set(gold_citations)
        if matches:
            print(f"    ‚úÖ Matches: {list(matches)[:3]}{'...' if len(matches) > 3 else ''}")
    
    # Overall metrics
    overall = macro_f1(predictions, gold_list)
    print("\n" + "=" * 60)
    print("üìä OVERALL MACRO METRICS:")
    print("=" * 60)
    print(f"  Precision: {overall['macro_precision']:.4f}")
    print(f"  Recall:    {overall['macro_recall']:.4f}")
    print(f"  F1:        {overall['macro_f1']:.4f}")

## 9. Generate Submission

In [None]:
# Load test set and generate submission
test_path = DATA_PATH / "test.csv"

if test_path.exists():
    test_df = pd.read_csv(test_path)
    print(f"üìù Test set: {len(test_df)} queries")
    print(f"   LLM Rerank: {'ON' if CONFIG['use_llm_rerank'] else 'OFF'}")
    print("=" * 60)
    
    submission_records = []
    
    for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Generating submission"):
        query_id = row["query_id"]
        query = row["query"]
        
        results = run_inference(
            query=query, 
            embedder=embedder, 
            graph_builder=graph_builder, 
            config=CONFIG,
            reranker=reranker,
            chunks_df=chunks_df,
            summaries_df=summaries_df,
            verbose=False,
        )
        
        prediction = format_predictions(results)
        
        submission_records.append({
            "query_id": query_id,
            "predicted_citations": prediction,
        })
    
    submission_df = pd.DataFrame(submission_records)
    
    # Save submission
    submission_path = OUTPUT_PATH / "submission.csv"
    submission_df.to_csv(submission_path, index=False)
    print(f"\n‚úÖ Submission saved to: {submission_path}")
    print(f"   Total queries: {len(submission_df)}")
    
    display(submission_df.head())
else:
    print(f"‚ùå Test file not found: {test_path}")

## 10. Debug: Inspect Graph Neighbors

In [None]:
def inspect_chunk(chunk_id: str, graph_builder: GraphBuilder, chunks_df: pd.DataFrame):
    """Inspect a chunk's graph neighborhood."""
    print(f"=== Chunk: {chunk_id} ===")
    
    # Basic info
    chunk_row = chunks_df[chunks_df["chunk_id"] == chunk_id]
    if not chunk_row.empty:
        row = chunk_row.iloc[0]
        print(f"Type: {row['chunk_type']}")
        print(f"Group: {row['group_id']}")
        print(f"Lang: {row['lang']}")
        print(f"Text: {row['text_raw'][:200]}...")
    
    # Similar neighbors
    print(f"\nSIMILAR_TO neighbors:")
    neighbors = graph_builder.similar_edges.get(chunk_id, [])
    for n_id, score in neighbors[:5]:
        print(f"  {score:.4f} | {n_id}")
    if len(neighbors) > 5:
        print(f"  ... and {len(neighbors) - 5} more")
    
    # Co-cited neighbors
    print(f"\nCO_CITED_WITH neighbors:")
    neighbors = graph_builder.cocite_edges.get(chunk_id, [])
    for n_id, weight in neighbors[:5]:
        print(f"  {weight:.4f} | {n_id}")
    if len(neighbors) > 5:
        print(f"  ... and {len(neighbors) - 5} more")
    
    # Group siblings
    group_id = graph_builder.chunk_to_group.get(chunk_id)
    if group_id:
        print(f"\nPART_OF group: {group_id}")
        siblings = graph_builder.group_to_chunks.get(group_id, [])
        print(f"Siblings ({len(siblings)} total):")
        for sib in siblings[:5]:
            if sib != chunk_id:
                print(f"  {sib}")

# Example usage
if len(chunks_df) > 0:
    sample_id = chunks_df["chunk_id"].iloc[0]
    inspect_chunk(sample_id, graph_builder, chunks_df)

## 11. Parameter Tuning

In [None]:
def evaluate_config(
    config: dict, 
    val_df: pd.DataFrame,
    embedder: Embedder,
    graph_builder: GraphBuilder,
    reranker: LLMReranker | None = None,
    chunks_df: pd.DataFrame | None = None,
    summaries_df: pd.DataFrame | None = None,
) -> dict:
    """Evaluate a configuration on validation set."""
    from omnilex.evaluation.metrics import macro_f1
    
    predictions = []
    gold_list = []
    
    for _, row in val_df.iterrows():
        query = row["query"]
        gold = row.get("gold_citations", "")
        
        results = run_inference(
            query=query,
            embedder=embedder,
            graph_builder=graph_builder,
            config=config,
            reranker=reranker,
            chunks_df=chunks_df,
            summaries_df=summaries_df,
            verbose=False,
        )
        
        pred_citations = [chunk_id for chunk_id, _, _ in results]
        gold_citations = [c.strip() for c in str(gold).split(";") if c.strip()]
        
        predictions.append(pred_citations)
        gold_list.append(gold_citations)
    
    return macro_f1(predictions, gold_list)


# =============================================================================
# Grid Search: Scoring Weights (Œ±, Œ≤, Œ≥, Œ¥)
# =============================================================================

if val_df is not None and len(val_df) <= 20:  # Only for small val set
    print("üîß Grid search over scoring weights...")
    print("   (LLM Rerank OFF for speed)")
    print("=" * 60)
    
    best_f1 = 0
    best_params = None
    results_table = []
    
    # Temporarily disable LLM rerank for faster grid search
    test_reranker = LLMReranker(llm_client=None, config=RerankerConfig(use_llm=False))
    
    for alpha in [0.5, 1.0, 1.5]:
        for beta in [0.3, 0.6, 0.9]:
            for gamma in [0.4, 0.8, 1.2]:
                test_config = CONFIG.copy()
                test_config["alpha"] = alpha
                test_config["beta"] = beta
                test_config["gamma"] = gamma
                test_config["use_llm_rerank"] = False
                
                metrics = evaluate_config(
                    test_config, val_df, embedder, graph_builder,
                    reranker=test_reranker,
                    chunks_df=chunks_df,
                    summaries_df=summaries_df,
                )
                f1 = metrics["macro_f1"]
                
                results_table.append({
                    "Œ±": alpha, "Œ≤": beta, "Œ≥": gamma, 
                    "F1": f1, "P": metrics["macro_precision"], "R": metrics["macro_recall"]
                })
                
                if f1 > best_f1:
                    best_f1 = f1
                    best_params = (alpha, beta, gamma)
                    print(f"  ‚ú® NEW BEST: Œ±={alpha}, Œ≤={beta}, Œ≥={gamma} ‚Üí F1={f1:.4f}")
    
    print("\n" + "=" * 60)
    print(f"üèÜ BEST: Œ±={best_params[0]}, Œ≤={best_params[1]}, Œ≥={best_params[2]}")
    print(f"   Macro F1 = {best_f1:.4f}")
    
    # Show top 5 configs
    results_sorted = sorted(results_table, key=lambda x: -x["F1"])[:5]
    print("\nTop 5 configurations:")
    for r in results_sorted:
        print(f"  Œ±={r['Œ±']}, Œ≤={r['Œ≤']}, Œ≥={r['Œ≥']} ‚Üí F1={r['F1']:.4f} (P={r['P']:.3f}, R={r['R']:.3f})")
else:
    print("‚è≠Ô∏è Skipping grid search (val set too large or missing)")

In [None]:
# =============================================================================
# Ablation Study: With vs Without LLM Rerank
# =============================================================================

if val_df is not None and len(val_df) <= 20:
    print("üî¨ Ablation Study: LLM Rerank Impact")
    print("=" * 60)
    
    # Test WITHOUT LLM Rerank
    print("\n1Ô∏è‚É£ Without LLM Rerank (fast scoring only):")
    config_no_llm = CONFIG.copy()
    config_no_llm["use_llm_rerank"] = False
    
    no_llm_reranker = LLMReranker(llm_client=None, config=RerankerConfig(use_llm=False))
    metrics_no_llm = evaluate_config(
        config_no_llm, val_df, embedder, graph_builder,
        reranker=no_llm_reranker,
        chunks_df=chunks_df,
        summaries_df=summaries_df,
    )
    print(f"   P={metrics_no_llm['macro_precision']:.4f} R={metrics_no_llm['macro_recall']:.4f} F1={metrics_no_llm['macro_f1']:.4f}")
    
    # Test WITH LLM Rerank
    print("\n2Ô∏è‚É£ With LLM Rerank:")
    config_with_llm = CONFIG.copy()
    config_with_llm["use_llm_rerank"] = True
    
    metrics_with_llm = evaluate_config(
        config_with_llm, val_df, embedder, graph_builder,
        reranker=reranker,
        chunks_df=chunks_df,
        summaries_df=summaries_df,
    )
    print(f"   P={metrics_with_llm['macro_precision']:.4f} R={metrics_with_llm['macro_recall']:.4f} F1={metrics_with_llm['macro_f1']:.4f}")
    
    # Delta
    delta_f1 = metrics_with_llm['macro_f1'] - metrics_no_llm['macro_f1']
    print("\n" + "=" * 60)
    print(f"üìä LLM Rerank Impact: ŒîF1 = {delta_f1:+.4f}")
    if delta_f1 > 0:
        print(f"   ‚úÖ LLM Rerank improves F1 by {delta_f1*100:.1f}%")
    else:
        print(f"   ‚ö†Ô∏è LLM Rerank decreases F1 by {abs(delta_f1)*100:.1f}%")
else:
    print("‚è≠Ô∏è Skipping ablation study (val set too large or missing)")