<a href="https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/docs/examples/retrievers/bm25_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BM25 Retriever
In this guide, we define a bm25 retriever that search documents using bm25 method.

This notebook is very similar to the RouterQueryEngine notebook.

## Setup

If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙.

In [1]:
# NOTE: This is ONLY necessary in jupyter notebook.
# Details: Jupyter runs an event-loop behind the scenes.
#          This results in nested event-loops when we start an event-loop to make async queries.
#          This is normally not allowed, we use nest_asyncio to allow it for convenience.
import nest_asyncio

nest_asyncio.apply()

In [2]:
import os
import openai

os.environ["OPENAI_API_KEY"] = "sk-..."
openai.api_key = os.environ["OPENAI_API_KEY"]

In [5]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().handlers = []
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

from llama_index.core import (
    SimpleDirectoryReader,
    StorageContext,
    VectorStoreIndex,
)
from llama_index.legacy.retrievers.bm25_retriever import BM25Retriever
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.openai import OpenAI

from llama_index.core.response.notebook_utils import display_source_node
from llama_index.core.retrievers import RouterRetriever
from llama_index.core.tools import RetrieverTool


## Load Data

We first show how to convert a Document into a set of Nodes, and insert into a DocumentStore.

In [6]:
# load documents
documents = SimpleDirectoryReader("../data/paul_graham").load_data()

In [7]:
# initialize LLM + node parser
llm = OpenAI(model="zhipu")
splitter = SentenceSplitter(chunk_size=1024)

nodes = splitter.get_nodes_from_documents(documents)

In [8]:
# initialize storage context (by default it's in-memory)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)

In [None]:
index = VectorStoreIndex(
    nodes=nodes,
    storage_context=storage_context,
)

HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


## BM25 Retriever

We will search document with bm25 retriever.

In [11]:
# We can pass in the index, doctore, or list of nodes to create the retriever
retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)

In [26]:
from llama_index.legacy.schema import NodeWithScore, QueryBundle, Node

# class JieRetriever(BM25Retriever):
#     async def _get_scored_nodes(self, query: str):
#         tokenized_query = self._tokenizer(query)
#         doc_scores = self.bm25.get_scores(tokenized_query)
#         nodes = []
#         for i, node in enumerate(self._nodes):
#             node_new = Node.from_dict(node.to_dict())
#             node_with_score = NodeWithScore(node=node_new, score=doc_scores[i])
#             nodes.append(node_with_score)
#         return nodes


In [1]:
class AdvancedRAG:
    def __init__(self):
        _ = load_dotenv(find_dotenv())
        # load documents
        self.documents = SimpleDirectoryReader("./data/", required_exts=['.pdf']).load_data()

        # global variables used later in code after initialization
        self.retriever = None
        self.reranker = None
        self.query_engine = None

        self.bootstrap()

    def bootstrap(self):
        # initialize LLMs
        llm = OpenAI(model="gpt-4",
                    api_key=os.getenv("OPENAI_API_KEY"),
                    temperature=0,
                    system_prompt="You are an expert on the Inflamatory Bowel Diseases and your job is to answer questions. Assume that all questions are related to the Inflammatory Bowel Diseases (IBD). Keep your answers technical and based on facts – do not hallucinate features.")

        # initialize service context (set chunk size)
        service_context = ServiceContext.from_defaults(chunk_size=1024, llm=llm)
        nodes = service_context.node_parser.get_nodes_from_documents(self.documents)

        # initialize storage context (by default it's in-memory)
        storage_context = StorageContext.from_defaults()
        storage_context.docstore.add_documents(nodes)

        index = VectorStoreIndex(
            nodes=nodes,
            storage_context=storage_context,
            service_context=service_context,
        )

        # We can pass in the index, doctore, or list of nodes to create the retriever
        self.retriever = BM25Retriever.from_defaults(similarity_top_k=2, index=index)

        # reranker setup & initialization
        self.reranker = SentenceTransformerRerank(top_n=1, model="BAAI/bge-reranker-base")

        self.query_engine = RetrieverQueryEngine.from_args(
            retriever=self.retriever,
            node_postprocessors=[self.reranker],
            service_context=service_context,
        )

    def query(self, query):
        # will retrieve context from specific companies
        nodes = self.retriever.retrieve(query)
        reranked_nodes = self.reranker.postprocess_nodes(
            nodes,
            query_bundle=QueryBundle(query_str=query)
        )

        print("Initial retrieval: ", len(nodes), " nodes")
        print("Re-ranked retrieval: ", len(reranked_nodes), " nodes")

        for node in nodes:
            print(node)

        for node in reranked_nodes:
            print(node)

        response = self.query_engine.query(str_or_query_bundle=query)
        return response


