# American Equity Generative AI Platform

## Program Load ✅
* This should include all the packages to import into the program

### Import Packages ✅

In [1]:
import pandas as pd
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
import streamlit as st
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import faiss, numpy as np
import csv, sqlite3, textwrap, atexit
from datetime import datetime
from pathlib import Path
from collections import deque
import re
import json
from copy import deepcopy

### API Key for OpenAI ✅

In [2]:
client = OpenAI(
  api_key="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"
)

### SQL Schema ✅
* Need to comment better in markdown instead of inside block

In [3]:
# Build ae_sales.sqlite from the five CSVs, plus views, FTS, and account policy

# --- Hardcoded paths ---
DB_PATH = "data/ae_sales.sqlite"
DATA_DIR = Path("data")

# --- Connection policy ---
PERSIST_CONN = True      # keep one connection alive for the app session
_GLOBAL_CON  = None

def _open_conn():
    con = sqlite3.connect(DB_PATH, check_same_thread=False)
    con.execute("PRAGMA foreign_keys = ON;")
    return con

def get_conn():
    """Return a live connection; auto-reopen if closed."""
    global _GLOBAL_CON
    try:
        if _GLOBAL_CON is None:
            _GLOBAL_CON = _open_conn()
        else:
            _GLOBAL_CON.execute("SELECT 1")  # will raise if closed
    except Exception:
        _GLOBAL_CON = _open_conn()
    return _GLOBAL_CON

def get_cur():
    return get_conn().cursor()

def _close_conn():
    global _GLOBAL_CON
    try:
        if _GLOBAL_CON is not None:
            _GLOBAL_CON.close()
    finally:
        _GLOBAL_CON = None

atexit.register(_close_conn)

# --- Helpers ---
def parse_date(s):
    if s is None:
        return None
    s = str(s).strip()
    if not s or s.lower() in {"na","n/a","none","null"}:
        return None
    for fmt in ("%m/%d/%y", "%m/%d/%Y", "%Y-%m-%d"):
        try:
            return datetime.strptime(s, fmt).date().isoformat()
        except ValueError:
            pass
    return None

def to_int(s):
    try:
        return int(float(str(s).strip()))
    except:
        return None

def to_float(s):
    try:
        return float(str(s).strip())
    except:
        return None

def to_text(s):
    return None if s is None else str(s)

# === Bootstrap DB (build/refresh) ===
con = get_conn()
cur = get_cur()

# Recreate for a clean start
cur.executescript("""
DROP TABLE IF EXISTS sales_pipeline;
DROP TABLE IF EXISTS sales_teams;
DROP TABLE IF EXISTS products;
DROP TABLE IF EXISTS accounts;
DROP TABLE IF EXISTS interactions;
DROP TABLE IF EXISTS interactions_fts;
""")

# --- Schema ---
cur.executescript("""
CREATE TABLE accounts (
    account_id         INTEGER PRIMARY KEY,
    account            TEXT,
    sector             TEXT,
    year_established   INTEGER,
    revenue            INTEGER,
    employees          INTEGER,
    office_location    TEXT,
    subsidiary_of      TEXT,
    propensity_to_buy  REAL
);

CREATE TABLE products (
    product_id  INTEGER PRIMARY KEY,
    product     TEXT,
    series      TEXT,
    sales_price INTEGER
);

CREATE TABLE sales_teams (
    sales_person_id INTEGER,
    sales_agent     TEXT PRIMARY KEY,
    manager         TEXT,
    regional_office TEXT
);

CREATE TABLE sales_pipeline (
    opportunity_id TEXT PRIMARY KEY,
    account_id     INTEGER,
    sales_agent    TEXT,
    product_id     INTEGER,
    product        TEXT,
    account        TEXT,
    deal_stage     TEXT,
    engage_date    TEXT,   -- ISO date string
    close_date     TEXT,   -- ISO date string
    close_value    INTEGER,
    FOREIGN KEY (account_id)  REFERENCES accounts(account_id),
    FOREIGN KEY (sales_agent) REFERENCES sales_teams(sales_agent),
    FOREIGN KEY (product_id)  REFERENCES products(product_id)
);

CREATE TABLE interactions (
    account_id     INTEGER,
    account_name   TEXT,
    contact_name   TEXT,
    activity_type  TEXT,
    status         TEXT,
    timestamp      TEXT,   -- ISO date string
    comment        TEXT
);

-- Helpful indexes
CREATE INDEX IF NOT EXISTS idx_sp_account_id ON sales_pipeline(account_id);
CREATE INDEX IF NOT EXISTS idx_sp_product_id ON sales_pipeline(product_id);
CREATE INDEX IF NOT EXISTS idx_sp_deal_stage ON sales_pipeline(deal_stage);
CREATE INDEX IF NOT EXISTS idx_interactions_account_id ON interactions(account_id);
""")

