In [1]:
%load_ext kedro.ipython

In [2]:
import json
import os
import glob
import re
from io import BytesIO
import uuid
from pathlib import Path

import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from kedro.config import OmegaConfigLoader
from kedro.framework.project import settings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from pypdf import PdfReader
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings
from tqdm import tqdm

In [3]:
conf_path = str(str(Path(os.getcwd()).parent / settings.CONF_SOURCE))
conf_loader = OmegaConfigLoader(conf_source=conf_path)
credentials = conf_loader["credentials"]
credentials.keys()  # view the available credentials to load

[1;35mdict_keys[0m[1m([0m[1m[[0m[32m'OPENAI_API_KEY'[0m[1m][0m[1m)[0m

In [4]:
embedding_model_name = catalog.load("params:embedding_model_name")

# Load the OpenAI API key
OPENAI_API_KEY = credentials["OPENAI_API_KEY"]
embedding_model = OpenAIEmbeddings(
    model=embedding_model_name, openai_api_key=OPENAI_API_KEY
)

In [5]:
catalog.list()


[1m[[0m
    [32m'docs_dict'[0m,
    [32m'pdfs_dict'[0m,
    [32m'parameters'[0m,
    [32m'params:vector_db'[0m,
    [32m'params:vector_db.path'[0m,
    [32m'params:vector_db.collection_name'[0m,
    [32m'params:websites'[0m,
    [32m'params:pdfs_dir_path'[0m,
    [32m'params:splitter'[0m,
    [32m'params:splitter.chunk_size'[0m,
    [32m'params:splitter.chunk_overlap'[0m,
    [32m'params:splitter.separators'[0m,
    [32m'params:embedding_model_name'[0m
[1m][0m

In [6]:
# Load the parameters for vector database
# which contain the path and collection name
db_params = catalog.load("params:vector_db")

db_path = db_params["path"]
collection_name = db_params["collection_name"]

In [7]:
splitter_params = catalog.load("params:splitter")

chunk_size = splitter_params["chunk_size"]
chunk_overlap = splitter_params["chunk_overlap"]
separators = splitter_params["separators"]

In [8]:
pdfs_dir_path = catalog.load("params:pdfs_dir_path")

In [9]:
pdfs_paths = glob.glob(os.path.join("..", pdfs_dir_path, "*.pdf"))
pdfs_paths

[1m[[0m[32m'../data/01_raw/pdfs/gdm---an-update-on-screening-diagnosis-and-follow-up-[0m[32m([0m[32mmay-2018[0m[32m)[0m[32m.pdf'[0m[1m][0m

In [10]:
def parse_pdf(file: BytesIO) -> list[str]:
    source = file
    pdf = PdfReader(source)
    output = []
    for page in pdf.pages:
        text = page.extract_text()
        # Merge hyphenated words
        text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
        # Fix newlines in the middle of sentences (use negative look behind and look ahead)
        text = re.sub(r"(?<!\n\s)\n(?!\n\s)", " ", text.strip())
        # Remove multiple newlines
        text = re.sub(r"\n\s*\n", "\n\n", text)
        output.append(text)

    return output, source


def text_to_docs(
    text: str | list[str],
    source: str,
    chunk_size: int,
    chunk_overlap: int,
    separators: list[str],
) -> list[Document]:
    if isinstance(text, str):
        # Take a single string as one page
        text = [text]

    page_docs = [Document(page_content=page) for page in text]

    # Add page as metadata
    for i, doc in enumerate(page_docs):
        doc.metadata["source"] = source
        doc.metadata["page"] = i + 1

    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        separators=separators,
        chunk_overlap=chunk_overlap,
    )

    # Split documents into chunks
    data_split = splitter.split_documents(page_docs)

    return data_split


def pdfs_to_docs(
    pdfs_paths: list[str], chunk_size: int, chunk_overlap: int, separators: list[str]
) -> tuple[list[Document], dict]:
    all_data_splits = []

    t = tqdm(pdfs_paths)

    for pdf_path in t:
        t.set_description("Parsing and splitting PDF into document chunks")
        output, source = parse_pdf(pdf_path)
        data_split = text_to_docs(output, source, chunk_size, chunk_overlap, separators)
        all_data_splits.extend(data_split)

    # Convert to JSON serializable format
    pdfs_dict = [dict(ds) for ds in all_data_splits]

    return all_data_splits, pdfs_dict

In [11]:
client = chromadb.Client(
    Settings(
        is_persistent=True,
        persist_directory=str(Path(os.getcwd()).parent / db_path),
    )
)

# Check collections
collections = [collection.name for collection in client.list_collections()]
collections

[1m[[0m[32m'healthcare'[0m[1m][0m

In [12]:
# If collection doesn't exist, we create the collection and index all documents
if collection_name not in collections:
    print(
        f"Collection: {collection_name} does not exist. Creating collection and indexing all documents."
    )
    all_data_splits, pdfs_dict = pdfs_to_docs(
        pdfs_paths, chunk_size, chunk_overlap, separators
    )

    db = Chroma.from_documents(
        all_data_splits,
        embedding_model,
        collection_name=collection_name,
        persist_directory=str(Path(os.getcwd()).parent / db_path),
    )

    # Save the updated docs_dict_to_update to the JSON
    catalog.save("pdfs_dict", pdfs_dict)

In [13]:
# If the collection exists, we want to check if there are
# any new documents. If so, we want to add them to the collection
if collection_name in collections:
    print(
        f"Collection: {collection_name} already exists. Checking for new documents to index into collection."
    )
    collection = client.get_collection(name=collection_name)
    # Get all the websites already in collection
    sources = set([metadata["source"] for metadata in collection.get()["metadatas"]])

    # From the websites, only keep those which do not already appear in the collection
    # (we do not want to index the same website twice)
    new_pdfs = [pdf_path for pdf_path in pdfs_paths if pdf_path not in sources]

    if new_pdfs:
        print(f"Indexing all {len(new_pdfs)} new documents into collection.")
        all_data_splits, new_pdfs_dict = pdfs_to_docs(
            new_pdfs, chunk_size, chunk_overlap, separators
        )

        # Get JSON already saved to be updated with new documents
        try:
            pdfs_dict = catalog.load("pdfs_dict")
        except:
            pdfs_dict = []

        print(f"Before updating: {len(pdfs_dict)}")
        # Extend the docs_dict_to_update with the new documents
        pdfs_dict.extend(new_pdfs_dict)
        print(f"After updating: {len(pdfs_dict)}")

        # Save the updated docs_dict_to_update to the JSON
        # catalog.save("pdfs_dict", pdfs_dict)
        with open(
            os.path.join("..", "data", "02_intermediate", "pdfs.json"),
            "w",
            encoding="utf-8",
        ) as f:
            json.dump(pdfs_dict, f, ensure_ascii=False, indent=4)

        embedding_function = embedding_functions.OpenAIEmbeddingFunction(
            model_name=embedding_model_name, api_key=OPENAI_API_KEY
        )

        documents = [ds.page_content for ds in all_data_splits]
        metadatas = [ds.metadata for ds in all_data_splits]
        embeddings = embedding_function(documents)
        ids = [str(uuid.uuid4()) for _ in embeddings]

        collection.add(
            documents=documents, embeddings=embeddings, metadatas=metadatas, ids=ids
        )

    else:
        print("There are no new documents to index.")

Collection: healthcare already exists. Checking for new documents to index into collection.
There are no new documents to index.


In [None]:
db = Chroma(
    client=client,
    collection_name=collection_name,
    embedding_function=embedding_model,
)

query = "What complications can I prevent from regular checkups?"

docs = db.similarity_search(query, k=3)
docs