In [21]:
retriever.retrieve("What is the best way to learn Python?")

ValidationError: 1 validation error for NodeWithScore
node
  Can't instantiate abstract class BaseNode with abstract methods get_content, get_metadata_str, get_type, hash, set_content (type=type_error)

In [18]:
from llama_index.core.response.notebook_utils import display_source_node
# from llama_index.data_structs.node_v2 import BaseNode
# will retrieve context from specific companies
nodes = await retriever.aretrieve("What happened at Viaweb and Interleaf?")
for node in nodes:
    display_source_node(node)

ValidationError: 1 validation error for NodeWithScore
node
  Can't instantiate abstract class BaseNode with abstract methods get_content, get_metadata_str, get_type, hash, set_content (type=type_error)

In [None]:
nodes = retriever.retrieve("What did Paul Graham do after RISD?")
for node in nodes:
    display_source_node(node)

**Node ID:** 42c88008-dd05-4590-8fd1-df85567579ea<br>**Similarity:** 5.389397745654172<br>**Text:** It was missing a lot of things you'd want in a programming language. So these had to be added, an...<br>

**Node ID:** 4d98c2ad-60cc-4cc0-abe2-b95c0c4be87b<br>**Similarity:** 1.141523170922594<br>**Text:** Painting students were supposed to express themselves, which to the more worldly ones meant to tr...<br>

## Router Retriever with bm25 method

Now we will combine bm25 retriever with vector index retriever.

In [None]:
from llama_index.core.tools import RetrieverTool

vector_retriever = VectorIndexRetriever(index)
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=2)

retriever_tools = [
    RetrieverTool.from_defaults(
        retriever=vector_retriever,
        description="Useful in most cases",
    ),
    RetrieverTool.from_defaults(
        retriever=bm25_retriever,
        description="Useful if searching about specific information",
    ),
]

In [None]:
from llama_index.core.retrievers import RouterRetriever

retriever = RouterRetriever.from_defaults(
    retriever_tools=retriever_tools,
    llm=llm,
    select_multi=True,
)

In [None]:
# will retrieve all context from the author's life
nodes = retriever.retrieve(
    "Can you give me all the context regarding the author's life?"
)
for node in nodes:
    display_source_node(node)

HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
Selecting retriever 0: The author's life context is a broad topic and can be useful in most cases..
HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"


**Node ID:** 9afb8cb9-42f3-4160-807c-1cd6685fa774<br>**Similarity:** 0.7961776744788842<br>**Text:** Now that I could write essays again, I wrote a bunch about topics I'd had stacked up. I kept writ...<br>

**Node ID:** b1ffb78d-dc4c-439b-906a-cca73b713940<br>**Similarity:** 0.7924813308773564<br>**Text:** We actually had one of those little stoves, fed with kindling, that you see in 19th century studi...<br>

## Advanced - Hybrid Retriever + Re-Ranking

Here we extend the base retriever class and create a custom retriever that always uses the vector retriever and BM25 retreiver.

Then, nodes can be re-ranked and filtered. This lets us keep intermediate top-k values large and letting the re-ranking filter out un-needed nodes.

To best demonstrate this, we will use a larger set of source documents -- Chapter 3 from the 2022 IPCC Climate Report.

### Setup data

In [None]:
!curl https://www.ipcc.ch/report/ar6/wg2/downloads/report/IPCC_AR6_WGII_Chapter03.pdf --output IPCC_AR6_WGII_Chapter03.pdf

In [None]:
# !pip install pypdf

In [None]:
from llama_index.core import (
    VectorStoreIndex,
    StorageContext,
    SimpleDirectoryReader,
    Document,
)
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.openai import OpenAI

# load documents
documents = SimpleDirectoryReader(
    input_files=["IPCC_AR6_WGII_Chapter03.pdf"]
).load_data()

In [None]:
# initialize llm + node parser
# -- here, we set a smaller chunk size, to allow for more effective re-ranking
llm = OpenAI(model="gpt-3.5-turbo")
splitter = SentenceSplitter(chunk_size=256)
# limit to a smaller section
nodes = splitter.get_nodes_from_documents(
    [Document(text=documents[0].get_content()[:1000000])]
)

