# Ensemble retriever

https://python.langchain.com/v0.1/docs/modules/data_connection/retrievers/ensemble/

Note:
* Code uses the LangChain BM25Retriever that depends on the rank_bm25 package

In [1]:
# Must install 
# !pip install --upgrade --quiet  rank_bm25

In [2]:
# https://api.python.langchain.com/en/latest/documents/langchain_core.documents.base.Document.html
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.retrievers import EnsembleRetriever

## 1. Setup test corpus

Documents will have metadata with an attribute that holds the source document information. For simplecity and ease of understanding the corpus consists of a set of sentences.

In [3]:
## 1. Create a test corpus
corpus = [
    "RAG addresses hallucinations",
    "Symptoms are hallucinations",
    "RAG is easier than fine tuning",
    "Use a RAG to clean it",
    "Retrieval Augmented Generation"
]

corpus_docs = []

# Add metadata
for i, dat in enumerate(corpus):
    document = Document(
        page_content= dat,
        metadata = {"source": "doc-"+str(i)}
    )
    corpus_docs.append(document)

# Print corpus
corpus_docs

[Document(page_content='RAG addresses hallucinations', metadata={'source': 'doc-0'}),
 Document(page_content='Symptoms are hallucinations', metadata={'source': 'doc-1'}),
 Document(page_content='RAG is easier than fine tuning', metadata={'source': 'doc-2'}),
 Document(page_content='Use a RAG to clean it', metadata={'source': 'doc-3'}),
 Document(page_content='Retrieval Augmented Generation', metadata={'source': 'doc-4'})]

## 2. Setup BM25 Retriever

In [4]:
# Create the BM25 Retriever
bm25_retriever = BM25Retriever.from_documents(corpus_docs, k=3)

## 3. Setup a vector store retriever

Setup the Vector store with same set of documents that were used for BM25 retriever.

In [5]:
# Create instance of ChromaDB and add the 
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = Chroma(collection_name="full_documents", embedding_function=embedding_function) 
vector_store.add_documents(corpus_docs)

# https://api.python.langchain.com/en/latest/vectorstores/langchain_community.vectorstores.chroma.Chroma.html#langchain_community.vectorstores.chroma.Chroma.as_retriever
chromadb_retriever = vector_store.as_retriever(search_kwargs={"k": 3})

## 4. Create the Ensemble retriever

https://api.python.langchain.com/en/latest/retrievers/langchain.retrievers.ensemble.EnsembleRetriever.html

* Requires a list of retrievers
* Weights for the retriever (by default equal weights to all retrievers)
* Metdata attribute that holds the document's identity

In [6]:
retrievers = [bm25_retriever, chromadb_retriever]
retriever_weights = [0.4, 0.6]

ensemble_retriever = EnsembleRetriever(
    retrievers = retrievers,
    weights = retriever_weights,
    id_key = "source"
)

## 5. Test

In [7]:
# Utility function to print the list of ranked documents
def dump_doc_source(result_documents):
    for doc in result_documents:
        print(doc.metadata["source"])
    print("\n")

In [8]:
# Test input 
input = ["rag is cheaper",
         "benefits of rag",
         "piece of cloth"
        ]

# change input index for testing
ndx = 0
print("Input: ", input[ndx],"\n")

# Dump the ranked list for BM25
print("BM25")
print("----")
results_bm25 = bm25_retriever.invoke(input[ndx])
dump_doc_source(results_bm25)

# Dump the ranked list for ChromaDB
print("ChromaDB")
print("--------")
results_chromadb = chromadb_retriever.invoke(input[ndx])
dump_doc_source(results_chromadb)

print("Ensemble Retriever")
print("------------------")
results = ensemble_retriever.invoke(input[ndx])

dump_doc_source(results)

results

Input:  rag is cheaper 

BM25
----
doc-2
doc-4
doc-3


ChromaDB
--------
doc-2
doc-3
doc-0


Ensemble Retriever
------------------
doc-2
doc-3
doc-0
doc-4




[Document(page_content='RAG is easier than fine tuning', metadata={'source': 'doc-2'}),
 Document(page_content='Use a RAG to clean it', metadata={'source': 'doc-3'}),
 Document(page_content='RAG addresses hallucinations', metadata={'source': 'doc-0'}),
 Document(page_content='Retrieval Augmented Generation', metadata={'source': 'doc-4'})]