# Agentic Retrieval Baseline for Omnilex Legal Retrieval

This notebook implements an **agentic retrieval approach** using a ReAct-style agent with search tools.

## Approach
1. Load a local LLM (GGUF format via llama-cpp-python)
2. Build BM25 search indices for laws and court decisions
3. Create search tools the agent can use
4. For each query, run a ReAct agent that:
   - Reasons about what to search
   - Uses tools to search laws and court decisions
   - Extracts citations from search results
   - Provides final answer with all found citations

## Advantages over Direct Generation
- Grounded in actual legal documents
- Less hallucination of non-existent citations
- Can iterate on searches to find more relevant sources

## Requirements
- llama-cpp-python
- rank-bm25
- A GGUF model file (e.g., Mistral-7B-Instruct)

## 1. Setup & Configuration

In [None]:
import os
import sys
from pathlib import Path

# === CONFIGURATION ===
# Choose which dataset to run on: "val" or "test"
DATASET_MODE = "val"  # Change to "test" for final submission

# Set to True to rebuild indices from CSV (required on first run)
# Set to False to load cached indices (faster for subsequent runs)
FORCE_REBUILD_INDICES = False

# Detect environment
KAGGLE_ENV = "KAGGLE_KERNEL_RUN_TYPE" in os.environ

if KAGGLE_ENV:
    # Kaggle paths
    DATA_PATH = Path("/kaggle/input/omnilex-data")
    MODEL_PATH = Path("/kaggle/input/llama-model")
    OUTPUT_PATH = Path("/kaggle/working")
    INDEX_PATH = Path("/kaggle/input/omnilex-indices")
    sys.path.insert(0, "/kaggle/input/omnilex-utils")
else:
    # Local development paths
    REPO_ROOT = Path(".").resolve().parent
    DATA_PATH = REPO_ROOT / "data"
    MODEL_PATH = REPO_ROOT / "models"
    OUTPUT_PATH = REPO_ROOT / "output"
    INDEX_PATH = REPO_ROOT / "data" / "processed"
    sys.path.insert(0, str(REPO_ROOT / "src"))

# CSV corpus files for index building
LAWS_CSV = DATA_PATH / "laws_de.csv"
COURTS_CSV = DATA_PATH / "court_considerations.csv"

# Index cache paths
LAWS_INDEX_PATH = INDEX_PATH / "laws_index.pkl"
COURTS_INDEX_PATH = INDEX_PATH / "courts_index.pkl"

# Derived paths based on DATASET_MODE
QUERY_FILE = DATA_PATH / f"{DATASET_MODE}.csv"
IS_VALIDATION_MODE = DATASET_MODE == "val"

# Create output directory
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
INDEX_PATH.mkdir(parents=True, exist_ok=True)

print(f"Environment: {'Kaggle' if KAGGLE_ENV else 'Local'}")
print(f"Dataset mode: {DATASET_MODE}")
print(f"Query file: {QUERY_FILE}")
print(f"Validation mode: {IS_VALIDATION_MODE}")
print(f"Force rebuild indices: {FORCE_REBUILD_INDICES}")
print(f"\nCorpus files:")
print(f"  Laws CSV: {LAWS_CSV} ({LAWS_CSV.stat().st_size / 1e6:.1f} MB)" if LAWS_CSV.exists() else f"  Laws CSV: {LAWS_CSV} (NOT FOUND)")
print(f"  Courts CSV: {COURTS_CSV} ({COURTS_CSV.stat().st_size / 1e9:.2f} GB)" if COURTS_CSV.exists() else f"  Courts CSV: {COURTS_CSV} (NOT FOUND)")
print(f"\nIndex cache: {INDEX_PATH}")

In [2]:
# Configuration
CONFIG = {
    # Model settings
    "model_file": "mistral-7b-instruct-v0.2.Q4_K_M.gguf",
    "n_ctx": 8192,         # Larger context for agent conversations
    "n_threads": 4,
    "n_gpu_layers": -1,    # GPU layers (-1 = offload all layers to GPU)
    
    # Agent settings
    "max_iterations": 5,   # Max agent iterations per query
    "max_tokens": 512,
    "temperature": 0.1,
    
    # Retrieval settings
    "top_k_laws": 5,       # Results per law search
    "top_k_courts": 5,     # Results per court search
    
    # Paths
    "test_file": "test.csv",
    "train_file": "train.csv",
}

## 2. Load Corpora and Build/Load Indices

