In [1]:
import duckdb
import os

if (ENV := os.getenv("ENV")) is None:
    raise RuntimeError("ENV must be set")
if (DUCKDB_PATH := os.getenv("DUCKDB_PATH")) is None:
    raise RuntimeError("DUCKDB_PATH must be set")
if (CHROMA_DB_DIR := os.getenv("CHROMA_DB_DIR")) is None:
    raise RuntimeError("CHROMA_DB_PATH must be set")
if (GEMINI_API_KEY := os.getenv("GEMINI_API_KEY")) is None:
    raise RuntimeError("GEMINI_API_KEY must be set")


In [2]:
from typing import Optional
from datetime import date

from pydantic import BaseModel

class RagRecord(BaseModel):
    doc_id: int
    document_title: str
    last_modified: date
    document_text: str
    ticket_id: Optional[str]
    priority: Optional[str]
    department: Optional[str]
    resolution_summary: Optional[str]
    rag_chunk: str

    def to_metadata(self) -> "ChunkMetadata":
        return ChunkMetadata(
            **self.model_dump(include=ChunkMetadata.model_fields.keys())
        )


class ChunkMetadata(BaseModel):
    doc_id: int
    document_title: str
    last_modified: date
    ticket_id: Optional[str]
    priority: Optional[str]
    department: Optional[str]
    resolution_summary: Optional[str]


In [3]:
import google.generativeai as genai


genai.configure(api_key=GEMINI_API_KEY)


def stream_rag_records(db_path: str, batch_size: int=10) -> list[RagRecord]:
    """
    Stream rows from the docs_for_rag table in DuckDB in fixed-size batches.

    The query cursor fetches up to `batch_size` rows at a time using `fetchmany`,
    which avoids loading the entire table into memory. Each batch is converted
    into a list of RagRecord objects and yielded to the caller.

    The `while True` loop terminates naturally when DuckDB returns an empty list,
    signalling that all rows have been consumed.
    """

    sql = f"SELECT * FROM {ENV}_marts.docs_for_rag"
    con = duckdb.connect(db_path, read_only=True)
    result = con.execute(sql)
    col_names = [c[0] for c in result.description]

    while True:
        rows = result.fetchmany(batch_size)
        if not rows:
            break

        yield [RagRecord(**dict(zip(col_names, row))) for row in rows]

    con.close()


def embed_texts(texts: list[str]) -> list[list[float]]:
    """Embed a batch of strings using Gemini embeddings."""
    
    response = genai.embed_content(
        model="gemini-embedding-001",
        content=texts
    )
    
    return response['embedding']


def chunk_text(text: str, max_words=500) -> list[str]:
    """
    Simple word-based chunker. Splits text into chunks of up to `max_words`
    while preserving word boundaries. Suitable for small datasets and demo RAG.
    """
    words = text.split()
    if not words:
        return []

    chunks = []
    for i in range(0, len(words), max_words):
        chunk_words = words[i : i + max_words]
        chunks.append(" ".join(chunk_words))

    return chunks
    

    
        

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import chromadb


client = chromadb.PersistentClient(path=os.getenv("CHROMA_DB_DIR"))


collection = client.get_or_create_collection(
    name="policy_rag",
    metadata={"hnsw:space": "cosine"} 
)


for records in stream_rag_records(DUCKDB_PATH):
    all_chunks = []
    all_ids = []
    all_metadatas = []

    for rec in records:
        chunks = chunk_text(rec.rag_chunk)

        for i, chunk in enumerate(chunks):
            meta = rec.to_metadata().model_dump(mode="json")
            meta["num_chunks"] = len(chunks)
            meta["chunk_index"] = i
            all_chunks.append(chunk)
            all_ids.append(f"{rec.doc_id}_{rec.ticket_id}_{i}")
            all_metadatas.append(meta)
            
    embeddings = embed_texts(all_chunks) 
    
    collection.add(
                ids=all_ids,
                embeddings=embeddings,
                documents=all_chunks,
                metadatas=all_metadatas
    )

    

In [5]:
# sanity check that number of records pulled from database equals number of entries in chroma.
# We'd expect this not to match if there were larger chunks, but in the provided demo data it should be the same

con = duckdb.connect(DUCKDB_PATH, read_only=True)
assert len(con.query(f"SELECT * FROM {ENV}_marts.docs_for_rag").fetchall()) == collection.count()

In [6]:
# Test metadata parsing/upload: Roll up data in mart and collection to check they are the same

from collections import Counter

import duckdb
import pandas as pd