# --- Loaders (CSV -> typed rows -> executemany) ---
def load_accounts():
    path = DATA_DIR / "accounts.csv"
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [(
            to_int(row.get("account_id")),
            to_text(row.get("account")),
            to_text(row.get("sector")),
            to_int(row.get("year_established")),
            to_int(row.get("revenue")),
            to_int(row.get("employees")),
            to_text(row.get("office_location")),
            to_text(row.get("subsidiary_of")),
            to_float(row.get("propensity_to_buy")),
        ) for row in r]
    get_cur().executemany("""
        INSERT INTO accounts(
            account_id, account, sector, year_established, revenue,
            employees, office_location, subsidiary_of, propensity_to_buy
        ) VALUES (?,?,?,?,?,?,?,?,?)
    """, rows)

def load_products():
    path = DATA_DIR / "products.csv"
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [(
            to_int(row.get("product_id")),
            to_text(row.get("product")),
            to_text(row.get("series")),
            to_int(row.get("sales_price")),
        ) for row in r]
    get_cur().executemany("""
        INSERT INTO products(product_id, product, series, sales_price)
        VALUES (?,?,?,?)
    """, rows)

def load_sales_teams():
    path = DATA_DIR / "sales_teams.csv"
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [(
            to_int(row.get("sales_person_id")),
            to_text(row.get("sales_agent")),
            to_text(row.get("manager")),
            to_text(row.get("regional_office")),
        ) for row in r]
    get_cur().executemany("""
        INSERT INTO sales_teams(sales_person_id, sales_agent, manager, regional_office)
        VALUES (?,?,?,?)
    """, rows)

def load_sales_pipeline():
    path = DATA_DIR / "sales_pipeline.csv"
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [(
            to_text(row.get("opportunity_id")),
            to_int(row.get("account_id")),
            to_text(row.get("sales_agent")),
            to_int(row.get("product_id")),
            to_text(row.get("product")),
            to_text(row.get("account")),
            to_text(row.get("deal_stage")),
            parse_date(row.get("engage_date")),
            parse_date(row.get("close_date")),
            to_int(row.get("close_value")),
        ) for row in r]
    get_cur().executemany("""
        INSERT INTO sales_pipeline(
            opportunity_id, account_id, sales_agent, product_id, product, account,
            deal_stage, engage_date, close_date, close_value
        ) VALUES (?,?,?,?,?,?,?,?,?,?)
    """, rows)

def load_interactions():
    path = DATA_DIR / "interactions.csv"
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        rows = [(
            to_int(row.get("account_id")),
            to_text(row.get("account_name")),
            to_text(row.get("contact_name")),
            to_text(row.get("activity_type")),
            to_text(row.get("status")),
            parse_date(row.get("timestamp")),
            to_text(row.get("comment")),
        ) for row in r]
    get_cur().executemany("""
        INSERT INTO interactions(
            account_id, account_name, contact_name, activity_type, status, timestamp, comment
        ) VALUES (?,?,?,?,?,?,?)
    """, rows)

# --- Execute loads in FK-friendly order ---
load_accounts()
load_products()
load_sales_teams()
load_sales_pipeline()
load_interactions()

# --- Views and FTS for fast retrieval ---
get_cur().executescript("""
CREATE VIEW IF NOT EXISTS v_pipeline_open AS
SELECT *
FROM sales_pipeline
WHERE deal_stage IN ('Prospecting','Engaging','Open','Negotiating');

CREATE VIEW IF NOT EXISTS v_bookings_month AS
SELECT strftime('%Y-%m', close_date) AS yyyymm,
       SUM(close_value) AS bookings
FROM sales_pipeline
WHERE deal_stage = 'Closed Won' AND close_date IS NOT NULL
GROUP BY 1
ORDER BY 1;

CREATE VIRTUAL TABLE IF NOT EXISTS interactions_fts USING fts5(
    comment, account_id, activity_type, status, tokenize='porter'
);

INSERT INTO interactions_fts(rowid, comment, account_id, activity_type, status)
SELECT rowid, comment, account_id, activity_type, status
FROM interactions;
""")

