In [1]:
from google.genai import types

retry_config = types.HttpRetryOptions(
    attempts= 5,
    exp_base = 0.2,
    initial_delay = 0.5,
    http_status_codes = [500, 502, 503, 504],   
)

In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

LLAMA_CLOUD_API_KEY = os.getenv("LLAMA_CLOUD_API_KEY")
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
#OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

print("API keys loaded" if LLAMA_CLOUD_API_KEY and GEMINI_API_KEY else "✗ Missing API keys")

API keys loaded


In [3]:
FILINGS= [
    {"id":"1","company":"Apple","year":"2023","path": r"C:\Users\rushy\Downloads\FINBOT\GenAI_FInBot\NOV_2023.pdf"},
    {"id":"2","company":"Tesla","year":"2023","path": r"C:\Users\rushy\Downloads\FINBOT\GenAI_FInBot\Tesla_2023.pdf"},
]

In [4]:
embedding_models = [
    #{"provider":"OpenAI","model_name":"text-embedding-3-small","api_key_env":"OPENAI_API_KEY","dimensions":1536},
    {"provider":"Google","model_name":"text-embedding-004","api_key_env":"GEMINI_API_KEY","dimensions":768},
    {"provider":"Local","model_name":"all-MiniLM-L6-v2","api_key_env":None,"dimensions":384},
    {"provider":"Local","model_name":"intfloat/e5-large-v2","api_key_env":None,"dimensions":1024},
]


Need to check
1. Latency p90
2. Recall p50
3. Semantic Quality


In [5]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

2.5.1+cu121
12.1
True


Pipleline Skeleton
1. pages = extract_pages_text(pdf_paths)
2. tables = extract_tables(pdf_paths)
3. docs = build_docs(pages,tables)
4. chunks = chuk_docs(docs)
5. index(chuks,embedding_model_cfg)
6. retrieve(query) -> context
7. generate(context,query)

In [6]:
import fitz
import re
from pathlib import Path

def clean_text(text:str) -> str:
    return re.sub(r"[^a-zA-Z0-9\s]", "", text).strip()


def extract_pages_from_filings(filings):
    all_pages = []

    for filing in filings:
        doc = fitz.open(filing['path'])
        source_name = Path(filing['path']).name

        for page_idx,page in enumerate(doc,start=1):
            text = page.get_text("text")
            all_pages.append({
                "filing_id":filing['id'],
                "company": filing['company'],
                "year":filing['year'],
                "source": source_name,
                "page":page_idx,
                "text":clean_text(text),
            })

        doc.close()

    return all_pages



In [7]:
pages = extract_pages_from_filings(FILINGS)

print(pages[100])
print(len(pages))

