Load Env

In [205]:
from dotenv import load_dotenv

load_dotenv()

True

Get Embeddings

In [206]:
from langchain_openai import OpenAIEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
embedding = OpenAIEmbeddings(model="text-embedding-3-large", dimensions=256)


Create Retriever

In [207]:
import chromadb
from langchain_chroma import Chroma

chroma_client = chromadb.HttpClient(host="localhost", port=8000)
vector_store = Chroma(
    collection_name="chat_history_test",
    client=chroma_client,
    embedding_function=embedding,
)

top_k = 10
retriever = vector_store.as_retriever(
        search_type="similarity",
        search_kwargs={"k": top_k},
    )

Create Reranker

In [208]:
from pydantic import BaseModel, Field, HttpUrl
from typing import List, Optional
from langchain_core.documents import Document
import requests as request

class CustomReranker(BaseModel):
    """
    Custom Reranker class to rerank documents based on a query using an external API.
    """

    url: HttpUrl = Field(..., description="API endpoint URL for the reranker service")
    model: str = Field(..., description="Model identifier/name to use for reranking")
    api_key: Optional[str] = Field(
        None, description="API key for authentication, if required"
    )

    def rerank(self, query, documents):
        """
        Rerank the documents based on the provided query.
        Args:
            query (str): The query to use for reranking the documents.
            documents (List[Document]): The list of documents to rerank.
        Returns:
            List[Document]: The reranked list of documents.
        """
        API_HEADER = {
            "Authorization": f"Bearer {self.api_key}" if self.api_key else "",
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.model,
            "query": query,
            "documents": [doc.page_content for doc in documents],
        }
        response = request.post(url=self.url, headers=API_HEADER, json=payload)
        if response.status_code != 200:
            raise ValueError(f"Reranker API error: {response.text}")
        final_results =  response.json()
        print("Raw API Response:", final_results)

        if "results" in final_results:
            final_results = final_results["results"]
        elif "data" in final_results:
            final_results = final_results["data"]
        else:
            raise ValueError("Reranker API did not return 'results' or 'data' key")
        
        sorted_results = sorted(final_results, key=lambda x: x["relevance_score"], reverse=True)

        reranked_docs = [documents[result["index"]] for result in sorted_results]
        return reranked_docs

In [209]:
custom_reranker = CustomReranker(
    url="https://api.cohere.com/v2/rerank",
    model="rerank-v3.5",
    api_key="z7qJ9yO47aI6pqH9vqFWNjKK8GHQLglIE7QT5koV",  # Set your API key if required
)

Define LLM

In [210]:
from helpers.llm_integrations import get_llm

llm = get_llm(model="gpt-4o")

Create Contextualize Chain

In [211]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import (
    Runnable,
    RunnablePassthrough,
    RunnableParallel,
    chain,
)
from operator import itemgetter

contextualize_instructions = """Convert the latest user question into a standalone question given the chat history. Don't answer the question, return the question and nothing else (no descriptive text)."""
contextualize_prompt = ChatPromptTemplate.from_messages(
      [
          ("system", contextualize_instructions),
          ("placeholder", "{chat_history}"),
          ("human", "{question}"),
      ]
  )
contextualize_question = contextualize_prompt | llm | StrOutputParser()

@chain
def contextualize_if_needed(input_: dict) -> Runnable:
    if input_.get("chat_history"):
        return contextualize_question
    else:
        return RunnablePassthrough() | itemgetter("question")

Create QA Chain

In [212]:
instruction = "Answer the questions using the given context."

qa_instructions = instruction + """\n\n{context}."""
qa_prompt = ChatPromptTemplate.from_messages(
  [("system", qa_instructions), ("human", "{question}")]
)

def format_docs(docs):
    return "".join(doc.page_content for doc in docs)

formatted_prompt = {
    "question": itemgetter("question") | RunnablePassthrough(),
    "context": lambda x: format_docs(x["context"]),
} | RunnableParallel(prompt=qa_prompt, question=itemgetter("question"))

qa_chain = formatted_prompt | RunnableParallel(
    llm_result=itemgetter("prompt") | llm,
    question=itemgetter("question"),
    )

Create Retrieval Chain (Pass Custom Reranker Instead of Retriever)

In [None]:
from langchain_core.runnables import RunnableLambda

def rerank_docs(input: str) -> List[Document]:
    results = retriever.invoke(input)
    results = custom_reranker.rerank(input, results)
    return results

retrieve_docs_chain = itemgetter("question") | RunnableLambda(rerank_docs)
retriever = itemgetter("question") | retriever 

Token Usage Callback

In [214]:
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from typing import Any


class LLMResultHandler(BaseCallbackHandler):
    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        if response.generations[0][0].message.usage_metadata:
            token_usage = response.generations[0][0].message.usage_metadata
        else:
            usage = response.generations[0][0].message.response_metadata["token_usage"]
            token_usage = {
                "input_tokens": usage.prompt_tokens,
                "output_tokens": usage.completion_tokens,
                "total_tokens": usage.total_tokens,
            }
        self.response = token_usage

llm_result_handler = LLMResultHandler()

Langfuse Callback

In [215]:
from langfuse.callback import CallbackHandler

langfuse_args = {}
langfuse_handler = (
        CallbackHandler(
          **langfuse_args
        )
    )

Create Final Chain (Contextualize -> Retrieval -> Q&A)

In [216]:
final_chain = (
        RunnablePassthrough.assign(question=contextualize_if_needed)
        .assign(context=retrieve_docs_chain)
        .assign(answer=qa_chain)
    )

Invoke Chain

In [217]:
input = "what are the interest rates of fixed deposit?"
result = final_chain.invoke(
        {"question": input, "chat_history": []},
        config={
            "callbacks": [llm_result_handler, langfuse_handler]
        },
    )

answer = result["answer"]
source_documents = [
    {"page_content": doc.page_content, "source": doc.metadata["source"]}
    for doc in result["context"]
]

token_usage = llm_result_handler.response

output = {
    "answer": answer,
    "source_documents": source_documents,
    "token_usage": token_usage,
}

output

Raw API Response: {'id': 'f4e7a22a-58f7-4883-9e40-62063c5dfea1', 'results': [{'index': 9, 'relevance_score': 0.60245985}, {'index': 0, 'relevance_score': 0.5843666}, {'index': 4, 'relevance_score': 0.51866865}, {'index': 2, 'relevance_score': 0.51662046}, {'index': 5, 'relevance_score': 0.48162505}, {'index': 3, 'relevance_score': 0.45391592}, {'index': 6, 'relevance_score': 0.3288855}, {'index': 7, 'relevance_score': 0.31321016}, {'index': 8, 'relevance_score': 0.30059206}, {'index': 1, 'relevance_score': 0.13483313}], 'meta': {'api_version': {'version': '2'}, 'billed_units': {'search_units': 1}}}


{'answer': {'llm_result': AIMessage(content='The interest rates for fixed deposits at Brillar Bank are as follows:\n\n- 1 month: 2.15%\n- 2 - 3 months: 2.25%\n- 4 - 5 months: 2.30%\n- 6 months: 2.30%\n- 7 - 11 months: 2.35%\n- 12 - 60 months: 2.50%', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 86, 'prompt_tokens': 4358, 'total_tokens': 4444, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_a288987b44', 'finish_reason': 'stop', 'logprobs': None}, id='run-944de715-f2e4-415b-bd58-a25cae1694ba-0', usage_metadata={'input_tokens': 4358, 'output_tokens': 86, 'total_tokens': 4444, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}}),
  'question': 'what are the interest rates 