# --- Account placeholder policy: backfill and triggers ---
get_cur().executescript("""
INSERT INTO accounts (account_id, account, sector, year_established, revenue, employees, office_location, subsidiary_of, propensity_to_buy)
SELECT 0, 'NO_ACCOUNT', 'Unknown', NULL, NULL, NULL, NULL, NULL, NULL
WHERE NOT EXISTS (SELECT 1 FROM accounts WHERE account_id = 0);

UPDATE sales_pipeline
SET account_id = 0,
    account    = 'NO_ACCOUNT'
WHERE account_id IS NULL
  AND deal_stage IN ('Engaging','Prospecting');

UPDATE sales_pipeline
SET account = (SELECT a.account FROM accounts a WHERE a.account_id = sales_pipeline.account_id)
WHERE account IS NULL AND account_id IS NOT NULL;

DROP TRIGGER IF EXISTS sp_ai_account_policy;
CREATE TRIGGER sp_ai_account_policy
AFTER INSERT ON sales_pipeline
FOR EACH ROW
BEGIN
  UPDATE sales_pipeline
     SET account_id = 0,
         account    = 'NO_ACCOUNT'
   WHERE opportunity_id = NEW.opportunity_id
     AND NEW.deal_stage IN ('Engaging','Prospecting')
     AND NEW.account_id IS NULL;

  UPDATE sales_pipeline
     SET account = (SELECT a.account FROM accounts a WHERE a.account_id = NEW.account_id)
   WHERE opportunity_id = NEW.opportunity_id
     AND NEW.account IS NULL
     AND NEW.account_id IS NOT NULL;
END;

DROP TRIGGER IF EXISTS sp_au_account_policy;
CREATE TRIGGER sp_au_account_policy
AFTER UPDATE ON sales_pipeline
FOR EACH ROW
BEGIN
  UPDATE sales_pipeline
     SET account_id = 0,
         account    = 'NO_ACCOUNT'
   WHERE opportunity_id = NEW.opportunity_id
     AND NEW.deal_stage IN ('Engaging','Prospecting')
     AND NEW.account_id IS NULL;

  UPDATE sales_pipeline
     SET account = (SELECT a.account FROM accounts a WHERE a.account_id = NEW.account_id)
   WHERE opportunity_id = NEW.opportunity_id
     AND NEW.account IS NULL
     AND NEW.account_id IS NOT NULL;
END;
""")

# --- Quick summary check (optional) ---
summary = dict(get_cur().execute("""
SELECT 'sp_null_account_id',   SUM(account_id IS NULL)
FROM sales_pipeline
UNION ALL
SELECT 'sp_null_account_text', SUM(account IS NULL)
FROM sales_pipeline
""").fetchall())

# Persist or close per policy
if PERSIST_CONN:
    # leave _GLOBAL_CON open for the rest of the app (sql_context, cache warming, etc.)
    pass
else:
    _close_conn()


### Load JSON Catalog ✅

In [4]:
SQL_DIR = Path("SQL")
SQL_CATALOG_PATH = Path("data/sql_catalog.json")

# in-memory global cache
sql_catalog = None

def load_sql_catalog(force_reload=False):
    """
    Load the SQL catalog from JSON, restoring numpy embeddings.
    
    force_reload=True will ignore anything cached in memory.
    """
    global sql_catalog

    # If cached and not forced, return current version
    if sql_catalog is not None and not force_reload:
        return sql_catalog

    # Load from JSON on disk
    if SQL_CATALOG_PATH.exists():
        raw = json.loads(SQL_CATALOG_PATH.read_text())

        # Restore embedding arrays
        for item in raw:
            emb = item.get("emb")
            if isinstance(emb, list):
                item["emb"] = np.array(emb, dtype=float)

        sql_catalog = raw
        return sql_catalog

    # If the JSON doesn't exist, bootstrap a new one
    sql_catalog = [
        {
            "name": p.stem,
            "path": str(p),
            "description": f"TODO: describe what {p.name} does."
        }
        for p in sorted(SQL_DIR.glob("*.sql"))
    ]

    # Save initial file
    SQL_CATALOG_PATH.write_text(json.dumps(sql_catalog, indent=2))
    return sql_catalog


### Caching ✅
* Lightweight Caches (Recent Turns, Q&A Cache)
* High-Value Caches (Frequent or FAST Path)

#### Lightweight Caches

In [5]:

recent_turns = deque(maxlen=8)    # short chat context window
qa_cache     = deque(maxlen=500)  # semantic Q&A cache

def my_embed(text: str) -> np.ndarray:
    # drop-in placeholder; you can replace with your SentenceTransformer or OpenAI embeddings later
    import hashlib
    return np.array([int(hashlib.sha1(text.encode()).hexdigest(), 16) % 1_000_000], dtype=float)

def cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    a = a.astype(float); b = b.astype(float)
    denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
    return float(np.dot(a, b) / denom)

def find_in_cache(user_text: str, threshold: float = 0.92, embed_fn=my_embed):
    if not qa_cache:
        return None
    q_vec = embed_fn(user_text)
    best  = max(qa_cache, key=lambda qa: cosine_sim(q_vec, qa["q_emb"]))
    sim   = cosine_sim(q_vec, best["q_emb"])
    return best["answer"] if sim >= threshold else None

def add_to_cache(user_text: str, answer: str, embed_fn=my_embed):
    qa_cache.append({
        "q_text": user_text,
        "q_emb":  embed_fn(user_text),
        "answer": answer,
        "ts":     datetime.utcnow().isoformat(timespec="seconds"),
    })

#### High-Value Caches
* These queries are popular or go-to queries that we can help prepopulate.
* Are these similar to the FAST path?
* Do we want to have this be a table in a csv file rather than hard coded? 

