# Aim:
Assess the performance of different RAG methods in financial documents analysis (primary background investigation for consulting, stock purchase) 

# Assessment criteria:
1. relevance
2. length of retrieved context
3. speed
4. cost

# Methods to test:
1. Dense Embeddings
   1.1 parameters
   1.2 finetune embedding model (need GPU machine, too expensive for now)
2. ColBERT
4. Hybrid retriever and rerank
5. Knowledge Augmented Generation (KAG, need to build a domain-specific architecture from sratch)
6. Contextual retrieval preprocessing (use llm to search through all chunks, too expensive)

In [3]:
import pandas as pd
# https://huggingface.co/spaces/mteb/leaderboard
embed_dt = pd.read_csv('../data/tmpsrfsg8rr.csv')
embed_dt.head()

Unnamed: 0.1,Unnamed: 0,Rank (Borda),Model,Zero-shot,Number of Parameters,Embedding Dimensions,Max Tokens,Mean (Task),Mean (TaskType),Bitext Mining,Classification,Clustering,Instruction Retrieval,Multilabel Classification,Pair Classification,Reranking,Retrieval,STS
0,0,1,[Linq-Embed-Mistral](https://huggingface.co/Li...,99,7B,4096,32768,61.47,54.21,70.34,62.24,51.27,0.94,24.77,80.43,64.37,58.69,74.86
1,1,2,[gte-Qwen2-7B-instruct](https://huggingface.co...,-1,7B,3584,32768,62.51,56.0,73.92,61.55,53.36,4.94,25.48,85.13,65.55,60.08,73.98
2,2,3,[multilingual-e5-large-instruct](https://huggi...,99,560M,1024,514,63.23,55.17,80.13,64.94,51.54,-0.4,22.91,80.86,62.61,57.12,76.81
3,3,4,[SFR-Embedding-Mistral](https://huggingface.co...,96,7B,4096,32768,60.93,54.0,70.0,60.02,52.57,0.16,24.55,80.29,64.19,59.44,74.79
4,4,5,[GritLM-7B](https://huggingface.co/GritLM/Grit...,99,7B,4096,4096,60.93,53.83,70.53,61.83,50.48,3.45,22.77,79.94,63.78,58.31,73.33


In [4]:
embed_dt.sort_values(by=['Instruction Retrieval'], ascending=False).head(5)

Unnamed: 0.1,Unnamed: 0,Rank (Borda),Model,Zero-shot,Number of Parameters,Embedding Dimensions,Max Tokens,Mean (Task),Mean (TaskType),Bitext Mining,Classification,Clustering,Instruction Retrieval,Multilabel Classification,Pair Classification,Reranking,Retrieval,STS
32,32,33,[gte-Qwen1.5-7B-instruct](https://huggingface....,-1,7B,4096,32768,,,60.8,,52.98,5.36,23.45,,,,
1,1,2,[gte-Qwen2-7B-instruct](https://huggingface.co...,-1,7B,3584,32768,62.51,56.0,73.92,61.55,53.36,4.94,25.48,85.13,65.55,60.08,73.98
221,221,222,flan-t5-large,100,783M,Unknown,1024,,,,,,4.72,,,,,
4,4,5,[GritLM-7B](https://huggingface.co/GritLM/Grit...,99,7B,4096,4096,60.93,53.83,70.53,61.83,50.48,3.45,22.77,79.94,63.78,58.31,73.33
25,25,26,[NV-Embed-v1](https://huggingface.co/nvidia/NV...,92,7B,4096,32768,54.86,48.39,48.9,57.04,43.36,3.02,18.95,76.19,64.29,53.98,69.77


In [5]:
embed_dt['Number of Parameters'].unique()

array(['7B', '560M', '57B', 'Unknown', '1B', '559M', '494M', '568M',
       '572M', '305M', '278M', '567M', '117M', '118M', '359M', '470M',
       '435M', '471M', '107M', '335M', '109M', '33M', '137M', '434M',
       '108M', '125M', '124M', '22M', '129M', '32M', '19M', '17M', '30M',
       '2B', '103M', '427M', '404M', '15M', '110M', '29M', '7M', '35M',
       '3M', '102M', '11M', '135M', '2M', '162M', '98M', '9B', '66M',
       '3B', '11B', '6B', '8B', '306M', '149M', '281M', '31M', '4B',
       '272M', '74M', '823M', '326M', '783M', '353M', '122M', '248M',
       '24M'], dtype=object)

In [6]:
import re
import numpy as np
embed_dt['Number of Parameters'] = embed_dt['Number of Parameters'].replace('Unknown', '0')
embed_dt['Number of Parameters'] = embed_dt['Number of Parameters'].apply(lambda x: re.search('[\\d]+', x.replace('B', '000'))[0] if 'B' in x else re.search('[\\d]+', x)[0])
embed_dt['Number of Parameters'] = embed_dt['Number of Parameters'].astype(int)

In [5]:
embed_dt[embed_dt['Number of Parameters'].between(1, 100)].sort_values(by=['Instruction Retrieval'], ascending=False).head(5)

Unnamed: 0.1,Unnamed: 0,Rank (Borda),Model,Zero-shot,Number of Parameters,Embedding Dimensions,Max Tokens,Mean (Task),Mean (TaskType),Bitext Mining,Classification,Clustering,Instruction Retrieval,Multilabel Classification,Pair Classification,Reranking,Retrieval,STS
146,146,147,[ternary-weight-embedding](https://huggingface...,-1,98,1024,512,31.19,26.71,12.65,39.03,24.85,0.92,12.14,64.88,32.09,10.63,43.16
127,127,128,[potion-base-4M](https://huggingface.co/minish...,99,3,128,Infinite,37.86,32.29,14.59,43.71,32.9,0.61,11.69,70.8,38.54,27.23,50.57
138,138,139,[potion-base-2M](https://huggingface.co/minish...,99,2,64,Infinite,36.33,31.06,12.19,42.25,31.79,0.59,11.53,70.43,37.38,23.99,49.34
124,124,125,[potion-base-8M](https://huggingface.co/minish...,99,7,256,Infinite,38.6,32.84,16.11,44.5,33.09,0.24,11.63,71.08,39.08,28.63,51.2
130,130,131,[rubert-tiny2](https://huggingface.co/cointegr...,99,29,312,2048,34.88,30.29,22.65,41.17,29.97,-0.09,15.27,69.6,34.58,14.35,45.13


In [37]:
embed_mdls = list(map(lambda x: x.split(']')[0][1:], embed_dt.sort_values(by=['Instruction Retrieval'], ascending=False).head(2)['Model'].values))

In [38]:
embed_mdls
# considering performance on tasks including Instruction Retrieval, Retrieval and Reranking, pick gte-Qwen2-7B-instruct as embedding model. but its's too large.
# minishlab/potion-base-8M

['gte-Qwen1.5-7B-instruct', 'gte-Qwen2-7B-instruct']

In [1]:
import pickle
import contextlib
import pandas as pd

def dump_pickle(file, outdir):
    """pickle a file to output directory
    """
    f = open(outdir,"wb")
    pickle.dump(file, f)
    
def load_pickle(indir):
    """ load a pickle from indir
    """
    f = open(indir,"rb")
    return pickle.load(f)
    
def multiple_strreplace(string, replace_dic):
    for k,v in replace_dic.items():
        string = string.replace(k,v)
    return string

def parse_queries(qa_fp, replace_dic):
    qa = pd.read_csv(qa_fp)
    queries = list(map(lambda query: multiple_strreplace(query, replace_dic), qa['question'].values))
    return queries
    
def log(content, logpath):
      try:
        content = content.replace('<s>[INST] <<SYS>>', 'System role:'),
        content = content.replace('<</SYS>>', ''),
        content = content.replace('[/INST]', '\\n'),
      except:
          pass
      if os.path.exists(logpath):
          with open(logpath, 'a') as f:
              with contextlib.redirect_stdout(f):
                  print(content)
      else:
          with open(logpath, 'w') as f:
              with contextlib.redirect_stdout(f):
                 print(content)

In [2]:
#%pip install ragatouille

In [3]:
#%pip install -qU langchain_community pdfminer.six

## 1. Dense Embeddings

In [5]:
company_name = 'NVDIA'
year1, year2 = 2025, 2024
qa_fp = '../inputs/Q-A.csv'
replace_dic = {'{company_name}':'NVDIA',
              '{year1}':str(year1),
               '{year2}':str(year2)}
querys = parse_queries(qa_fp, replace_dic)
dump_pickle(querys, './querys.pck')

for q in querys:
    print(q)

What kind of products or services is NVDIA providing?
Who are the customers of NVDIA or what types of markets are NVDIA operating in?
Who are the competitors of NVDIA?
What are the risk factors and uncertainties that could affect the NVDIA's future performance?
What is the 2025 revenue of NVDIA?
What is the 2024 revenue of NVDIA?
What is the 2025 total liabilities?
What is the 2025 total shareholders' equity?
What is the 2025 total current assets?
What is the 2025 total current liabilities?
What is the 2025 gross margin?


In [6]:
import os
from dotenv import load_dotenv
#from huggingface_hub import login
from pathlib import Path
dotenv_path = Path('../keys/.env')
load_dotenv(dotenv_path=dotenv_path)
hf_token = os.getenv("HF_TOKEN")
os.environ['HUGGINGFACEHUB_API_TOKEN'] = hf_token

In [7]:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [3]:
import os
os.listdir('../data')

['.DS_Store', 'nvda-20250126.pdf', '.ipynb_checkpoints']

In [5]:
'nvda-20250126.pdf'.endswith('pdf')

True

In [8]:
from langchain.document_loaders import UnstructuredPDFLoader, PyPDFLoader, PDFMinerLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.vectorstores import Chroma

from langchain.llms import HuggingFaceEndpoint
from langchain.chains import RetrievalQA
from langchain.storage import InMemoryStore
from langchain.retrievers import ParentDocumentRetriever

In [9]:
querys = load_pickle('./querys.pck')

In [10]:
file_path = "../data/nvda-20250126.pdf"
#data = UnstructuredPDFLoader(file_path) version conflicts
#data = PyPDFLoader(file_path)
data = PDFMinerLoader(file_path, mode='page')
content = data.load()
print(len(content), len(content[0].page_content))

118 5167


In [9]:
parent_chunk_size = 800
child_chunk_size = 100

#embedding_model_name = 'minishlab/potion-base-8M'
embedding_model_name = "BAAI/bge-base-en-v1.5"
# https://huggingface.co/ng3owb/finance_embedding_8k to test
#embedding_model_name = "mixedbread-ai/mxbai-embed-large-v1"

In [10]:
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma

embeddings = HuggingFaceEmbeddings(
    model_name=embedding_model_name,
)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=parent_chunk_size,chunk_overlap=0)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=child_chunk_size,chunk_overlap=0)
#texts = ["Hello, world!", "How are you?"]
#embed_test = embeddings.embed_documents(texts)
#import numpy as np
#np.shape(embed_test)
vectorstore = Chroma(embedding_function=embeddings)
store = InMemoryStore()
retriever = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
    parent_splitter=parent_splitter,
    
)

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
retriever.add_documents(content,ids=None)

In [12]:
logfp = '../outputs/pdfminer_retrieval_log.txt'
log('', logfp)
log(f'parent_chunk_size:{parent_chunk_size}, child_chunk_size:{child_chunk_size}, embed_model:{embedding_model_name}', logfp)
for query in querys:
    relevant_context = retriever.get_relevant_documents(query)
    log(f"retrieve_instruction:{query}", logfp)
    for d in relevant_context:
        log(f"retrieved_content:{d.page_content}", logfp)
        log('', logfp)

  relevant_context = retriever.get_relevant_documents(query)


## 2. ColBERT 

In [11]:
max_document_length = 512
documents = list(map(lambda x: x.page_content, content))

In [12]:
from ragatouille import RAGPretrainedModel
# https://github.com/AnswerDotAI/RAGatouille
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")

  from .autonotebook import tqdm as notebook_tqdm


[Mar 08, 02:11:12] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




In [13]:
RAG.index(
    collection=documents,
    index_name="NVDIA",
    max_document_length=max_document_length,
    split_documents=True,
)
clbt_retriever = RAG.as_langchain_retriever(k=5)

This is a behaviour change from RAGatouille 0.8.0 onwards.
This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations.
If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour.
--------------------


[Mar 08, 02:11:13] #> Note: Output directory .ragatouille/colbert/indexes/NVDIA already exists


[Mar 08, 02:11:13] #> Will delete 10 files already at .ragatouille/colbert/indexes/NVDIA in 20 seconds...
[Mar 08, 02:11:34] [0] 		 #> Encoding 274 passages..


100%|█████████████████████████████████████████████| 9/9 [05:52<00:00, 39.17s/it]

[Mar 08, 02:17:27] [0] 		 avg_doclen_est = 261.6788330078125 	 len(local_sample) = 274





[Mar 08, 02:17:27] [0] 		 Creating 4,096 partitions.
[Mar 08, 02:17:27] [0] 		 *Estimated* 71,700 embeddings.
[Mar 08, 02:17:27] [0] 		 #> Saving the indexing plan to .ragatouille/colbert/indexes/NVDIA/plan.json ..
used 20 iterations (73.6211s) to cluster 68115 items into 4096 clusters
[0.026, 0.03, 0.029, 0.027, 0.029, 0.028, 0.027, 0.028, 0.027, 0.028, 0.026, 0.027, 0.028, 0.03, 0.028, 0.028, 0.025, 0.026, 0.028, 0.028, 0.028, 0.028, 0.026, 0.028, 0.026, 0.026, 0.029, 0.027, 0.029, 0.03, 0.027, 0.032, 0.028, 0.026, 0.028, 0.025, 0.028, 0.03, 0.027, 0.032, 0.027, 0.027, 0.027, 0.03, 0.028, 0.026, 0.026, 0.032, 0.032, 0.027, 0.025, 0.028, 0.029, 0.027, 0.026, 0.028, 0.034, 0.029, 0.034, 0.027, 0.026, 0.028, 0.028, 0.03, 0.03, 0.029, 0.031, 0.029, 0.025, 0.028, 0.029, 0.025, 0.027, 0.03, 0.027, 0.029, 0.029, 0.029, 0.027, 0.031, 0.031, 0.029, 0.028, 0.028, 0.028, 0.026, 0.027, 0.029, 0.028, 0.032, 0.029, 0.031, 0.027, 0.029, 0.028, 0.028, 0.031, 0.027, 0.028, 0.028, 0.028, 0.032, 0.028,

0it [00:00, ?it/s]

[Mar 08, 02:18:41] [0] 		 #> Encoding 274 passages..



  0%|                                                     | 0/9 [00:00<?, ?it/s][A
 11%|█████                                        | 1/9 [00:37<04:59, 37.42s/it][A
 22%|██████████                                   | 2/9 [01:22<04:54, 42.03s/it][A
 33%|███████████████                              | 3/9 [02:04<04:11, 41.96s/it][A
 44%|████████████████████                         | 4/9 [02:51<03:40, 44.12s/it][A
 56%|█████████████████████████                    | 5/9 [03:35<02:55, 43.98s/it][A
 67%|██████████████████████████████               | 6/9 [04:16<02:08, 42.89s/it][A
 78%|███████████████████████████████████          | 7/9 [04:54<01:22, 41.32s/it][A
 89%|████████████████████████████████████████     | 8/9 [05:29<00:39, 39.42s/it][A
100%|█████████████████████████████████████████████| 9/9 [05:52<00:00, 39.18s/it][A
1it [05:57, 357.07s/it]
100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 402.10it/s]

[Mar 08, 02:24:38] #> Optimizing IVF to store map from centroids to list of pids..
[Mar 08, 02:24:38] #> Building the emb2pid mapping..
[Mar 08, 02:24:38] len(emb2pid) = 71700



100%|█████████████████████████████████████| 4096/4096 [00:01<00:00, 3339.45it/s]

[Mar 08, 02:24:40] #> Saved optimized IVF to .ragatouille/colbert/indexes/NVDIA/ivf.pid.pt
Done indexing!





In [16]:
logfp = '../outputs/pdfminer_retrieval_log.txt'
log('', logfp)
log(f'max_document_length:{max_document_length}, embed_model:colbertv2.0', logfp)
for query in querys:
    relevant_context = clbt_retriever.get_relevant_documents(query)
    log(f"retrieve_instruction:{query}", logfp)
    for d in relevant_context:
        log(f"retrieved_content:{d.page_content}", logfp)
        log('', logfp)

Loading searcher for index NVDIA for the first time... This may take a few seconds
[Mar 08, 00:51:25] #> Loading codec...
[Mar 08, 00:51:25] #> Loading IVF...
[Mar 08, 00:51:25] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




[Mar 08, 00:51:25] #> Loading doclens...


100%|███████████████████████████████████████████| 1/1 [00:00<00:00, 1187.52it/s]

[Mar 08, 00:51:25] #> Loading codes and residuals...



100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 113.58it/s]

[Mar 08, 00:51:25] Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...





[Mar 08, 00:51:25] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
Searcher loaded!

#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: What kind of products or services is NVDIA providing?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101,     1,  2054,  2785,  1997,  3688,  2030,  2578,  2003,  1050,
        16872,  2401,  4346,  1029,   102,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])





## 3. Hybrid retriever + rerank
https://haystack.deepset.ai/blog/hybrid-retrieval
Popular rerank models include Cohere rerank, bge-reranker, among others.
bm25_retriever

https://superlinked.com/vectorhub/articles/optimizing-rag-with-hybrid-search-reranking

In [None]:
pip install langchain langchain-community rank_bm25 pypdf unstructured chromadb

In [18]:
pip install rank_bm25

Collecting rank_bm25
  Using cached rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Using cached rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2
Note: you may need to restart the kernel to use updated packages.


In [6]:
chunk_size = 200
chunk_overlap = 30

In [14]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain.llms import HuggingFaceHub
import torch
from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline, )
from langchain import HuggingFacePipeline

from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import os

chunk_size = 800
chunk_overlap = 30
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size,
                                          chunk_overlap=chunk_overlap)
chunks = splitter.split_documents(content)

In [15]:
keyword_retriever = BM25Retriever.from_documents(chunks)
keyword_retriever.k = 5

ensemble_retriever = EnsembleRetriever(retrievers=[clbt_retriever,
                                                   keyword_retriever],
                                       weights=[0.3, 0.7])

In [17]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
rerank_model_name = "BAAI/bge-reranker-base"
rerank_model = HuggingFaceCrossEncoder(model_name=rerank_model_name)
compressor = CrossEncoderReranker(model=rerank_model, top_n=5)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=ensemble_retriever
)

In [11]:
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vectorstore_retreiver
)

In [18]:
logfp = '../outputs/pdfminer_retrieval_log.txt'
log('', logfp)
log(f'chunk_size:{chunk_size}, chunk_overlap:{chunk_overlap}, max_document_length:{max_document_length}, embed_model:colbertv2.0, keyword_model:BM25Retriever, rerank_model:{rerank_model_name}', logfp)
for query in querys:
    relevant_context = compression_retriever.get_relevant_documents(query)
    log(f"retrieve_instruction:{query}", logfp)
    for d in relevant_context:
        log(f"retrieved_content:{d.page_content}", logfp)
        log('', logfp)

  relevant_context = compression_retriever.get_relevant_documents(query)


Loading searcher for index NVDIA for the first time... This may take a few seconds
[Mar 08, 02:28:16] #> Loading codec...
[Mar 08, 02:28:16] #> Loading IVF...
[Mar 08, 02:28:16] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




[Mar 08, 02:28:16] #> Loading doclens...


100%|████████████████████████████████████████████| 1/1 [00:00<00:00, 567.64it/s]

[Mar 08, 02:28:16] #> Loading codes and residuals...



100%|█████████████████████████████████████████████| 1/1 [00:00<00:00, 46.69it/s]

[Mar 08, 02:28:16] Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...





[Mar 08, 02:28:17] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
Searcher loaded!

#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: What kind of products or services is NVDIA providing?, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([  101,     1,  2054,  2785,  1997,  3688,  2030,  2578,  2003,  1050,
        16872,  2401,  4346,  1029,   102,   103,   103,   103,   103,   103,
          103,   103,   103,   103,   103,   103,   103,   103,   103,   103,
          103,   103])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])



  incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask


In [19]:
parent_chunk_size = 800
child_chunk_size = 100
embedding_model_name = "BAAI/bge-base-en-v1.5"

In [21]:
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=parent_chunk_size,chunk_overlap=0)
child_splitter = RecursiveCharacterTextSplitter(chunk_size=child_chunk_size,chunk_overlap=0)
#texts = ["Hello, world!", "How are you?"]
#embed_test = embeddings.embed_documents(texts)
#import numpy as np
#np.shape(embed_test)
vectorstore = Chroma(embedding_function=embeddings)
store = InMemoryStore()
vectorstore_retreiver = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=child_splitter,
    parent_splitter=parent_splitter,
    k=5
)
vectorstore_retreiver.add_documents(content,ids=None)

In [22]:
splitter = RecursiveCharacterTextSplitter(chunk_size=parent_chunk_size,chunk_overlap=20)
chunks = splitter.split_documents(content)
keyword_retriever = BM25Retriever.from_documents(chunks)
keyword_retriever.k = 5
ensemble_retriever = EnsembleRetriever(retrievers=[vectorstore_retreiver,
                                                   keyword_retriever],
                                       weights=[0.3, 0.7])

In [23]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
rerank_model_name = "BAAI/bge-reranker-base"
rerank_model = HuggingFaceCrossEncoder(model_name=rerank_model_name)
compressor = CrossEncoderReranker(model=rerank_model, top_n=5)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vectorstore_retreiver
)

In [24]:
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=vectorstore_retreiver
)

In [25]:
logfp = '../outputs/retrieval_log.txt'
log('', logfp)
log(f'parent_chunk_size:{parent_chunk_size}, child_chunk_size:{child_chunk_size}, embed_model:{embedding_model_name}, keyword_model:BM25Retriever, rerank_model:{rerank_model_name}', logfp)
for query in querys:
    relevant_context = compression_retriever.get_relevant_documents(query)
    log(f"retrieve_instruction:{query}", logfp)
    for d in relevant_context:
        log(f"retrieved_content:{d.page_content}", logfp)
        log('', logfp)

## Step-1 Retrieval

In [None]:
os.environ["COHERE_API_KEY"] = "5uuX8mk9dhf9KHzw7vSDhQdXlV2x92MzELvJ972T"

In [None]:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CohereRerank

In [None]:
from cohere import Client

In [None]:
co = Client(api_key = "5uuX8mk9dhf9KHzw7vSDhQdXlV2x92MzELvJ972T")

In [None]:
from typing import ForwardRef
from pydantic import BaseModel

class CustomCohereRerank(CohereRerank):
  class Config(BaseModel.Config):
    arbitrary_types_allowed = True

CustomCohereRerank.update_forward_refs()

In [None]:
compressor = CustomCohereRerank(client=co)

In [None]:
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

## Step - 2 Augment

In [7]:
from langchain_core.prompts import ChatPromptTemplate

In [19]:
template = """
<|system|>>
You are an AI Assistant that follows instructions extremely well.
Please be truthful and give direct answers. Please tell 'I don't know' if user query is not in CONTEXT

CONTEXT: {context}
</s>
<|user|>
{query}
</s>
<|assistant|>
"""

In [11]:
template = """
<|system|>>
You are an AI Assistant that follows instructions extremely well.
Please be truthful and give direct answers. Please tell 'I don't know' if user query is not in CONTEXT

</s>
<|user|>
{query}
</s>
<|assistant|>
"""
query = "What is the revenue of nvdia in 2023?"

In [12]:
prompt = ChatPromptTemplate.from_template(template)

In [None]:
llm_chain = prompt | llm
print(llm_chain.invoke({"query": query}))

## Step-3 Generation

In [None]:
from langchain_core.output_parsers import StrOutputParser

In [None]:
from langchain_core.runnables import RunnablePassthrough

In [None]:
output_parser = StrOutputParser()

In [None]:
chain = (
    {"context": retriever, "query": RunnablePassthrough()}
    | prompt
    | model
    | output_parser
)

In [None]:
query = "Who is Rahul?"

In [None]:
response = chain.invoke(query)

In [None]:
print(response)

I do not have information about a specific person named rahul. please provide more context or information about rahul to help me identify who you are referring to.


In [None]:
print(chain.invoke("what is Tarun's role at AI Planet?"))

Tarun's role at AI Planet is "Developer Relations and Community Manager." (from the provided context)