In [None]:
import pandas as pd
from tqdm.notebook import tqdm
from omnilex.retrieval.bm25_index import BM25Index


def load_csv_corpus(
    csv_path: Path,
    chunk_size: int = 100_000,
    max_rows: int | None = None
) -> list[dict]:
    """Load CSV corpus into list of dicts with progress bar.
    
    Args:
        csv_path: Path to CSV file with 'citation' and 'text' columns
        chunk_size: Rows to process per chunk (for memory efficiency)
        max_rows: Optional limit on rows (for testing with smaller corpus)
    
    Returns:
        List of {"citation": str, "text": str} dicts
    """
    documents = []
    
    # Count rows for progress bar (fast line count)
    print(f"Counting rows in {csv_path.name}...")
    with open(csv_path, encoding='utf-8') as f:
        total_rows = sum(1 for _ in f) - 1  # minus header
    
    if max_rows:
        total_rows = min(total_rows, max_rows)
    print(f"Total rows to load: {total_rows:,}")
    
    rows_loaded = 0
    with tqdm(total=total_rows, desc=f"Loading {csv_path.name}") as pbar:
        for chunk in pd.read_csv(csv_path, chunksize=chunk_size):
            for _, row in chunk.iterrows():
                if max_rows and rows_loaded >= max_rows:
                    break
                documents.append({
                    "citation": str(row["citation"]),
                    "text": str(row["text"]) if pd.notna(row["text"]) else ""
                })
                rows_loaded += 1
            pbar.update(min(len(chunk), total_rows - pbar.n))
            if max_rows and rows_loaded >= max_rows:
                break
    
    return documents


def get_or_build_index(
    name: str,
    csv_path: Path,
    index_path: Path,
    force_rebuild: bool = False,
    max_rows: int | None = None
) -> BM25Index:
    """Load cached index or build from CSV.
    
    Args:
        name: Index name for logging
        csv_path: Path to corpus CSV
        index_path: Path to cache index pickle
        force_rebuild: If True, rebuild even if cache exists
        max_rows: Optional row limit (for testing with smaller corpus)
    
    Returns:
        BM25Index instance
    """
    # Use cached index if available and not forcing rebuild
    if index_path.exists() and not force_rebuild:
        print(f"Loading cached {name} index from {index_path}")
        index = BM25Index.load(index_path)
        print(f"  Loaded {len(index.documents):,} documents")
        return index
    
    # Check CSV exists
    if not csv_path.exists():
        print(f"Warning: {csv_path} not found. Creating empty index.")
        return BM25Index(documents=[])
    
    # Load corpus from CSV
    print(f"\n{'='*50}")
    print(f"Building {name} index from {csv_path}")
    print(f"{'='*50}")
    documents = load_csv_corpus(csv_path, max_rows=max_rows)
    
    if not documents:
        print(f"Warning: No documents loaded. Creating empty index.")
        return BM25Index(documents=[])
    
    # Build BM25 index
    print(f"\nBuilding BM25 index for {len(documents):,} documents...")
    index = BM25Index(
        documents=documents,
        text_field="text",
        citation_field="citation"
    )
    print(f"Index built successfully!")
    
    # Cache index for future runs
    if not KAGGLE_ENV:
        print(f"Saving index to {index_path}...")
        index.save(index_path)
        print(f"Index cached.")
    
    return index

In [None]:
# Load or build laws index
# Laws CSV: ~45MB, ~269K rows
# Build time: ~30 seconds | Load from cache: <1 second

laws_index = get_or_build_index(
    name="laws",
    csv_path=LAWS_CSV,
    index_path=LAWS_INDEX_PATH,
    force_rebuild=FORCE_REBUILD_INDICES,
    # max_rows=10000  # Uncomment to test with smaller corpus
)
print(f"\nLaws index: {len(laws_index.documents):,} documents")

# Test search
test_results = laws_index.search("Vertrag", top_k=3)
print(f"\nTest search 'Vertrag': {len(test_results)} results")
if test_results:
    print(f"  Top result: {test_results[0].get('citation', 'N/A')}")

In [None]:
# Load or build courts index
# Courts CSV: ~2.3GB, ~2.5M rows
# Build time: ~15-20 minutes | Load from cache: ~10 seconds
# Peak memory during build: ~8-10GB

courts_index = get_or_build_index(
    name="courts",
    csv_path=COURTS_CSV,
    index_path=COURTS_INDEX_PATH,
    force_rebuild=FORCE_REBUILD_INDICES,
    # max_rows=100000  # Uncomment to test with smaller corpus
)
print(f"\nCourts index: {len(courts_index.documents):,} documents")