In [6]:
def warm_cache_from_sqlite(limit: int = 50):
    # Example: common pipeline questions -> prebuild snippets
    rows = cur.execute("""
        SELECT yyyymm, bookings FROM v_bookings_month ORDER BY yyyymm DESC LIMIT ?
    """, (limit,)).fetchall()
    if rows:
        summary = "Recent bookings by month:\n" + "\n".join(f"- {y}:{v}" for y, v in rows[:12])
        add_to_cache("What are recent bookings by month?", summary)

    # Example: seed recent activity blurbs (use FTS if you like)
    rows2 = cur.execute("""
        SELECT account_id, activity_type, status, substr(comment,1,160)
        FROM interactions
        ORDER BY date(timestamp) DESC, rowid DESC
        LIMIT ?
    """, (limit,)).fetchall()
    if rows2:
        blurb = "Recent interactions (latest):\n" + "\n".join(
            f"- acct {a} • {t} • {s} • {c or ''}" for a,t,s,c in rows2[:15]
        )
        add_to_cache("What happened recently with interactions?", blurb)

# Call once on app start (or behind a small toggle)
warm_cache_from_sqlite()


  "ts":     datetime.utcnow().isoformat(timespec="seconds"),


## Prescribed SQL Statements ✅
* User Input Cleaned - This cleans the user input into a 25 or less word sentence to match the input to a SQL prompt. 
* Embeddings & Regressions
* Load Canned SQL Statements Database
* Run SQL Statements

### User Input Cleaned ✅
* def clean_question

In [7]:
def clean_question(raw_question: str) -> str:
    """Rewrite a noisy question into a short, precise analytics request."""
    prompt = (
        "Rewrite the user's question as a short, clear analytics request "
        "about the American Equity sales data. Keep it under 25 words.\n\n"
        f"User question: {raw_question}"
    )
    resp = client.responses.create(
        model="gpt-5-nano",
        input=prompt,
    )
    cleaned = resp.output_text.strip()
    return cleaned

### Embeddings & Regressions

#### Empties SQL Catalog Cache ✅

In [8]:
SQL_DIR = Path("SQL")
SQL_CATALOG_PATH = Path("data/sql_catalog.json")

sql_catalog = []  # global cache


#### Loads Catalog ✅

In [9]:
def load_sql_catalog():
    """
    Load the SQL catalog file, restoring embeddings as numpy arrays.
    If it does not exist, bootstrap entries from filenames.
    """
    global sql_catalog

    # Load from disk
    raw = json.loads(SQL_CATALOG_PATH.read_text())

    # restore embeddings from lists → numpy arrays
    for item in raw:
        emb = item.get("emb")
        if isinstance(emb, list):
            item["emb"] = np.array(emb, dtype=float)

    sql_catalog = raw
    return sql_catalog


#### Generates Vectors from User Text into Embedding Model ✅

In [10]:
def embed_text(text: str) -> np.ndarray:
    # Guardrail: make sure we are embedding a string
    if not isinstance(text, str):
        raise TypeError(f"embed_text expected a string, got {type(text)}")

    resp = client.embeddings.create(
        model="text-embedding-3-small",
        input=text,          # <-- string, not [text]
    )
    return np.array(resp.data[0].embedding, dtype=float)

#### If JSON is updated, this creates a new emb if there isn't an instance ✅

In [11]:
sql_catalog = load_sql_catalog()

for item in sql_catalog:
    if "emb" not in item or not isinstance(item["emb"], np.ndarray):
        item["emb"] = embed_text(item["description"])

# Persist embeddings back to disk as plain lists
_catalog_for_disk = deepcopy(sql_catalog)
for item in _catalog_for_disk:
    if isinstance(item.get("emb"), np.ndarray):
        item["emb"] = item["emb"].tolist()

SQL_CATALOG_PATH.write_text(json.dumps(_catalog_for_disk, indent=2))


697373

#### Creates a Similarity Score for each SQL ✅

In [12]:
def rank_sql_queries(cleaned_question: str):
    if not sql_catalog:
        load_sql_catalog()

    q_vec = embed_text(cleaned_question)
    q_norm = q_vec / (np.linalg.norm(q_vec) or 1.0)

    scored = []
    for item in sql_catalog:
        emb = np.array(item["emb"], dtype=float)
        emb_norm = emb / (np.linalg.norm(emb) or 1.0)
        sim = float(np.dot(q_norm, emb_norm))
        scored.append({
            "item": item,
            "sim": sim,
        })

    scored.sort(key=lambda d: d["sim"], reverse=True)
    return scored

### Select Canned SQL Question ✅

In [13]:
def select_sql_for_question(raw_question: str):
    cleaned = clean_question(raw_question)
    ranked = rank_sql_queries(cleaned)

    best = ranked[0]
    second = ranked[1] if len(ranked) > 1 else None

    sim_best = best["sim"]
    sim_second = second["sim"] if second else 0.0
    gap = sim_best - sim_second

    # Suggested thresholds for similarity
    if sim_best >= 0.45 and gap >= 0.04:
        decision = "high_confidence"
    elif sim_best >= 0.30:
        decision = "medium_confidence"
    else:
        decision = "low_confidence"

    return {
        "raw_question": raw_question,
        "cleaned_question": cleaned,
        "best": best,
        "second": second,
        "sim_best": sim_best,
        "sim_second": sim_second,
        "gap": gap,
        "decision": decision,
        "all": ranked[:5]
    }