def duckdb_rollup(db_path="/app/database/mydb.duckdb"):
    con = duckdb.connect(db_path, read_only=True)
    df = con.execute(
        f"""
        SELECT department, COUNT(*) AS duckdb_count
        FROM {ENV}_marts.docs_for_rag
        GROUP BY department
        ORDER BY duckdb_count DESC
        """
    ).df()
    con.close()
    return df


def chroma_rollup(collection, batch_size=500):
    dept_counter = Counter()
    offset = 0

    while True:
        res = collection.get(
            include=["metadatas"],
            limit=batch_size,
            offset=offset
        )

        metadatas = res.get("metadatas", [])
        if not metadatas:
            break

        for m in metadatas:
            dept = m.get("department")
            if dept is not None:
                dept_counter[dept] += 1

        offset += batch_size

    df = pd.DataFrame(
        [{"department": d, "chroma_count": c} for d, c in dept_counter.items()]
    ).sort_values("chroma_count", ascending=False)
    return df

from IPython.display import display, HTML


duckdb_df = duckdb_rollup()
chroma_df = chroma_rollup(collection)


html = f"""
<div style="display:flex; gap:40px;">
    <div>{duckdb_df.to_html(index=False)}</div>
    <div>{chroma_df.to_html(index=False)}</div>
</div>
"""

display(HTML(html))

department,duckdb_count
HR,5
Sales,3
Engineering,2
Finance,2
IT,2
Legal,1
Marketing,1
Security,1
Procurement,1
Compliance,1

department,chroma_count
HR,5
Sales,3
Finance,2
Engineering,2
IT,2
Legal,1
Compliance,1
DevOps,1
Procurement,1
Marketing,1


In [7]:
# Search function to test semantic search

def search(query: str, top_k: int = 1, where: dict | None = None) -> list[dict]:
    embedding = embed_texts([query])[0]

    
    result = collection.query(
        query_embeddings=[embedding],
        n_results=top_k,
        where=where
    )

    return {
        "documents": result["documents"][0],
        "metadatas": result["metadatas"][0],
        "distances": result["distances"][0],
        "ids": result["ids"][0],
    }

In [8]:
# Investigate unique values for filtering

con.query(f"select distinct department from {ENV}_marts.docs_for_rag").fetchall()

[('IT',),
 ('Compliance',),
 ('Sales',),
 ('Finance',),
 ('HR',),
 ('Security',),
 ('Engineering',),
 ('Procurement',),
 ('DevOps',),
 ('Legal',),
 ('Marketing',)]

In [9]:
search("hotel policy", 2)

{'documents': ['Policy "Employee Travel & Expense Guidelines" (1002) states: Employees are limited to $150/night for hotel stays. Receipts must be uploaded within 14 days of travel completion. Failure to submit on time will result in a 30-day reimbursement delay. Airfare is strictly Economy class. Related ticket TKT-4522 (MEDIUM, Finance) was resolved as: User queried the $150 hotel limit. Confirmed policy applies to all domestic travel.',
  'Policy "Employee Travel & Expense Guidelines" (1002) states: Employees are limited to $150/night for hotel stays. Receipts must be uploaded within 14 days of travel completion. Failure to submit on time will result in a 30-day reimbursement delay. Airfare is strictly Economy class. Related ticket TKT-4525 (HIGH, Sales) was resolved as: Disputed an Economy class booking for a long-haul flight. Policy confirmed no exceptions.'],
 'metadatas': [{'department': 'Finance',
   'last_modified': '2024-11-01',
   'ticket_id': 'TKT-4522',
   'document_title'

In [10]:
search("how much monthly paid time off do employees get", 1, {"department": "HR"})

{'documents': ['Policy "Paid Time Off (PTO) Accrual" (1007) states: Full-time employees accrue 1.67 days of PTO per month. There is a maximum rollover of 5 days into the next calendar year. Sick time is managed separately and does not count against this balance. Related ticket TKT-4529 (LOW, HR) was resolved as: Employee asked about maximum PTO rollover. Confirmed 5-day limit.'],
 'metadatas': [{'doc_id': 1007,
   'last_modified': '2024-10-01',
   'num_chunks': 1,
   'department': 'HR',
   'priority': 'LOW',
   'document_title': 'Paid Time Off (PTO) Accrual',
   'resolution_summary': 'Employee asked about maximum PTO rollover. Confirmed 5-day limit.',
   'ticket_id': 'TKT-4529',
   'chunk_index': 0}],
 'distances': [0.2757105827331543],
 'ids': ['1007_TKT-4529_0']}