# Test search
test_results = courts_index.search("Meinungsfreiheit", top_k=3)
print(f"\nTest search 'Meinungsfreiheit': {len(test_results)} results")
if test_results:
    print(f"  Top result: {test_results[0].get('citation', 'N/A')}")

## 3. Define Search Tools

In [6]:
from omnilex.retrieval.tools import LawSearchTool, CourtSearchTool

# Create tools
law_tool = LawSearchTool(
    index=laws_index,
    top_k=CONFIG["top_k_laws"],
    max_excerpt_length=300,
)

court_tool = CourtSearchTool(
    index=courts_index,
    top_k=CONFIG["top_k_courts"],
    max_excerpt_length=300,
)

# Tool registry
TOOLS = {
    "search_laws": law_tool,
    "search_courts": court_tool,
}

print("Tools registered:")
for name, tool in TOOLS.items():
    print(f"  - {name}: {tool.description.split(chr(10))[0]}")

Tools registered:
  - search_laws: Search Swiss federal laws (SR/Systematische Rechtssammlung) by keywords.
  - search_courts: Search Swiss Federal Court decisions (BGE) by keywords.


In [7]:
# Test tools
print("Testing law search:")
print(law_tool("Vertrag Abschluss"))

print("\nTesting court search:")
print(court_tool("Meinungsfreiheit"))

Testing law search:
No relevant federal laws found for: 'Vertrag Abschluss'

Testing court search:
No relevant court decisions found for: 'Meinungsfreiheit'


## 4. Load Local LLM

In [8]:
from llama_cpp import Llama
from omnilex.llm import has_cuda_support, get_device_info

# Find model file
model_file = MODEL_PATH / CONFIG["model_file"]

if not model_file.exists():
    gguf_files = list(MODEL_PATH.glob("*.gguf")) + list(MODEL_PATH.rglob("*.gguf"))
    if gguf_files:
        model_file = gguf_files[0]
        print(f"Using model: {model_file}")
    else:
        raise FileNotFoundError(
            f"No model found. Please download a GGUF model to {MODEL_PATH}"
        )

print(f"Loading model: {model_file}")

# Auto-detect GPU: use GPU if available, else CPU
n_gpu_layers = CONFIG["n_gpu_layers"]
if n_gpu_layers == -1 and not has_cuda_support():
    n_gpu_layers = 0  # Fallback to CPU if no CUDA support

llm = Llama(
    model_path=str(model_file),
    n_ctx=CONFIG["n_ctx"],
    n_threads=CONFIG["n_threads"],
    n_gpu_layers=n_gpu_layers,
    verbose=False,
)

print("Model loaded successfully!")
print(f"Running on: {get_device_info(n_gpu_layers)}")

Loading model: /home/arijo/Omnilex-Agentic-Retrieval-Competition/models/mistral-7b-instruct-v0.2.Q4_K_M.gguf


llama_context: n_ctx_per_seq (8192) < n_ctx_train (32768) -- the full capacity of the model will not be utilized


Model loaded successfully!
Running on: GPU (all layers offloaded)


## 5. Define ReAct Agent

In [9]:
import re

AGENT_SYSTEM_PROMPT = """You are a Swiss legal research assistant with access to two search tools:

1. search_laws(query): Search Swiss federal laws (SR/Systematische Rechtssammlung) by keywords
   - Returns relevant law provisions with citations and text excerpts
   - Use for finding statutory law: codes, acts, ordinances

2. search_courts(query): Search Swiss Federal Court decisions (BGE/Bundesgerichtsentscheide) by keywords
   - Returns relevant case law with citations and excerpts
   - Use for finding judicial interpretations and precedents

Your task is to find ALL relevant Swiss legal citations for the given query.

Instructions:
- Search BOTH laws AND court decisions for comprehensive results
- Use multiple search queries if needed (different terms, German/English)
- Extract citations in standard format: SR XXX Art. Y or BGE XXX YY ZZZ
- Continue searching until you have found all relevant sources

Format your response as:
Thought: [Your reasoning about what to search next]
Action: [tool_name]
Action Input: [search query]

After receiving results, either continue searching or provide final answer:
Final Answer: [List of all found citations, one per line]

Remember: Always search both laws AND court decisions before giving your final answer."""


