In [1]:
import os
import sys
import glob
from typing import List
from dotenv import load_dotenv
import argparse
from multiprocessing import Pool
from tqdm import tqdm
from langchain.vectorstores import Chroma

from langchain.document_loaders import (
    CSVLoader,
    PDFMinerLoader,
    TextLoader,
    UnstructuredHTMLLoader,
    UnstructuredWordDocumentLoader,
    Docx2txtLoader
)

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from db_config import CHROMA_SETTINGS


load_dotenv()

LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    ".docx": (Docx2txtLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".pdf": (PDFMinerLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
}


def load_single_document(file_path: str) -> Document:
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADER_MAPPING:
        loader_class, loader_args = LOADER_MAPPING[ext]
        loader = loader_class(file_path, **loader_args)
        return loader.load()[0]

    raise ValueError(f"Unsupported file extension '{ext}'")


def load_documents(source_dir: str) -> List[Document]:
    all_files = []
    for ext in LOADER_MAPPING:
        all_files.extend(
            glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
        )

    with Pool(processes=os.cpu_count()) as pool:
        results = []
        with tqdm(total=len(all_files), desc='Loading documents', ncols=80) as pbar:
            for i, doc in enumerate(pool.imap_unordered(load_single_document, all_files)):
                results.append(doc)
                pbar.update()

    return results


def does_vectorstore_exist(persist_directory: str) -> bool:
    if os.path.exists(os.path.join(persist_directory, 'index')):
        if os.path.exists(os.path.join(persist_directory, 'chroma-collections.parquet')) and os.path.exists(os.path.join(persist_directory, 'chroma-embeddings.parquet')):
            list_index_files = glob.glob(os.path.join(persist_directory, 'index/*.bin'))
            list_index_files += glob.glob(os.path.join(persist_directory, 'index/*.pkl'))
            if len(list_index_files) > 3:
                return True
    return False


def main(collection):
    persist_directory = os.environ.get('PERSIST_DIRECTORY')
    source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
    embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
    os.makedirs(source_directory, exist_ok=True)

    print(f"Loading documents from {source_directory}")
    chunk_size = 500
    chunk_overlap = 50
    documents = load_documents(source_directory)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    texts = text_splitter.split_documents(documents)
    print(f"Loaded {len(documents)} documents from {source_directory}")
    print(f"Split into {len(texts) if texts else 0} chunks of text (max. {chunk_size} characters each)")

    embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)

    if does_vectorstore_exist(persist_directory):
        print(f"Appending to existing vectorstore at {persist_directory}")
        db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
        db.add_documents(texts)
    else:
        print("Creating new vectorstore")
        db = Chroma.from_documents(texts, embeddings, collection_name=collection, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)

    db.persist()

    from chromaviz import visualize_collection
    visualize_collection(db._collection)


if __name__ == "__main__":
    try:
        parser = argparse.ArgumentParser()
        parser.add_argument("--collection", help="Saves the embedding in a collection name as specified")
        args = parser.parse_args()
        main(args.collection)
    except Exception as e:
        print(f"Error: {str(e)}")
        sys.exit(1)


usage: ipykernel_launcher.py [-h] [--collection COLLECTION]
ipykernel_launcher.py: error: unrecognized arguments: -f C:\Users\shaund\AppData\Roaming\jupyter\runtime\kernel-abb3ddea-8ad6-40b0-af3c-72e33845b8c6.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [2]:
from langchain.document_loaders import (
    CSVLoader,
    PDFMinerLoader,
    TextLoader,
    UnstructuredHTMLLoader,
    UnstructuredWordDocumentLoader,
    Docx2txtLoader
)

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from db_config import CHROMA_SETTINGS


load_dotenv()


