# Implementing Simple RAG using ChromaDB and Langchain

Source [Medium](https://medium.com/@callumjmac/implementing-rag-in-langchain-with-chroma-a-step-by-step-guide-16fc21815339)

In [1]:
%pip install unstructured

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Langchain dependencies
from langchain.document_loaders.pdf import PyPDFDirectoryLoader # Importing PDF loader from Langchain
from langchain.text_splitter import RecursiveCharacterTextSplitter # Importing text splitter from Langchain
from langchain_google_genai import GoogleGenerativeAIEmbeddings # Importing OpenAI embeddings from Langchain
from langchain.schema import Document # Importing Document schema from Langchain
from langchain.vectorstores.chroma import Chroma # Importing Chroma vector store from Langchain
from dotenv import load_dotenv # Importing dotenv to get API key from .env file
from langchain.chat_models import ChatOpenAI # Import OpenAI LLM
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import PythonCodeTextSplitter



import os # Importing os module for operating system functionalities
import shutil # Importing shutil module for high-level file operations

  from .autonotebook import tqdm as notebook_tqdm


**Document loaders form [Langchain](https://python.langchain.com/v0.1/docs/modules/data_connection/document_loaders/)**

In [3]:
repo_path = 'test/python-type-hinting-main'
db_name = repo_path.split('/')[-1]
db_name

'python-type-hinting-main'

## Load Documents

In [4]:
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import LanguageParser
from langchain.schema import Document
from langchain.text_splitter import Language

# Custom LanguageParser to catch decoding errors and report the problematic file
class DebugLanguageParser(LanguageParser):
    def lazy_parse(self, blob):
        try:
            # Attempt to read the file content
            code = blob.as_string()
        except UnicodeDecodeError as e:
            # Print the file causing the issue and the error details
            print(f"Encoding error in file: {blob.source}")
            print(f"Error details: {e}")
            return  # Skip this blob due to the encoding issue

        language = self.language or (
            LANGUAGE_EXTENSIONS.get(blob.source.rsplit(".", 1)[-1])
            if isinstance(blob.source, str)
            else None
        )

        if language is None:
            raise ValueError("Could not determine the language for the provided blob.")

        # Create a Document object manually
        yield Document(page_content=code, metadata={"source": blob.source})
        
        
def load_documents(document_path):
    

    # Use the custom parser to catch encoding issues
    document_loader = GenericLoader.from_filesystem(
        document_path,
        glob="**/*",
        suffixes=[".py"],
        parser=DebugLanguageParser(language=Language.PYTHON, parser_threshold=500)
    )
    
    return document_loader.load()

In [5]:
document = load_documents(repo_path)
len(document)

7

## Chunking

In RAG systems, “chunking” refers to the segmentation of input text into shorter and more meaningful units. This enables the system to efficiently pinpoint and retrieve relevant pieces of information. The quality of chunks are **essential** to how effective your system will be.

The most **important** thing to consider when **deciding a chunking strategy** is the structure of the documents that you are loading into your vector database. If the documents contain similar-length paragraphs, it would be useful to consider this when determining the size of the chunk.

----

If your `chunk_size` is too large, then you will be inputing a lot of noisy and unwanted context to the final LLM query. Further, as LLMs are limited by their context window size, the larger the chunk, the fewer pieces of relevant information you can include as context to your query.

The `chunk_overlap` refers to intentionally duplicating tokens at the beginning and end of each adjacent chunk. This helps retain additional context from neighbouring chunks which can improve the quality of the information passed into the prompt.

When deploying these systems to real-world applications, it is important to plot distributions of text lengths in your data, and tune these parameters based on experimentation of parameters such as `chunk_size` and `chunk_overlap`.

**Testing `PythonCodeTextSplitter`**

In [6]:
def split_text(documents: list[Document]):
    
    splitter = PythonCodeTextSplitter(
        chunk_size=200,    # Each chunk will be 500 characters
        chunk_overlap=100   # Overlap of 50 characters between chunks
    )

    # Split documents into smaller chunks using text splitter
    chunks = splitter.split_documents(documents)
    print(f"Split {len(documents)} documents into {len(chunks)} chunks.")

    # Print example of page content and metadata for a chunk
    document = chunks[0]
    print(document.page_content)
    print(document.metadata)

    return chunks # Return the list of split text chunks

## Create VectorDB (Chroma)

It's encourageg experimentation with open-sourced models such as **Llama3** (try the 8B parameter version first, especially if your system needs to be 100% local). If you are working on a very niche or nuanced use case, off-the-shelf embedding models may not be useful. Therefore, you might want to investigate fine-tuning the embedding model on the domain data to improve the retrieval quality.



In [7]:
# Path to the directory to save Chroma database
CHROMA_PATH = db_name

def save_to_chroma(chunks: list[Document]):
    """
    Save the given list of Document objects to a Chroma database.

    Args:
    chunks (list[Document]): List of Document objects representing text chunks to save.
    Returns:
    None
    """

    # Load environment variables
    load_dotenv()
    GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")

    # Ensure API key is loaded correctly
    if not GOOGLE_API_KEY:
        raise ValueError("Google API key not found in environment variables")

    # Clear out the existing database directory if it exists
    if os.path.exists(CHROMA_PATH):
        shutil.rmtree(CHROMA_PATH)  # Ensure all cached files are deleted
        
        while os.path.exists(CHROMA_PATH):  # Extra safeguard to ensure deletion
            pass

    
    os.mkdir(CHROMA_PATH)  # Ensure directory is created if not exists

    # Create a new Chroma database from the documents using GEMINI embeddings
    
    try:
        db = Chroma.from_documents(
            chunks,
            GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY),
            persist_directory=CHROMA_PATH,
            # client_settings={"use_existing_db": False}  # Ensure a new database is created
        )
        # db.persist()  # Ensure the data is written properly
    except Exception as e:
        print(f"Failed to save to Chroma: {e}")
        return

    print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.")

In [8]:
def generate_data_store():
  """
  Function to generate vector database in chroma from documents.
  """
  documents = load_documents(repo_path) # Load documents from a source
  print("Document Loaded!")
  
  chunks = split_text(documents) # Split documents into manageable chunks
  print("Coverted Document into Chunks!")
  
  save_to_chroma(chunks) # Save the processed data to a data store


# Load environment variables from a .env file
load_dotenv()
# Generate the data store
generate_data_store()


Document Loaded!
Split 7 documents into 91 chunks.
from dataclasses import dataclass
from typing import TypedDict


# Define a TypedDict to explicitly type the car_data dictionary
{'source': 'test/python-type-hinting-main/test.py'}
Coverted Document into Chunks!
Saved 91 chunks to python-type-hinting-main.


## Query the VectorDB

In [9]:
query_text = "list down the class names and method names"

PROMPT_TEMPLATE = """
Answer the question based on the following context:
{context}
However this context might contain incomplete statements. So understand the content inherent context to provide insightful answers

====

Answer the question based on the above context: {question}
"""

In [10]:
def query_rag(query_text):
  """
  Query a Retrieval-Augmented Generation (RAG) system using Chroma database and OpenAI.
  Args:
    - query_text (str): The text to query the RAG system with.
    
  Returns:
    - formatted_response (str): Formatted response including the generated text and sources.
    - response_text (str): The generated response text.
  """
  
  load_dotenv()
  GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
  
  # YOU MUST - Use same embedding function as before
  embedding_function = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=GOOGLE_API_KEY)

  # Prepare the database
  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)
  
  # Retrieving the context from the DB using similarity search
  results = db.similarity_search_with_relevance_scores(query_text, k=3)

  # Check if there are any matching results or if the relevance score is too low
  # if len(results) == 0 or results[0][1] < 0.7:
  #   print(f"Unable to find matching results.")

  # Combine context from matching documents
  context_text = "\n\n - -\n\n".join([doc.page_content for doc, _score in results])
  
  print("CONTEXT TEXT: \n", context_text)
 
  # Create prompt template using context and query text
  prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
  prompt = prompt_template.format(context=context_text, question=query_text)
  
  # Initialize OpenAI chat model
  model = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.5, max_tokens=1048, google_api_key=GOOGLE_API_KEY)

  # Generate response text based on the prompt
  response_text = model.invoke(prompt)  # model.predict() but its depricated
 
   # Get sources of the matching documents
  sources = [doc.metadata.get("source", None) for doc, _score in results]
 
  # Format and return response including generated text and sources
  formatted_response = f"\nResponse: {response_text}\nSources: {sources}"
  return formatted_response, response_text

In [11]:
formatted_response, response_text = query_rag(query_text)

  db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)


CONTEXT TEXT: 
 # Lets create a class for our showroom

 - -

# Solution: Import classes used only for annotations when
# typing.TYPE_CHECKING is True (i.e. not at runtime).
if typing.TYPE_CHECKING:
    from b_classes import Showroom

 - -

# We can annotate types outside function definitions
car_models: list[str] = ["bmw", "mercedes", "ferrari"]
car_counts: list[int] = [1, 2, 54]


In [12]:
print(type(response_text.pretty_print()))



The provided context only mentions one class:

* **Showroom:** This class is imported from the module `b_classes`. However, the context doesn't provide any information about its methods. 

Therefore, we can only list the class name: **Showroom**.
<class 'NoneType'>


In [13]:
with open('output2.md', 'w+') as f:
    f.writelines(response_text.content)