In [1]:
import os

if "GOOGLE_API_KEY" not in os.environ:
    with open("./.env", "r") as mykey:        
        os.environ["GOOGLE_API_KEY"] = mykey.read().strip()

In [2]:
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain_community.document_loaders import TextLoader
import os
from tqdm import tqdm

DATA_PATH = "./data/"

def load_pdfs():
    if not os.path.exists(DATA_PATH):
        raise FileNotFoundError(f"The directory '{DATA_PATH}' does not exist.")
    
    #pdf_loader = PyPDFDirectoryLoader(DATA_PATH)    
    documents = []
    for pdf_path in tqdm(os.listdir(DATA_PATH)):
        try:
            if pdf_path.endswith(".pdf"):
                loader = PyPDFDirectoryLoader(DATA_PATH, pdf_path)
                documents.extend(loader.load())
        except Exception as e:
            print(f"Failed to process {pdf_path}: {e}")
    return documents

def load_text():
    text_loader = TextLoader("./data/5400Notes.txt")
    return text_loader.load()

In [3]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.schema.document import Document

def split_pdfs(documents: list[Document]):
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = 800,
        chunk_overlap = 80,
        length_function = len,
        is_separator_regex = False,
    )
    return text_splitter.split_documents(documents)

In [9]:
# Vector Database
from langchain_chroma import Chroma

CHROMA_PATH = "chroma"

def calculate_chunk_ids(chunks):

    # This will create IDs like "data/monopoly.pdf:6:2"
    # Page Source : Page Number : Chunk Index

    last_page_id = None
    current_chunk_index = 0

    for chunk in tqdm(chunks):
        source = chunk.metadata.get("source")
        page = chunk.metadata.get("page")
        current_page_id = f"{source}:{page}"

        # If the page ID is the same as the last one, increment the index.
        if current_page_id == last_page_id:
            current_chunk_index += 1
        else:
            current_chunk_index = 0

        # Calculate the chunk ID.
        chunk_id = f"{current_page_id}:{current_chunk_index}"
        last_page_id = current_page_id

        # Add it to the page meta-data.
        chunk.metadata["id"] = chunk_id

    return chunks


def add_to_chroma(chunks: list[Document]):
    # Load the existing database.
    db = Chroma(
        persist_directory=CHROMA_PATH, embedding_function=get_embedding_function()
    )

    # Calculate Page IDs.
    chunks_with_ids = calculate_chunk_ids(chunks)

    # Add or Update the documents.
    existing_items = db.get(include=[])  # IDs are always included by default
    existing_ids = set(existing_items["ids"])
    print(f"Number of existing documents in DB: {len(existing_ids)}")

    # Only add documents that don't exist in the DB.
    new_chunks = []
    for chunk in tqdm(chunks_with_ids):
        if chunk.metadata["id"] not in existing_ids:
            new_chunks.append(chunk)

    if len(new_chunks):
        print(f"Adding new documents: {len(new_chunks)}")
        new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks]
        db.add_documents(new_chunks, ids=new_chunk_ids)
    else:
        print("No new documents to add")


def clear_database():
    if os.path.exists(CHROMA_PATH):
        shutil.rmtree(CHROMA_PATH)

In [5]:
# from langchain_community.embeddings.ollama import OllamaEmbeddings
# from langchain_community.embeddings.bedrock import BedrockEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings

# Embedding Function: Used when creating the DB, or making a query.
def get_embedding_function():
    # Bedrock Embeddings for AWS Deploy
    # embeddings = BedrockEmbeddings(
    #    credentials_profile_name="default", region_name="us-east-1"
    #)
    # Ollama Embeddings for Local Run
    # Install and 'ollama pull llama2|mistral' to deploy.
    # Use 'ollama serve' for restful API
    # embeddings = OllamaEmbeddings(model="nomic-embed-text")
    # 
    embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
    return embeddings

In [None]:
pdfs = load_pdfs()
# print(pdfs[0])
chunks = split_pdfs(pdfs)
# print(len(chunks))
add_to_chroma(chunks)

In [25]:
query_rag("What's the differences between eq? and equal?")

'Based on the provided text, `eq?` checks for pointer equality at a low level, exposing implementation details of Racket and contracts.  `eqv?` (and therefore `equal?`) compares numbers considering both exactness and numerical equality, unlike `=` which converts inexact numbers to exact before comparing.  For characters, `eqv?` and `equal?` behave the same as `char=?`, which performs a case-sensitive comparison.  `char-ci=?` ignores case.  If `eq?` returns true, the two values behave identically in all respects.\n'

In [None]:
from langchain_community.llms.ollama import Ollama

EVAL_PROMPT = """
Expected Response: {expected_response}
Actual Response: {actual_response}
---
(Answer with 'true' or 'false') Does the actual response match the expected response? 
"""


def test_lambda_rules():
    assert query_and_validate(
        question="How to write a lambda in Racket?",
        expected_response="(lambda())",
    )


def test_ticket_to_ride_rules():
    assert query_and_validate(
        question="What function is similar to equal?",
        expected_response="eq?",
    )


def query_and_validate(question: str, expected_response: str):
    response_text = query_rag(question)
    prompt = EVAL_PROMPT.format(
        expected_response=expected_response, actual_response=response_text
    )

    model = Ollama(model="mistral")
    evaluation_results_str = model.invoke(prompt)
    evaluation_results_str_cleaned = evaluation_results_str.strip().lower()

    print(prompt)

    if "true" in evaluation_results_str_cleaned:
        # Print response in Green if it is correct.
        print("\033[92m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m")
        return True
    elif "false" in evaluation_results_str_cleaned:
        # Print response in Red if it is incorrect.
        print("\033[91m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m")
        return False
    else:
        raise ValueError(
            f"Invalid evaluation result. Cannot determine if 'true' or 'false'."
        )

In [None]:
def cli():
    # Create CLI.
    parser = argparse.ArgumentParser()
    parser.add_argument("query_text", type=str, help="The query text.")
    args = parser.parse_args()
    query_text = args.query_text
    query_rag(query_text)

In [None]:
cli()