In [0]:
# %run /Workspace/Users/harshith.r@diggibyte.com/DBRX_dynamic_Class_RAG/load_components

## What is Retrieval Augmented Generation (RAG) for LLMs?

<img src="https://github.com/HarshithRL/DBRX_RAG/blob/main/Images_for_notebook/RAG.png?raw=true" width="700px" style="float: right" />

RAG, or Retrieval-Augmented Generation, stands out as a potent GenAI technique. It's a fusion of retrieval-based and generation-based models, amplifying both accuracy and adaptability. By feeding custom data into the model, RAG refines its ability to generate precise responses without needing extensive retraining.

This approach significantly reduces errors like hallucination, where the model generates irrelevant or inaccurate content. It's a game-changer for applications like chatbots and Q&A systems operating in fields where staying updated with specific information is vital.

In simpler terms, RAG ensures that the responses it generates closely match the provided context. This not only boosts performance but also maintains consistency across interactions. And when it comes to response length, RAG can adapt to match the context while delivering varied outputs.

### Vector Store & Vector Search

To be able to provide additional context to our LLM, we need to search for documents/articles where the answer to our user question might be.
To do so,  a common solution is to deploy a vector database. This involves the creation of document embeddings, vectors of fixed size representing your document.<br/>
The vectors will then be used to perform real-time similarity search during inference.

### Implementing RAG with Databricks AI Foundation models

In this demo, we will show you how to build and deploy your custom chatbot, answering questions on any custom or private information.

As an example, we will specialize this chatbot to answer questions over Databricks, feeding databricks.com documentation articles to the model for accurate answers.

Here is the flow we will implement:



<img src="https://github.com/databricks-demos/dbdemos-resources/blob/main/images/product/chatbot-rag/llm-rag-managed-flow-0.png?raw=true" style="margin-left: 10px"  width="1100px;">

In [0]:
# Importing necessary libraries
from PyPDF2 import PdfReader
from transformers import AutoTokenizer
from IPython.core.display import Markdown
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain.chains.question_answering import load_qa_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter


In [0]:
config = {
    "PROMPT_TEMPLATE" : """
        Your a Databricks Assistant ,
        Expert In Artificial Intelligence , Generative AI in particular

        Answer the Question as detailed as possible from the provided context, make sure to provide all the details in a structured Way, if the answer is not in 
        provided context just say , 'answer is not available in the context' , don't provide the wrong answer &  if the answer is yes or no condition and the content is not in Provided context say No or else say ,
        """
}

In [0]:
class RAGmodule:
    """
    Class for utilizing Retrieval-Augmented Generation (RAG) model in a chatbot application.
    """
    def __init__(self, pdf_path: str, tokenizer_model, embeddings_model, chat_model) -> None:
        """
        Initialize the RAGmodule.

        Parameters:
            pdf_path (str): Path to the PDF document.
            tokenizer_model: Model for tokenization.
            embeddings_model: Model for generating embeddings.
            chat_model: Model for chat interactions.
        """

        text = self.parse_text(pdf_path)  # Parse text from PDF
        tokens = self.tokenize(text, tokenizer_model)  # Tokenize text
        self.vector_index = self.create_vector_index(tokens, embeddings_model)  # Create vector index
        self.chat_model = chat_model  # Assign chat model
    
    def parse_text(self, pdf_path: str) -> str:
        """
        Parse text from a PDF document.

        Parameters:
            pdf_path (str): Path to the PDF document.

        Returns:
            str: Parsed text from the PDF.
        """
        text = ""
        for page in PdfReader(pdf_path).pages:
            text += page.extract_text()  # Extract text from each page
        return text
    
    def tokenize(self, text, tokenizer_model) -> list:
        """
        Tokenize the text using a tokenizer model.

        Parameters:
            text (str): Input text to be tokenized.
            tokenizer_model: Model for tokenization.

        Returns:
            list: List of tokenized chunks.
        """
        text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
            tokenizer_model, 
            chunk_size=500, 
            chunk_overlap=50
        )
        text = str(text)
        chunks = text_splitter.create_documents([text])  # Create tokenized chunks
        return chunks

    def create_vector_index(self, chunks, embeddings_model):
        """
        Create vector index from chunks using an embeddings model.

        Parameters:
            chunks: List of tokenized chunks.
            embeddings_model: Model for generating embeddings.
        """
        return FAISS.from_documents(chunks, embedding = embeddings_model)  # Create vector index

    def get_conve_chain(self):
        """
        Get conversation chain for chat interactions.
        """
        prompt_suffic = "Context:\n {context}?\n\nQuestion: \n{question}\n\n\nAnswer:"
        prompt_template_final = config['PROMPT_TEMPLATE'] + prompt_suffic
        prompt = PromptTemplate(
            template = prompt_template_final,
            input_variables = ["context", "question"],
            callbacks = [StrOutputParser]
        )
        return load_qa_chain(self.chat_model, chain_type = "stuff", prompt=prompt)  # Load conversation chain

    def query(self, user_query):
        """
        Query the chat model with user input.

        Parameters:
            user_query (str): User's query.

        Returns:
            Markdown: Response in Markdown format.
        """
        docs = self.vector_index.similarity_search(user_query)  # Search for similar documents
        chain = self.get_conve_chain()  # Get conversation chain
        response = chain({'input_documents': docs, "question": user_query}, return_only_outputs=True)  # Get response
        return Markdown(response['output_text'])  # Return response in Markdown format
