# Single Example Test - Citation Retrieval System

This notebook tests the full retrieval pipeline on a single example from the ScholarCopilot dataset.

**Pipeline:**
1. Load dataset and extract one citation context
2. Build retrieval resources (BM25, E5, SPECTER)
3. Run retrieval with all three methods
4. Aggregate results using RRF
5. Rerank with LLM (Hugging Face or Ollama)
6. Show final ranked papers

In [1]:
# Setup
import json
import re
import sys
from pathlib import Path
from dotenv import load_dotenv
import os

# Load environment variables
load_dotenv()

# Add parent directory to path
sys.path.insert(0, str(Path.cwd()))

print("‚úÖ Imports successful")
print(f"üìÅ Working directory: {Path.cwd()}")
print(f"üîß Inference Engine: {os.getenv('INFERENCE_ENGINE', 'ollama')}")
print(f"ü§ñ Local LLM: {os.getenv('LOCAL_LLM', 'gemma3:4b')}")

‚úÖ Imports successful
üìÅ Working directory: /Users/ishaankalra/Dev/Retrieval/server
üîß Inference Engine: ollama
ü§ñ Local LLM: gemma3:4b


## Step 1: Load Dataset and Extract Citation Context

In [3]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent  # from .../CitationRetrieval/server to .../CitationRetrieval
sys.path.insert(0, str(PROJECT_ROOT))
from corpus_loaders.scholarcopilot import load_dataset, build_citation_corpus

# Load dataset
dataset_path = os.getenv('DATASET_DIR', 'corpus_loaders/scholarcopilot/scholar_copilot_eval_data_1k.json')
print(f"üìö Loading dataset from: {dataset_path}")

dataset = load_dataset(dataset_path)
print(f"‚úÖ Loaded {len(dataset)} papers")

# Build corpus
print("\nüî® Building citation corpus...")
corpus = build_citation_corpus(dataset)
print(f"‚úÖ Corpus: {len(corpus)} documents")

üìö Loading dataset from: /Users/ishaankalra/Dev/Retrieval/corpus_loaders/scholarcopilot/scholar_copilot_eval_data_1k.json
‚úÖ Loaded 1000 papers

üî® Building citation corpus...
‚úÖ Corpus: 9740 documents


In [None]:
# ============================================================================
# CONFIGURATION: Change this to test different queries
# ============================================================================
QUERY_INDEX = 0  # Change to 1 for 2nd query, 2 for 3rd query, etc.
MAX_PAPERS_TO_SEARCH = 100  # How many papers to search through

# ============================================================================
# Extract citation contexts
# ============================================================================
cite_pattern = re.compile(r"<\|cite_\d+\|>")

print(f"üîç Extracting citation contexts from first {MAX_PAPERS_TO_SEARCH} papers...")

all_examples = []
for paper in dataset[:MAX_PAPERS_TO_SEARCH]:
    paper_text = paper.get("paper", "")
    if not paper_text:
        continue

    bib_info = paper.get("bib_info", {})

    # Find all citation markers
    for match in cite_pattern.finditer(paper_text):
        cite_token = match.group(0)

        if cite_token not in bib_info:
            continue

        refs = bib_info[cite_token]
        if not refs:
            continue

        # Get ground truth IDs
        relevant_ids = set()
        for ref in refs:
            ref_id = ref.get("citation_key") or ref.get("paper_id")
            if ref_id:
                relevant_ids.add(str(ref_id))

        if not relevant_ids:
            continue

        # Extract context around citation (¬±100 words)
        pos = match.start()
        words_before = paper_text[:pos].split()[-100:]
        words_after = paper_text[match.end():].split()[:100]

        context = " ".join(words_before + words_after)
        context = re.sub(r"<\|cite_\d+\|>", "", context)
        context = " ".join(context.split())

        if len(context.split()) < 10:
            continue

        all_examples.append({
            "query": context,
            "relevant_ids": relevant_ids,
            "ground_truth_titles": [ref.get("title", "Unknown") for ref in refs],
            "paper_id": paper.get("paper_id", "unknown")
        })

print(f"‚úÖ Found {len(all_examples)} valid citation contexts")

