# S-C Evidence Pipeline - Stage 1: Retrieval

This notebook runs the NV-Embed-v2 retrieval stage and saves candidates for reranking.

**Run this notebook in the NV-Embed-v2 conda environment.**

## Output
Saves retrieval candidates to `outputs/retrieval_candidates/` for the reranker notebook.

In [None]:
import os
import sys
import json
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple, Set
from tqdm.auto import tqdm

# Add project root to path
project_root = Path("..")
sys.path.insert(0, str(project_root))

from final_sc_review.data.io import load_criteria, load_groundtruth, load_sentence_corpus
from final_sc_review.data.splits import split_post_ids
from final_sc_review.retriever.zoo import RetrieverZoo

## 1. Configuration

In [None]:
# Pipeline config
RETRIEVER_NAME = "nv-embed-v2"
TOP_K_RETRIEVER = 20  # Candidates from retriever
N_FOLDS = 5

# Data paths
DATA_DIR = project_root / "data"
GROUNDTRUTH_PATH = DATA_DIR / "groundtruth/evidence_sentence_groundtruth.csv"
CORPUS_PATH = DATA_DIR / "groundtruth/sentence_corpus.jsonl"
CRITERIA_PATH = DATA_DIR / "DSM5/MDD_Criteira.json"
CACHE_DIR = DATA_DIR / "cache"

# Output directory for retrieval candidates
OUTPUT_DIR = project_root / "outputs" / "retrieval_candidates"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Retriever: {RETRIEVER_NAME}")
print(f"Top-K Retriever: {TOP_K_RETRIEVER}")
print(f"N Folds: {N_FOLDS}")
print(f"Output Directory: {OUTPUT_DIR}")

## 2. Load Data

In [None]:
# Load groundtruth, sentences, and criteria
groundtruth = load_groundtruth(GROUNDTRUTH_PATH)
sentences = load_sentence_corpus(CORPUS_PATH)
criteria = load_criteria(CRITERIA_PATH)

print(f"Groundtruth rows: {len(groundtruth)}")
print(f"Sentences: {len(sentences)}")
print(f"Criteria: {len(criteria)}")

# Build lookup maps
sentences_by_post = defaultdict(list)
for sent in sentences:
    sentences_by_post[sent.post_id].append(sent)

sentence_map = {s.sent_uid: s for s in sentences}
criteria_map = {c.criterion_id: c.text for c in criteria}

# Get all post IDs
all_post_ids = sorted(set(row.post_id for row in groundtruth))
print(f"Unique posts: {len(all_post_ids)}")

## 3. Create 5-Fold Cross-Validation Splits

In [None]:
def create_kfold_splits(post_ids: List[str], n_folds: int = 5, seed: int = 42) -> List[Tuple[List[str], List[str]]]:
    """Create k-fold cross-validation splits at post level (post-disjoint)."""
    np.random.seed(seed)
    post_ids = np.array(post_ids)
    np.random.shuffle(post_ids)
    
    fold_size = len(post_ids) // n_folds
    folds = []
    
    for i in range(n_folds):
        start = i * fold_size
        if i == n_folds - 1:
            end = len(post_ids)
        else:
            end = start + fold_size
        
        val_posts = post_ids[start:end].tolist()
        train_posts = np.concatenate([post_ids[:start], post_ids[end:]]).tolist()
        folds.append((train_posts, val_posts))
    
    return folds

# Create folds
folds = create_kfold_splits(all_post_ids, n_folds=N_FOLDS)

for i, (train_posts, val_posts) in enumerate(folds):
    print(f"Fold {i+1}: Train={len(train_posts)} posts, Val={len(val_posts)} posts")

## 4. Build Query Data from Groundtruth

In [None]:
def build_query_data(groundtruth, post_ids: Set[str], include_no_evidence: bool = True):
    """Build query data from groundtruth for given post IDs.
    
    Returns list of dicts with:
        - post_id, criterion_id, criterion_text
        - gold_uids: set of positive sentence UIDs
        - is_no_evidence: whether this query has no positives
    """
    queries = {}
    
    for row in groundtruth:
        if row.post_id not in post_ids:
            continue
        
        key = (row.post_id, row.criterion_id)
        if key not in queries:
            queries[key] = {
                'post_id': row.post_id,
                'criterion_id': row.criterion_id,
                'criterion_text': criteria_map.get(row.criterion_id, row.criterion_id),
                'gold_uids': set(),
            }
        
        if row.groundtruth == 1:
            queries[key]['gold_uids'].add(row.sent_uid)
    
    result = []
    for query_data in queries.values():
        query_data['is_no_evidence'] = len(query_data['gold_uids']) == 0
        if include_no_evidence or not query_data['is_no_evidence']:
            result.append(query_data)
    
    return result

# Test with all posts
all_queries = build_query_data(groundtruth, set(all_post_ids), include_no_evidence=True)
queries_with_evidence = [q for q in all_queries if not q['is_no_evidence']]
queries_no_evidence = [q for q in all_queries if q['is_no_evidence']]

print(f"Total queries: {len(all_queries)}")
print(f"Queries with evidence: {len(queries_with_evidence)}")
print(f"Queries with no evidence: {len(queries_no_evidence)}")

## 5. Initialize Retriever

In [None]:
# Initialize retriever (shared across folds)
retriever_zoo = RetrieverZoo(sentences=sentences, cache_dir=CACHE_DIR)
retriever = retriever_zoo.get_retriever(RETRIEVER_NAME)

