In [1]:
import os
# Set the USER_AGENT environment variable
os.environ['USER_AGENT'] = 'DocumentQA/1.0'
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader, OnlinePDFLoader
from langchain_chroma import Chroma
from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document


class DocumentProcessor:
    def __init__(self, db_directory="persistent_vector_db"):
        # Initialize or load the persistent vector database from disk
        self.embeddings = OllamaEmbeddings(model="nomic-embed-text", show_progress=True)
        self.vector_db = Chroma(persist_directory=db_directory,
                                embedding_function=self.embeddings,
                                collection_name="persistent-rag-chroma")
        self.uploaded_docs = self._load_existing_docs()  # Initialize with existing documents

    def _load_existing_docs(self):
        # Load existing documents from the vector store if possible
        # Here we assume that document names were stored as metadata
        uploaded_docs = []
        for doc in self.vector_db.get()["documents"]:
            if "metadata" in doc:
                uploaded_docs.append(doc["metadata"].get("source", "Unknown Document"))
        return uploaded_docs

    def load_documents(self, files):
        docs_list = []
        loaders = {
            'pdf': PyPDFLoader,
            'txt': self._load_text_file,
            'docx': OnlinePDFLoader,  # Assuming you have a specific loader for docx files
            # Add more file type loaders here if needed
        }

        for file in files:
            ext = file.name.split('.')[-1].lower()
            loader = loaders.get(ext)
            if loader:
                if ext == 'txt':
                    docs = loader(file.name)
                else:
                    docs = loader(file.name).load()

                # Ensure docs are in the expected format
                for doc in docs:
                    if isinstance(doc, Document):
                        docs_list.append(doc)
                    else:
                        docs_list.append(Document(page_content=doc.get('page_content', doc), metadata=doc.get('metadata', {"source": file.name})))

                # Track the uploaded document by just the name and extension
                file_name = os.path.basename(file.name)  # Get just the file name and extension
                self.uploaded_docs.append(file_name)

        # Process and add the documents to the vector store
        self._process_and_store_documents(docs_list)

    def _load_text_file(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as file:
            text = file.read()
        return [Document(page_content=text, metadata={"source": file_path})]

    def _process_and_store_documents(self, docs_list):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=7500,
            chunk_overlap=100,
        )
        doc_splits = text_splitter.split_documents(docs_list)

        # Add documents to the vector store
        self.vector_db.add_documents(doc_splits)
        # No need to explicitly persist; the vector store should handle it automatically

    def get_retriever(self):
        return self.vector_db.as_retriever()

    def get_uploaded_docs(self):
        return "\n".join(self.uploaded_docs)  # Return a formatted string of uploaded documents


class QuestionAnsweringSystem:
    def __init__(self):
        self.processor = DocumentProcessor()
        self.model_local = ChatOllama(model="mistral")
        self.prompt_template = """Answer the question based only on the following context:
        {context}
        Question: {question}
        """

    def load_and_process(self, files):
        self.processor.load_documents(files)

    def answer_question(self, question):
        retriever = self.processor.get_retriever()
        after_rag_prompt = ChatPromptTemplate.from_template(self.prompt_template)
        after_rag_chain = (
            {"context": retriever, "question": RunnablePassthrough()}
            | after_rag_prompt
            | self.model_local
            | StrOutputParser()
        )
        return after_rag_chain.invoke(question)

    def get_uploaded_docs(self):
        return self.processor.get_uploaded_docs()


# Gradio Interfaces for Uploading Documents and Asking Questions
qa_system = QuestionAnsweringSystem()  # Create a single instance of QuestionAnsweringSystem

def upload_documents(files):
    qa_system.load_and_process(files)
    return "Documents uploaded and processed successfully!", qa_system.get_uploaded_docs()

def query_documents(question):
    return qa_system.answer_question(question)

iface = gr.Blocks()

with iface:
    with gr.Tab("Upload Documents"):
        file_input = gr.File(label="Upload Documents", file_count="multiple", file_types=["pdf", "txt", "docx"])
        upload_button = gr.Button("Upload")
        upload_output = gr.Textbox(label="Upload Status")
        document_list = gr.Textbox(label="Uploaded Documents", interactive=False)
        upload_button.click(upload_documents, inputs=file_input, outputs=[upload_output, document_list])
    
    with gr.Tab("Query Documents"):
        question_input = gr.Textbox(label="Question")
        query_button = gr.Button("Ask")
        query_output = gr.Textbox(label="Answer")
        query_button.click(query_documents, inputs=question_input, outputs=query_output)

iface.launch()


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