In [None]:
# initialize storage context (by default it's in-memory)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)

In [None]:
index = VectorStoreIndex(nodes, storage_context=storage_context)

In [None]:
from llama_index.retrievers.bm25 import BM25Retriever

# retireve the top 10 most similar nodes using embeddings
vector_retriever = index.as_retriever(similarity_top_k=10)

# retireve the top 10 most similar nodes using bm25
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10)

### Custom Retriever Implementation

In [None]:
from llama_index.core.retrievers import BaseRetriever


class HybridRetriever(BaseRetriever):
    def __init__(self, vector_retriever, bm25_retriever):
        self.vector_retriever = vector_retriever
        self.bm25_retriever = bm25_retriever
        super().__init__()

    def _retrieve(self, query, **kwargs):
        bm25_nodes = self.bm25_retriever.retrieve(query, **kwargs)
        vector_nodes = self.vector_retriever.retrieve(query, **kwargs)

        # combine the two lists of nodes
        all_nodes = []
        node_ids = set()
        for n in bm25_nodes + vector_nodes:
            if n.node.node_id not in node_ids:
                all_nodes.append(n)
                node_ids.add(n.node.node_id)
        return all_nodes

In [None]:
index.as_retriever(similarity_top_k=5)

hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever)

### Re-Ranker Setup

In [None]:
!pip install sentence-transformers

In [None]:
from llama_index.core.postprocessor import SentenceTransformerRerank

reranker = SentenceTransformerRerank(top_n=4, model="BAAI/bge-reranker-base")

### Retrieve

In [None]:
from llama_index.core import QueryBundle

retrieved_nodes = hybrid_retriever.retrieve(
    "What is the impact of climate change on the ocean?"
)
reranked_nodes = reranker.postprocess_nodes(
    nodes,
    query_bundle=QueryBundle(
        "What is the impact of climate change on the ocean?"
    ),
)

print("Initial retrieval: ", len(retrieved_nodes), " nodes")
print("Re-ranked retrieval: ", len(reranked_nodes), " nodes")

In [None]:
from llama_index.core.response.notebook_utils import display_source_node

for node in reranked_nodes:
    display_source_node(node)

**Node ID:** 735c563d-e68c-42e7-bbeb-d360a2918d22<br>**Similarity:** 0.6910567283630371<br>**Text:** lb88\8a|hiac2:,sg
`HfvyȔskz4
4+)t$EAs9<"L䴺-iai]}?R)Jv8Q0V#9>f=3`߼ב&WTFKNӅ9GN8v`4׏g
tz...<br>

**Node ID:** 07a796db-8307-47a0-83a0-9b5f953b86b0<br>**Similarity:** 0.609274685382843<br>**Text:** {_d3<,fCȀ
0K0~n
(A	̙$saxt
 :
xY3z/ߕXAWwpTeHY0HZHe̚kÇ>.82%ϖ^.rCS.2^12C?<br>

**Node ID:** 0ee98fbd-d015-4cd4-aba0-812423880915<br>**Similarity:** 0.5765283107757568<br>**Text:** X+V/Z

6;O8/%Z3EC\asU1f(xLͼ]X\q"1}.66OKGSi9ǧ'(?Iʲi=4ޮ^m,Mp1~lBxP3]S<?C)3WM;WGQt@l...<br>

**Node ID:** 2562bbfb-fda2-432d-9674-5ba77ff4b767<br>**Similarity:** 0.3882860541343689<br>**Text:** PxS0s.MAK|+MOUÊ^;tc0
$XHmL<ҁ;Q@nUd&ի
2`B01i^yOw:rsM冫S*wu4a{-.Խ=j鷩nQD_A޾A%~T>~'OxU	gk...<br>

### Full Query Engine 

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(
    retriever=hybrid_retriever,
    node_postprocessors=[reranker],
    llm=llm,
)

response = query_engine.query(
    "What is the impact of climate change on the ocean?"
)

In [None]:
from llama_index.core.response.notebook_utils import display_response

display_response(response)

**`Final Response:`** Climate change can have a significant impact on the ocean. Rising temperatures can lead to the melting of polar ice caps, resulting in sea-level rise and coastal flooding. It can also disrupt ocean currents and affect marine ecosystems, including coral reefs and fish populations. Additionally, climate change can lead to ocean acidification, which can harm marine life and impact the overall health of the ocean.