# ACM ICAIF 2025: Fusion Split-Ensemble Model

Solution to the *Agentic Retrieval Grand Challenge, ACM-ICAIF '25*.

**Databricks Reproducibility Notebook**

### Model Architecture
Three-stage split ensemble for chunk ranking:
- Stage 1 (120B): Local ranking with BM25 fusion, normalized within-split <- Recall
- Stage 2 (120B): Split ~50 candidates into 2x25, pure LLM ranking <- Recall
- Stage 3 (405B): Final rescore on ~20 candidates, pure LLM -> Precision

Architecture for both Document and Chunk ranking:
- Smart retry with ensemble, any task/stage -> emulate multiple judges
    - Adaptive retries: based on response quality, fuse multiple partial answers
    - Forced retries: redo same query again to get more opinions, then fusion
- 4-level semaphore: QUERY → STAGE1_PART → STAGE2_PART → STAGE3
- 4-level parsing: text blocks → reasoning → regex → GPT-4o-mini rescue
- Max-5 chunk splitting with pre-processing at data load

### Core ideas
- Manage attention allocation, avoid long context, avoid crowded chunk-pool
- Stage-specific custom strategy: prompting + model choice + ensemble

**Model design**: Read more [docs/model_fusion_se.md](https://github.com/yourusername/AgentRAG_Public/blob/main/docs/model_fusion_se.md)

## Setup Instructions

1. *Install required packages*
2. *Set environment variables*
3. *Download competition data*
4. *Update paths*: in Configuration cell (DATA_DIR, OUTPUT_DIR) if needed
5. *Run all cells*

In [None]:
# Install required packages
# %pip install bm25s pandas jsonlines tqdm httpx tiktoken kaggle python-dotenv

# Restart Python kernel
# dbutils.library.restartPython()

In [None]:
# In Databricks: set your keys in using secrets

# from databricks.sdk import WorkspaceClient
# w = WorkspaceClient()
# w.secrets.put_secret(
#     "icaif",
#     "databricks_token",
#     string_value="your_token_here"
# )

In [None]:
# Set environment variables
import os

# Detect environment: Databricks or local
try:
    # Try to access dbutils - only available in Databricks
    dbutils.secrets.get("icaif", "databricks_token")
    is_databricks = True
    print("Running in Databricks environment")
except:
    is_databricks = False
    print("Running in local environment")

if is_databricks:
    # Use Databricks secrets (recommended for production)
    os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get("icaif", "databricks_token")
    os.environ["OPENAI_API_KEY"] = dbutils.secrets.get("icaif", "openai_api_key")
    
    # Kaggle credentials (for data download)
    os.environ["KAGGLE_USERNAME"] = dbutils.secrets.get("icaif", "kaggle_username")
    os.environ["KAGGLE_KEY"] = dbutils.secrets.get("icaif", "kaggle_key")
else:
    # Load from .env file for local execution
    from dotenv import load_dotenv
    load_dotenv()
    print("Loaded environment variables from .env file")

# Verify environment variables are set
print("\nEnvironment variables configured:")
print(f"  DATABRICKS_TOKEN: {len(os.environ.get('DATABRICKS_TOKEN', ''))} characters")
print(f"  OPENAI_API_KEY: {len(os.environ.get('OPENAI_API_KEY', ''))} characters")
print(f"  KAGGLE_USERNAME: {os.environ.get('KAGGLE_USERNAME', 'NOT SET')}")
print(f"  KAGGLE_KEY: {len(os.environ.get('KAGGLE_KEY', ''))} characters")

Running in local environment
Loaded environment variables from .env file

Environment variables configured:
  DATABRICKS_TOKEN: 36 characters
  OPENAI_API_KEY: 164 characters
  KAGGLE_USERNAME: pandalikematcha
  KAGGLE_KEY: 32 characters


# Configuration

- Modify paths (DATA_DIR, OUTPUT_DIR) in Configuration cell below
- Set SAMPLE_SIZE for testing (None for full production run of 400 queries)

In [3]:
import asyncio
import csv
import json
import math
import os
import random
import re
import time
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import bm25s
import httpx
import jsonlines
import pandas as pd
import tiktoken
from tqdm.asyncio import tqdm_asyncio

# Suppress warnings
warnings.filterwarnings('ignore')

print("Imports complete")

  from .autonotebook import tqdm as notebook_tqdm


Imports complete


In [4]:
# =============================================================================
# CONFIGURATION - Modify these settings for your Databricks environment
# =============================================================================

# Model Configuration (3-stage vertical ensemble)
MODEL_STAGE1 = "databricks-gpt-oss-120b"                       # Stage 1: Local ranking
MODEL_STAGE2 = "databricks-meta-llama-3-1-405b-instruct"       # Stage 2: Split rescore
MODEL_STAGE3 = "databricks-meta-llama-3-1-405b-instruct"       # Stage 3: Final rescore

# Databricks Serving Endpoint
DATABRICKS_SERVING_ENDPOINT = os.environ.get(
    "DATABRICKS_SERVING_ENDPOINT",
    "https://dbc-e650c56f-0e0e.cloud.databricks.com/serving-endpoints")

# 4-Level Concurrency Control
QUERY_SEMAPHORE = 5          # Level 1: Max concurrent queries
STAGE1_PART_SEMAPHORE = 2    # Level 2: Max concurrent Stage 1 parts per query
STAGE2_PART_SEMAPHORE = 2    # Level 3: Max concurrent Stage 2 parts per query
STAGE3_SEMAPHORE = 2         # Level 4: Max concurrent Stage 3 calls

# Fusion and Splitting Parameters
FUSION_WEIGHT_STAGE1 = 0.7   # Stage 1: 70% LLM semantic, 30% BM25 lexical
STAGE2_SPLIT_COUNT = 2       # Split 50 candidates into 2 parts (25 each)
STAGE2_K_PER_PART = 10       # Extract top-10 from each part (~20 for Stage 3)
TARGET_TOKENS_PER_PART = 15000  # Target for chunk-based splitting
FIXED_LOCAL_K = 10           # Fixed K for local ranking

# Timing Parameters (rate limit management)
DOC_STAGGER_INTERVAL = 2.0   # seconds between document query starts
CHUNK_STAGGER_INTERVAL = 3.0 # seconds between chunk query starts
STAGE1_JITTER_MAX = 25.0     # Stage 1 jitter (per-part)
STAGE2_JITTER_MAX = 15.0     # Stage 2 jitter (per-part)
STAGE3_JITTER_MAX = 15.0     # Stage 3 jitter (single call)

# =============================================================================
# PATHS - Configure for your Databricks environment
# =============================================================================

# Default paths for local testing
DATA_DIR = "./data/raw"
OUTPUT_DIR = "./output"

# For Databricks, uncomment and modify these paths:
# DATA_DIR = "/dbfs/mnt/data"
# OUTPUT_DIR = "/dbfs/mnt/output"

# =============================================================================
# SAMPLE SIZE - Set to None for full production run (400 queries)
# =============================================================================

SAMPLE_SIZE = 2  # Small test sample (set to None for production)

# Create output directory if it doesn't exist
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

print("Configuration loaded successfully")
print(f"  Models: {MODEL_STAGE1} / {MODEL_STAGE2} / {MODEL_STAGE3}")
print(f"  Data directory: {DATA_DIR}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Sample size: {SAMPLE_SIZE if SAMPLE_SIZE else 'Full production (400 queries)'}")

Configuration loaded successfully
  Models: databricks-gpt-oss-120b / databricks-meta-llama-3-1-405b-instruct / databricks-meta-llama-3-1-405b-instruct
  Data directory: ./data/raw
  Output directory: ./output
  Sample size: 2


# Download competition data

In [5]:
# === Kaggle Data Download ===
import os, zipfile
from pathlib import Path
from kaggle.api.kaggle_api_extended import KaggleApi

def download_data(data_dir="./data/raw"):
    """Download competition eval datasets (200 doc + 200 chunk queries)"""
    if not os.getenv('KAGGLE_USERNAME') or not os.getenv('KAGGLE_KEY'):
        print("ERROR: Set KAGGLE_USERNAME and KAGGLE_KEY environment variables")
        return False
    
    api = KaggleApi()
    api.authenticate()
    
    comp = "acm-icaif-25-ai-agentic-retrieval-grand-challenge"
    data_path = Path(data_dir)
    data_path.mkdir(parents=True, exist_ok=True)
    
    print(f"Downloading {comp}...")
    api.competition_download_files(comp, path=data_path)
    
    # Extract
    zip_file = data_path / f"{comp}.zip"
    if zip_file.exists():
        with zipfile.ZipFile(zip_file, 'r') as z:
            z.extractall(data_path)
        zip_file.unlink()
    
    # Show files
    files = sorted(data_path.glob("*.jsonl"))
    print(f"Downloaded {len(files)} files: {[f.name for f in files]}")
    return True

# Uncomment to run:
# download_data(DATA_DIR)

# Utilities

In [6]:
class APITracker:
    """
    Track API statistics during model execution.

    Thread-safe for concurrent async operations (CPython GIL protection).
    """

    def __init__(self):
        """Initialize tracker with zero counters."""
        self.stats = {
            'total_calls': 0,
            'total_time': 0.0,
            'document_calls': 0,
            'document_time': 0.0,
            'document_errors': 0,
            'document_retries': 0,
            'chunk_calls': 0,
            'chunk_time': 0.0,
            'chunk_errors': 0,
            'chunk_retries': 0,
            'rate_limit_retries': 0,
            'parsing_retries': 0,
        }

    def track_call(self, call_type: str, elapsed_time: float):
        """Track successful API call."""
        self.stats['total_calls'] += 1
        self.stats['total_time'] += elapsed_time

        if call_type == 'document':
            self.stats['document_calls'] += 1
            self.stats['document_time'] += elapsed_time
        elif call_type == 'chunk':
            self.stats['chunk_calls'] += 1
            self.stats['chunk_time'] += elapsed_time

    def track_error(self, call_type: Optional[str] = None):
        """Track API error."""
        if call_type == 'document':
            self.stats['document_errors'] += 1
        elif call_type == 'chunk':
            self.stats['chunk_errors'] += 1

    def track_rate_limit(self):
        """Track rate limit retry attempt."""
        self.stats['rate_limit_retries'] += 1

    def track_retry(self, call_type: Optional[str] = None):
        """Track retry attempt."""
        if call_type == 'document':
            self.stats['document_retries'] += 1
        elif call_type == 'chunk':
            self.stats['chunk_retries'] += 1

    def track_parsing_retry(self):
        """Track parsing error retry attempt (chunk ranking only)."""
        self.stats['parsing_retries'] += 1

    def get_stats(self) -> Dict:
        """Get current statistics."""
        return self.stats.copy()

    def print_summary(self):
        """Print summary statistics."""
        stats = self.get_stats()

        # Compute averages
        avg_time = stats['total_time'] / stats['total_calls'] if stats['total_calls'] > 0 else 0.0
        avg_doc_time = stats['document_time'] / stats['document_calls'] if stats['document_calls'] > 0 else 0.0
        avg_chunk_time = stats['chunk_time'] / stats['chunk_calls'] if stats['chunk_calls'] > 0 else 0.0

        print("\n" + "="*80)
        print("API STATISTICS SUMMARY")
        print("="*80)
        print(f"Total API calls: {stats['total_calls']}")
        print(f"Total time: {stats['total_time']:.1f}s (avg: {avg_time:.2f}s/call)")
        print(f"\nDocument ranking:")
        print(f"  Calls: {stats['document_calls']}")
        print(f"  Time: {stats['document_time']:.1f}s (avg: {avg_doc_time:.2f}s/call)")
        print(f"  Errors: {stats['document_errors']}")
        print(f"  Retries: {stats['document_retries']}")
        print(f"\nChunk ranking:")
        print(f"  Calls: {stats['chunk_calls']}")
        print(f"  Time: {stats['chunk_time']:.1f}s (avg: {avg_chunk_time:.2f}s/call)")
        print(f"  Errors: {stats['chunk_errors']}")
        print(f"  Retries: {stats['chunk_retries']}")
        print(f"\nRetries breakdown:")
        print(f"  Rate limit retries: {stats['rate_limit_retries']}")
        print(f"  Parsing retries: {stats['parsing_retries']}")
        print("="*80)

In [7]:
def split_chunks_n_way(chunks: List[str], indices: List[int], n: int) -> List[Tuple[List[str], List[int]]]:
    """
    Split chunks into N equal parts by count for multi-stage processing.

    Args:
        chunks: List of chunk texts
        indices: List of chunk indices (same length as chunks)
        n: Number of parts to split into

    Returns:
        List of (part_chunks, part_indices) tuples

    Edge case: If len(chunks) < n, return single part
    """
    # Edge case: fewer chunks than splits
    if len(chunks) < n:
        return [(chunks, indices)]

    # Compute base chunk size per part
    chunk_size = len(chunks) // n
    parts = []

    for i in range(n):
        start = i * chunk_size

        # Last part gets remainder
        if i == n - 1:
            end = len(chunks)
        else:
            end = start + chunk_size

        # Only add non-empty parts
        if start < len(chunks):
            part_chunks = chunks[start:end]
            part_indices = indices[start:end]
            parts.append((part_chunks, part_indices))

    return parts

In [8]:
def truncate_chunk(chunk: str, max_chars: int = 10000) -> str:
    """
    Truncate chunk preserving start, middle, end.

    Total after truncation: ~8050 chars (4K + 2K + 2K + 50 marker chars)
    """
    if len(chunk) <= max_chars:
        return chunk

    marker = " [...TRUNCATED...] "
    first_part = chunk[:4000]
    middle_part = chunk[4000:6000]
    last_part = chunk[-2000:]

    return first_part + marker + middle_part + marker + last_part


def compute_bm25_scores_for_query(
    question: str,
    chunks: List[str],
    chunk_indices: List[int]
) -> Dict[int, float]:
    """
    Compute BM25 scores for chunks against question using bm25s library.

    Returns:
        Dict mapping chunk_idx -> BM25 score
        Empty dict on failure (graceful fallback)
    """
    try:
        # Tokenize with English stopwords
        corpus_tokens = bm25s.tokenize(chunks, stopwords="en")
        query_tokens = bm25s.tokenize([question], stopwords="en")

        # Index and retrieve scores
        retriever = bm25s.BM25()
        retriever.index(corpus_tokens)
        results, scores = retriever.retrieve(query_tokens, k=len(chunks))

        # Map to dict
        score_dict = {}
        for retrieved_idx, score in zip(results[0], scores[0]):
            original_idx = chunk_indices[retrieved_idx]
            score_dict[original_idx] = float(score)

        # Add missing indices with 0.0 score
        for idx in chunk_indices:
            if idx not in score_dict:
                score_dict[idx] = 0.0

        return score_dict

    except Exception as e:
        print(f"[WARNING] BM25 computation failed: {e}")
        return {}


def _extract_question(content: str) -> str:
    """Extract question text from message content."""
    match = re.search(r'Question:\s*(.+?)(?:\n|$)', content, re.DOTALL)
    if not match:
        raise ValueError("Question not found in message content")

    question = match.group(1).strip()

    # Remove everything after double newline
    next_section = question.find('\n\n')
    if next_section != -1:
        question = question[:next_section].strip()

    return question


def _parse_chunks_from_message(content: str) -> Tuple[List[str], List[int]]:
    """Parse chunks and indices from message content."""
    # Remove trailing instruction text
    task_pattern = r'\n+Task:\s+Select and rank.*?Response Format:.*?$'
    content_cleaned = re.sub(task_pattern, '', content, flags=re.DOTALL | re.IGNORECASE)

    # Parse chunks
    pattern = r'\[Chunk Index (\d+)\]\s*(.+?)(?=\n\[Chunk Index |\n*$)'
    matches = re.findall(pattern, content_cleaned, re.DOTALL)

    if not matches:
        raise ValueError("No chunks found in message content")

    chunks = []
    chunk_indices = []

    for idx_str, chunk_text in matches:
        chunk_text = chunk_text.strip()
        if chunk_text:
            chunks.append(chunk_text)
            chunk_indices.append(int(idx_str))

    if not chunks:
        raise ValueError("All parsed chunks were empty")

    return chunks, chunk_indices


def _reconstruct_message_with_truncated_chunks(
    original_content: str,
    chunks: List[str],
    chunk_indices: List[int],
    truncated_chunks: List[str]
) -> str:
    """Reconstruct message content with truncated chunks."""
    # Extract question section
    pattern = r'^(.*?)\[Chunk Index \d+\]'
    match = re.search(pattern, original_content, re.DOTALL)

    if not match:
        return original_content

    question_section = match.group(1)

    # Reconstruct chunks section
    chunks_section = ""
    for idx, truncated_chunk in zip(chunk_indices, truncated_chunks):
        chunks_section += f"[Chunk Index {idx}] {truncated_chunk}\n"

    return question_section + chunks_section.rstrip()


def _compute_splits(df: pd.DataFrame, target_tokens_per_part: int = 15000) -> pd.DataFrame:
    """Add n_splits column with chunk-based Max-5 strategy."""
    def compute_n_splits(num_chunks: int) -> int:
        """Compute n_splits using chunk-based Max-5 strategy."""
        chunks_per_part = max(30, num_chunks // 5)
        n_splits = min(5, math.ceil(num_chunks / chunks_per_part))
        return n_splits

    df['n_splits'] = df['num_chunks'].apply(compute_n_splits)
    return df


def load_document_data(filepath: Path) -> pd.DataFrame:
    """
    Load JSONL for document ranking.

    Returns:
        DataFrame with columns: query_id, question
    """
    data_records = []
    errors = []
    seen_query_ids = set()

    with jsonlines.open(filepath) as reader:
        for line_num, item in enumerate(reader, 1):
            try:
                # Extract query_id
                if 'uuid' in item:
                    query_id = item['uuid']
                elif '_id' in item:
                    query_id = item['_id']
                elif 'record_id' in item:
                    query_id = str(item['record_id'])
                else:
                    errors.append(f"Line {line_num}: No query ID field found")
                    continue

                # Check for duplicates
                if query_id in seen_query_ids:
                    errors.append(f"Line {line_num}: Duplicate query_id '{query_id}'")
                    continue
                seen_query_ids.add(query_id)

                # Extract question
                messages = item.get('messages', [])
                if not messages:
                    errors.append(f"Line {line_num}: No messages field")
                    continue

                content = messages[0].get('content', '')
                if not content:
                    errors.append(f"Line {line_num}: Empty message content")
                    continue

                try:
                    question = _extract_question(content)
                except ValueError as e:
                    errors.append(f"Line {line_num}: {str(e)}")
                    continue

                if not question:
                    errors.append(f"Line {line_num}: Empty question after parsing")
                    continue

                data_records.append({
                    'query_id': query_id,
                    'question': question
                })

            except Exception as e:
                errors.append(f"Line {line_num}: Unexpected error - {str(e)}")

    if errors:
        error_report = "\n".join(errors[:10])
        if len(errors) > 10:
            error_report += f"\n... and {len(errors) - 10} more errors"
        raise ValueError(f"Document data validation failed:\n{error_report}")

    df = pd.DataFrame(data_records)
    return df


def load_chunk_data(filepath: Path, target_tokens_per_part: int = 15000) -> pd.DataFrame:
    """
    Load JSONL for chunk ranking with pre-computed splits.

    Returns:
        DataFrame with columns: query_id, question, chunks, chunk_indices,
                                num_chunks, num_tokens, bm25_scores, n_splits
    """
    data_records = []
    errors = []
    seen_query_ids = set()

    # Initialize tokenizer
    encoding = tiktoken.get_encoding("cl100k_base")

    with jsonlines.open(filepath) as reader:
        for line_num, item in enumerate(reader, 1):
            try:
                # Extract query_id
                if 'uuid' in item:
                    query_id = item['uuid']
                elif '_id' in item:
                    query_id = item['_id']
                elif 'record_id' in item:
                    query_id = str(item['record_id'])
                else:
                    errors.append(f"Line {line_num}: No query ID field found")
                    continue

                # Check for duplicates
                if query_id in seen_query_ids:
                    errors.append(f"Line {line_num}: Duplicate query_id '{query_id}'")
                    continue
                seen_query_ids.add(query_id)

                # Extract messages
                messages = item.get('messages', [])
                if not messages:
                    errors.append(f"Line {line_num}: No messages field")
                    continue

                content = messages[0].get('content', '')
                if not content:
                    errors.append(f"Line {line_num}: Empty message content")
                    continue

                # Parse question
                try:
                    question = _extract_question(content)
                except ValueError as e:
                    errors.append(f"Line {line_num}: {str(e)}")
                    continue

                if not question:
                    errors.append(f"Line {line_num}: Empty question after parsing")
                    continue

                # Parse chunks and indices
                try:
                    chunks, chunk_indices = _parse_chunks_from_message(content)
                except ValueError as e:
                    errors.append(f"Line {line_num}: {str(e)}")
                    continue

                # Compute BM25 scores on ORIGINAL chunks (before truncation)
                bm25_scores = compute_bm25_scores_for_query(question, chunks, chunk_indices)

                # Truncate chunks
                truncated_chunks = [truncate_chunk(chunk) for chunk in chunks]

                # Reconstruct message with truncated chunks
                content_truncated = _reconstruct_message_with_truncated_chunks(
                    content, chunks, chunk_indices, truncated_chunks
                )

                # Validation
                record_errors = []

                if len(truncated_chunks) != len(chunk_indices):
                    record_errors.append("chunks/indices length mismatch")

                if any(idx < 0 for idx in chunk_indices):
                    record_errors.append("negative chunk indices found")

                if any(not chunk.strip() for chunk in truncated_chunks):
                    record_errors.append("empty chunks found")

                # Compute token count from TRUNCATED content
                try:
                    num_tokens = len(encoding.encode(content_truncated))
                except Exception:
                    num_tokens = 0

                if not (0 < num_tokens < 1_000_000):
                    record_errors.append(f"token count out of range: {num_tokens}")

                if record_errors:
                    errors.append(f"Line {line_num}: {', '.join(record_errors)}")
                    continue

                data_records.append({
                    'query_id': query_id,
                    'question': question,
                    'chunks': truncated_chunks,
                    'chunk_indices': chunk_indices,
                    'num_chunks': len(truncated_chunks),
                    'num_tokens': num_tokens,
                    'bm25_scores': bm25_scores
                })

            except Exception as e:
                errors.append(f"Line {line_num}: Unexpected error - {str(e)}")

    if errors:
        error_report = "\n".join(errors[:10])
        if len(errors) > 10:
            error_report += f"\n... and {len(errors) - 10} more errors"
        raise ValueError(f"Chunk data validation failed:\n{error_report}")

    df = pd.DataFrame(data_records)

    # Compute n_splits dynamically
    df = _compute_splits(df, target_tokens_per_part)
    return df

In [9]:
# Document ranking system prompt
SYSTEM_PROMPT_DOCUMENT = """You are a helpful financial analyst. You have expertise in SEC filings.

Your task is to rank 5 financial document types to answer the given question.
"""

def build_doc_messages(question: str) -> List[Dict[str, str]]:
    """
    Build messages for document ranking API call.
    Uses Dict format: {"0": 2, "1": 3, ...}
    """
    user_content = f"""
###QUESTION###
{question}

###TASK###
Rank ALL 5 document types using the 0-4 relevance scale to answer the given question (most relevant=4).

Document types to score:
- 0: DEF 14A (Proxy Statement)
- 1: 10-K (Annual Report)
- 2: 10-Q (Quarterly Report)
- 3: 8-K (Current Report)
- 4: Earnings Call Transcript

###OUTPUT FORMAT###
Return a JSON dictionary: document-type indices as keys (strings) and scores as values (0-4).
You MUST include ALL 5 document types in your ranking.

###REAL EXAMPLES###
Question: What is Apple's latest positioning in terms of global smartphone market share?
Answer: {{"0": 2, "1": 3, "2": 0, "3": 1, "4": 4}}

This means Earnings Call Transcript (key "4") is most relevant with score 4, then 10-K (key "1") with score 3, and so on.

###ANSWER###
Return a JSON dictionary of all 5 document type rankings:"""

    return [
        {"role": "system", "content": SYSTEM_PROMPT_DOCUMENT},
        {"role": "user", "content": user_content}
    ]


# Chunk recall system prompt (Stage 1 and Stage 2)
SYSTEM_PROMPT_RECALL = """
You are a helpful financial analyst. You have expertise in SEC filings and documents.
Your task is to identify the MOST relevant text chunks to answer a question.
"""

def build_chunk_messages_recall(
    question: str,
    chunks: List[str],
    chunk_indices: List[int],
    k: int = 15
) -> List[Dict[str, str]]:
    """
    Recall Focus: Local ranking with recall focus (0-4 scale).
    Used by Stage 1 (local ranking) and Stage 2 (split rescore).
    Uses Dict format: {"67": 4, "91": 2, ...}
    """
    # Build chunks section
    chunks_section = ""
    for chunk, orig_idx in zip(chunks, chunk_indices):
        chunks_section += f"---\n**Chunk Index {orig_idx}**\n{chunk}\n"

    user_content = f"""###QUESTION###
{question}

###TEXT CHUNKS###
{chunks_section}

###INSTRUCTIONS###
Identify the top-{k} most relevant chunks, assign a relevance score (0-4) to each chunk: most relevant=4, least relevant=0

Principles:
- If you find less than {k} relevant chunks, just add more random chunks with 0 score.
- If ALL chunks are NOT relevant, give back random chunks with 0 score.

###OUTPUT FORMAT###
Your answer MUST be a JSON dictionary. Your answer MUST have exactly {k} chunk indices with scores as values (0-4).
For chunk index, use just the index number as string (e.g., "98").

GOOD EXAMPLE: {{"67": 4, "91": 2, "12": 1, "85": 0, "136": 0, ...}}
This means: chunk "67" is most relevant with score 4, then chunk "91" with score 2, and so on.

BAD EXAMPLE: {{"Chunk Index 67: 4, Chunk Index 91: 2, ..."}}

###ANSWER###
Return the JSON dictionary of top-{k} chunk rankings:"""

    return [
        {"role": "system", "content": SYSTEM_PROMPT_RECALL},
        {"role": "user", "content": user_content}
    ]


# Chunk precision system prompt (Stage 3)
SYSTEM_PROMPT_PRECISION = """You are a helpful financial analyst.
Your task is to rank the top-k RELEVANT text chunks to answer a given question.
"""

def build_chunk_messages_precision(
    question: str,
    chunks: List[str],
    chunk_indices: List[int],
    k: int = 5
) -> List[Dict[str, str]]:
    """
    Precision Focus: Global re-scoring with precision focus (0-2 scale).
    Used by Stage 3 (final rescore).
    Uses List format: [67, 91, 12, 85, 136]
    """
    # Build chunks section
    chunks_section = ""
    for chunk, orig_idx in zip(chunks, chunk_indices):
        chunks_section += f"---\n**Chunk Index {orig_idx}**\n{chunk}\n"

    user_content = f"""###QUESTION###
{question}

###CANDIDATES###
{chunks_section}

###INSTRUCTIONS###
Rank top-{k} most relevant chunks in DESCENDING order of relevance (most relevance first).

###OUTPUT FORMAT###
Your answer MUST be a RANKED LIST: with exactly {k} chunk indices, most relevant item first, least relevant item last.
Use ONLY index number as integer (89, 12, etc.)

IMPORTANT: Do NOT include verbose reasoning or explanation in the answer.
You MUST output ONLY the list as the final answer.

GOOD EXAMPLE: [67, 91, 12, 85, 136]
This means: chunk 67 is most relevant, then both chunk 91, and so on.

BAD EXAMPLE: ["Chunk Index 67", "Chunk Index 91", ... ]

###ANSWER###
Return the ranked list of top-{k} chunk rankings:"""

    return [
        {"role": "system", "content": SYSTEM_PROMPT_PRECISION},
        {"role": "user", "content": user_content}
    ]

In [10]:
class UnifiedLLMClient:
    """Unified LLM client supporting OpenAI and Databricks via direct HTTP calls."""

    def __init__(
        self,
        backend: str,
        model: str,
        temperature: float = 0.1,
        max_tokens: int = 1000,
        extra_params: Dict[str, Any] = None
    ):
        """
        Initialize unified LLM client.

        Args:
            backend: 'openai' or 'databricks'
            model: Model name
            temperature: Sampling temperature
            max_tokens: Maximum tokens in response
            extra_params: Additional parameters (e.g., {"reasoning_effort": "medium"})
        """
        self.backend = backend
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.extra_params = extra_params or {}

        # Detect GPT-5 models
        self.is_gpt5 = 'gpt-5' in model.lower()

        if backend == "openai":
            self.api_key = os.getenv("OPENAI_API_KEY")
            self.api_base = "https://api.openai.com/v1"
        elif backend == "databricks":
            self.api_key = os.getenv("DATABRICKS_TOKEN")
            self.api_base = os.getenv("DATABRICKS_SERVING_ENDPOINT")
        else:
            raise ValueError(f"Unsupported backend: {backend}")

        if not self.api_key:
            raise ValueError(f"API key not found for {backend}")

    async def achat(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
        """
        Direct async chat completion.

        Args:
            messages: Chat messages [{"role": "user", "content": "..."}]

        Returns:
            Raw API response dict with 'choices' field
        """
        if self.backend == "openai":
            url = f"{self.api_base}/chat/completions"

            # GPT-5 models have different parameters
            if self.is_gpt5:
                payload = {
                    "model": self.model,
                    "messages": messages,
                    "max_completion_tokens": self.max_tokens
                }
                payload.update(self.extra_params)
            else:
                payload = {
                    "model": self.model,
                    "messages": messages,
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens
                }
                payload.update(self.extra_params)
        else:  # databricks
            # Build full endpoint URL with model name
            if self.api_base.endswith('/serving-endpoints'):
                url = f"{self.api_base}/{self.model}/invocations"
            elif not self.api_base.endswith('/invocations'):
                url = f"{self.api_base}/invocations"
            else:
                url = self.api_base

            payload = {
                "messages": messages,
                "temperature": self.temperature,
                "max_tokens": self.max_tokens
            }
            payload.update(self.extra_params)

        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }

        # Increased timeout to 150s for complex queries
        try:
            async with httpx.AsyncClient(timeout=150.0) as client:
                response = await client.post(url, headers=headers, json=payload)

                if response.status_code != 200:
                    text = response.text
                    raise httpx.HTTPError(f"API error {response.status_code}: {text}")

                return response.json()

        except httpx.TimeoutException:
            raise httpx.HTTPError(f"API request timeout after 150s (backend: {self.backend}, model: {self.model})")

In [11]:
@dataclass
class ParseResult:
    """Result from LLM response parsing with metadata."""
    rankings: List[Tuple[int, int]]  # (index, score) tuples
    is_complete: bool                 # True if len >= expected_count
    extraction_stage: str             # "direct" or "rescue"


class ResponseParser:
    """
    Parse LLM responses with 4-stage extraction strategy.

    Supports two output formats:
    1. Dict format: {"45": 2, "91": 1, ...}
    2. List format: [45, 91, 12, ...]
    """

    def __init__(self):
        """Initialize parser."""
        self.rescue_client = None  # Lazy init only if needed

    async def parse_rankings(
        self,
        response_content: Any,
        expected_count: Optional[int] = None
    ) -> ParseResult:
        """Parse rankings from LLM response content."""
        # Try stages 1-3: Direct extraction
        stage = "direct"
        try:
            rankings = self._extract_direct(response_content)
        except Exception as e:
            print(f"[WARNING] Direct extraction failed: {e}")

            # Try stage 4: Rescue parsing
            try:
                rankings = await self._extract_rescue(response_content)
                stage = "rescue"
            except Exception as rescue_error:
                print(f"[ERROR] All extraction methods failed: {rescue_error}")
                raise ValueError("All extraction methods failed")

        # Check completeness
        is_complete = (expected_count is None or len(rankings) >= expected_count)

        return ParseResult(rankings, is_complete, stage)

    def _extract_direct(self, content: Any) -> List[Tuple[int, int]]:
        """Direct extraction with 3 priority levels."""
        # Handle string content
        if isinstance(content, str):
            return self._extract_from_text(content)

        # Handle list content (Databricks format)
        if not isinstance(content, list):
            raise ValueError(f"Unexpected content type: {type(content)}")

        # Priority 1: Extract from 'text' blocks
        text_blocks = self._collect_text_blocks(content)
        for text in text_blocks:
            try:
                rankings = self._extract_from_text(text)
                return rankings
            except:
                continue

        # Priority 2: Extract from 'reasoning' blocks
        reasoning_texts = self._collect_reasoning_texts(content)
        for text in reasoning_texts:
            try:
                if self._is_example_quote(text):
                    continue
                rankings = self._extract_from_text(text)
                return rankings
            except:
                continue

        # Priority 3: Regex-based extraction
        rankings = self._extract_regex(content)
        return rankings

    def _collect_text_blocks(self, content: List) -> List[str]:
        """Collect all 'text' blocks from list content."""
        blocks = []
        for item in content:
            if isinstance(item, dict) and item.get('type') == 'text' and 'text' in item:
                blocks.append(item['text'])
        return blocks

    def _collect_reasoning_texts(self, content: List) -> List[str]:
        """Collect all reasoning texts from list content."""
        texts = []
        for item in content:
            if isinstance(item, dict) and item.get('type') == 'reasoning':
                if 'summary' in item and isinstance(item['summary'], list):
                    for summary_item in item['summary']:
                        if isinstance(summary_item, dict) and 'text' in summary_item:
                            texts.append(summary_item['text'])
        return texts

    def _normalize_index_key(self, key: str) -> int:
        """Normalize dictionary keys to extract numeric index."""
        key_str = str(key).strip()

        # Try direct conversion first
        try:
            return int(key_str)
        except ValueError:
            pass

        # Pattern: "Chunk Index X" or "Document Index X"
        match = re.search(r'(?:Chunk|Document)?\s*Index\s+(\d+)', key_str, re.IGNORECASE)
        if match:
            return int(match.group(1))

        # Fallback: extract first number
        match = re.search(r'\d+', key_str)
        if match:
            return int(match.group())

        raise ValueError(f"Cannot extract numeric index from key: {key_str}")

    def _extract_from_text(self, text: str) -> List[Tuple[int, int]]:
        """Extract rankings from text using JSON parsing."""
        # Try dict format first
        dict_start = text.find('{')
        dict_end = text.rfind('}') + 1
        if dict_start >= 0 and dict_end > dict_start:
            try:
                json_str = text[dict_start:dict_end]
                rank_dict = json.loads(json_str)
                if isinstance(rank_dict, dict):
                    rankings = [(self._normalize_index_key(k), int(v)) for k, v in rank_dict.items()]
                    return rankings
            except:
                pass

        # Try list format
        list_start = text.find('[')
        list_end = text.rfind(']') + 1
        if list_start >= 0 and list_end > list_start:
            try:
                json_str = text[list_start:list_end]
                rank_list = json.loads(json_str)
                if isinstance(rank_list, list):
                    # Assign scores by position
                    max_score = len(rank_list) - 1
                    rankings = [(int(idx), max_score - i) for i, idx in enumerate(rank_list)]
                    return rankings
            except:
                pass

        raise ValueError("No valid JSON found in text")

    def _is_example_quote(self, text: str) -> bool:
        """Check if text is just quoting the example from prompt."""
        indicators = [
            'Return JSON dict like',
            'Return JSON list like',
            'Example:',
            'The user asks:',
        ]
        quote_count = sum(1 for ind in indicators if ind in text)
        return quote_count >= 2

    def _extract_regex(self, content: Any) -> List[Tuple[int, int]]:
        """Regex-based extraction."""
        # Extract all text from content
        if isinstance(content, list):
            all_text = []
            for item in content:
                if isinstance(item, dict):
                    if 'text' in item:
                        all_text.append(item['text'])
                    elif 'summary' in item and isinstance(item['summary'], list):
                        for summary_item in item['summary']:
                            if isinstance(summary_item, dict) and 'text' in summary_item:
                                all_text.append(summary_item['text'])
            combined_text = "\n".join(all_text)
        else:
            combined_text = str(content)

        # Pattern 1: List format [2, 4, 1, 3, 0]
        list_pattern = r'\[\s*(\d+(?:\s*,\s*\d+)*)\s*\]'
        list_matches = re.findall(list_pattern, combined_text)
        for match in list_matches:
            try:
                indices = [int(x.strip()) for x in match.split(',')]
                max_score = len(indices) - 1
                return [(idx, max_score - i) for i, idx in enumerate(indices)]
            except:
                continue

        # Pattern 2: Dict format {"2": 4, "4": 3, ...}
        if '{' in combined_text:
            dict_start = combined_text.find('{')
            dict_end = combined_text.find('}', dict_start) + 1
            if dict_start >= 0 and dict_end > dict_start:
                dict_str = combined_text[dict_start:dict_end]
                try:
                    rank_dict = json.loads(dict_str)
                    if isinstance(rank_dict, dict):
                        rankings = [(self._normalize_index_key(k), int(v)) for k, v in rank_dict.items()]
                        return rankings
                except:
                    # Fallback to regex
                    pairs = re.findall(r'"(\d+)"\s*:\s*(\d+)', dict_str)
                    if pairs:
                        rankings = [(int(k), int(v)) for k, v in pairs]
                        return rankings

        raise ValueError("No valid ranking patterns found with regex")

    async def _extract_rescue(self, content: Any) -> List[Tuple[int, int]]:
        """Rescue parsing using GPT-4o-mini via UnifiedLLMClient."""
        # Lazy init rescue client (MODIFIED: Use UnifiedLLMClient instead of LlamaIndex)
        if self.rescue_client is None:
            self.rescue_client = UnifiedLLMClient(
                backend="openai",
                model="gpt-4o-mini",
                temperature=0.0,
                max_tokens=100
            )

        # Extract all text from content
        if isinstance(content, list):
            text_parts = []
            for item in content:
                if isinstance(item, dict):
                    if 'text' in item:
                        text_parts.append(item['text'])
                    elif 'summary' in item and isinstance(item['summary'], list):
                        for summary_item in item['summary']:
                            if isinstance(summary_item, dict) and 'text' in summary_item:
                                text_parts.append(summary_item['text'])
            content_str = "\n".join(text_parts)
        else:
            content_str = str(content)

        rescue_prompt = f"""Extract ranking information from this model output. The text may contain
various blocks and formats.
IMPORTANT: Focus on the LAST part of the text for the final answer.
Return a JSON dict with indices as keys (strings) and scores as values. Keep answer to max 10 items.
Example: {{"45": 2, "91": 1, "12": 0, ...}}

Output:
{content_str}

Return ONLY the JSON dict:"""

        # MODIFIED: Use dict format instead of ChatMessage
        messages = [{"role": "user", "content": rescue_prompt}]
        response = await self.rescue_client.achat(messages)
        # MODIFIED: Access response via dict keys instead of response.message.content
        rescued_text = response['choices'][0]['message']['content']

        # Extract from rescued response
        return self._extract_from_text(rescued_text)

# Model

In [12]:
@dataclass
class ApiResult:
    """
    Result from API call with ensemble data.

    Attributes:
        attempts: ALL attempt results [(idx, score), ...]
        quality: "complete" | "early_stop" | "max_retries" | "failed"
        n_complete: How many complete attempts
        avg_quality: Average completeness (0.0-1.0)
        error: Error message if failed
    """
    attempts: List[List[Tuple[int, int]]]
    quality: str
    n_complete: int
    avg_quality: float
    error: Optional[str] = None


def fuse_retry_attempts(
    attempts: List[List[Tuple[int, int]]],
    k: int = 60
) -> List[Tuple[int, float]]:
    """
    Fuse multiple retry attempts using ensemble ranking.

    Combines three signals:
    1. Frequency: How many attempts included this item
    2. RRF score: Reciprocal rank fusion across attempts
    3. Score sum: Aggregate relevance scores
    """
    item_signals = defaultdict(lambda: {
        'frequency': 0,
        'rrf_score': 0.0,
        'score_sum': 0
    })

    n_attempts = len(attempts)

    # Accumulate signals from all attempts
    for attempt in attempts:
        for rank, (idx, score) in enumerate(attempt):
            item_signals[idx]['frequency'] += 1
            item_signals[idx]['rrf_score'] += 1 / (rank + k)
            item_signals[idx]['score_sum'] += score

    # Calculate ensemble score
    fused = []
    for idx, signals in item_signals.items():
        freq_weight = signals['frequency'] / n_attempts
        rrf_normalized = signals['rrf_score'] / n_attempts
        score_avg = signals['score_sum'] / signals['frequency']

        # Weighted combination (40% consensus, 30% position, 30% relevance)
        ensemble_score = (
            freq_weight * 0.4 +
            rrf_normalized * 0.3 +
            score_avg * 0.3
        )

        fused.append((idx, ensemble_score))

    # Sort by ensemble score descending
    fused.sort(key=lambda x: x[1], reverse=True)
    return fused


async def call_with_ensemble_retry(
    api_func,
    expected_count: int,
    max_retries: int = 10,
    min_attempts: int = 2,
    quality_threshold: float = 0.6,
    tracker=None,
    call_type: str = "chunk",
    force_min_attempts: bool = False,
    jitter_max: float = 15.0
) -> ApiResult:
    """
    Retry with smart stopping and ensemble fusion.

    Stop conditions:
    1. First complete → STOP (unless force_min_attempts)
    2. Early stop (n_attempts >= min AND avg_quality >= threshold)
    3. Max retries exhausted
    """
    all_attempts = []
    n_complete = 0

    for attempt in range(max_retries):
        try:
            # Retry jitter (skip first attempt)
            if attempt > 0:
                pre_jitter = random.uniform(0.0, jitter_max)
                await asyncio.sleep(pre_jitter)

            # Call API function
            parsed = await api_func()

            # Check if complete
            if len(parsed) >= expected_count:
                all_attempts.append(parsed)
                n_complete += 1

                # Check force_min_attempts
                if force_min_attempts and len(all_attempts) < min_attempts:
                    continue

                # Stop condition met
                avg_quality = 1.0
                return ApiResult(all_attempts, "complete", n_complete, avg_quality)
            else:
                # Incomplete - store and check early stop
                all_attempts.append(parsed)

                # Calculate average quality
                avg_quality = sum(len(a) for a in all_attempts) / (len(all_attempts) * expected_count)

                # Early stop check
                if len(all_attempts) >= min_attempts and avg_quality >= quality_threshold:
                    return ApiResult(all_attempts, "early_stop", 0, avg_quality)

        except Exception as e:
            error_str = str(e).lower()

            # Track rate limits
            if '429' in error_str or 'rate limit' in error_str:
                if tracker:
                    tracker.track_rate_limit()

            # Track retries
            if tracker and attempt > 0:
                tracker.track_retry(call_type)

            print(f"[WARNING] Attempt {attempt + 1} failed: {str(e)[:140]}")

            # Backoff + jitter
            if attempt < max_retries - 1:
                backoff = min(35, (attempt + 1) * 12.0)
                jitter = random.uniform(0.0, 20.0)
                wait_total = backoff + jitter
                await asyncio.sleep(wait_total)

            continue

    # Max retries exhausted
    avg_quality = sum(len(a) for a in all_attempts) / (len(all_attempts) * expected_count) if all_attempts else 0.0

    if all_attempts:
        return ApiResult(all_attempts, "max_retries", 0, avg_quality)

    # No successful attempts
    return ApiResult([], "failed", 0, 0.0, error="No successful attempts")


async def create_ranking(
    api_result: ApiResult,
    expected_count: int,
    candidate_pool: List[int],
    messages: Any,
    rescue_client,
    parser,
    semaphore: asyncio.Semaphore,
    domain_fallback: Optional[List[int]] = None
) -> List[int]:
    """
    Create final ranking with complete-first cascade and LLM rescue.

    Fallback cascade:
    1. Has complete → Use last complete attempt
    2. Has incomplete (>=2 attempts) → Ensemble fusion
    3. Failed → Try GPT-4o-mini rescue
    4. Rescue failed + domain → Domain fallback
    5. Rescue failed + no domain → Random
    """
    # Case 1: Has complete result
    if api_result.n_complete >= 1:
        complete_attempt = next(
            a for a in reversed(api_result.attempts)
            if len(a) >= expected_count
        )

        # Sort by score descending
        sorted_items = sorted(complete_attempt[:expected_count], key=lambda x: (-x[1], x[0]))
        return [idx for idx, _ in sorted_items]

    # Case 2: Has incomplete data - ensemble fusion
    if len(api_result.attempts) >= 2:
        fused = fuse_retry_attempts(api_result.attempts)

        # Sort by score descending
        sorted_items = sorted(fused[:expected_count], key=lambda x: (-x[1], x[0]))
        return [idx for idx, _ in sorted_items]

    # Case 3: Failed - try GPT-4o-mini rescue
    if rescue_client:
        try:
            async with semaphore:
                response = await rescue_client.achat(messages)
                content = response['choices'][0]['message']['content']
                result = await parser.parse_rankings(content, expected_count=expected_count)

                if len(result.rankings) >= expected_count:
                    # Sort by score descending
                    sorted_items = sorted(result.rankings[:expected_count], key=lambda x: (-x[1], x[0]))
                    return [idx for idx, _ in sorted_items]
        except Exception as e:
            print(f"[WARNING] LLM rescue failed: {e}")

    # Case 4: Domain fallback
    if domain_fallback:
        return domain_fallback[:expected_count]

    # Case 5: Random
    candidate_pool_copy = candidate_pool.copy()
    random.shuffle(candidate_pool_copy)
    return candidate_pool_copy[:expected_count]

In [13]:
def pad_to_n_indices(ranking: List[int], target_n: int, all_indices: List[int], query_id: str = "") -> List[int]:
    """
    Pad ranking to exactly target_n indices.

    Args:
        ranking: Current ranking list
        target_n: Target length
        all_indices: Available indices for padding
        query_id: Query ID for logging

    Returns:
        Padded ranking of exactly target_n length

    Strategy: unused indices → cyclic repeats → default fallback
    """
    if len(ranking) >= target_n:
        return ranking[:target_n]

    # Use unused indices first
    used = set(ranking)
    remaining = sorted([idx for idx in all_indices if idx not in used])
    ranking.extend(remaining[:target_n - len(ranking)])

    # Cyclic padding if still short
    while len(ranking) < target_n and all_indices:
        ranking.append(all_indices[len(ranking) % len(all_indices)])

    # Catastrophic fallback
    if len(ranking) < target_n:
        print(f"[WARNING] [{query_id}] Padding fallback: insufficient indices")
        while len(ranking) < target_n:
            ranking.append(len(ranking))

    return ranking[:target_n]


def fuse_llm_bm25_scores(
    llm_scores: List[Tuple[int, int]],
    bm25_scores: Dict[int, float],
    indices: List[int],
    weight_llm: float
) -> List[Tuple[int, float]]:
    """
    Weighted fusion of LLM semantic and BM25 lexical scores.

    Args:
        llm_scores: (chunk_idx, llm_score) tuples from API
        bm25_scores: Pre-computed BM25 scores dict
        indices: Chunk indices in this part
        weight_llm: LLM weight (0-1), BM25 gets (1-weight_llm)

    Returns:
        Sorted (chunk_idx, fused_score) tuples
    """
    llm_dict = {idx: score for idx, score in llm_scores}

    # Normalize LLM (handle all-zero case)
    llm_max = max(llm_dict.values()) if llm_dict and any(llm_dict.values()) else 1
    llm_norm = {idx: score / llm_max for idx, score in llm_dict.items()}

    # Normalize BM25 for this part
    part_bm25 = {idx: bm25_scores.get(idx, 0) for idx in indices}
    bm25_max = max(part_bm25.values()) if any(part_bm25.values()) else 1
    bm25_norm = {idx: score / bm25_max for idx, score in part_bm25.items()}

    # Weighted fusion
    weight_bm25 = 1.0 - weight_llm
    if any(bm25_norm.values()):
        fused = {
            idx: weight_llm * llm_norm.get(idx, 0) + weight_bm25 * bm25_norm.get(idx, 0)
            for idx in indices
        }
    else:
        fused = llm_norm

    # Scale back to preserve magnitude, sort descending
    fused_items = [(idx, fused.get(idx, 0) * llm_max) for idx in indices]
    fused_items.sort(key=lambda x: -x[1])

    return fused_items

In [14]:
async def rescore_split_stage2(
    question: str,
    chunks: List[str],
    candidate_items: List[Tuple[int, float]],
    query_id: str,
    client_stage2,
    parser,
    tracker
) -> List[Tuple[int, float]]:
    """
    Stage 2: Split top-50 into 2x25 parts, rank independently with pure LLM.

    Args:
        question: User question
        chunks: Full chunk list
        candidate_items: (chunk_idx, stage1_score) tuples from Stage 1
        query_id: Query identifier
        client_stage2: Stage 2 LLM client
        parser: Response parser
        tracker: API tracker

    Returns:
        ~20 (chunk_idx, llm_score) tuples (top-10 from each part), or empty if failed

    Sequential split preserves Stage 1 ordering quality.
    """
    top_n_candidates = min(50, len(candidate_items))
    candidates = candidate_items[:top_n_candidates]

    # Sequential split (no shuffle)
    part_size = len(candidates) // STAGE2_SPLIT_COUNT
    parts = []
    for i in range(STAGE2_SPLIT_COUNT):
        start = i * part_size
        end = start + part_size if i < STAGE2_SPLIT_COUNT - 1 else len(candidates)
        parts.append(candidates[start:end])

    stage2_part_semaphore = asyncio.Semaphore(STAGE2_PART_SEMAPHORE)

    # Rank parts in parallel
    part_tasks = []
    for i, part_candidates in enumerate(parts):
        part_indices = [idx for idx, _ in part_candidates]
        part_chunks = [chunks[idx] for idx in part_indices if idx < len(chunks)]

        messages = build_chunk_messages_recall(
            question, part_chunks, part_indices, k=STAGE2_K_PER_PART
        )

        task = rank_chunk_part_with_retry(
            messages=messages,
            semaphore=stage2_part_semaphore,
            query_id=f"{query_id}_stage2_part{i+1}",
            expected_count=STAGE2_K_PER_PART,
            candidate_indices=part_indices,
            client=client_stage2,
            parser=parser,
            tracker=tracker,
            rescue_client=None,  # Will be passed from caller
            jitter_max=STAGE2_JITTER_MAX,
            stage_name=f"Stage2_Part{i+1}",
            force_min_attempts=False
        )
        part_tasks.append(task)

    # Gather and concatenate (no sorting - Stage 3 will handle)
    all_stage2_items = []
    for task in part_tasks:
        llm_scored = await task
        all_stage2_items.extend(llm_scored[:STAGE2_K_PER_PART])

    return all_stage2_items


async def rescore_final_stage3(
    question: str,
    chunks: List[str],
    candidate_items: List[Tuple[int, float]],
    query_id: str,
    stage3_semaphore: asyncio.Semaphore,
    client_stage3,
    parser,
    tracker,
    rescue_client
) -> List[Tuple[int, float]]:
    """
    Stage 3: Final rescore on ~20 candidates with 405B model.

    Args:
        question: User question
        chunks: Full chunk list
        candidate_items: (chunk_idx, stage2_score) tuples from Stage 2
        query_id: Query identifier
        stage3_semaphore: Independent semaphore for Stage 3
        client_stage3: Stage 3 LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client

    Returns:
        Top-10 (chunk_idx, llm_score) tuples, or empty if failed

    Escapes Local Pool Paradox by assigning NEW scores in full global context.
    Independent semaphore allows tuning 405B concurrency separately from 120B.
    """
    candidates = candidate_items[:min(20, len(candidate_items))]

    candidate_indices = [idx for idx, _ in candidates]
    candidate_chunks = [chunks[i] for i in candidate_indices if i < len(chunks)]

    messages = build_chunk_messages_precision(question, candidate_chunks, candidate_indices, k=10)

    llm_scored_items = await rank_chunk_part_with_retry(
        messages=messages,
        semaphore=stage3_semaphore,
        query_id=f"{query_id}_stage3",
        expected_count=10,
        candidate_indices=candidate_indices,
        client=client_stage3,
        parser=parser,
        tracker=tracker,
        rescue_client=rescue_client,
        jitter_max=STAGE3_JITTER_MAX,
        stage_name="Stage3",
        force_min_attempts=False
    )

    if not llm_scored_items:
        llm_scored_items = candidates[:10]
        print(f"[WARNING] [{query_id}_stage3] FALLBACK: Using Stage 2 results")

    return llm_scored_items if llm_scored_items else []


async def rank_chunk_part_with_retry(
    messages: List[Dict],
    semaphore: Optional[asyncio.Semaphore],
    query_id: str,
    expected_count: int,
    candidate_indices: List[int],
    client,
    parser,
    tracker,
    rescue_client,
    jitter_max: float = None,
    stage_name: str = "Stage1",
    force_min_attempts: bool = False
) -> List[Tuple[int, int]]:
    """
    Rank chunk part with unified smart retry strategy.

    Args:
        messages: Prompt messages
        semaphore: Rate limiting semaphore
        query_id: Query identifier
        expected_count: Expected item count
        candidate_indices: Candidate chunk indices
        client: LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client
        jitter_max: Max jitter seconds
        stage_name: Stage name for logging
        force_min_attempts: Force min attempts if first complete

    Returns:
        (index, score) tuples with actual LLM scores

    Handles API failures, incomplete responses, and ensemble fusion.
    """
    if jitter_max is None:
        jitter_max = STAGE1_JITTER_MAX

    async def wrapped_part_call():
        # Jitter for desynchronization (stacks with retry jitter)
        jitter = random.uniform(0, jitter_max)
        await asyncio.sleep(jitter)

        start_time = time.time()

        if semaphore:
            async with semaphore:
                response = await client.achat(messages)
        else:
            response = await client.achat(messages)

        elapsed = time.time() - start_time

        if tracker:
            tracker.track_call('chunk', elapsed)

        if not response or 'choices' not in response or not response['choices']:
            raise ValueError("Malformed API response - missing 'choices'")

        content = response['choices'][0]['message']['content']

        result = await parser.parse_rankings(content, expected_count=expected_count)

        return result.rankings

    api_result = await call_with_ensemble_retry(
        wrapped_part_call,
        expected_count=expected_count,
        max_retries=10,
        min_attempts=2,
        quality_threshold=0.7,
        tracker=tracker,
        call_type="chunk",
        force_min_attempts=force_min_attempts
    )

    # Return ACTUAL scores from API (not positional scores)
    # Case 1: Complete result
    if api_result.n_complete >= 1:
        complete_attempt = next(
            a for a in reversed(api_result.attempts)
            if len(a) >= expected_count
        )

        sorted_items = sorted(complete_attempt[:expected_count], key=lambda x: (-x[1], x[0]))
        return sorted_items

    # Case 2: Ensemble fusion (preserves actual scores)
    elif len(api_result.attempts) >= 2:
        fused = fuse_retry_attempts(api_result.attempts)

        sorted_items = sorted(fused[:expected_count], key=lambda x: (-x[1], x[0]))
        return sorted_items

    # Case 3: Rescue fallback (ONLY case using positional scores)
    else:
        sem_for_rescue = semaphore if semaphore else asyncio.Semaphore(1)

        ranking = await create_ranking(
            api_result,
            expected_count=expected_count,
            candidate_pool=candidate_indices,
            messages=messages,
            rescue_client=rescue_client,
            parser=parser,
            semaphore=sem_for_rescue,
            domain_fallback=None
        )

        print(f"[WARNING] [{query_id}] Using rescue/random with positional scores")

        scored_items = [(idx, expected_count - i) for i, idx in enumerate(ranking[:expected_count])]
        return scored_items

In [15]:
async def rank_single_document(
    row: pd.Series,
    semaphore: asyncio.Semaphore,
    client_stage1,
    parser,
    tracker,
    rescue_client
) -> Tuple[str, List[int]]:
    """
    Process single document ranking query.

    Args:
        row: DataFrame row with query data
        semaphore: Rate limiting semaphore
        client_stage1: Stage 1 LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client

    Returns:
        (query_id, ranking) tuple
    """
    query_id = row['query_id']
    question = row['question']
    messages = build_doc_messages(question)

    async def wrapped_api_call():
        async with semaphore:
            start_time = time.time()
            response = await client_stage1.achat(messages)
            elapsed = time.time() - start_time

            if tracker:
                tracker.track_call('document', elapsed)

            if not response or 'choices' not in response or not response['choices']:
                raise ValueError("Malformed API response - missing 'choices'")

            content = response['choices'][0]['message']['content']

            result = await parser.parse_rankings(content, expected_count=5)

            return result.rankings

    api_result = await call_with_ensemble_retry(
        wrapped_api_call,
        expected_count=5,
        max_retries=10,
        min_attempts=2,
        quality_threshold=0.8,
        tracker=tracker,
        call_type="document"
    )

    # 10-K > 10-Q > DEF14A > 8-K > Earnings
    DOMAIN_FALLBACK = [1, 2, 0, 3, 4]
    ranking = await create_ranking(
        api_result,
        expected_count=5,
        candidate_pool=[0, 1, 2, 3, 4],
        messages=messages,
        rescue_client=rescue_client,
        parser=parser,
        semaphore=semaphore,
        domain_fallback=DOMAIN_FALLBACK
    )

    return (query_id, ranking)


async def rank_single_chunk(
    row: pd.Series,
    stage3_semaphore: asyncio.Semaphore,
    client_stage1,
    client_stage2,
    client_stage3,
    parser,
    tracker,
    rescue_client,
    local_predictions: List
) -> Tuple[str, List[int]]:
    """
    3-stage chunk ranking: Stage1 BM25 fusion → Stage2 split → Stage3 final.

    Args:
        row: DataFrame row with query data
        stage3_semaphore: Independent semaphore for Stage 3
        client_stage1: Stage 1 LLM client
        client_stage2: Stage 2 LLM client
        client_stage3: Stage 3 LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client
        local_predictions: List to store predictions

    Returns:
        (query_id, ranking) tuple
    """
    query_id = row['query_id']
    question = row['question']
    chunks = row['chunks']
    chunk_indices = row['chunk_indices']
    n_splits = row['n_splits']
    bm25_scores = row.get('bm25_scores', {})

    all_scored_items_fallback = []

    try:
        # Stage 1: Local ranking with BM25 fusion
        parts = split_chunks_n_way(chunks, chunk_indices, n_splits)

        part_semaphore = asyncio.Semaphore(STAGE1_PART_SEMAPHORE)

        # Process parts in parallel
        part_inputs = []
        tasks = []
        for i, (part_chunks, part_indices) in enumerate(parts):
            part_inputs.append(part_indices)
            part_messages = build_chunk_messages_recall(
                question, part_chunks, part_indices, k=FIXED_LOCAL_K
            )

            part_id = f"{query_id}_part{i+1}"
            task = rank_chunk_part_with_retry(
                part_messages, part_semaphore, part_id,
                expected_count=FIXED_LOCAL_K,
                candidate_indices=part_indices,
                client=client_stage1,
                parser=parser,
                tracker=tracker,
                rescue_client=rescue_client
            )
            tasks.append((task, part_id, part_indices))

        # Gather LLM results
        part_responses = []
        for task, part_id, part_indices in tasks:
            result = await task
            part_responses.append(result)

        # BM25 fusion for each part
        fused_part_responses = []
        for part_scores, part_indices in zip(part_responses, part_inputs):
            fused_scores = fuse_llm_bm25_scores(
                llm_scores=part_scores,
                bm25_scores=bm25_scores,
                indices=part_indices,
                weight_llm=FUSION_WEIGHT_STAGE1
            )
            fused_part_responses.append(fused_scores)

        # Combine and track
        all_scored_items = []
        for i, (part_scores, part_indices) in enumerate(zip(fused_part_responses, part_inputs)):
            local_predictions.append((
                query_id, f"part{i+1}", part_scores.copy(), n_splits, part_indices
            ))
            all_scored_items.extend(part_scores)

        all_scored_items_fallback = all_scored_items.copy()

        # Sort by fused score
        all_scored_items.sort(key=lambda x: (-x[1], x[0]))
        top_candidates = all_scored_items[:min(50, len(all_scored_items))]

        try:
            # Stage 2: Split rescore
            stage2_items = await rescore_split_stage2(
                question, chunks, top_candidates, query_id,
                client_stage2, parser, tracker
            )

            if stage2_items:
                local_predictions.append((
                    query_id, "stage2_split", stage2_items.copy(), n_splits, None
                ))

            # Stage 3: Final rescore
            if stage2_items:
                rescored_items = await rescore_final_stage3(
                    question, chunks, stage2_items, query_id, stage3_semaphore,
                    client_stage3, parser, tracker, rescue_client
                )
            else:
                rescored_items = []

            if rescored_items:
                local_predictions.append((
                    query_id, "stage3_final", rescored_items.copy(), n_splits, None
                ))

            # Score-aware padding
            if rescored_items and len(rescored_items) >= 5:
                top_indices = [idx for idx, _ in rescored_items[:5]]
            else:
                # Augment with Stage 1 backup
                combined = list(rescored_items) if rescored_items else []
                stage3_indices = {idx for idx, _ in combined}
                stage1_backup = [(idx, score) for idx, score in all_scored_items
                                if idx not in stage3_indices]
                combined.extend(stage1_backup)
                combined.sort(key=lambda x: (-x[1], x[0]))
                top_indices = [idx for idx, _ in combined[:5]]

                print(f"[WARNING] [{query_id}] Stage 3 insufficient ({len(rescored_items)}), augmented with Stage 1")

                local_predictions.append((
                    query_id, "stage3_final", combined[:50].copy(), n_splits, None
                ))

        except Exception as e:
            print(f"[ERROR] [{query_id}] Stage 2/3 failed: {str(e)[:100]}")
            local_predictions.append((
                query_id, "stage3_final", all_scored_items.copy(), n_splits, None
            ))
            top_indices = [idx for idx, _ in all_scored_items[:5]]

        ranking = pad_to_n_indices(top_indices, 5, chunk_indices, query_id)

        return (query_id, ranking)

    except Exception as e:
        # Catastrophic fallback: model scores → random → duplicates → defaults
        error_msg = str(e).split(':')[0] if ':' in str(e) else str(e)
        print(f"[ERROR] Chunk ranking {query_id} failed: {error_msg}")

        top_indices = []

        # Use model scores if available
        if all_scored_items_fallback:
            top_indices = [idx for idx, _ in all_scored_items_fallback[:5]]

        # Pad with random unused indices
        if len(top_indices) < 5 and chunk_indices:
            used = set(top_indices)
            available = [idx for idx in chunk_indices if idx not in used]
            random.shuffle(available)
            top_indices.extend(available[:5 - len(top_indices)])

        # Cyclic duplicates if still short
        while len(top_indices) < 5 and chunk_indices:
            top_indices.append(chunk_indices[len(top_indices) % len(chunk_indices)])

        # Absolute fallback
        if len(top_indices) < 5:
            top_indices = list(range(5))
            print(f"[ERROR] [{query_id}] Using default [0,1,2,3,4]")

        return (query_id, top_indices[:5])

In [16]:
async def rank_all_documents(
    df: pd.DataFrame,
    client_stage1,
    parser,
    tracker,
    rescue_client
) -> List[Tuple[str, List[int]]]:
    """
    Process all document ranking queries.

    Args:
        df: DataFrame from load_document_data()
        client_stage1: Stage 1 LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client

    Returns:
        (query_id, ranking) tuples
    """
    semaphore = asyncio.Semaphore(QUERY_SEMAPHORE)

    async def process_with_stagger(row, query_index):
        initial_delay = query_index * DOC_STAGGER_INTERVAL
        if initial_delay > 0:
            await asyncio.sleep(initial_delay)
        return await rank_single_document(row, semaphore, client_stage1, parser, tracker, rescue_client)

    tasks = [process_with_stagger(row, idx) for idx, (_, row) in enumerate(df.iterrows())]
    results = await tqdm_asyncio.gather(*tasks, desc="Document ranking")

    return results


async def rank_all_chunks(
    df: pd.DataFrame,
    client_stage1,
    client_stage2,
    client_stage3,
    parser,
    tracker,
    rescue_client,
    local_predictions: List
) -> List[Tuple[str, List[int]]]:
    """
    Process all chunk ranking queries.

    Args:
        df: DataFrame from load_chunk_data()
        client_stage1: Stage 1 LLM client
        client_stage2: Stage 2 LLM client
        client_stage3: Stage 3 LLM client
        parser: Response parser
        tracker: API tracker
        rescue_client: Rescue LLM client
        local_predictions: List to store predictions

    Returns:
        (query_id, ranking) tuples
    """
    query_semaphore = asyncio.Semaphore(QUERY_SEMAPHORE)
    stage3_semaphore = asyncio.Semaphore(STAGE3_SEMAPHORE)

    async def process_with_query_sem(row, query_index):
        initial_delay = query_index * CHUNK_STAGGER_INTERVAL

        if initial_delay > 0:
            await asyncio.sleep(initial_delay)
            print(f"[Query {query_index}] Starting after {initial_delay:.1f}s stagger")

        async with query_semaphore:
            return await rank_single_chunk(
                row, stage3_semaphore,
                client_stage1, client_stage2, client_stage3,
                parser, tracker, rescue_client, local_predictions
            )

    tasks = [process_with_query_sem(row, idx) for idx, (_, row) in enumerate(df.iterrows())]
    results = await tqdm_asyncio.gather(*tasks, desc="Chunk ranking")

    return results

# Main orchestrator

In [17]:
async def main(data_dir: str, output_dir: str):
    """
    Main execution pipeline.

    Args:
        data_dir: Directory containing eval JSONL files
        output_dir: Directory for output submission
    """
    # Determine backends per stage
    backend_stage1 = "databricks" if "databricks" in MODEL_STAGE1.lower() else "openai"
    backend_stage2 = "databricks" if "databricks" in MODEL_STAGE2.lower() else "openai"
    backend_stage3 = "databricks" if "databricks" in MODEL_STAGE3.lower() else "openai"

    # Configure extra parameters for specific models
    extra_params_stage1 = {}
    if "gpt-oss-120b" in MODEL_STAGE1.lower():
        extra_params_stage1["reasoning_effort"] = "medium"
    elif "gpt-5" in MODEL_STAGE1.lower():
        extra_params_stage1["reasoning_effort"] = "minimal"

    client_stage1 = UnifiedLLMClient(
        backend=backend_stage1,
        model=MODEL_STAGE1,
        temperature=0.1,
        max_tokens=1500,
        extra_params=extra_params_stage1
    )

    extra_params_stage2 = {}
    if "gpt-5" in MODEL_STAGE2.lower():
        extra_params_stage2["reasoning_effort"] = "medium"

    client_stage2 = UnifiedLLMClient(
        backend=backend_stage2,
        model=MODEL_STAGE2,
        temperature=0.1,
        max_tokens=1500,
        extra_params=extra_params_stage2
    )

    extra_params_stage3 = {}
    if "gpt-5" in MODEL_STAGE3.lower():
        extra_params_stage3["reasoning_effort"] = "medium"

    client_stage3 = UnifiedLLMClient(
        backend=backend_stage3,
        model=MODEL_STAGE3,
        temperature=0.1,
        max_tokens=1500,
        extra_params=extra_params_stage3
    )

    # Initialize rescue client if OpenAI API key available
    rescue_client = None
    if os.getenv("OPENAI_API_KEY"):
        rescue_client = UnifiedLLMClient(
            backend="openai",
            model="gpt-4o-mini",
            temperature=0.0,
            max_tokens=100
        )
    else:
        print("[WARNING] OPENAI_API_KEY not set, rescue disabled")

    # Initialize tracker and parser
    tracker = APITracker()
    parser = ResponseParser()

    # Initialize local predictions list
    local_predictions = []

    # Load data
    doc_path = Path(data_dir) / "document_ranking_kaggle_eval.jsonl"
    chunk_path = Path(data_dir) / "chunk_ranking_kaggle_eval.jsonl"

    doc_df = load_document_data(doc_path)
    chunk_df = load_chunk_data(chunk_path, TARGET_TOKENS_PER_PART)

    # Apply sample size if specified
    if SAMPLE_SIZE:
        doc_df = doc_df.head(SAMPLE_SIZE)
        chunk_df = chunk_df.head(SAMPLE_SIZE)

    # Execute ranking
    start_time = time.time()

    doc_start = time.time()
    doc_results = await rank_all_documents(doc_df, client_stage1, parser, tracker, rescue_client)
    doc_time = time.time() - doc_start

    chunk_start = time.time()
    chunk_results = await rank_all_chunks(
        chunk_df, client_stage1, client_stage2, client_stage3,
        parser, tracker, rescue_client, local_predictions
    )
    chunk_time = time.time() - chunk_start

    # Generate submission
    submission_data = []
    for query_id, ranking in doc_results:
        if not ranking:
            continue
        sample_id = query_id if query_id.startswith('doc_') else f'doc_{query_id}'
        for target_idx in ranking[:5]:
            submission_data.append({'sample_id': sample_id, 'target_index': target_idx})

    for query_id, ranking in chunk_results:
        if not ranking:
            continue
        sample_id = query_id if query_id.startswith('chunk_') else f'chunk_{query_id}'
        for target_idx in ranking[:5]:
            submission_data.append({'sample_id': sample_id, 'target_index': target_idx})

    # Write submission
    output_path = Path(output_dir) / "submission.csv"

    with open(output_path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['sample_id', 'target_index'])
        for entry in submission_data:
            writer.writerow([entry['sample_id'], entry['target_index']])

    # Validate submission
    validate_submission(output_path)

    # Print summary
    total_time = time.time() - start_time

    def format_time(seconds: float) -> str:
        """Format seconds as mm:ss"""
        minutes = int(seconds // 60)
        secs = int(seconds % 60)
        return f"{minutes}:{secs:02d}"

    print("\n" + "="*80)
    print("FUSION SE COMPLETE")
    print("="*80)
    print(f"Submission: {output_path} ({len(submission_data)} rows)")
    print(f"Runtime: {format_time(total_time)} (doc={format_time(doc_time)}, chunk={format_time(chunk_time)})")
    tracker.print_summary()
    print("="*80)

In [20]:
# Optional: validate output for submission

def validate_submission(csv_path: Path) -> bool:
    """
    Validate submission CSV format for Kaggle.

    Args:
        csv_path: Path to submission.csv

    Returns:
        True if validation passed

    Raises:
        AssertionError if validation fails
    """
    df = pd.read_csv(csv_path)

    # Check 1: Column names
    expected_cols = ['sample_id', 'target_index']
    assert list(df.columns) == expected_cols, f"Expected columns {expected_cols}, got {list(df.columns)}"

    # Check 2: Row count (varies based on SAMPLE_SIZE)

    # Check 3: No duplicates
    assert not df.duplicated().any(), "Found duplicate rows"

    # Check 4: Target index validity
    doc_targets = df[df['sample_id'].str.startswith('doc_')]['target_index']
    chunk_targets = df[df['sample_id'].str.startswith('chunk_')]['target_index']

    assert doc_targets.between(0, 4).all(), "Document target_index must be 0-4"
    assert chunk_targets.min() >= 0, "Chunk target_index must be >= 0"

    return True

# Run Evaluation
- cost ~$2 USD in Databricks
- for `gpt-oss-120b` and `llama-405b`

In [19]:
# Run the main execution pipeline
# Notebook mode: Uses DATA_DIR and OUTPUT_DIR from Configuration cell above

await main(DATA_DIR, OUTPUT_DIR)

Document ranking: 100%|██████████| 2/2 [00:04<00:00,  2.13s/it]
Chunk ranking:   0%|          | 0/2 [00:00<?, ?it/s]

[Query 1] Starting after 3.0s stagger


Chunk ranking:  50%|█████     | 1/2 [02:31<02:31, 151.85s/it]



Chunk ranking: 100%|██████████| 2/2 [03:53<00:00, 116.74s/it]


FUSION SE COMPLETE
Submission: output/submission.csv (20 rows)
Runtime: 3:57 (doc=0:04, chunk=3:53)

API STATISTICS SUMMARY
Total API calls: 18
Total time: 92.4s (avg: 5.14s/call)

Document ranking:
  Calls: 2
  Time: 5.1s (avg: 2.55s/call)
  Errors: 0
  Retries: 0

Chunk ranking:
  Calls: 16
  Time: 87.3s (avg: 5.46s/call)
  Errors: 0
  Retries: 0

Retries breakdown:
  Rate limit retries: 2
  Parsing retries: 0



