In [5]:
import os
from dotenv import load_dotenv

# Required
load_dotenv(override=True)
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')

# Paths / names
CHROMA_PATH = "data/chroma"          # your persistent Chroma directory
COLLECTION_NAME = "legal_cases"      # change if you used a different collection name

# Models
EMBED_MODEL = "text-embedding-3-large"  # must match what you used to build the DB
CHAT_MODEL = "gpt-4o-mini"              # any chat-capable model is fine

# Retrieval
TOP_K = 5

if not OPENAI_API_KEY:
    raise RuntimeError("Set OPENAI_API_KEY in your environment or in this cell.")
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY



RuntimeError: Set OPENAI_API_KEY in your environment or in this cell.

In [None]:
# OpenAI client
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

# Chroma persistent client and collection
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
try:
    collection = chroma_client.get_collection(COLLECTION_NAME)
except Exception:
    collection = chroma_client.get_or_create_collection(
        COLLECTION_NAME, metadata={"hnsw:space": "cosine"}
    )

print("Connected to Chroma collection:", collection.name)

In [None]:
def retrieve(query: str, top_k: int = TOP_K):
    """
    Embeds the query with text-embedding-3-large and queries Chroma.
    Returns (context_text, citations_list).
    """
    # 1) Embed query
    emb = client.embeddings.create(model=EMBED_MODEL, input=query).data[0].embedding

    # 2) Query Chroma
    res = collection.query(
        query_embeddings=[emb],
        n_results=top_k,
        include=["documents", "metadatas", "distances"],
    )
    docs  = (res.get("documents")  or [[]])[0]
    metas = (res.get("metadatas")  or [[]])[0]
    dists = (res.get("distances")  or [[]])[0]

    # 3) Build context and simple citations
    blocks, cites = [], []
    for i, (doc, md, dist) in enumerate(zip(docs, metas, dists)):
        md = md or {}
        source = md.get("source", "unknown")
        chunk_idx = md.get("chunk_index", i)
        start = md.get("start_char")
        end = md.get("end_char")
        tag = f"[{source}#chunk{chunk_idx}]"
        blocks.append(f"{tag}\n{doc}")
        if start is not None and end is not None:
            cites.append(f"{tag}({start}-{end})")
        else:
            cites.append(tag)

    context_text = "\n\n---\n\n".join(blocks)
    return context_text, cites
