In [None]:
# ! pip install gradio langchain chromadb
# ollama pull llama3.2:latest

In [None]:
import os
import ollama
from utils import timeit
import chromadb
import hashlib
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter

In [74]:
class DocumentParser(object):
    def __init__(self, file_path):
        self.file_path = file_path
        self._validate_path()
    
    def _validate_path(self):
        is_valid = True
        message = None
        if not os.path.exists(self.file_path):
            is_valid = False
            message = f"The provided path {self.file_path} doesn't exist. Please recheck and provide the correct path."
            return is_valid, message
        elif not (self.file_path.endswith('.txt') or self.file_path.endswith('.pdf')):
            is_valid = False
            message = f"Only .txt and .pdf extensions are supported. Please recheck and provide the correct path."
            return is_valid, message
        file_size_bytes = os.path.getsize(self.file_path)
        if file_size_bytes>1000000:
            is_valid = False
            return is_valid, f"Large File. Cannot process. Current file size: {file_size_bytes}>1000000"
        return is_valid, message
    
    def load_document(self):
        if self.file_path.endswith(".txt"):
            loader = TextLoader(self.file_path)
        elif self.file_path.endswith(".pdf"):
            loader = PyPDFLoader(self.file_path)
        documents = loader.load()
        return documents
    
    def split_document(self, documents):
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
        text_chunks = text_splitter.split_documents(documents)
        return text_chunks

    @timeit
    def parse(self):
        documents = self.load_document()
        text_chunks = self.split_document(documents)
        return text_chunks

In [None]:
class Embedding(object):
    def __init__(self, model_name='nomic-embed-text'):
        self.model_name = model_name
    
    def embed(self, doc):
        #returns the embedding of a single chunk
        return ollama.embed(
                model=self.model_name,
                input=doc,
            ).embeddings


class VectorStore(object):
    def __init__(self, db_path="./chroma_db", collection_name="docrag"):
        # calling a persistent client to save and load database from local machine
        self.database_client = chromadb.PersistentClient(path=db_path)
        self.collection_name = collection_name
        # creates a new collection, like a folder for each logically seperate vector stores
        self.collection = self.database_client.get_or_create_collection(name=collection_name)
    
    def add_chunk(self, chunk_id, chunk_text, chunk_embedding):
        # add the chunk's embedding, text into the vector store and maps it to the provided id
        self.collection.add(ids=chunk_id, embeddings=chunk_embedding, documents=chunk_text)
    
    def add_chunks(self, chunk_ids, chunk_texts, chunk_embeddings):
        self.collection.add(ids=chunk_ids, embeddings=chunk_embeddings, documents=chunk_texts)
    
    def retrieve(self, query_embedding, top_k=5):
        # inbuilt query function provided by chromaDB, retrives the top k similar documents from the DB. 
        # by default chrom
        results = self.collection.query(query_embedding, n_results=top_k, include=["documents"])
        return results['documents'][0]

class Indexer(object):
    def __init__(self, vector_store, embedder):
        self.embedder = embedder
        self.vector_store = vector_store

    def create_index(self, docs):
        for chunk in docs:
            text = chunk.page_content
            embedding = self.embedder.embed(text)[0]
            chunk_id = str(hashlib.md5(text.encode()).hexdigest())
            self.vector_store.add_chunk(chunk_id, text, embedding)

        

In [90]:


class LLM(object):
    def __init__(self, model_name='mistral'):
        self.model_name = model_name

    @timeit
    def call(self, user_message=None):
        messages = None
        if user_message: 
            messages = self._add_message('user', user_message)
        response = ollama.chat(model=self.model_name, messages=messages)
        return self._get_content_from_response(response) 

    def _get_content_from_response(self, response):
        return response['message']['content']

    def _add_message(self, role, message):
        messages = [{'role': role, 'content': message}]
        return messages

class DocRAG(object):
    def __init__(self, file_path, llm_model_name):
        self.file_path = file_path
        self.llm_model_name = llm_model_name
        self.llm = LLM(model_name=self.llm_model_name)
        self.document_parser = DocumentParser(file_path=self.file_path)
        self.embedder = Embedding()
        document_identifier = self._get_document_identifier()
        print("Creating collection ", document_identifier)
        self.collection_name = document_identifier
        self.vector_store = VectorStore(collection_name=self.collection_name)
        self.indexer = Indexer(vector_store=self.vector_store, embedder=self.embedder)
        self.process_document()
    
    def _get_document_identifier(self):
        document_identifier = self.file_path.split('/')[-1].replace(' ', '_').replace('.', '_')[:64]
        return document_identifier

    def process_document(self):
        document_chunks = self.document_parser.parse()
        self.indexer.create_index(document_chunks)
    
    def generate_query(self, user_question):
        user_message = f"""You are a query rewriting assistant. Your task is to rewrite the user's question into a search query that will retrieve the most relevant documents from a knowledge base. The name of the file is {self.file_path.split("/")[-1]}
        Instructions:
        - Keep the query concise.
        - Use key concepts, keywords, and entities mentioned in the original question.
        - Do not include phrases like "Find documents about..." or "Search for...".
        - Just output the improved query.
        - Generate only one query dont use OR and AND connecters

        Original Question:
        {user_question}

        Rewritten Search Query:"""
        output = self.llm.call(user_message=user_message)
        return output

    def get_answer(self, query, user_question):
        context = self.query_db(query=query)
        print("Following context retrieved from database", context)
        message = f"""You are a helpful assistant. Use the provided context to answer the user's question.

        If the answer is not present in the context, say "I don't have enough information based on the provided documents."

        Context:
        {context}

        Question:
        {user_question}

        Answer:"""
        output = self.llm.call(user_message=message)
        return output

    
    def query_db(self, query):
        query_embedding = self.embedder.embed(query)
        context = self.vector_store.retrieve(query_embedding=query_embedding)
        return context

    def chat(self, user_question):
        query = self.generate_query(user_question=user_question)
        print(f"Querying database: {query}")
        answer = self.get_answer(query=query, user_question=user_question)
        return answer
    
class DocRAGApp():
    def __init__(self, model_name):
        self.model_name = model_name
        self.doc_rag_pipeline = None

    def process_document(self, file_path):
        self.doc_rag_pipeline = DocRAG(
            file_path=file_path,
            llm_model_name=self.model_name
        )
        return "Documents successfully processed and indexed."
    
    def chat(self, user_question):
        return self.doc_rag_pipeline.chat(user_question=user_question)


In [96]:
import gradio as gr

doc_rag = DocRAGApp('llama3.2:latest')

with gr.Blocks() as demo:
    gr.Markdown("## DocRAG: Local Retrieval-Augmented Chatbot for your docs")

    with gr.Row():
        file_input = gr.File(file_types=[".pdf", ".txt"], label="Upload Document", file_count="single")
        upload_button = gr.Button("Process File")
    
    upload_output = gr.Textbox(label="Upload Status")

    question_input = gr.Textbox(label="Ask a Question")
    ask_button = gr.Button("Get Answer")
    answer_output = gr.Textbox(label="Answer")

    # Upload handler
    upload_button.click(fn=doc_rag.process_document, inputs=[file_input], outputs=[upload_output])

    # Ask handler
    ask_button.click(fn=doc_rag.chat, inputs=[question_input], outputs=[answer_output])

demo.launch()

* Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.