In [14]:
select_sql_for_question("Which parts of our sales organization and product lineup appear to generate most of our recent closed business?")

{'raw_question': 'Which parts of our sales organization and product lineup appear to generate most of our recent closed business?',
 'cleaned_question': 'Analyze American Equity sales data to identify which parts of the sales organization and which product lines generate the most recently closed business.',
 'best': {'item': {'name': 'product_price_range',
   'path': 'SQL/product_price_range.sql',
   'description': 'Show minimum, maximum, and average closed-won deal values for each product.',
   'emb': array([-0.01168524,  0.01348875,  0.00886725, ..., -0.01563041,
           0.0259505 ,  0.01396467])},
  'sim': 0.5011501039428546},
 'second': {'item': {'name': 'sales_pipeline_in-progress',
   'path': 'SQL/sales_pipeline_in-progress.sql',
   'description': 'List all active sales opportunities with stages and expected close dates.',
   'emb': array([-0.02232847,  0.01979916,  0.01882273, ...,  0.00361161,
           0.00200139,  0.00301458])},
  'sim': 0.47554058697666635},
 'sim_best':

### Routing logic ✅

#### High confidence
* sim_best >= 0.45 and gap to second is at least 0.01
* Auto-run the matched SQL

#### Medium confidence
* sim_best >= 0.30
* Suggest top match (and maybe second) to user, or still auto-run if you want to be aggressive

#### Low confidence
* sim_best < 0.30
* Do not trust catalog, fall back to “generate new SELECT” with a bigger model

In [15]:
def route_question(raw_question: str):
    """
    Decide whether to:
      - use a predefined SQL from the catalog
      - or fall back to generating a new SELECT for the full dataset
    """
    cleaned = clean_question(raw_question)
    ranked = rank_sql_queries(cleaned)

    best = ranked[0]
    second = ranked[1] if len(ranked) > 1 else None

    sim_best = best["sim"]
    sim_second = second["sim"] if second else 0.0
    gap = sim_best - sim_second

    if sim_best >= 0.45 and gap >= 0.01:
        mode = "use_catalog_high_conf"
    elif sim_best >= 0.3:
        mode = "use_catalog_medium_conf"
    else:
        mode = "generate_new_sql"

    return {
        "raw_question": raw_question,
        "cleaned_question": cleaned,
        "mode": mode,
        "best": best,
        "second": second,
        "sim_best": sim_best,
        "sim_second": sim_second,
        "gap": gap,
        "all": ranked[:5],
    }


### Execute Top Ranked SQL File ✅

In [16]:
def run_sql_from_catalog_item(item) -> pd.DataFrame:
    """
    Execute the SQL file pointed at by `item` and return a DataFrame.
    """
    sql_path = Path(item["path"])
    sql_text = sql_path.read_text()

    con = get_conn()     # your SQLite connection builder
    df = pd.read_sql_query(sql_text, con)

    return df

### Canned SQL Statements

### Run SQL Statements

## Bundle Input for Query
* Summarize Recent Turns - Collects the most recent user–assistant exchanges into a concise summary. This helps the model remember context from earlier in the chat without resending the entire conversation.
* Maybe Lookup SQL - Runs lightweight, rule-based SQL queries that extract just the data most relevant to the current question. It anchors the model’s reasoning in actual database values instead of relying solely on text patterns.
* Needs More Data - Applies heuristics to decide when the lightweight context isn’t enough (for example, the user requests detailed statistics or row-level data). It signals whether to escalate to a larger “full dataset” retrieval and the more capable model.
* Build Full Data Context - Generates compact summaries from larger database slices—counts, aggregates, or top records—without overwhelming the token budget. This ensures that heavy analytical questions can still be answered efficiently and accurately.
* Build Bundled Input - Combines everything—chat history, cached responses, SQL context, optional full data summaries, and the user’s question—into a single structured message for the LLM. This unified bundle gives the model the richest, most efficient context possible for consistent, data-grounded answers.
* build_sql_commentary_prompt


In [17]:
MODELS = {
    "fast":  "gpt-5-nano",   # low latency, cheap
    "smart": "gpt-5",  # higher quality for heavy queries
}

def summarize_recent_turns(n: int = 6) -> str:
    if not recent_turns:
        return ""
    turns = list(recent_turns)[-n:]
    lines = []
    for role, content in turns:
        lines.append(f"{role.upper()}: {content.strip()}")
    return "\n".join(lines)

def maybe_lookup_sql(user_text: str) -> str:
    # Tiny rule-based mapper; expand later with your flow chart
    text = user_text.lower()
    ctx  = []
    if "booking" in text or "revenue" in text:
        ctx.append("/* SQL: monthly bookings */")
        rows = cur.execute("""
            SELECT yyyymm, bookings
            FROM v_bookings_month
            ORDER BY yyyymm DESC LIMIT 6
        """).fetchall()
        if rows:
            ctx.append("\n".join(f"{y}: {b}" for y,b in rows))
    if "interaction" in text or "activity" in text:
        ctx.append("/* SQL: last interactions */")
        rows = cur.execute("""
            SELECT account_id, activity_type, status, substr(comment,1,120)
            FROM interactions
            ORDER BY date(timestamp) DESC, rowid DESC
            LIMIT 10
        """).fetchall()
        for a,t,s,c in rows:
            ctx.append(f"acct {a} • {t} • {s} • {c or ''}")
    return "\n".join(ctx)

def needs_more_data(user_text: str, sql_ctx: str, cached_ans: str | None) -> bool:
    """
    Heuristic: return True when the light bundle is likely insufficient.
    Tune these rules to your flow chart.
    """
    t = user_text.lower()
    wants_row_level = any(k in t for k in [
        "all rows", "full dataset", "entire dataset", "export", "everything",
        "per account breakdown", "per rep breakdown", "line item", "detail view"
    ])
    wants_stats = any(k in t for k in [
        "distribution", "histogram", "outlier", "correlation", "anova",
        "regression", "forecast", "time series", "per product over time"
    ])
    no_sql_found = not sql_ctx or len(sql_ctx.strip()) < 40
    no_cache = cached_ans is None

    # Trigger if they ask for heavy tasks, or our limited context is empty, or both
    return wants_row_level or wants_stats or (no_sql_found and no_cache)


def build_full_data_context(limit_per_table: int = 2000) -> str:
    """
    Load larger slices for deeper analysis. We still should not dump millions of tokens.
    Instead, compute compact summaries. Adjust limits to your needs.
    """
    cur = get_cur()

    parts = []

    # 1) High level counts
    counts = {}
    for t in ["accounts","products","sales_teams","sales_pipeline","interactions"]:
        counts[t] = cur.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]
    parts.append("[full:row_counts]\n" + ", ".join(f"{k}={v}" for k,v in counts.items()))

    # 2) Pipeline stats by stage
    rows = cur.execute("""
        SELECT deal_stage, COUNT(*) AS n, SUM(COALESCE(close_value,0)) AS value_sum
        FROM sales_pipeline
        GROUP BY 1 ORDER BY n DESC
    """).fetchall()
    if rows:
        parts.append("[full:pipeline_by_stage]\n" + "\n".join(f"{s}: n={n}, sum={v}" for s,n,v in rows))

    # 3) Bookings by month full history
    rows = cur.execute("""
        SELECT yyyymm, bookings FROM v_bookings_month ORDER BY yyyymm
    """).fetchall()
    if rows:
        parts.append("[full:bookings_month_all]\n" + "\n".join(f"{y}: {b}" for y,b in rows))

    # 4) Recent interactions bigger window
    rows = cur.execute("""
        SELECT account_id, activity_type, status, substr(comment,1,160), timestamp
        FROM interactions
        ORDER BY date(timestamp) DESC, rowid DESC
        LIMIT ?
    """, (limit_per_table,)).fetchall()
    if rows:
        parts.append("[full:interactions_recent]\n" + "\n".join(
            f"acct {a} • {t} • {s} • {c or ''} • {ts or ''}" for a,t,s,c,ts in rows
        ))

    # 5) Top accounts by value
    rows = cur.execute("""
        SELECT COALESCE(account,'NO_ACCOUNT') AS account, SUM(COALESCE(close_value,0)) AS val
        FROM sales_pipeline
        WHERE deal_stage='Closed Won'
        GROUP BY 1
        ORDER BY val DESC
        LIMIT 50
    """).fetchall()
    if rows:
        parts.append("[full:top_accounts_won]\n" + "\n".join(f"{a}: {v}" for a,v in rows))

    return "\n\n".join(parts)


