In [None]:
# Imports

import os
import shutil
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.document import Document
from langchain.vectorstores.chroma import Chroma

In [None]:
# Constants

CHROMA_PATH = "./chroma"
DATA_PATH = "./data"

In [None]:
# Clear existing Chroma database by removing the CHROMA_PATH directory

def clear_database():
  if os.path.exists(CHROMA_PATH):
    shutil.rmtree(CHROMA_PATH)
    print(f"Database at {CHROMA_PATH} has been cleared.")

clear_database()

In [None]:
# Load PDF documents from the specified data path using the PyPDFDirectoryLoader.

document_loader = PyPDFDirectoryLoader(DATA_PATH)
documents = document_loader.load()
print(f"Loaded {len(documents)} documents.")

In [None]:
# Split documents into smaller chunks
CHUNK_SIZE=800
CHUNK_OVERLAP=80

text_splitter = RecursiveCharacterTextSplitter(
  chunk_size=CHUNK_SIZE,
  chunk_overlap=CHUNK_OVERLAP,
  length_function=len,
  is_separator_regex=False
)
chunks = text_splitter.split_documents(documents)
print(f"Split documents into {len(chunks)} chunks.")

In [None]:
# Calculates unique chunk IDs based on the document source and page

def calculate_chunk_ids(chunks):
  last_page_id = None
  current_chunk_index = 0

  for chunk in chunks:
    source = chunk.metadata.get("source")
    page = chunk.metadata.get("page")
    current_page_id = f"{source}:{page}"

    if current_page_id == last_page_id:
      current_chunk_index += 1
    else:
      current_chunk_index = 0

    chunk_id = f"{current_page_id}:{current_chunk_index}"
    last_page_id = current_page_id
    chunk.metadata["id"] = chunk_id

  print(f"Calculated chunk IDs for {len(chunks)} chunks.")
  return chunks

In [None]:
from langchain_community.embeddings.ollama import OllamaEmbeddings

# embeddings = OllamaEmbeddings(model="nomic-embed-text")
embeddings = OllamaEmbeddings(model="llama3")

In [None]:
# Add chunks to the Chroma database
# It only adds new chunks based on their IDs.

db = Chroma(
  persist_directory=CHROMA_PATH, 
  embedding_function=embeddings
)

chunks_with_ids = calculate_chunk_ids(chunks)
existing_items = db.get(include=[])
existing_ids = set(existing_items["ids"])

print(f"Number of existing documents in DB: {len(existing_ids)}")

new_chunks = [chunk for chunk in chunks_with_ids if chunk.metadata["id"] not in existing_ids]

if new_chunks:
  print(f"👉 Adding new documents: {len(new_chunks)}")
  new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
  db.add_documents(new_chunks, ids=new_chunk_ids)
  db.persist()
  print(f"✅ Successfully added {len(new_chunks)} new documents.")
else:
  print("✅ No new documents to add")

In [None]:
# Query RAG function

from langchain.prompts import ChatPromptTemplate
from langchain_community.llms.ollama import Ollama

PROMPT_TEMPLATE = """
Answer the question based only on the following context:

{context}

---

Answer the question based on the above context: {question}
"""

# Query Chroma DB using given query text, retrieve relevant chunks, and generate reponse using context

def query_rag(query_text: str):
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)

  results = db.similarity_search_with_score(query_text, k=5)
  context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])

  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
  prompt = prompt_template.format(context=context_text, question=query_text)

  model = Ollama(model="llama3")
  response_text = model.invoke(prompt)

  sources = [doc.metadata.get("id", None) for doc, _score in results]
  formatted_response = f"Response: {response_text}\nSources: {sources}"
  print(formatted_response)
  return response_text


In [None]:
query_rag("What is this document about?")