{'filing_id': '2', 'company': 'Tesla', 'year': '2023', 'source': 'Tesla_2023.pdf', 'page': 21, 'text': 'Our information technology systems or data or those of our service providers or customers or users could be subject to cyber\nattacks or other security incidents which could result in data breaches intellectual property theft claims litigation regulatory \ninvestigations significant liability reputational damage and other adverse consequences \nWe continue to expand our information technology systems as our operations grow such as product data management procurement inventory \nmanagement production planning and execution sales service and logistics dealer management financial tax and regulatory compliance systems \nThis includes the implementation of new internally developed systems and the deployment of such systems in the US and abroad While we maintain \ninformation technology measures designed to protect us against intellectual property theft data breaches sabotage and other ext

In [8]:
import pdfplumber
import pandas as pd


def extract_tables_from_filings(filings,*,max_tables_per_page=None):
    all_tables = []
    for filing in filings:
        pdf_path = filing['path']
        source_name = Path(pdf_path).name

        with pdfplumber.open(pdf_path) as pdf:
            for page_idx,page in enumerate(pdf.pages,start=1):
                tables = page.extract_tables()

                if not tables:
                    continue

                if max_tables_per_page is not None:
                    tables = tables[:max_tables_per_page]

                for t_idx,table in enumerate(tables):
                    df = pd.DataFrame(table)

                    all_tables.append({
                        "filing_id":filing['id'],
                        "company":filing['company'],
                        "year":filing['year'],
                        "source":source_name,
                        "page":page_idx,
                        "table_id":f"{source_name}_p{page_idx}_t{t_idx}",
                        "df": df
                    })
    return all_tables

In [9]:
tables = extract_tables_from_filings(FILINGS)
len(tables)

151

In [10]:
tables[11]['page']

26

In [11]:
tables[10]['df']

Unnamed: 0,0,1,2,3,4,5,6
0,Products,$,108803.0,,"$ 114,728",$,105126.0
1,Services,60345,,,56054,47710,
2,Total gross margin,$,169148.0,,"$ 170,782",$,152836.0


In [12]:
def table_to_text(df):
    df = df.fillna("").astype(str)
    lines = []
    for row in df.values.tolist():
        row = [cell.strip() for cell in row if cell.strip()]
        lines.append(" | ".join(row))

    return "\n".join(lines)



table_text = table_to_text(tables[11]["df"])
print(table_text)


Products | 36.5 | % | 36.3 | % | 35.3 | %
Services | 70.8 | % | 71.7 | % | 69.7 | %
Total gross margin percentage | 44.1 | % | 43.3 | % | 41.8 | %


visit agian


In [13]:
import re

def extract_year_headers(page_text, max_years=6):
    """Find unique years like 2023, 2022, 2021 in the page text (keeps order)."""
    years = re.findall(r"\b(20\d{2})\b", page_text)
    seen = []
    for y in years:
        if y not in seen:
            seen.append(y)
    return seen[:max_years]

def attach_years_if_possible(df, years):
    """
    If we have years and the row looks like it has 3 numeric values,
    produce 'Label | 2023: x | 2022: y | 2021: z' style lines.
    Otherwise fallback to plain table_to_text().
    """
    df = df.fillna("").astype(str)

    # If no years, just return normal table text
    if not years:
        return table_to_text(df)

    lines = []
    for row in df.values.tolist():
        # keep non-empty cells
        row = [c.strip() for c in row if str(c).strip()]

        if not row:
            continue

        label = row[0]

        # Remove "$" and "%" tokens from values so we can pair cleanly with years
        vals = [c for c in row[1:] if c not in {"$", "%"}]

        # If values count matches number of years, pair them
        if len(vals) >= len(years) and len(years) >= 2:
            pairs = [f"{y}: {vals[i]}" for i, y in enumerate(years)]
            lines.append(label + " | " + " | ".join(pairs))
        else:
            # fallback: keep original row
            lines.append(" | ".join(row))

    return "\n".join(lines)

def build_table_docs(tables, pages):
    """
    tables: output of extract_tables_from_filings
    pages:  output of Step 1 (page-by-page text extraction)
            IMPORTANT: pages must contain dicts with keys: company, year, page, text
    returns: list of table documents (embedding-ready later)
    """
    # Create a quick lookup for page text by (company, year, page)
    page_text_lookup = {(p["company"], p["year"], p["page"]): p["text"] for p in pages}

    table_docs = []
    for t in tables:
        page_text = page_text_lookup.get((t["company"], t["year"], t["page"]), "")
        years = extract_year_headers(page_text)

        table_body = attach_years_if_possible(t["df"], years)

        table_docs.append({
            "id": t["table_id"],
            "type": "table",
            "filing_id": t["filing_id"],
            "company": t["company"],
            "year": t["year"],
            "source": t["source"],
            "page": t["page"],
            "years_detected": years,  # helpful for debugging
            "text": (
                f"[TABLE]\n"
                f"Company: {t['company']}\n"
                f"FilingYear: {t['year']}\n"
                f"Source: {t['source']}\n"
                f"Page: {t['page']}\n"
                f"Years: {', '.join(years) if years else 'N/A'}\n\n"
                f"{table_body}"
            )
        })

    return table_docs


In [14]:
table_docs = build_table_docs(tables, pages)

print("Total table docs:", len(table_docs))

# preview one
print("\n--- Sample Table Doc ---")
print("ID:", table_docs[10]["id"])
print("Meta:", table_docs[10]["company"], table_docs[10]["year"], "page", table_docs[10]["page"])
print(table_docs[10]["text"][:800])


Total table docs: 151

--- Sample Table Doc ---
ID: NOV_2023.pdf_p26_t0
Meta: Apple 2023 page 26
[TABLE]
Company: Apple
FilingYear: 2023
Source: NOV_2023.pdf
Page: 26
Years: 2023, 2022, 2021

Products | 2023: 108,803 | 2022: $ 114,728 | 2021: 105,126
Services | 2023: 60,345 | 2022: 56,054 | 2021: 47,710
Total gross margin | 2023: 169,148 | 2022: $ 170,782 | 2021: 152,836


In [15]:
#Cross page overlap

def add_cross_page_overlap(pages, tail_chars=400):
    """
    For each page, prepend last N chars from previous page text.
    Keeps metadata the same (source/page).
    """
    out = []
    prev_tail = ""
    for p in pages:
        merged_text = (prev_tail + " " + p["text"]).strip()
        prev_tail = p["text"][-tail_chars:] if p["text"] else ""
        out.append({**p, "text": merged_text})
    return out


In [16]:
# chunk page text
def chunk_text(text, chunk_size=1200, overlap=200):
    chunks = []
    start = 0
    n = len(text)

    while start < n:
        end = min(n, start + chunk_size)
        chunks.append(text[start:end])
        if end == n:
            break
        start += (chunk_size - overlap)

    return chunks


In [17]:
# chreate text chunks from pages
def build_text_chunks(pages, chunk_size=1200, overlap=200, tail_chars=400):
    pages_merged = add_cross_page_overlap(pages, tail_chars=tail_chars)

    text_chunks = []
    for p in pages_merged:
        if not p["text"]:
            continue

        splits = chunk_text(p["text"], chunk_size=chunk_size, overlap=overlap)
        for i, s in enumerate(splits):
            text_chunks.append({
                "id": f"{p['source']}_p{p['page']}_text_c{i}",
                "type": "text",
                "company": p["company"],
                "year": p["year"],
                "source": p["source"],
                "page": p["page"],
                "text": s
            })
    return text_chunks


In [18]:
#create table chunks

def build_table_chunks(table_docs, max_chars=3500):
    table_chunks = []
    for t in table_docs:
        txt = t["text"].strip()
        if not txt:
            continue

        # If table text is short enough, keep as one chunk
        if len(txt) <= max_chars:
            table_chunks.append({
                "id": t["id"],
                "type": "table",
                "company": t["company"],
                "year": t["year"],
                "source": t["source"],
                "page": t["page"],
                "text": txt
            })
            continue

        # Else split table by lines into groups
        lines = txt.splitlines()
        buf, part = [], 0
        cur_len = 0
        for line in lines:
            if cur_len + len(line) + 1 > max_chars and buf:
                table_chunks.append({
                    "id": f"{t['id']}_part{part}",
                    "type": "table",
                    "company": t["company"],
                    "year": t["year"],
                    "source": t["source"],
                    "page": t["page"],
                    "text": "\n".join(buf)
                })
                buf, cur_len = [], 0
                part += 1
            buf.append(line)
            cur_len += len(line) + 1

        if buf:
            table_chunks.append({
                "id": f"{t['id']}_part{part}",
                "type": "table",
                "company": t["company"],
                "year": t["year"],
                "source": t["source"],
                "page": t["page"],
                "text": "\n".join(buf)
            })

    return table_chunks


In [19]:
# final chunks

text_chunks = build_text_chunks(pages, chunk_size=1200, overlap=200, tail_chars=400)
table_chunks = build_table_chunks(table_docs, max_chars=3500)

chunks = text_chunks + table_chunks

print("Text chunks:", len(text_chunks))
print("Table chunks:", len(table_chunks))
print("Total chunks:", len(chunks))

print("\nSample text chunk:\n", chunks[0]["id"], "| page", chunks[0]["page"])
print(chunks[0]["text"][:400])

print("\nSample table chunk:\n", table_chunks[0]["id"], "| page", table_chunks[0]["page"])
print(table_chunks[0]["text"][:400])


Text chunks: 1345
Table chunks: 152
Total chunks: 1497

Sample text chunk:
 NOV_2023.pdf_p1_text_c0 | page 1
UNITED STATES
SECURITIES AND EXCHANGE COMMISSION
Washington DC 20549
FORM 10K
Mark One
 ANNUAL REPORT PURSUANT TO SECTION 13 OR 15d OF THE SECURITIES EXCHANGE ACT OF 1934
For the fiscal year ended September 30 2023
or
 TRANSITION REPORT PURSUANT TO SECTION 13 OR 15d OF THE SECURITIES EXCHANGE ACT OF 1934
For the transition period from              to             
Commission File Number 00136743
Ap

Sample table chunk:
 NOV_2023.pdf_p3_t0 | page 3
[TABLE]
Company: Apple
FilingYear: 2023
Source: NOV_2023.pdf
Page: 3
Years: 2023

Item 1. | Business | 1
Item 1A. | Risk Factors | 5
Item 1B. | Unresolved Staff Comments | 16
Item 1C. | Cybersecurity | 16
Item 2. | Properties | 17
Item 3. | Legal Proceedings | 17
Item 4. | Mine Safety Disclosures | 17


In [20]:
import os
import re
import chromadb

chroma_client = chromadb.PersistentClient(path="./chroma_db")  # folder created locally

def safe_name(s: str) -> str:
    s = s.lower()
    s = re.sub(r"[^a-z0-9_]+", "_", s)
    return s[:60]


In [21]:
import torch

_LOCAL_MODEL_CACHE = {}

def _get_local_model(model_name: str):
    if model_name in _LOCAL_MODEL_CACHE:
        return _LOCAL_MODEL_CACHE[model_name]
    from sentence_transformers import SentenceTransformer
    device = "cuda" if torch.cuda.is_available() else "cpu"
    m = SentenceTransformer(model_name, device=device)
    _LOCAL_MODEL_CACHE[model_name] = m
    return m

def _is_e5(model_name: str) -> bool:
    return "e5" in model_name.lower()

def embed_texts(texts, cfg, *, mode="doc"):
    provider = cfg["provider"].lower()

    # ---- Local (GPU) ----
    if provider == "local":
        model = _get_local_model(cfg["model_name"])
        if _is_e5(cfg["model_name"]):
            prefix = "query: " if mode == "query" else "doc: "
            texts = [prefix + t for t in texts]
        vecs = model.encode(texts, batch_size=64, normalize_embeddings=True, show_progress_bar=False)
        return vecs.tolist()


# ---- Google (Gemini) ----
    if provider == "google":
        from google import genai
        api_key = os.getenv(cfg["api_key_env"])
        if not api_key:
            raise ValueError(f"Missing env var: {cfg['api_key_env']}")
        client = genai.Client(api_key=api_key)
        out = []
        for t in texts:
            r = client.models.embed_content(model=cfg["model_name"], contents=t)
            out.append(r.embeddings[0].values)
        return out

    raise ValueError(f"Unknown provider: {cfg['provider']}")





    # ---- OpenAI ----
    '''
    if provider == "openai":
        from openai import OpenAI
        api_key = os.getenv(cfg["api_key_env"])
        if not api_key:
            raise ValueError(f"Missing env var: {cfg['api_key_env']}")
        client = OpenAI(api_key=api_key)
        resp = client.embeddings.create(model=cfg["model_name"], input=texts)
        return [d.embedding for d in resp.data]'''

    
    



In [22]:
def upsert_chunks_to_chroma(chunks, cfg, batch_size=64):
    # Collection name encodes provider+model+dim so no mismatch
    collection_name = safe_name(
        f"{cfg['provider']}_{cfg['model_name']}_{cfg['dimensions']}"
    )

    # Create or get collection (use cosine space)
    collection = chroma_client.get_or_create_collection(
        name=collection_name,
        metadata={"hnsw:space": "cosine"}
    )

    # Add in batches
    for i in range(0, len(chunks), batch_size):
        batch = chunks[i:i+batch_size]
        texts = [c["text"] for c in batch]
        ids = [c["id"] for c in batch]

        # store only JSON-serializable metadata
        metadatas = [{
            "type": c["type"],
            "company": c["company"],
            "year": str(c["year"]),
            "source": c["source"],
            "page": int(c["page"]),
        } for c in batch]

        embs = embed_texts(texts, cfg, mode="doc")

        # sanity check: dimension match
        got_dim = len(embs[0])
        if got_dim != cfg["dimensions"]:
            raise ValueError(
                f"Dim mismatch for {cfg['provider']} {cfg['model_name']}: "
                f"expected {cfg['dimensions']}, got {got_dim}"
            )

        collection.upsert(
            ids=ids,
            documents=texts,
            metadatas=metadatas,
            embeddings=embs
        )

    return collection_name


In [23]:
created = []
for cfg in embedding_models:
    name = upsert_chunks_to_chroma(chunks, cfg, batch_size=64)
    created.append(name)
    print("Indexed:", name)

print("\nAll collections created:")
for n in created:
    print(" -", n)


Indexed: google_text_embedding_004_768


  from .autonotebook import tqdm as notebook_tqdm


Indexed: local_all_minilm_l6_v2_384
Indexed: local_intfloat_e5_large_v2_1024

All collections created:
 - google_text_embedding_004_768
 - local_all_minilm_l6_v2_384
 - local_intfloat_e5_large_v2_1024


In [24]:
def get_collection_for_model(cfg):
    collection_name = safe_name(f"{cfg['provider']}_{cfg['model_name']}_{cfg['dimensions']}")
    return chroma_client.get_collection(name=collection_name)


In [25]:
def retrieve_topk_with_cfg(collection, query, cfg, k=8, where=None):
    q_emb = embed_texts([query], cfg, mode="query")[0]  # correct dims
    res = collection.query(
        query_embeddings=[q_emb],
        n_results=k,
        where=where,
        include=["documents", "metadatas", "distances"]
    )
    hits = []
    for _id, doc, meta, dist in zip(res["ids"][0], res["documents"][0], res["metadatas"][0], res["distances"][0]):
        hits.append({"id": _id, "text": doc, "meta": meta, "distance": dist})
    return hits


In [26]:
import re

def parse_text_chunk_id(chunk_id):
    """
    Extract (source, page, chunk_index) from ids like:
    Tesla_2023.pdf_p32_text_c3
    """
    m = re.match(r"(.+)_p(\d+)_text_c(\d+)$", chunk_id)
    if not m:
        return None
    return (m.group(1), int(m.group(2)), int(m.group(3)))

def expand_neighbors(collection, hits, window=1, max_neighbors=30):
    """
    For each hit, include +/- window neighboring chunks if they exist.
    Only applies to text chunks with _text_cX ids.
    """
    wanted_ids = set(h["id"] for h in hits)

    for h in hits:
        parsed = parse_text_chunk_id(h["id"])
        if not parsed:
            continue
        source, page, cidx = parsed
        for j in range(cidx - window, cidx + window + 1):
            if j < 0:
                continue
            nid = f"{source}_p{page}_text_c{j}"
            wanted_ids.add(nid)
            if len(wanted_ids) >= max_neighbors + len(hits):
                break

    # fetch neighbor docs by ids (get)
    got = collection.get(ids=list(wanted_ids), include=["documents", "metadatas"])
    out = []
    for _id, doc, meta in zip(got["ids"], got["documents"], got["metadatas"]):
        out.append({"id": _id, "text": doc, "meta": meta})
    return out


In [27]:
def dedupe_and_sort(hits):
    seen = {}
    for h in hits:
        seen[h["id"]] = h

    final = list(seen.values())
    final.sort(key=lambda x: (
        x["meta"].get("source", ""),
        int(x["meta"].get("page", 0)),
        x["meta"].get("type", ""),
        x["id"]
    ))
    return final


In [28]:
def build_context_pack(chunks):
    """
    Returns a single context string + a citation list.
    """
    blocks = []
    citations = []
    for c in chunks:
        src = c["meta"].get("source")
        page = c["meta"].get("page")
        ctype = c["meta"].get("type")
        blocks.append(
            f"[{ctype.upper()}] {src} (page {page})\n{c['text']}".strip()
        )
        citations.append({"id": c["id"], "source": src, "page": page, "type": ctype})

    context = "\n\n---\n\n".join(blocks)
    return context, citations


In [29]:
'''
cfg = embedding_models[1]  # Google text-embedding-004 (for example)
collection = get_collection_for_model(cfg)

question = "What was the total gross margin in 2023 vs 2022 vs 2021?"
hits = retrieve_topk(collection, question, k=8)

# expand neighbors (helps continuity)
expanded = expand_neighbors(collection, hits, window=1)

final_chunks = dedupe_and_sort(hits + expanded)
context, citations = build_context_pack(final_chunks)

print("Retrieved chunks:", len(final_chunks))
print("\n--- CONTEXT (first 1200 chars) ---\n")
print(context[:1200])

print("\n--- CITATIONS ---")
print(citations[:10])
'''

'\ncfg = embedding_models[1]  # Google text-embedding-004 (for example)\ncollection = get_collection_for_model(cfg)\n\nquestion = "What was the total gross margin in 2023 vs 2022 vs 2021?"\nhits = retrieve_topk(collection, question, k=8)\n\n# expand neighbors (helps continuity)\nexpanded = expand_neighbors(collection, hits, window=1)\n\nfinal_chunks = dedupe_and_sort(hits + expanded)\ncontext, citations = build_context_pack(final_chunks)\n\nprint("Retrieved chunks:", len(final_chunks))\nprint("\n--- CONTEXT (first 1200 chars) ---\n")\nprint(context[:1200])\n\nprint("\n--- CITATIONS ---")\nprint(citations[:10])\n'

In [30]:
cfg = embedding_models[1]
print(cfg)  # Google text-embedding-004
collection = get_collection_for_model(cfg)

question = "What was the total gross margin in 2023 vs 2022 vs 2021?"
hits = retrieve_topk_with_cfg(collection, question, cfg, k=8, where={"company":"Apple"})  

expanded = expand_neighbors(collection, hits, window=1)

final_chunks = dedupe_and_sort(hits + expanded)
context, citations = build_context_pack(final_chunks)

print("Retrieved chunks:", len(final_chunks))
print("\n--- CONTEXT (first 1200 chars) ---\n")
print(context[:1200])

print("\n--- CITATIONS ---")
print(citations[:10])


{'provider': 'Local', 'model_name': 'all-MiniLM-L6-v2', 'api_key_env': None, 'dimensions': 384}
Retrieved chunks: 11

--- CONTEXT (first 1200 chars) ---

[TEXT] NOV_2023.pdf (page 24)
Company repurchased 766 billion of its common stock and paid dividends and dividend equivalents of 150 billion
Macroeconomic Conditions
Macroeconomic conditions including inflation changes in interest rates and currency fluctuations have directly and indirectly impacted and could in the future
materially impact the Companys results of operations and financial condition
Apple Inc  2023 Form 10K  20 Segment Operating Performance
The following table shows net sales by reportable segment for 2023 2022 and 2021 dollars in millions
2023
Change
2022
Change
2021
Net sales by reportable segment
Americas

162560 
4

169658 
11 

153306 
Europe
94294 
1
95118 
7 
89307 
Greater China
72559 
2
74200 
9 
68366 
Japan
24257 
7
25977 
9
28482 
Rest of Asia Pacific
29615 
1 
29375 
11 
26356 
Total net sales

383285 
3



In [31]:
def build_rag_prompt(question, context):
    return f"""
You are a financial filing assistant.

RULES:
- Use ONLY the provided context.
- If the answer is not in the context, say: "I don't have enough information in the provided context."
- For every numeric claim, include a citation in this format: (Source: <file>, Page: <page>).
- Keep the answer concise and structured.

QUESTION:
{question}

CONTEXT:
{context}
""".strip()


In [47]:
import os
from google import genai

def generate_answer_gemini(question, context, model="gemini-2.0-flash"):
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        raise ValueError("Missing GEMINI_API_KEY in environment variables")

    client = genai.Client(api_key=api_key)
    prompt = build_rag_prompt(question, context)

    resp = client.models.generate_content(
        model=model,
        contents=prompt
    )
    return resp.text


In [49]:
answer = generate_answer_gemini(question, context)
print(answer)


ClientError: 429 RESOURCE_EXHAUSTED. {'error': {'code': 429, 'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 0, model: gemini-2.0-flash\n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 0, model: gemini-2.0-flash\n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_input_token_count, limit: 0, model: gemini-2.0-flash\nPlease retry in 24.643963572s.', 'status': 'RESOURCE_EXHAUSTED', 'details': [{'@type': 'type.googleapis.com/google.rpc.Help', 'links': [{'description': 'Learn more about Gemini API quotas', 'url': 'https://ai.google.dev/gemini-api/docs/rate-limits'}]}, {'@type': 'type.googleapis.com/google.rpc.QuotaFailure', 'violations': [{'quotaMetric': 'generativelanguage.googleapis.com/generate_content_free_tier_requests', 'quotaId': 'GenerateRequestsPerDayPerProjectPerModel-FreeTier', 'quotaDimensions': {'location': 'global', 'model': 'gemini-2.0-flash'}}, {'quotaMetric': 'generativelanguage.googleapis.com/generate_content_free_tier_requests', 'quotaId': 'GenerateRequestsPerMinutePerProjectPerModel-FreeTier', 'quotaDimensions': {'location': 'global', 'model': 'gemini-2.0-flash'}}, {'quotaMetric': 'generativelanguage.googleapis.com/generate_content_free_tier_input_token_count', 'quotaId': 'GenerateContentInputTokensPerModelPerMinute-FreeTier', 'quotaDimensions': {'location': 'global', 'model': 'gemini-2.0-flash'}}]}, {'@type': 'type.googleapis.com/google.rpc.RetryInfo', 'retryDelay': '24s'}]}}

In [34]:
MODEL_RUNS = [
  {"name":"google_768", "cfg": embedding_models[0], "collection":"google_text_embedding_004_768"},
  {"name":"minilm_384", "cfg": embedding_models[1], "collection":"local_all_minilm_l6_v2_384"},
  {"name":"e5_1024",    "cfg": embedding_models[2], "collection":"local_intfloat_e5_large_v2_1024"},
]


In [35]:
EVAL_QUESTIONS = [
  {"id":"q1", "company":"Apple", "question":"What was total gross margin in 2023 vs 2022 vs 2021?"},
  {"id":"q2", "company":"Apple", "question":"What are the main risk factors mentioned about supply chain?"},
  {"id":"q3", "company":"Tesla", "question":"What does Tesla say about revenue recognition?"},
]



In [36]:
import json

def build_rag_json_prompt(question, context):
    schema = {
      "question": "string",
      "answer": "string",
      "citations": [{"source":"string", "page": "number", "type":"text|table"}],
      "support_level": "high|medium|low",
      "missing_info": "string|null"
    }
    return f"""
Return ONLY valid JSON matching this schema:
{json.dumps(schema, indent=2)}

Rules:
- Use ONLY the provided context.
- If not enough info, set support_level="low" and missing_info with what is missing.
- Every numeric claim must have citations with correct source+page.

QUESTION:
{question}

CONTEXT:
{context}
""".strip()


In [51]:
def generate_answer_gemini_json(question, context, model="gemini-2.5-pro"):
  
    import os
    from google import genai

    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        raise ValueError("Missing GEMINI_API_KEY")

    client = genai.Client(api_key=api_key)
    prompt = build_rag_json_prompt(question, context)
    resp = client.models.generate_content(model=model, contents=prompt)
    txt = resp.text.strip()

    # best effort JSON parse
    try:
        return json.loads(txt)
    except Exception:
        return {
            "question": question,
            "answer": txt,
            "citations": [],
            "support_level": "low",
            "missing_info": "Model did not return valid JSON"
        }


In [52]:
import time
import numpy as np

def p50(xs): return float(np.percentile(xs, 50)) if xs else None
def p90(xs): return float(np.percentile(xs, 90)) if xs else None


In [53]:
import re

def estimate_citation_coverage(answer_text: str):
    # counts "(Source:" occurrences as citations
    cites = len(re.findall(r"\(Source:", answer_text))
    # rough sentence count
    sents = max(1, len(re.findall(r"[.!?]\s+", answer_text)) + 1)
    return min(1.0, cites / sents)

def support_level_to_score(level: str):
    level = (level or "").lower()
    return {"high": 1.0, "medium": 0.6, "low": 0.2}.get(level, 0.2)

def retrieval_score_from_hits(hits):
    # use best distance only; if missing, fallback
    dists = [h.get("distance") for h in hits if h.get("distance") is not None]
    if not dists:
        return 0.5
    best = min(dists)
    # turn distance into a bounded score (tunable)
    return float(1 / (1 + best))

def confidence_score(hits, gen_json):
    ans = gen_json.get("answer", "")
    cov = estimate_citation_coverage(ans)
    sup = support_level_to_score(gen_json.get("support_level"))

    rscore = retrieval_score_from_hits(hits)

    # weighted composite
    return round(0.45*rscore + 0.35*cov + 0.20*sup, 4)


In [54]:
from collections import deque

class RagChat:
    def __init__(self, max_turns=6):
        self.history = deque(maxlen=max_turns)

    def add_turn(self, q, a):
        self.history.append({"question": q, "answer": a})

    def history_block(self):
        if not self.history:
            return ""
        items = []
        for t in self.history:
            items.append(f"User: {t['question']}\nAssistant: {t['answer']}")
        return "\n\n".join(items)


In [55]:
def build_rag_json_prompt_with_history(question, context, history_text):
    base = build_rag_json_prompt(question, context)
    if history_text.strip():
        return f"HISTORY:\n{history_text}\n\n{base}"
    return base


In [56]:
def run_one_question(model_run, question_obj, *, k=8):
    cfg = model_run["cfg"]
    collection = chroma_client.get_collection(model_run["collection"])

    question = question_obj["question"]
    where = {"company": question_obj["company"]} if "company" in question_obj else None

    t0 = time.perf_counter()
    hits = retrieve_topk_with_cfg(collection, question, cfg, k=k, where=where)
    t1 = time.perf_counter()

    expanded = expand_neighbors(collection, hits, window=1)
    final_chunks = dedupe_and_sort(hits + expanded)
    context, citations = build_context_pack(final_chunks)

    t2 = time.perf_counter()
    gen = generate_answer_gemini_json(question, context)
    t3 = time.perf_counter()

    # compute confidence
    conf = confidence_score(hits, gen)

    return {
        "model": model_run["name"],
        "question_id": question_obj.get("id"),
        "company": question_obj.get("company"),
        "question": question,
        "retrieval_ms": round((t1 - t0) * 1000, 2),
        "generation_ms": round((t3 - t2) * 1000, 2),
        "end_to_end_ms": round((t3 - t0) * 1000, 2),
        "confidence": conf,
        "response_json": gen,
        "top_sources": list({(c["source"], c["page"]) for c in citations})[:6],
    }


In [57]:
import time
import numpy as np
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from google.genai.errors import ClientError

@retry(
    retry=retry_if_exception_type(ClientError),
    wait=wait_exponential(multiplier=1, min=4, max=60),
    stop=stop_after_attempt(5),
    reraise=True
)
def run_one_with_retry(m, q, k=8):
    return run_one_question(m, q, k=k)

def benchmark(models, questions, delay_between=2.0):
    results = []
    total = len(questions) * len(models)
    
    for i, q in enumerate(questions):
        for j, m in enumerate(models):
            idx = i * len(models) + j + 1
            print(f"[{idx}/{total}] {m['name']} | {q['question'][:50]}...")
            
            try:
                r = run_one_with_retry(m, q, k=8)
                results.append(r)
            except ClientError as e:
                print(f"  ⚠ Skipped after retries: {e}")
                continue
            
            # throttle to stay under RPM limits
            time.sleep(delay_between)

    # summarize latency percentiles by model
    summary = {}
    for m in models:
        name = m["name"]
        rows = [r for r in results if r["model"] == name]
        retr = [r["retrieval_ms"] for r in rows]
        gen  = [r["generation_ms"] for r in rows]
        e2e  = [r["end_to_end_ms"] for r in rows]
        conf = [r["confidence"] for r in rows]

        summary[name] = {
            "n": len(rows),
            "confidence_avg": round(float(np.mean(conf)), 4) if conf else None,
            "retrieval_p50_ms": p50(retr),
            "retrieval_p90_ms": p90(retr),
            "generation_p50_ms": p50(gen),
            "generation_p90_ms": p90(gen),
            "e2e_p50_ms": p50(e2e),
            "e2e_p90_ms": p90(e2e),
        }
    return results, summary

results, summary = benchmark(MODEL_RUNS, EVAL_QUESTIONS, delay_between=2.0)
print(summary)

[1/9] google_768 | What was total gross margin in 2023 vs 2022 vs 202...
  ⚠ Skipped after retries: 429 RESOURCE_EXHAUSTED. {'error': {'code': 429, 'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. To monitor your current usage, head to: https://ai.dev/usage?tab=rate-limit. \n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 0, model: gemini-2.5-pro\n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 0, model: gemini-2.5-pro\n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_input_token_count, limit: 0, model: gemini-2.5-pro\n* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_input_token_count, limit: 0, model: gemini-2.5-pro\nPlease retry in 27.645336177s.', 'stat

KeyboardInterrupt: 