def build_bundled_input(user_text: str, use_full: bool = False) -> str:
    chat_ctx   = summarize_recent_turns()
    cached_ans = find_in_cache(user_text)[0] if 'find_in_cache' in globals() else None
    sql_ctx    = maybe_lookup_sql(user_text)
    full_ctx   = build_full_data_context() if use_full else ""

    sections = []
    if chat_ctx:
        sections.append("### Recent chat context\n" + chat_ctx)
    if cached_ans:
        sections.append("### Previously answered (for reference)\n" + cached_ans)
    if sql_ctx:
        sections.append("### Data context (SQLite views)\n" + sql_ctx)
    if full_ctx:
        sections.append("### Expanded data context (full dataset summaries)\n" + full_ctx)

    sections.append("### Question\n" + user_text)
    return "\n\n".join(sections)


def build_sql_commentary_prompt(user_question: str,
                                sql_item: dict,
                                df: pd.DataFrame,
                                max_rows: int = 20) -> str:
    """
    Build a prompt that asks the LLM to explain what this canned SQL
    and its results mean for a sales user.
    """

    query_name = sql_item.get("name", "Unnamed query")
    query_desc = sql_item.get("description", "").strip()

    sample = df.head(max_rows) if df is not None else None

    if sample is not None and not sample.empty:
        sample_markdown = sample.to_markdown(index=False)
        cols = ", ".join(sample.columns)
    else:
        sample_markdown = "No rows returned."
        cols = "no columns"

    prompt = f"""
You are a sales analytics assistant for American Equity.

The user asked this question:
{user_question}

We matched their question to this canned SQL query:

Name: {query_name}
Description: {query_desc}

The query returned a table with columns:
{cols}

Here is a small sample of the results:
{sample_markdown}

Write a short explanation in plain business language for a sales person. Please:
- Explain what this query is doing and what the table represents.
- Call out any key numbers or patterns that stand out.
- Give 2 or 3 concrete insights or talking points they could share with their manager or a wholesaler.
- If the table is empty or very small, explain what that might mean.

Keep it concise, conversational, and non technical.
"""
    return prompt.strip()


