In [15]:
# In this notebook we are using Ollama to use and download open-source models like llama | mistral | etc
# We can download and pull embedding models as well
# I downloaded these models on my PC but it takes 15 minutes on an average to respond to user query
# Running this on HPC reduced response time to almost 2 minutes on an average and on HPC we can use bigger models

In [3]:
import ollama
!OLLAMA_ACCELERATE=1


ProgressResponse(status='success', completed=None, total=None, digest=None)

In [None]:
ollama.pull("llama3.2")

In [9]:
ollama.pull("nomic-embed-text")

ProgressResponse(status='success', completed=None, total=None, digest=None)

In [11]:
import gradio as gr
import os
import re
from langchain_community.document_loaders import PyPDFLoader
from langchain_ollama import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_ollama.chat_models import ChatOllama
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_core.output_parsers import StrOutputParser
from markdown import markdown

In [12]:
# Global variables
VECTOR_DB_NAME = "local-rag"
local_model = "llama3.2"
llm = ChatOllama(model=local_model)
vector_db = None

In [13]:
# Function to load PDF
def load_pdf(file_path):
    loader = PyPDFLoader(file_path=file_path)
    return loader.load()

# Function to split text
def split_text(data, chunk_size=1000, chunk_overlap=200):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_documents(data)

# Function to create vector database
def create_vector_db(chunks):
    global vector_db
    if not chunks:
        return "Error: No text extracted from the PDF."
    
    vector_db = Chroma.from_documents(
        documents=chunks,
        embedding=OllamaEmbeddings(model="nomic-embed-text"),
        collection_name=VECTOR_DB_NAME,
        persist_directory="./chroma_db7"
#         persist_directory=None # Set to None for in-memory storage -> so that embeddings are not saved to disk
    )
    return "PDF processed successfully! You can now ask questions."

# Function to set up retriever
def get_retriever():
    if not vector_db:
        return None
    
    query_prompt = ChatPromptTemplate.from_template(
        """Generate 2 alternative versions of the question to improve retrieval:
        Original: {question}"""
    )
    
    return MultiQueryRetriever.from_llm(vector_db.as_retriever(), llm, prompt=query_prompt)

# Function to create RAG chain
def create_rag_chain():
    retriever = get_retriever()
    if not retriever:
        return None
    
    prompt = ChatPromptTemplate.from_template(
        """You are an AI assistant that answers questions based only on the given context.
        Provide a well-structured, coherent, and concise response.

        ### Context:
        {context}

        ### Question:
        {question}

        ### Answer:
        """
    )
    
    return (
        {"context": retriever, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser()
    )

# Function to process user query
def process_query(question, chat_history):
    if not vector_db:
        return "Error: No vector database found. Please upload and process a PDF first.", chat_history

    chain = create_rag_chain()
    if not chain:
        return "Error: Unable to initialize the RAG chain.", chat_history

    response = chain.invoke(question)
    
    # Extract only the answer from the response
    answer_match = re.search(r"### Answer:\s*(.*)", response, re.DOTALL)
    answer = answer_match.group(1).strip() if answer_match else response.strip()
    
    chat_history.append((question, markdown(answer)))
    return "", chat_history

# Function to process PDF upload
def process_pdf(file):
    global vector_db
    if not file:
        return "Please upload a valid PDF file."
    
    # Reset the vector DB before adding new embeddings
    vector_db = None
    # Load and process the PDF
    file_path = file.name
    data = load_pdf(file_path)
    chunks = split_text(data)
    
    return create_vector_db(chunks)

In [14]:
# Gradio UI
def gradio_ui():
    with gr.Blocks(theme="soft") as demo:
        gr.Markdown("<h1 style='text-align: center; color: #4A90E2;'>📖 Conversational AI for PDFs</h1>")

        with gr.Row():
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(label="AI Chat")
                user_input = gr.Textbox(placeholder="Ask me a question...", label="Your Question")

                with gr.Row():
                    ask_button = gr.Button("🔍 Ask", variant="primary")
                    clear_button = gr.Button("🗑️ Clear")

            with gr.Column(scale=1):
                gr.Markdown("### Upload PDFs Here:")
                pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
                status = gr.Textbox(label="Status", interactive=False)

        # Define button actions
        ask_button.click(process_query, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
        clear_button.click(lambda: [], outputs=[chatbot])
        pdf_upload.change(process_pdf, inputs=[pdf_upload], outputs=[status])

    return demo

demo = gradio_ui()
demo.launch(share=True, debug=True)

Running on local URL:  http://127.0.0.1:7866
Running on public URL: https://7cb64de9ad943f2ffb.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2
Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2
Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2
Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2
Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2
Number of requested results 4 is greater than number of elements in index 2, updating n_results = 2


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7866 <> https://7cb64de9ad943f2ffb.gradio.live