# Select the query at QUERY_INDEX
if QUERY_INDEX >= len(all_examples):
    raise ValueError(f"QUERY_INDEX={QUERY_INDEX} is too large. Only {len(all_examples)} queries available.")

test_example = all_examples[QUERY_INDEX]

print("\n" + "="*80)
print(f"üìù TEST EXAMPLE #{QUERY_INDEX + 1} of {len(all_examples)}")
print("="*80)
print(f"\nüîç Query (first 300 chars):\n{test_example['query'][:300]}...")
print(f"\n‚úÖ Ground Truth Citations ({len(test_example['relevant_ids'])}):")
for i, title in enumerate(test_example['ground_truth_titles'], 1):
    print(f"   {i}. {title}")
print(f"\nüí° To test a different query, change QUERY_INDEX (0 to {len(all_examples)-1})")
print("="*80)

## Step 2: Build Retrieval Resources

In [5]:
from src.resources.builders import build_inmemory_resources

print("üîß Building retrieval resources...")
print("   This may take a few minutes on first run (models will be downloaded)")

resources = build_inmemory_resources(
    corpus,
    enable_bm25=True,
    enable_e5=True,
    enable_specter=True
)

print("\n‚úÖ Resources built:")
print(f"   - BM25: {len(resources['bm25']['ids'])} documents indexed")
print(f"   - E5: {resources['e5']['corpus_embeddings'].shape[0]} embeddings")
print(f"   - SPECTER: {resources['specter']['corpus_embeddings'].shape[0]} embeddings")

üîß Building retrieval resources...
   This may take a few minutes on first run (models will be downloaded)


  from .autonotebook import tqdm as notebook_tqdm
Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 609/609 [13:51<00:00,  1.37s/it]                  
Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1218/1218 [09:22<00:00,  2.16batch/s]


‚úÖ Resources built:
   - BM25: 9740 documents indexed
   - E5: 9740 embeddings
   - SPECTER: 9740 embeddings





## Step 3: Run Individual Retrievers

In [10]:
import bm25s
import torch

k = 30  # Top-k results
query = test_example['query']

# BM25
print("üîç Running BM25...")
bm25_res = resources['bm25']
q_tokens = bm25s.tokenize(query, stopwords="en", stemmer=bm25_res["stemmer"])
doc_indices, scores = bm25_res["index"].retrieve(q_tokens, k=k)

bm25_results = []
for idx, score in zip(doc_indices[0], scores[0]):
    bm25_results.append({
        "id": bm25_res["ids"][idx],
        "title": bm25_res.get("titles", [""])[idx],
        "score": float(score),
        "source": "bm25"
    })

print(f"   ‚úÖ Retrieved {len(bm25_results)} papers (top score: {bm25_results[0]['score']:.3f})")

# E5
print("\nüîç Running E5...")
e5_res = resources['e5']
with torch.no_grad():
    q_emb = e5_res['model'].encode(
        [query], convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False
    )

scores = (q_emb @ e5_res['corpus_embeddings'].T)[0]
top_k = torch.topk(scores, k=min(k, len(scores)))

e5_results = []
for idx, score in zip(top_k.indices, top_k.values):
    e5_results.append({
        "id": e5_res["ids"][idx.item()],
        "title": e5_res.get("titles", [""])[idx.item()],
        "score": score.item(),
        "source": "e5"
    })

print(f"   ‚úÖ Retrieved {len(e5_results)} papers (top score: {e5_results[0]['score']:.3f})")

# SPECTER
print("\nüîç Running SPECTER...")
specter_res = resources['specter']
device = specter_res.get('device') or str(specter_res['corpus_embeddings'].device)

model = specter_res['model']
if str(next(model.parameters()).device) != device:
    model = model.to(device)

tokenizer = specter_res['tokenizer']