def llm_summarize(prompt: str) -> str:
    """Send a prompt to GPT-5 to generate business-friendly commentary."""
    response = client.responses.create(
        model="gpt-5",
        input=prompt
    )
    return response.output_text

def ensure_qa_history_table(conn):
    conn.execute(
        """
        CREATE TABLE IF NOT EXISTS qa_history (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            ts TEXT,
            question TEXT,
            answer TEXT,
            route_mode TEXT,
            sql_item_name TEXT,
            sql_item_description TEXT
        )
        """
    )
    conn.commit()

def save_qa(conn, question, answer, route_mode, sql_item=None):
    ensure_qa_history_table(conn)
    conn.execute(
        """
        INSERT INTO qa_history (ts, question, answer, route_mode, sql_item_name, sql_item_description)
        VALUES (?, ?, ?, ?, ?, ?)
        """,
        (
            datetime.utcnow().isoformat(timespec="seconds"),
            question,
            answer,
            route_mode,
            sql_item.get("name") if sql_item else None,
            sql_item.get("description") if sql_item else None,
        ),
    )
    conn.commit()

def get_past_answers_for_sql(conn, sql_item_name: str, limit: int = 5):
    """Return recent past Q&A rows that used the same canned SQL."""
    ensure_qa_history_table(conn)
    cur = conn.execute(
        """
        SELECT ts, question, answer
        FROM qa_history
        WHERE sql_item_name = ?
        ORDER BY ts DESC
        LIMIT ?
        """,
        (sql_item_name, limit),
    )
    rows = cur.fetchall()
    return [
        {"ts": r[0], "question": r[1], "answer": r[2]}
        for r in rows
    ]

def build_historical_comparison_prompt(
    user_question: str,
    sql_item: dict,
    current_df: pd.DataFrame | None,
    current_commentary: str | None,
    past_qas: list,
    max_rows: int = 10,
) -> str:

    sample = current_df.head(max_rows) if current_df is not None else None
    if sample is not None and not sample.empty:
        sample_markdown = sample.to_markdown(index=False)
    else:
        sample_markdown = "No rows returned."

    history_block = ""
    for qa in past_qas:
        history_block += f"- [{qa['ts']}] Question: {qa['question']}\n"
        history_block += f"  Answer snapshot:\n{qa['answer']}\n\n"

    prompt = f"""
You are a sales analytics assistant for American Equity.

The user just asked:
{user_question}

This was matched to canned SQL:
{sql_item.get('name')} - {sql_item.get('description', '').strip()}

Current result sample:
{sample_markdown}

Current commentary:
{current_commentary or '(no commentary yet)'}

Here are previous related questions and answers that used this same SQL:
{history_block or '(no prior related answers found.)'}

Please:
1. Compare the current result to the earlier answers. If numbers seem higher, lower, or different in structure, call that out.
2. Give 2 or 3 concise insights that place the current result in context of history.
3. Suggest one follow up question or analysis the user might ask next.

Keep it short and focused on business interpretation.
"""
    return prompt.strip()

## Router
* 

In [18]:
def answer_question(user_text: str) -> str:
    # A) Try cache first
    cached = find_in_cache(user_text)
    if cached:
        return cached

    # B) Build bundle and call model
    bundle = build_bundled_input(user_text)
    resp = client.responses.create(
        model=MODEL_ID,
        input=bundle,
        store=True,
    )
    answer = resp.output_text

    # C) Record turn + cache the successful answer
    recent_turns.append(("user", user_text))
    recent_turns.append(("assistant", answer))
    add_to_cache(user_text, answer)
    return answer



## App UI

In [22]:
# --- Session state setup ------------------------------------------------------
if "messages" not in st.session_state:
    st.session_state.messages = []  # each item: dict(role, type, content, [data])

# Only show welcome once per browser session
if "has_shown_welcome" not in st.session_state:
    st.session_state.has_shown_welcome = False

if not st.session_state.has_shown_welcome:
    welcome_text = textwrap.dedent(
        """
        Welcome to the American Equity Sales Assistant.

        You can ask natural language questions and the app will:
        - Try to match your question to a curated SQL query from the catalog
        - Run that SQL against the warehouse
        - Summarize the results for you

        A few example questions:
        - "Can you give me a quick summary of activity and revenue by account?"
        - “Who are the customers most likely to buy from us?”
        - "Rank customers by a combination of total revenue and propensity to purchase."
        - “Show me the biggest deals we’ve won.”

        When you are ready, type your question in the chat box below.
        """
    ).strip()

    st.session_state.messages.append(
        {
            "role": "assistant",
            "type": "text",
            "content": welcome_text,
        }
    )
    st.session_state.has_shown_welcome = True

