# Demo: Entity-Aware Indexing

Demonstrates `rag.ner.BiomedicalNER` for entity extraction and entity-filtered retrieval.

In [None]:
%load_ext autoreload
%autoreload 2

from datasets import load_dataset
from langchain_core.documents import Document

from rag.ner import BiomedicalNER
from rag.retrieval import HybridRetriever, HybridConfig

## 1. Entity Extraction

In [None]:
ner = BiomedicalNER()

# Test extraction
text = "BRCA1 and BRCA2 mutations significantly increase breast cancer risk. Tamoxifen is commonly used for treatment."
result = ner.extract(text)

print("Entities found:")
for entity in result.entities:
    print(f"  {entity.label}: {entity.text}")

print("\nGrouped:")
print(result.to_dict())

## 2. Load & Enrich Documents

In [None]:
# Load subset for demo
corpus = load_dataset("rag-datasets/rag-mini-bioasq", "text-corpus", split="passages")
corpus = corpus.filter(lambda x: x["passage"] and x["passage"] != "nan")
corpus = corpus.select(range(1000))  # Subset for demo

docs = [
    Document(page_content=row["passage"], metadata={"doc_id": row["id"]})
    for row in corpus
]
print(f"Loaded {len(docs)} documents")

In [None]:
# Enrich with entities
docs = ner.enrich_with_filter_fields(docs)

# Check example
print("Example document metadata:")
print(docs[0].metadata)

## 3. Index with Entities

In [None]:
config = HybridConfig(collection_name="bioasq-entities-demo")
retriever = HybridRetriever(config)
retriever.index(docs, force_recreate=True)

## 4. Entity-Filtered Search

In [None]:
query = "What causes cancer?"

# Standard search
print("=== Standard Search ===")
for doc in retriever.search(query, k=3):
    print(f"[{doc.metadata['doc_id']}] {doc.page_content[:100]}...")

print("\n=== Filtered by GENE entities ===")
# Search only docs that mention genes
for doc in retriever.search_with_filter(query, k=3, entity_filter={"entity_GENE": ["p53", "BRCA1", "BRCA2"]}):
    print(f"[{doc.metadata['doc_id']}] {doc.page_content[:100]}...")
    print(f"  Entities: {doc.metadata.get('entities', {})}")

## 5. Entity Statistics

In [None]:
from collections import Counter

# Count entity types across corpus
entity_counts = Counter()
for doc in docs:
    entities = doc.metadata.get("entities", {})
    for etype, elist in entities.items():
        entity_counts[etype] += len(elist)

print("Entity type distribution:")
for etype, count in entity_counts.most_common():
    print(f"  {etype}: {count}")