with torch.no_grad():
    inputs = tokenizer([query], padding=True, truncation=True, max_length=256, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    q_emb = model(**inputs).last_hidden_state.mean(dim=1)
    q_emb = torch.nn.functional.normalize(q_emb, dim=1)

corpus_embs = specter_res['corpus_embeddings'].to(device)
corpus_embs = torch.nn.functional.normalize(corpus_embs, dim=1)
scores = (q_emb @ corpus_embs.T)[0]
top_k = torch.topk(scores, k=min(k, len(scores)))

specter_results = []
for idx, score in zip(top_k.indices, top_k.values):
    specter_results.append({
        "id": specter_res["ids"][idx.item()],
        "title": specter_res.get("titles", [""])[idx.item()],
        "score": score.item(),
        "source": "specter"
    })

print(f"   ‚úÖ Retrieved {len(specter_results)} papers (top score: {specter_results[0]['score']:.3f})")

üîç Running BM25...


                                                     

   ‚úÖ Retrieved 30 papers (top score: 26.077)

üîç Running E5...




   ‚úÖ Retrieved 30 papers (top score: 0.897)

üîç Running SPECTER...
   ‚úÖ Retrieved 30 papers (top score: 0.937)


## Step 4: Aggregate with Reciprocal Rank Fusion (RRF)

In [12]:
def reciprocal_rank_fusion(results_dict, k=60):
    """Aggregate results using RRF."""
    paper_scores = {}
    
    for source, results in results_dict.items():
        for rank, paper in enumerate(results):
            paper_id = paper['id']
            rrf_score = 1.0 / (k + rank + 1)
            
            if paper_id not in paper_scores:
                paper_scores[paper_id] = {
                    'paper': paper,
                    'rrf_score': 0,
                    'sources': []
                }
            
            paper_scores[paper_id]['rrf_score'] += rrf_score
            paper_scores[paper_id]['sources'].append(source)
    
    # Sort by RRF score
    ranked = sorted(paper_scores.values(), key=lambda x: x['rrf_score'], reverse=True)
    return ranked

print("üîÄ Aggregating results with RRF...")
aggregated = reciprocal_rank_fusion({
    'bm25': bm25_results,
    'e5': e5_results,
    'specter': specter_results
})

print(f"   ‚úÖ Aggregated into {len(aggregated)} unique papers")

# Show top 5
print("\nüìä Top 5 after RRF aggregation:")
for i, item in enumerate(aggregated, 1):
    paper = item['paper']
    print(f"\n{i}. {paper['title'][:80]}...")
    print(f"   RRF Score: {item['rrf_score']:.4f}")
    print(f"   Sources: {', '.join(item['sources'])}")

üîÄ Aggregating results with RRF...
   ‚úÖ Aggregated into 74 unique papers

üìä Top 5 after RRF aggregation:

1. Lite Transformer with Long-Short Range Attention...
   RRF Score: 0.0474
   Sources: bm25, e5, specter

2. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding...
   RRF Score: 0.0401
   Sources: bm25, e5, specter

3. Tinybert: Distilling {BERT} for natural language understanding...
   RRF Score: 0.0320
   Sources: bm25, specter

4. TinyBERT: Distilling BERT for Natural Language Understanding...
   RRF Score: 0.0304
   Sources: bm25, specter

5. Funnel-transformer: Filtering out sequential redundancy for efficient language p...
   RRF Score: 0.0301
   Sources: bm25, specter

6. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter...
   RRF Score: 0.0300
   Sources: bm25, specter

7. Q8bert: Quantized 8bit bert...
   RRF Score: 0.0296
   Sources: bm25, e5

8. A survey on visual transformer...
   RRF Score: 0.0295
   Sour

## Step 5: LLM Reranking (Hugging Face or Ollama)

In [13]:
from src.prompts.llm_reranker import LLMRerankerPrompt
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
import logging

# Suppress verbose logs
logging.getLogger("httpx").setLevel(logging.WARNING)

# Prepare candidates
candidate_papers = [item['paper'] for item in aggregated[:20]]  # Top 20 for reranking

print(f"ü§ñ LLM Reranking {len(candidate_papers)} candidates...\n")

# Initialize LLM
inference_engine = os.getenv("INFERENCE_ENGINE", "ollama").lower()
model_id = os.getenv("LOCAL_LLM", "gemma3:4b")

if inference_engine == "ollama":
    print(f"üîÑ Using Ollama with model: {model_id}")
    llm = ChatOllama(model=model_id, temperature=0)
else:
    print(f"üîÑ Using Hugging Face model: {model_id}")
    from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
    from langchain_huggingface import HuggingFacePipeline
    import torch
    
    print("   Loading model (this may take a minute)...")
    tok = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    )
    
    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tok,
        max_new_tokens=1024,
        do_sample=False,
    )
    
    llm = HuggingFacePipeline(pipeline=gen)
    print("   ‚úÖ Model loaded!")