st.title("American Equity Assistant")

# --- 1. Render history --------------------------------------------------------
for msg in st.session_state.messages:
    with st.chat_message(msg["role"]):
        if msg.get("type", "text") == "text":
            st.markdown(msg["content"])
        elif msg["type"] == "table":
            st.markdown(msg["content"])
            df_hist = pd.DataFrame(msg["data"])
            st.dataframe(df_hist)

# --- 2. New user input --------------------------------------------------------
user_text = st.chat_input("Ask a question about the sales data...")

if user_text:
    # Show user bubble for this turn
    with st.chat_message("user"):
        st.markdown(user_text)

    # Save user message in history
    st.session_state.messages.append(
        {
            "role": "user",
            "type": "text",
            "content": user_text,
        }
    )

    raw_question = user_text

    # --- 3. Build lightweight context ----------------------------------------
    chat_ctx = summarize_recent_turns()
    cached = find_in_cache(user_text)
    cached_ans = cached[0] if cached else None
    sql_ctx = maybe_lookup_sql(user_text)

    # --- 4. Route question ----------------------------------------------------
    route = route_question(raw_question)

    # High or medium confidence match to canned SQL
    if route["mode"] in ("use_catalog_high_conf", "use_catalog_medium_conf"):
        sql_item = route["best"]["item"]
        df = run_sql_from_catalog_item(sql_item)

        # grab past Q&A for this same canned SQL before we save the new one
        conn = get_conn()
        past_qas = []
        if sql_item.get("name"):
            past_qas = get_past_answers_for_sql(conn, sql_item["name"], limit=5)

        cleaned_q = route.get("cleaned_question", raw_question)
        header = f"**Question:** {cleaned_q}\n\n"

        medium_note = ""
        if route["mode"] == "use_catalog_medium_conf":
            medium_note = (
                "Your question did not fully match a predefined SQL query. "
                "We think this may be what you were asking.\n\n"
            )

        assistant_text = (
            header
            + medium_note
            + f"Running canned SQL: `{sql_item['name']}`"
        )

        # Base SQL result bubble
        with st.chat_message("assistant"):
            st.markdown(assistant_text)
            if df is not None:
                st.dataframe(df)
            else:
                st.info("The query returned no rows.")

        # Save table snapshot into history
        if df is not None:
            st.session_state.messages.append(
                {
                    "role": "assistant",
                    "type": "table",
                    "content": assistant_text,
                    "data": df.head(20).to_dict("records"),
                }
            )
        else:
            st.session_state.messages.append(
                {
                    "role": "assistant",
                    "type": "text",
                    "content": assistant_text + "\n\n(No rows returned.)",
                }
            )

        # --- 4a. LLM commentary on the SQL results ---------------------------
        commentary = None
        if df is not None:
            commentary_prompt = build_sql_commentary_prompt(
                user_question=raw_question,
                sql_item=sql_item,
                df=df,
            )

            commentary = llm_summarize(commentary_prompt)

            with st.chat_message("assistant"):
                st.markdown("**AI commentary on these results**")
                st.markdown(commentary)

            st.session_state.messages.append(
                {
                    "role": "assistant",
                    "type": "text",
                    "content": "AI commentary on these results:\n\n" + commentary,
                }
            )

        # --- 4b. Historical comparison using past Q&A -----------------------
        hist_summary = None
        if df is not None and past_qas:
            hist_prompt = build_historical_comparison_prompt(
                user_question=raw_question,
                sql_item=sql_item,
                current_df=df,
                current_commentary=commentary,
                past_qas=past_qas,
            )

            hist_summary = llm_summarize(hist_prompt)

            with st.chat_message("assistant"):
                st.markdown("**How this compares to previous answers**")
                st.markdown(hist_summary)

            st.session_state.messages.append(
                {
                    "role": "assistant",
                    "type": "text",
                    "content": "Comparison to history:\n\n" + hist_summary,
                }
            )

        # --- 4c. Save this Q&A into persistent history ----------------------
        full_answer_text = assistant_text
        if commentary:
            full_answer_text += "\n\nAI commentary on these results:\n\n" + commentary
        if hist_summary:
            full_answer_text += "\n\nComparison to history:\n\n" + hist_summary

        save_qa(conn, raw_question, full_answer_text, route["mode"], sql_item)

    # Fallback: pure LLM analysis
    else:
        bundle = build_light_bundle(user_text, chat_ctx, sql_ctx, cached_ans)
        model_key = choose_model_key(use_full=False, needs_smart=False)

        answer = llm_call(
            model_key,
            bundle,
        )

        with st.chat_message("assistant"):
            st.markdown(answer)

        st.session_state.messages.append(
            {
                "role": "assistant",
                "type": "text",
                "content": answer,
            }
        )