def parse_agent_action(response: str):
    """Parse action and input from agent response."""
    action_match = re.search(r"Action:\s*(\w+)", response, re.IGNORECASE)
    input_match = re.search(r"Action Input:\s*(.+?)(?:\n|$)", response, re.IGNORECASE)
    
    if action_match and input_match:
        return action_match.group(1).strip(), input_match.group(1).strip()
    return None, None


def extract_citations_from_text(text: str) -> list[str]:
    """Extract citations from any text (tool output or final answer)."""
    citations = []
    
    # SR pattern: SR followed by number (optionally with article)
    sr_matches = re.findall(
        r"SR\s*\d{3}(?:\.\d+)?(?:\s+Art\.?\s*\d+[a-z]?)?",
        text,
        re.IGNORECASE
    )
    citations.extend(sr_matches)
    
    # BGE pattern: BGE volume section page
    bge_matches = re.findall(
        r"BGE\s+\d{1,3}\s+[IVX]+[a-z]?\s+\d+(?:\s+E\.\s*\d+[a-z]?)?",
        text,
        re.IGNORECASE
    )
    citations.extend(bge_matches)
    
    # Art. pattern: Art. X LAW (e.g., Art. 1 ZGB, Art. 41 OR)
    art_matches = re.findall(
        r"Art\.?\s+\d+[a-z]?\s+(?:Abs\.?\s*\d+\s+)?[A-Z]{2,}",
        text,
        re.IGNORECASE
    )
    citations.extend(art_matches)
    
    return list(set(citations))


def run_agent(query: str, verbose: bool = False) -> list[str]:
    """Run ReAct agent to retrieve citations."""
    # Format with Mistral Instruct tags
    conversation = f"[INST] {AGENT_SYSTEM_PROMPT}\n\nQuery: {query}\n\nThought: [/INST]"
    all_citations = []
    
    for iteration in range(CONFIG["max_iterations"]):
        # Get LLM response
        response = llm(
            conversation,
            max_tokens=CONFIG["max_tokens"],
            temperature=CONFIG["temperature"],
            stop=["Observation:", "[INST]", "</s>"],
        )["choices"][0]["text"]
        
        # For subsequent turns, we need to handle the conversation format
        if iteration == 0:
            conversation = f"[INST] {AGENT_SYSTEM_PROMPT}\n\nQuery: {query} [/INST]\n\nThought:{response}"
        else:
            conversation += response
        
        if verbose:
            print(f"\n[Iteration {iteration + 1}]")
            print(response[:500])
        
        # Check for final answer
        if "Final Answer:" in response:
            final_text = response.split("Final Answer:")[-1].strip()
            citations = extract_citations_from_text(final_text)
            all_citations.extend(citations)
            break
        
        # Parse and execute action
        action, action_input = parse_agent_action(response)
        
        if action and action_input:
            action_lower = action.lower()
            
            if action_lower in TOOLS:
                observation = TOOLS[action_lower](action_input)
                
                # Extract citations from observation
                obs_citations = extract_citations_from_text(observation)
                all_citations.extend(obs_citations)
                
                conversation += f"\nObservation: {observation}\n\n[INST] Continue your analysis. [/INST]\n\nThought:"
                
                if verbose:
                    print(f"\n[Tool: {action}]")
                    print(observation[:300])
            else:
                conversation += f"\nObservation: Unknown tool '{action}'. Available: search_laws, search_courts\n\n[INST] Continue. [/INST]\n\nThought:"
        else:
            # No action found, try to extract from response anyway
            citations = extract_citations_from_text(response)
            all_citations.extend(citations)
            break
    
    # Deduplicate
    return list(set(all_citations))

In [10]:
# Test agent with a sample query
test_query = "What are the requirements for a valid contract under Swiss law?"
print(f"Query: {test_query}")
print("\nRunning agent...\n")

citations = run_agent(test_query, verbose=True)

print("\n" + "="*50)
print("Found citations:")
for c in citations:
    print(f"  - {c}")

Query: What are the requirements for a valid contract under Swiss law?

Running agent...


[Iteration 1]
 To find the requirements for a valid contract under Swiss law, I will first search for relevant provisions in the Swiss federal laws using the term "Vertrag" which is the German word for contract.

Action: search_laws
Action Input: Vertrag

[Waiting for results]

Thought: The search results from the laws might not be comprehensive, so I will also look for relevant court decisions interpreting the requirements for a valid contract in Swiss law.

Action: search_courts
Action Input: Valid contract

Found citations:
  - Art. 1
BGE
  - SR 111.2 Art. 1
  - BGE 135 I 123
  - BGE 123 IV 567
  - Art. 31
SR
  - SR 111.1 Art. 31