# Encode corpus (uses cache if available)
print(f"Encoding corpus for {RETRIEVER_NAME}...")
retriever.encode_corpus(rebuild=False)
print("Retriever ready!")

## 6. Prepare Retrieval Data for Each Fold

In [None]:
def prepare_retrieval_data(queries, retriever, top_k: int = 20):
    """Prepare retrieval data by retrieving candidates for each query."""
    retrieval_data = []
    
    for query in tqdm(queries, desc="Retrieving candidates"):
        post_id = query['post_id']
        criterion_text = query['criterion_text']
        gold_uids = query['gold_uids']
        
        # Get sentences for this post
        post_sentences = sentences_by_post.get(post_id, [])
        if len(post_sentences) < 2:
            continue
        
        # Retrieve candidates
        try:
            results = retriever.retrieve_within_post(
                query=criterion_text,
                post_id=post_id,
                top_k=min(top_k, len(post_sentences)),
            )
        except Exception as e:
            print(f"Error retrieving for {post_id}: {e}")
            continue
        
        if len(results) < 2:
            continue
        
        candidates = []
        for r in results:
            candidates.append({
                'sent_uid': r.sent_uid,
                'text': r.text,
                'score': float(r.score),
                'label': 1 if r.sent_uid in gold_uids else 0,
            })
        
        retrieval_data.append({
            'query': criterion_text,
            'post_id': post_id,
            'criterion_id': query['criterion_id'],
            'gold_uids': list(gold_uids),
            'is_no_evidence': query['is_no_evidence'],
            'candidates': candidates,
        })
    
    return retrieval_data

In [None]:
# Run retrieval for each fold and save results
all_fold_data = {}

for fold_idx, (train_posts, val_posts) in enumerate(folds):
    print(f"\n{'='*60}")
    print(f"FOLD {fold_idx + 1}/{N_FOLDS}")
    print(f"{'='*60}")
    print(f"Train posts: {len(train_posts)}, Val posts: {len(val_posts)}")
    
    # Build query data for this fold
    train_queries = build_query_data(groundtruth, set(train_posts), include_no_evidence=True)
    val_queries = build_query_data(groundtruth, set(val_posts), include_no_evidence=True)
    
    print(f"Train queries: {len(train_queries)}, Val queries: {len(val_queries)}")
    
    # Prepare training and validation data with retriever candidates
    print("\nPreparing training data...")
    train_data = prepare_retrieval_data(train_queries, retriever, top_k=TOP_K_RETRIEVER)
    
    print("Preparing validation data...")
    val_data = prepare_retrieval_data(val_queries, retriever, top_k=TOP_K_RETRIEVER)
    
    print(f"Training samples: {len(train_data)}, Validation samples: {len(val_data)}")
    
    # Store fold data
    all_fold_data[f'fold_{fold_idx + 1}'] = {
        'train_posts': train_posts,
        'val_posts': val_posts,
        'train_data': train_data,
        'val_data': val_data,
    }

print(f"\nRetrieval complete for all {N_FOLDS} folds!")

## 7. Save Retrieval Results

In [None]:
# Save configuration metadata
config_metadata = {
    'retriever_name': RETRIEVER_NAME,
    'top_k_retriever': TOP_K_RETRIEVER,
    'n_folds': N_FOLDS,
    'total_posts': len(all_post_ids),
    'total_queries': len(all_queries),
}

with open(OUTPUT_DIR / "config_metadata.json", "w") as f:
    json.dump(config_metadata, f, indent=2)

print(f"Saved config metadata to {OUTPUT_DIR / 'config_metadata.json'}")

In [None]:
# Save fold data using pickle (efficient for large internal data)
# Note: This is internal pipeline data, not untrusted external content
output_file = OUTPUT_DIR / "retrieval_candidates.pkl"

with open(output_file, "wb") as f:
    pickle.dump(all_fold_data, f)

# Get file size
file_size_mb = output_file.stat().st_size / (1024 * 1024)
print(f"Saved retrieval candidates to {output_file}")
print(f"File size: {file_size_mb:.2f} MB")

In [None]:
# Also save as JSON for inspection (optional, may be large)
json_output_file = OUTPUT_DIR / "retrieval_candidates.json"

with open(json_output_file, "w") as f:
    json.dump(all_fold_data, f, indent=2)

json_file_size_mb = json_output_file.stat().st_size / (1024 * 1024)
print(f"Saved JSON backup to {json_output_file}")
print(f"File size: {json_file_size_mb:.2f} MB")

## 8. Summary

In [None]:
print("\n" + "="*60)
print("RETRIEVAL STAGE COMPLETE")
print("="*60)

print(f"\nConfiguration:")
print(f"  - Retriever: {RETRIEVER_NAME}")
print(f"  - Top-K: {TOP_K_RETRIEVER}")
print(f"  - N Folds: {N_FOLDS}")

print(f"\nData Statistics:")
for fold_name, fold_data in all_fold_data.items():
    print(f"  - {fold_name}: {len(fold_data['train_data'])} train, {len(fold_data['val_data'])} val samples")

print(f"\nOutput Files:")
print(f"  - {OUTPUT_DIR / 'config_metadata.json'}")
print(f"  - {OUTPUT_DIR / 'retrieval_candidates.pkl'} ({file_size_mb:.2f} MB)")
print(f"  - {OUTPUT_DIR / 'retrieval_candidates.json'} ({json_file_size_mb:.2f} MB)")

print(f"\nNext Step:")
print(f"  Run sc_reranker_pipeline_no_postProcessing.ipynb in the Jina reranker environment.")