In [1]:
# !pip install chromadb sentence_transformers

In [2]:
from pathlib import Path

DATA_DIR = Path("../data")
SNAPSHOTS_DIR = DATA_DIR / "platform-docs-snapshots"
VERSIONS_DIR = DATA_DIR / "platform-docs-versions"

# hparams
chunk_size = 1024
chunk_overlap = chunk_size // 10
embedder_name = "BAAI/bge-small-en-v1.5" # https://huggingface.co/spaces/mteb/leaderboard

# Pre-processing

In [3]:
from langchain_community.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

docs = DirectoryLoader(VERSIONS_DIR, glob="[!.]*/[!.]*.md", loader_cls=TextLoader).load()
docs = [d for d in docs if Path(d.metadata['source']) != VERSIONS_DIR / "README.md"]

In [4]:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs_split = text_splitter.split_documents(docs)

In [5]:
def format_document_source(document):
    source = Path(document[0].metadata['source'])
    return f"{source.parts[-2]}: {source.stem}".replace("-", " ")

# Build index

In [6]:
import uuid
from tqdm.notebook import tqdm
import chromadb
from chromadb.utils import embedding_functions
import re

# Chroma uses all-MiniLM-L6-v2 by default
chroma_client = chromadb.PersistentClient()
 
def format_model_name(name):
    # chromaDB only allows these characters
    return re.sub(r'[^a-zA-Z0-9_-]', '_', name)

collection_name = f"DSA_{format_model_name(embedder_name)}"

try:
    collection = chroma_client.get_collection(collection_name)
except ValueError:
    print("Building collection...")
    embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(embedder_name=embedder_name)

    collection = chroma_client.create_collection(name=collection_name, embedding_function=embedding_fn)
    
    text = [d.page_content for d in docs_split]
    metadatas = [d.metadata for d in docs_split]
    ids = [uuid.uuid4().hex for _ in range(len(docs_split))]
    
    collection.add(
        documents=text,
        metadatas=metadatas,
        ids=ids,
    )

In [7]:
results = collection.query(
    query_texts=["How does the TikTok API work?"],
    n_results=5,
)

results

{'ids': [['beba62a4cde34a21bf0d417cb1e5de3b',
   '34c54707c3f0432b9115dffc66f20f27',
   '0ddfb6efa24f4ac7bbc0ea3fb20aa9d4',
   '90eb6b77692e45179702e3c8735f265d',
   '1cf2cf83ee5f4e91a18dc12bdd1e4478']],
 'distances': [[1.585018277168274,
   1.5988060235977173,
   1.6000548601150513,
   1.6002038717269897,
   1.6082261800765991]],
 'metadatas': [[{'source': '../data/platform-docs-versions/X_Twitter-API-Enterprise/Account Activity API.md'},
   {'source': '../data/platform-docs-versions/X_Twitter-API-Enterprise/Account Activity API.md'},
   {'source': '../data/platform-docs-versions/X_Twitter-API-Enterprise/Engagement API.md'},
   {'source': '../data/platform-docs-versions/X_Twitter-API-Enterprise/Account Activity API.md'},
   {'source': '../data/platform-docs-versions/X_Twitter-API-Enterprise/Account Activity API.md'}]],
 'embeddings': None,
 'documents': [["The Account Activity API has a number of benefits:\n\n1. **Speed**: we deliver data at the speed of Twitter.\n2. **Simplicity**: w