## 6. Load Test Data

In [None]:
import pandas as pd

# Load queries from the configured query file
if not QUERY_FILE.exists():
    raise FileNotFoundError(f"Query file not found: {QUERY_FILE}")

test_df = pd.read_csv(QUERY_FILE)

print(f"Loaded {len(test_df)} queries from {QUERY_FILE}")
print(f"Columns: {list(test_df.columns)}")

if IS_VALIDATION_MODE and "gold_citations" in test_df.columns:
    print(f"Gold citations available for evaluation")

test_df.head()

## 7. Generate Predictions

In [12]:
from tqdm import tqdm
from omnilex.citations.normalizer import CitationNormalizer

# Initialize normalizer
normalizer = CitationNormalizer()

# Generate predictions
predictions = []

for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Running agent"):
    query_id = row["query_id"]
    query_text = row["query"]
    
    # Run agent
    raw_citations = run_agent(query_text, verbose=False)
    
    # Normalize citations
    normalized = normalizer.canonicalize_list(raw_citations)
    
    predictions.append({
        "query_id": query_id,
        "predicted_citations": ";".join(normalized),
    })

print(f"\nGenerated predictions for {len(predictions)} queries")

Running agent:   0%|          | 0/2 [00:00<?, ?it/s]

Running agent: 100%|██████████| 2/2 [00:27<00:00, 13.80s/it]


Generated predictions for 2 queries





In [13]:
# Preview predictions
predictions_df = pd.DataFrame(predictions)
predictions_df.head(10)

Unnamed: 0,query_id,predicted_citations
0,test_001,BGE 143 II 168;BGE 161 II 121;BGE 154 II 135;B...
1,test_002,BGE 113 I 112;BGE 111 II 158;BGE 125 II 167;BG...


## 8. Create Submission

In [14]:
# Save submission
submission_path = OUTPUT_PATH / "submission.csv"
predictions_df.to_csv(submission_path, index=False)

print(f"Submission saved to: {submission_path}")
print(f"Total predictions: {len(predictions_df)}")

# Show sample
print("\nSample submission:")
print(predictions_df.head())

Submission saved to: /home/arijo/Omnilex-Agentic-Retrieval-Competition/output/submission.csv
Total predictions: 2

Sample submission:
   query_id                                predicted_citations
0  test_001  BGE 143 II 168;BGE 161 II 121;BGE 154 II 135;B...
1  test_002  BGE 113 I 112;BGE 111 II 158;BGE 125 II 167;BG...


## 9. Local Evaluation (Optional)

In [None]:
# Evaluate if in validation mode with gold labels
if IS_VALIDATION_MODE and "gold_citations" in test_df.columns:
    from omnilex.evaluation import evaluate_submission
    
    # Join predictions with gold citations from the same file
    eval_df = predictions_df.merge(
        test_df[["query_id", "gold_citations"]],
        on="query_id",
        how="inner"
    )
    
    if len(eval_df) > 0:
        scores = evaluate_submission(
            eval_df[["query_id", "predicted_citations"]],
            eval_df[["query_id", "gold_citations"]],
        )
        
        print("\n" + "="*50)
        print("EVALUATION RESULTS")
        print("="*50)
        print(f"Queries evaluated: {len(eval_df)}")
        print(f"\nMacro F1 (PRIMARY): {scores['macro_f1']:.4f}")
        print(f"Macro Precision:    {scores['macro_precision']:.4f}")
        print(f"Macro Recall:       {scores['macro_recall']:.4f}")
        print(f"\nMAP:                {scores['map']:.4f}")
    else:
        print("No overlapping queries for evaluation.")
else:
    print("Skipping evaluation (not in validation mode or no gold labels available)")

## Summary

This agentic retrieval baseline demonstrates a more sophisticated approach:

1. **Tool-augmented generation**: The LLM can search actual legal corpora rather than relying solely on parametric knowledge.

2. **ReAct-style reasoning**: The agent reasons about what to search, executes searches, observes results, and iterates.

3. **Grounded citations**: Citations are extracted from actual search results, reducing hallucination.

4. **Comprehensive search**: The agent searches both laws and court decisions for complete results.

## Potential Improvements

- **Better search**: Use semantic search (embeddings) instead of BM25
- **Query expansion**: Generate multiple search queries in different languages
- **Relevance filtering**: Add a step to verify citations are actually relevant
- **Citation validation**: Check that generated citations exist in the corpus
- **Multi-hop reasoning**: Follow citation chains to find related sources