# Map file extensions to document loaders and their arguments
LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    ".docx": (Docx2txtLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".pdf": (PDFMinerLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
}


def load_single_document(file_path: str) -> Document:
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADER_MAPPING:
        loader_class, loader_args = LOADER_MAPPING[ext]
        loader = loader_class(file_path, **loader_args)
        return loader.load()[0]

    raise ValueError(f"Unsupported file extension '{ext}'")


In [3]:
def load_documents(source_dir: str) -> List[Document]:
    # Loads all documents from source documents directory
    all_files = []
    for ext in LOADER_MAPPING:
        all_files.extend(
            glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
        )
    return [load_single_document(file_path) for file_path in all_files]

In [None]:
persist_directory = os.environ.get('PERSIST_DIRECTORY')
source_directory = os.environ.get('SOURCE_DIRECTORY', 'source_documents')
embeddings_model_name = os.environ.get('EMBEDDINGS_MODEL_NAME')
os.makedirs(source_directory, exist_ok=True)
# Load documents and split in chunks
print(f"Loading documents from {source_directory}")
chunk_size = 500
chunk_overlap = 50
documents = load_documents(source_directory)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
texts = text_splitter.split_documents(documents)
print(f"Loaded {len(documents)} documents from {source_directory}")
print(f"Split into {len(texts)} chunks of text (max. {chunk_size} characters each)")

# Create embeddings
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)

# Create and store locally vectorstore
db = Chroma.from_documents(texts, embeddings, persist_directory=persist_directory, client_settings=CHROMA_SETTINGS)
db.persist()

# print(f"Number of generated embeddings: {embeddings}")




In [5]:
# Get all embeddings
db._collection.get(include=['embeddings'])

# Get embeddings by document_id


{'ids': ['9e04acba-37e8-11ee-a2fe-e0d04558c3ab',
  '9e04acbb-37e8-11ee-9e58-e0d04558c3ab',
  '9e04acbc-37e8-11ee-bf94-e0d04558c3ab',
  '9e04acbd-37e8-11ee-8a23-e0d04558c3ab',
  '9e04acbe-37e8-11ee-ac49-e0d04558c3ab',
  '9e04acbf-37e8-11ee-be16-e0d04558c3ab',
  '9e04acc0-37e8-11ee-b7a8-e0d04558c3ab',
  '9e04acc1-37e8-11ee-9d66-e0d04558c3ab',
  '9e04acc2-37e8-11ee-a029-e0d04558c3ab',
  '9e04acc3-37e8-11ee-9941-e0d04558c3ab',
  '9e04acc4-37e8-11ee-b95a-e0d04558c3ab',
  '9e04acc5-37e8-11ee-a28a-e0d04558c3ab',
  '9e04acc6-37e8-11ee-8481-e0d04558c3ab',
  '9e04acc7-37e8-11ee-a3f9-e0d04558c3ab',
  '9e04acc8-37e8-11ee-ba6d-e0d04558c3ab',
  '9e04acc9-37e8-11ee-b507-e0d04558c3ab',
  '9e04acca-37e8-11ee-8a82-e0d04558c3ab',
  '9e04accb-37e8-11ee-8f8b-e0d04558c3ab',
  '9e04accc-37e8-11ee-8f41-e0d04558c3ab',
  '9e04accd-37e8-11ee-b15c-e0d04558c3ab',
  '9e04acce-37e8-11ee-9355-e0d04558c3ab',
  '9e04accf-37e8-11ee-bfe2-e0d04558c3ab',
  '9e04acd0-37e8-11ee-8577-e0d04558c3ab',
  '9e04acd1-37e8-11ee-a343-

In [15]:
db._collection.get(ids=['9e04acbb-37e8-11ee-9e58-e0d04558c3ab',
  'a4538379-37eb-11ee-9e3f-e0d04558c3ab',
  'a453837a-37eb-11ee-b7dc-e0d04558c3ab',
  'a453837b-37eb-11ee-87d2-e0d04558c3ab',

  'a45383db-37eb-11ee-9e09-e0d04558c3ab'
], include=['embeddings'])

{'ids': ['9e04acbb-37e8-11ee-9e58-e0d04558c3ab',
  'a4538379-37eb-11ee-9e3f-e0d04558c3ab',
  'a453837a-37eb-11ee-b7dc-e0d04558c3ab',
  'a453837b-37eb-11ee-87d2-e0d04558c3ab',
  'a45383db-37eb-11ee-9e09-e0d04558c3ab'],
 'embeddings': None,
 'metadatas': None,
 'documents': None}

In [9]:
from chromaviz import visualize_collection
visualize_collection(db._collection)

In [11]:
print(db._collection.count())


650