# Build prompt
prompt = LLMRerankerPrompt(query=query, candidate_papers=candidate_papers).get_prompt()

print("\nüìù Invoking LLM...")
response = llm.invoke(prompt)
response_text = response.content

print(f"‚úÖ LLM response received ({len(response_text)} chars)")

ü§ñ LLM Reranking 20 candidates...

üîÑ Using Ollama with model: gemma3:4b

üìù Invoking LLM...
‚úÖ LLM response received (645 chars)


## Step 6: Parse Results and Show Rankings

In [14]:
import json

# Parse LLM response
try:
    json_match = re.search(r"\[[\s\S]*\]", response_text)
    if json_match:
        json_str = json_match.group()
        rankings = json.loads(json_str)
    else:
        rankings = json.loads(response_text)
    
    print(f"‚úÖ Successfully parsed {len(rankings)} rankings\n")
    
    # Build final ranked list
    final_ranked = []
    for item in rankings:
        idx = item['index'] - 1
        score = item['score']
        if 0 <= idx < len(candidate_papers):
            final_ranked.append((candidate_papers[idx], score))
    
    # Display results
    print("="*80)
    print("üèÜ FINAL RANKED RESULTS (Top 10)")
    print("="*80)
    
    relevant_ids = test_example['relevant_ids']
    found_count = 0
    
    for i, (paper, score) in enumerate(final_ranked[:10], 1):
        is_correct = paper['id'] in relevant_ids
        marker = "‚úÖ GROUND TRUTH" if is_correct else ""
        
        if is_correct:
            found_count += 1
        
        print(f"\n{i}. {paper['title'][:100]}...")
        print(f"   LLM Score: {score:.3f} {marker}")
        print(f"   ID: {paper['id']}")
    
    print("\n" + "="*80)
    print(f"üìä EVALUATION")
    print("="*80)
    print(f"Ground truth citations found in top 10: {found_count}/{len(relevant_ids)}")
    print(f"Recall@10: {found_count/len(relevant_ids):.2%}")
    
except Exception as e:
    print(f"‚ùå Error parsing LLM response: {e}")
    print(f"\nRaw response (first 500 chars):\n{response_text[:500]}")

‚úÖ Successfully parsed 20 rankings

üèÜ FINAL RANKED RESULTS (Top 10)

1. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding...
   LLM Score: 0.980 
   ID: devlin2019bert

2. Lite Transformer with Long-Short Range Attention...
   LLM Score: 0.950 
   ID: lite_transformer

3. Tinybert: Distilling {BERT} for natural language understanding...
   LLM Score: 0.850 
   ID: jiao2019tinybert

4. TinyBERT: Distilling BERT for Natural Language Understanding...
   LLM Score: 0.800 
   ID: jiao20tinybert

5. Funnel-transformer: Filtering out sequential redundancy for efficient language processing...
   LLM Score: 0.750 
   ID: funnel

6. Distilbert, a distilled version of BERT: smaller, faster, cheaper and lighter...
   LLM Score: 0.700 
   ID: distillbert

7. Q8bert: Quantized 8bit bert...
   LLM Score: 0.650 
   ID: zafrir2019q8bert

8. A survey on visual transformer...
   LLM Score: 0.600 
   ID: han2020survey

9. Compressing large-scale transformer-based models:

## Summary

This notebook demonstrated the complete citation retrieval pipeline:
1. ‚úÖ Loaded citation context from ScholarCopilot dataset
2. ‚úÖ Built BM25, E5, and SPECTER retrieval indices
3. ‚úÖ Retrieved top-k candidates from each method
4. ‚úÖ Aggregated using Reciprocal Rank Fusion
5. ‚úÖ Reranked with LLM (Hugging Face or Ollama)
6. ‚úÖ Evaluated against ground truth

**To run full evaluation:**
```bash
python compare_baselines_vs_system.py --num-examples 500 --use-dspy --llm-reranker --output-dir final --k 20
```