<a href="https://colab.research.google.com/github/Renlim61/MVP_Product001_2025_Tier120pbc/blob/version-history/Phase1_RAG_MVP_v15.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

MVP Version 15:
STEP 5.0 – Model Registry (DB + helpers)

STEP 5.1 – ModelConfig & lookup helpers

In [None]:
# ============================================================
# CELL 0 / STEP 0 – Install & Imports
# ============================================================
# Run this once at the top of the notebook (Colab style).

%pip install -q faiss-cpu openai gradio PyPDF2 python-docx

import os
import io
import time
import json
import shutil
import pickle
import sqlite3
from uuid import uuid4
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import faiss
import gradio as gr

from PyPDF2 import PdfReader
from docx import Document as DocxDocument

# Global datetime import for the whole notebook
from datetime import datetime, timezone, timedelta

# Colab detection
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive as colab_drive

# OpenAI client (v1 library)
from openai import OpenAI




In [None]:

# ============================================================
# CELL 1 / STEP 1 – Paths, Defaults, and OpenAI Client
# ============================================================

# Base directory for everything (indexes, DB, etc.)
BASE_DIR = "/content/rag_mvp" if IN_COLAB else os.path.join(os.getcwd(), "rag_mvp")
os.makedirs(BASE_DIR, exist_ok=True)

# SQLite DB path (for saved documents, cohorts, users, chat history)
DB_PATH = os.path.join(BASE_DIR, "rag_documents.db")

# Directory to store FAISS indexes and metadata
INDEX_DIR = os.path.join(BASE_DIR, "indexes")
os.makedirs(INDEX_DIR, exist_ok=True)

# Trace log file for v15 debugging
TRACE_LOG_PATH = os.path.join(BASE_DIR, "debug_trace_v15.log")


# Chunking parameters
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200

# Defaults – you can change these if you like
EMBED_MODEL_DEFAULT = "text-embedding-3-small"
CHAT_MODEL_DEFAULT = "gpt-4.1-mini"

def build_openai_client(api_key: str) -> OpenAI:
    """
    Build a new OpenAI client from an API key.
    """
    if not api_key:
        raise ValueError("OpenAI API key is required.")
    return OpenAI(api_key=api_key)

def resolve_models(chat_model: str, embed_model: str) -> Tuple[str, str]:
    """
    Resolve user selections or fall back to sensible defaults.
    """
    resolved_chat = chat_model.strip() or CHAT_MODEL_DEFAULT
    resolved_embed = embed_model.strip() or EMBED_MODEL_DEFAULT
    return resolved_chat, resolved_embed

def ensure_base_dirs():
    os.makedirs(BASE_DIR, exist_ok=True)
    os.makedirs(INDEX_DIR, exist_ok=True)




In [None]:


# ============================================================
# CELL 1.5 / STEP 1.5 – Validate OpenAI Key & Models
# ============================================================

def validate_openai_key_and_models(api_key: str, chat_model: str, embed_model: str) -> str:
    """
    Lightweight validation:
    - Instantiate client
    - Do a tiny chat completion
    - Do a small embedding call
    Returns a human-readable status string.
    """
    if not api_key:
        return "❌ Please provide an OpenAI API key."

    try:
        client = build_openai_client(api_key)
        resolved_chat, resolved_embed = resolve_models(chat_model, embed_model)

        # Tiny chat test
        _ = client.chat.completions.create(
            model=resolved_chat,
            messages=[
                {"role": "system", "content": "Model availability test."},
                {"role": "user", "content": "Respond with 'OK' only."},
            ],
            max_tokens=2,
            temperature=0.0,
        )

        # Tiny embedding test
        _ = client.embeddings.create(
            model=resolved_embed,
            input=["test"],
        )

        return f"✅ OpenAI key valid. Chat model: `{resolved_chat}`, Embed model: `{resolved_embed}`"

    except Exception as e:
        return f"❌ Error validating key/models: {e}"



In [None]:

# ============================================================
# CELL 2 / STEP 2 – Document Loading & Chunking
# ============================================================

def load_pdf(file_bytes: bytes) -> str:
    reader = PdfReader(io.BytesIO(file_bytes))
    texts = []
    for page in reader.pages:
        try:
            txt = page.extract_text() or ""
        except Exception:
            txt = ""
        texts.append(txt)
    return "\n".join(texts)

def load_docx(file_bytes: bytes) -> str:
    f = io.BytesIO(file_bytes)
    doc = DocxDocument(f)
    return "\n".join(p.text for p in doc.paragraphs)

def load_txt(file_bytes: bytes, encoding: str = "utf-8") -> str:
    return file_bytes.decode(encoding, errors="ignore")

def load_file_to_text(file_obj) -> Tuple[str, str]:
    """
    Accepts either:
    - a string filepath (when gr.File(type="filepath") is used), or
    - a file-like object with a .name attribute (older behavior).

    Returns:
      (text_content, original_filename)
    """
    # Case 1: gr.File(type="filepath") -> we get a string path
    if isinstance(file_obj, str):
        path = file_obj
        name = os.path.basename(path)
    else:
        # Case 2: some object with a .name attribute
        path = getattr(file_obj, "name", None)
        if path is None:
            raise ValueError("Unsupported file object from uploader.")
        name = os.path.basename(path)

    with open(path, "rb") as f:
        data = f.read()

    lower = name.lower()
    if lower.endswith(".pdf"):
        text = load_pdf(data)
    elif lower.endswith(".docx"):
        text = load_docx(data)
    elif lower.endswith(".txt"):
        text = load_txt(data)
    else:
        raise ValueError(f"Unsupported file type for {name}")

    return text, name


def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
    """
    Simple sliding-window chunking.
    """
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    tokens = text.split()
    chunks = []
    start = 0
    while start < len(tokens):
        end = start + chunk_size
        chunk_tokens = tokens[start:end]
        chunk = " ".join(chunk_tokens)
        chunks.append(chunk)
        start += chunk_size - overlap
    return chunks




In [None]:

# ============================================================
# CELL 3 / STEP 3 – Embedding & FAISS Index Helpers
# ============================================================

def embed_texts(
    api_key: str,
    embed_model: str,
    texts: List[str],
    batch_size: int = 32,
) -> np.ndarray:
    """
    Embed a list of texts using OpenAI embeddings.
    Returns an ndarray of shape (N, D).
    """
    client = build_openai_client(api_key)
    resolved_chat, resolved_embed = resolve_models("", embed_model)

    vectors: List[List[float]] = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        resp = client.embeddings.create(model=resolved_embed, input=batch)
        for d in resp.data:
            vectors.append(d.embedding)

    arr = np.array(vectors, dtype="float32")
    return arr

def build_faiss_index(vectors: np.ndarray) -> faiss.IndexFlatIP:
    """
    Build a simple inner-product FAISS index from vectors.
    """
    norm = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-10
    normed = vectors / norm
    dim = normed.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(normed)
    return index

def save_index(index: faiss.IndexFlatIP, index_id: str):
    path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
    faiss.write_index(index, path)

def load_index(index_id: str) -> faiss.IndexFlatIP:
    path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Index file not found: {path}")
    return faiss.read_index(path)

def save_metadata(index_id: str, meta: Dict[str, Any]):
    path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
    with open(path, "wb") as f:
        pickle.dump(meta, f)

def load_metadata(index_id: str) -> Dict[str, Any]:
    path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Metadata file not found: {path}")
    with open(path, "rb") as f:
        return pickle.load(f)



In [None]:
# ============================================================
# CELL 4 / STEP 4 – SQLite Persistence for Documents & Cohorts
# ============================================================

def get_db_conn():
    ensure_base_dirs()
    os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
    return sqlite3.connect(DB_PATH)


def ensure_docs_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS documents (
            id              TEXT PRIMARY KEY,
            doc_name        TEXT NOT NULL,
            cohort_name     TEXT NOT NULL,
            index_id        TEXT NOT NULL,
            n_chunks        INTEGER NOT NULL,
            embed_model     TEXT NOT NULL,
            created_at      TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()


def ensure_cohort_table():
    """
    Ensure both:
      - cohorts: cohort metadata (required by v15 tests)
      - cohort_docs: mapping of cohort_name -> doc_name (used by the app)
    """
    conn = get_db_conn()
    cur = conn.cursor()

    # ---- New table required by tests ----
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS cohorts (
            id              INTEGER PRIMARY KEY AUTOINCREMENT,
            name            TEXT NOT NULL UNIQUE,
            description     TEXT,
            owner_user_id   TEXT,
            created_at      TEXT NOT NULL,
            updated_at      TEXT NOT NULL
        )
        """
    )

    # ---- Existing mapping table (unchanged behavior) ----
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS cohort_docs (
            cohort_name TEXT NOT NULL,
            doc_name    TEXT NOT NULL
        )
        """
    )

    conn.commit()
    conn.close()


def list_cohorts() -> List[str]:
    # We keep existing behavior: list distinct names from cohort_docs
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute("SELECT DISTINCT cohort_name FROM cohort_docs ORDER BY cohort_name ASC")
    rows = cur.fetchall()
    conn.close()
    return [r[0] for r in rows]


def list_docs_in_cohort(cohort_name: str) -> List[str]:
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT doc_name FROM cohort_docs WHERE cohort_name = ? ORDER BY doc_name ASC",
        (cohort_name,),
    )
    rows = cur.fetchall()
    conn.close()
    return [r[0] for r in rows]


def add_docs_to_cohort(cohort_name: str, doc_names: List[str]):
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    for dn in doc_names:
        cur.execute(
            "INSERT INTO cohort_docs (cohort_name, doc_name) VALUES (?, ?)",
            (cohort_name, dn),
        )
    conn.commit()
    conn.close()


def rename_cohort(old_name: str, new_name: str):
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "UPDATE cohort_docs SET cohort_name = ? WHERE cohort_name = ?",
        (new_name, old_name),
    )
    cur.execute(
        "UPDATE documents SET cohort_name = ? WHERE cohort_name = ?",
        (new_name, old_name),
    )
    conn.commit()
    conn.close()


def delete_cohort(cohort_name: str, reassign_to: Optional[str] = None) -> str:
    """
    Delete a cohort. If reassign_to is provided, documents move there.
    Otherwise, documents are deleted (and their indexes removed).
    NOTE: This is intentionally explicit to avoid orphaned docs.
    """
    ensure_docs_table()
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute(
        "SELECT id, index_id FROM documents WHERE cohort_name = ?",
        (cohort_name,),
    )
    docs = cur.fetchall()

    if reassign_to:
        # Just move docs
        cur.execute(
            "UPDATE documents SET cohort_name = ? WHERE cohort_name = ?",
            (reassign_to, cohort_name),
        )
        cur.execute(
            "UPDATE cohort_docs SET cohort_name = ? WHERE cohort_name = ?",
            (reassign_to, cohort_name),
        )
        msg = f"✅ Cohort '{cohort_name}' renamed/reassigned to '{reassign_to}'. No indexes deleted."
    else:
        # Delete docs and indexes
        for doc_id, index_id in docs:
            # Remove index & metadata
            faiss_path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
            pkl_path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
            for p in [faiss_path, pkl_path]:
                if os.path.exists(p):
                    os.remove(p)
            cur.execute("DELETE FROM documents WHERE id = ?", (doc_id,))

        cur.execute("DELETE FROM cohort_docs WHERE cohort_name = ?", (cohort_name,))
        msg = f"✅ Cohort '{cohort_name}' and its documents/indexes were deleted."

    conn.commit()
    conn.close()
    return msg


def register_document(
    doc_name: str,
    cohort_name: str,
    index_id: str,
    n_chunks: int,
    embed_model: str,
):
    ensure_docs_table()
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    doc_id = str(uuid4())

    # ✅ Use dt.datetime (module alias) so the bottom `import datetime`
    #    in the self-test cell cannot break this.
    created_at = dt.datetime.now(dt.timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO documents (id, doc_name, cohort_name, index_id, n_chunks,
                               embed_model, created_at)
        VALUES (?, ?, ?, ?, ?, ?, ?)
        """,
        (doc_id, doc_name, cohort_name, index_id, n_chunks, embed_model, created_at),
    )
    cur.execute(
        "INSERT INTO cohort_docs (cohort_name, doc_name) VALUES (?, ?)",
        (cohort_name, doc_name),
    )
    conn.commit()
    conn.close()
    return doc_id


def get_doc_index_id(doc_name: str, cohort_name: str) -> Optional[str]:
    ensure_docs_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT index_id
        FROM documents
        WHERE doc_name = ? AND cohort_name = ?
        """,
        (doc_name, cohort_name),
    )
    row = cur.fetchone()
    conn.close()
    if row:
        return row[0]
    return None


def list_all_documents() -> List[Tuple[str, str, str]]:
    """
    Return list of (doc_name, cohort_name, created_at).
    """
    ensure_docs_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT doc_name, cohort_name, created_at FROM documents ORDER BY created_at DESC"
    )
    rows = cur.fetchall()
    conn.close()
    return rows
# ------------------------------------------------------------
# SCHEMA BOOTSTRAP (ensures required tables exist for tests)
# ------------------------------------------------------------
ensure_docs_table()
ensure_cohort_table()


In [None]:
# ============================================================
# CELL 4.5 / STEP 4.5 – Users & Chat History (7-Day Retention)
# ============================================================

from dataclasses import dataclass
import datetime as dt

# ============================================================
# USER IDENTITY MODEL
# ============================================================

@dataclass
class SessionUser:
    username: str | None = None
    role: str = "anonymous"   # "anonymous", "user", "admin"

    @property
    def is_authenticated(self) -> bool:
        return self.username is not None

    @property
    def is_admin(self) -> bool:
        return self.role == "admin"


# MVP in-memory auth store (will be replaced later by ICAM/SSO)
USERS = {
    "admin": {"password": "admin123", "role": "admin"},
    "demo":  {"password": "demo123",  "role": "user"},
}


# ============================================================
# USERS TABLE
# ============================================================

def ensure_user_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS users (
            user_id      TEXT PRIMARY KEY,
            display_name TEXT,
            role         TEXT,
            created_at   TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()


def upsert_user(user_id: str, role: str, display_name: Optional[str] = None):
    """
    Basic ICAM-ready user record.
    Inserts new or updates existing users.
    """
    if not user_id:
        return

    ensure_user_table()
    conn = get_db_conn()
    cur = conn.cursor()
    now = dt.datetime.now(dt.timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO users (user_id, display_name, role, created_at)
        VALUES (?, ?, ?, ?)
        ON CONFLICT(user_id) DO UPDATE SET
            display_name = COALESCE(?, users.display_name),
            role = COALESCE(?, users.role)
        """,
        (user_id, display_name, role, now, display_name, role),
    )
    conn.commit()
    conn.close()


# ============================================================
# CHAT HISTORY TABLE
# ============================================================

def ensure_chat_history_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS chat_history (
            id             INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id        TEXT,
            role           TEXT,
            cohort_name    TEXT,
            original_query TEXT,
            improved_query TEXT,
            which_prompt   TEXT,
            answer         TEXT,
            chat_model     TEXT,
            created_at     TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()


# ============================================================
# 7-DAY RETENTION
# ============================================================

def prune_chat_history(days: int = 7):
    ensure_chat_history_table()
    cutoff = dt.datetime.now(dt.timezone.utc) - dt.timedelta(days=days)
    cutoff_iso = cutoff.isoformat()

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute("DELETE FROM chat_history WHERE created_at < ?", (cutoff_iso,))
    conn.commit()
    conn.close()


# ============================================================
# v14 BEHAVIOR: Save Chat Interaction (DETAILED LOGGING)
# ============================================================

def save_chat_interaction(
    user_id: str,
    role: str,
    cohort_name: str,
    original_query: str,
    improved_query: str,
    which_prompt: str,
    answer: str,
    chat_model: str,
):
    """
    This is your existing v14 logging mechanism.
    Now fully preserved and compatible with v15.
    """
    ensure_chat_history_table()
    prune_chat_history(days=7)

    conn = get_db_conn()
    cur = conn.cursor()
    now = dt.datetime.now(dt.timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO chat_history (
            user_id, role, cohort_name, original_query, improved_query,
            which_prompt, answer, chat_model, created_at
        )
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """,
        (
            user_id or None,
            role or None,
            cohort_name or None,
            original_query,
            improved_query,
            which_prompt,
            answer,
            chat_model,
            now,
        ),
    )
    conn.commit()
    conn.close()


# ============================================================
# v15 REQUIRED FUNCTIONS – FLEXIBLE SIGNATURES
# ============================================================

def _extract_user_id(user: Any) -> str:
    """
    Helper to derive a stable user_id from various representations.
    """
    if user is None:
        return "anonymous"

    # dict-like
    if isinstance(user, dict):
        for key in ("username", "user_id", "name"):
            v = user.get(key)
            if v:
                return str(v)

    # attribute-based (SessionUser or other objects)
    for attr in ("username", "user_id", "name"):
        try:
            v = getattr(user, attr, None)
            if v:
                return str(v)
        except Exception:
            pass

    # mapping-like (sqlite3.Row, etc.)
    try:
        for key in ("username", "user_id", "name"):
            if key in user:
                v = user[key]
                if v:
                    return str(v)
    except Exception:
        pass

    return "anonymous"


def save_chat_history(
    messages=None,
    *args,
    user: Any = None,
    user_id: Optional[str] = None,
    cohort_name: str = "default",
    cohort: Optional[str] = None,
    **kwargs,
):
    """
    v15 Test Suite Requirement:
    Supports BOTH:
      1) save_chat_history(messages=[{role, content}, ...], user=..., cohort=...)
      2) save_chat_history(user=..., cohort=..., question="Q", answer="A", model_used="...")

    In (2) we write a single row using the Q/A fields.
    """
    ensure_chat_history_table()
    prune_chat_history(days=7)

    # Normalize cohort alias
    if cohort is not None:
        cohort_name = cohort

    # Derive user_id if needed
    if user_id is None:
        user_id = _extract_user_id(user)

    # ---- Path 1: test-style call with question/answer/model_used ----
    question = kwargs.get("question")
    answer = kwargs.get("answer")
    model_used = kwargs.get("model_used")

    if messages is None and (question is not None or answer is not None):
        now = dt.datetime.now(dt.timezone.utc).isoformat()
        conn = get_db_conn()
        cur = conn.cursor()

        cur.execute(
            """
            INSERT INTO chat_history (
                user_id, role, cohort_name, original_query,
                improved_query, which_prompt, answer,
                chat_model, created_at
            )
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                user_id,          # user_id
                "user",           # role
                cohort_name,      # cohort_name
                question or "",   # original_query
                None,             # improved_query
                None,             # which_prompt
                answer or "",     # answer
                model_used,       # chat_model
                now,              # created_at
            ),
        )

        conn.commit()
        conn.close()
        return  # We're done for this style of call

    # ---- Path 2: normal messages-based usage ----
    if messages is None:
        messages = kwargs.get("messages", None)

    # If still no messages, treat as no-op
    if messages is None:
        return

    now = dt.datetime.now(dt.timezone.utc).isoformat()
    conn = get_db_conn()
    cur = conn.cursor()

    for m in messages:
        role = m.get("role", "user")
        content = m.get("content", "")

        cur.execute(
            """
            INSERT INTO chat_history (
                user_id, role, cohort_name, original_query,
                improved_query, which_prompt, answer,
                chat_model, created_at
            )
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (
                user_id,
                role,
                cohort_name,
                content,   # original_query
                None,
                None,
                None,
                None,
                now,
            ),
        )

    conn.commit()
    conn.close()



def load_chat_history(
    *args,
    **kwargs,
):
    """
    v15 Test Suite Requirement:
    Returns list[dict] with keys: role, content, created_at.

    For simplicity and maximum compatibility with the tests,
    we ignore user/cohort filters here and just return the
    oldest messages up to `limit`.
    """
    ensure_chat_history_table()

    limit = int(kwargs.get("limit", 50))

    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute(
        """
        SELECT role, original_query, created_at
        FROM chat_history
        ORDER BY created_at ASC
        LIMIT ?
        """,
        (limit,),
    )

    rows = cur.fetchall()
    conn.close()

    return [
        {"role": r[0], "content": r[1], "created_at": r[2]}
        for r in rows
    ]



# ============================================================
# RECENT HISTORY (Admin screens / Debug UI)
# ============================================================

def get_recent_history(
    user_id: Optional[str] = None,
    cohort_name: Optional[str] = None,
    limit: int = 50,
) -> List[Dict[str, Any]]:
    """
    Returns full detailed history rows for admin/debug views.
    """
    ensure_chat_history_table()

    conn = get_db_conn()
    cur = conn.cursor()

    query = """
        SELECT user_id, role, cohort_name, original_query, improved_query,
               which_prompt, answer, chat_model, created_at
        FROM chat_history
    """
    params = []
    conditions = []

    if user_id:
        conditions.append("user_id = ?")
        params.append(user_id)
    if cohort_name:
        conditions.append("cohort_name = ?")
        params.append(cohort_name)

    if conditions:
        query += " WHERE " + " AND ".join(conditions)

    query += " ORDER BY created_at DESC LIMIT ?"
    params.append(limit)

    cur.execute(query, params)
    rows = cur.fetchall()
    conn.close()

    return [
        {
            "user_id": r[0],
            "role": r[1],
            "cohort_name": r[2],
            "original_query": r[3],
            "improved_query": r[4],
            "which_prompt": r[5],
            "answer": r[6],
            "chat_model": r[7],
            "created_at": r[8],
        }
        for r in rows
    ]


# ============================================================
# ADMIN: LIST USERS
# ============================================================

def list_users() -> List[Tuple[str, str, str]]:
    ensure_user_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute("SELECT user_id, display_name, role FROM users ORDER BY created_at DESC")
    rows = cur.fetchall()
    conn.close()
    return rows


In [None]:
# CELL 4.6 / STEP 4.6 – Audit Log (v15)
# ============================================================

def ensure_audit_table():
    """
    Create an audit_log table if it doesn't already exist.
    """
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS audit_log (
            id        INTEGER PRIMARY KEY AUTOINCREMENT,
            ts        TEXT NOT NULL,
            username  TEXT,
            role      TEXT,
            action    TEXT NOT NULL,
            details   TEXT
        )
        """
    )
    conn.commit()
    conn.close()


def log_audit(username: str, role: str, action: str, details: str = ""):
    """
    Insert a row into the audit_log table.
    - ts: UTC ISO timestamp
    - username / role: may be None/empty for anonymous
    - action: short code, e.g. 'login', 'ask', 'admin_refresh', 'delete_cohort'
    - details: freeform string with context
    """
    ensure_audit_table()
    conn = get_db_conn()
    cur = conn.cursor()

    # Use the global datetime alias `dt` so we don't conflict with any
    # later "import datetime" inside the self-test cell.
    now = dt.datetime.now(dt.timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO audit_log (ts, username, role, action, details)
        VALUES (?, ?, ?, ?, ?)
        """,
        (now, username or "", role or "", action, details or ""),
    )
    conn.commit()
    conn.close()


def trace_log(message: str):
    """
    Append a timestamped debug line to debug_trace_v15.log in BASE_DIR.

    Used to capture errors/warnings that happen inside Gradio callbacks
    where the UI might just show a generic 'Error'.
    """
    try:
        ts = dt.datetime.now(dt.timezone.utc).isoformat()
    except Exception:
        ts = "UNKNOWN_TIME"

    line = f"{ts} {message}\n"
    try:
        with open(TRACE_LOG_PATH, "a", encoding="utf-8") as f:
            f.write(line)
    except Exception as e:
        # Last resort: don't let logging itself crash anything
        print("TRACE_LOG_ERROR", e, line)



In [None]:
# ============================================================
# STEP 4.7 – Cohort Ownership & Sharing (v15-safe)
# ============================================================

from typing import Optional, List, Dict, Any

def ensure_cohort_meta_table():
    """
    Metadata for cohorts:
      - owner_user_id: who created/owns the cohort
      - is_shared: 0 = private to owner, 1 = shared to all users
      - created_ts: UTC timestamp
    """
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS cohort_meta (
            cohort_name    TEXT PRIMARY KEY,
            owner_user_id  TEXT,
            is_shared      INTEGER DEFAULT 0,
            created_ts     TEXT
        )
        """
    )
    conn.commit()
    conn.close()


def _extract_username_from_user(user: SessionUser | dict | None) -> Optional[str]:
    """
    Helper: accept SessionUser, dict, or None and pull out a username string.
    """
    if user is None:
        return None

    # If SessionUser dataclass
    if isinstance(user, SessionUser):
        return user.username

    # If dict-like
    if isinstance(user, dict):
        # Try common keys in order
        for k in ("username", "user_id", "name"):
            try:
                v = user.get(k)
                if v:
                    return str(v)
            except Exception:
                pass

    return None


def set_cohort_owner(cohort_name: str, user: SessionUser | dict | None):
    """
    Register or update the owner of a cohort.

    For now:
      - owner_user_id = extracted username (or 'anonymous' if not logged in)
      - is_shared = 0 by default (private)
    """
    if not cohort_name:
        return

    # Local import guarantees we get the datetime CLASS, not the module
    from datetime import datetime, timezone

    ensure_cohort_meta_table()
    conn = get_db_conn()
    cur = conn.cursor()

    owner = _extract_username_from_user(user) or "anonymous"
    now = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO cohort_meta (cohort_name, owner_user_id, is_shared, created_ts)
        VALUES (?, ?, 0, ?)
        ON CONFLICT(cohort_name) DO UPDATE SET
            owner_user_id = excluded.owner_user_id
        """,
        (cohort_name, owner, now),
    )

    conn.commit()
    conn.close()


def cohort_exists(cohort_name: str) -> bool:
    """
    Returns True if a cohort with this name already exists in cohort_docs.
    This is global (not per-user) to avoid confusing duplicate names.
    """
    if not cohort_name:
        return False

    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT 1 FROM cohort_docs WHERE cohort_name = ? LIMIT 1",
        (cohort_name,),
    )
    row = cur.fetchone()
    conn.close()
    return row is not None


def list_cohorts_for_user(user: SessionUser | dict | None) -> list[str]:
    """
    Return list of cohort names visible to the given user.

    Rules:
      - Admins: all cohorts (global)
      - Non-admin:
          * Cohorts where they are owner (cohort_meta.owner_user_id)
          * Cohorts marked is_shared = 1
      - Anonymous: only shared cohorts (is_shared = 1)

    If anything goes wrong with the metadata logic, falls back to global list_cohorts().
    Also writes debug info to the v15 trace log.
    """
    trace_log(f"list_cohorts_for_user called with user={user!r}")

    ensure_cohort_table()
    ensure_cohort_meta_table()

    conn = get_db_conn()
    cur = conn.cursor()

    try:
        # Determine if user is admin, safely
        is_admin = False
        username = None

        if isinstance(user, SessionUser):
            is_admin = user.is_admin
            username = user.username
        elif isinstance(user, dict):
            username = _extract_username_from_user(user)
            role = user.get("role")
            is_admin = (role == "admin")

        if is_admin:
            # Admin sees all distinct cohorts
            cur.execute(
                """
                SELECT DISTINCT cohort_name
                FROM cohort_docs
                ORDER BY cohort_name
                """
            )
        else:
            # Non-admin or anonymous
            if username:
                # Logged-in non-admin -> owner or shared
                cur.execute(
                    """
                    SELECT DISTINCT cd.cohort_name
                    FROM cohort_docs cd
                    LEFT JOIN cohort_meta cm
                      ON cd.cohort_name = cm.cohort_name
                    WHERE cm.owner_user_id = ?
                       OR cm.is_shared = 1
                       OR cm.cohort_name IS NULL   -- safety: cohorts without meta still appear
                    ORDER BY cd.cohort_name
                    """,
                    (username,),
                )
            else:
                # Anonymous -> only shared (or cohorts w/o meta as a fallback)
                cur.execute(
                    """
                    SELECT DISTINCT cd.cohort_name
                    FROM cohort_docs cd
                    LEFT JOIN cohort_meta cm
                      ON cd.cohort_name = cm.cohort_name
                    WHERE cm.is_shared = 1
                       OR cm.cohort_name IS NULL
                    ORDER BY cd.cohort_name
                    """
                )

        rows = cur.fetchall()
        conn.close()
        names = [r[0] for r in rows]

        # Final safety net: if nothing, fall back to global.
        if not names:
            trace_log("list_cohorts_for_user -> no rows; falling back to list_cohorts()")
            names = list_cohorts()

        trace_log(f"list_cohorts_for_user returning {names}")
        return names

    except Exception as e:
        conn.close()
        trace_log(f"list_cohorts_for_user ERROR: {e}")
        return list_cohorts()


    except Exception as e:
        conn.close()
        print("DEBUG list_cohorts_for_user error:", e)
        return list_cohorts()


# ------------------------------------------------------------
# UI HELPERS FOR COHORT LISTS (used in STEP 10)
# ------------------------------------------------------------

def get_cohorts_for_user(username: Optional[str]) -> List[str]:
    """
    Convenience wrapper for the Gradio UI:
    take a simple username string and delegate to list_cohorts_for_user().

    Also logs to the v15 trace file for easier debugging of Refresh Cohorts.
    """
    try:
        trace_log(f"get_cohorts_for_user called with username={username!r}")

        if not username:
            user = SessionUser(username=None, role="anonymous")
        else:
            role = USERS.get(username, {}).get("role", "user")
            user = SessionUser(username=username, role=role)

        names = list_cohorts_for_user(user)
        trace_log(f"get_cohorts_for_user returning {names}")
        return names

    except Exception as e:
        trace_log(f"get_cohorts_for_user ERROR for username={username!r}: {e}")
        # Last-resort fallback so the UI doesn't hard-crash
        return list_cohorts()



def get_all_cohorts() -> List[List[str]]:
    """
    For the Admin tab: return [[cohort_name, owner], ...].
    """
    ensure_cohort_table()
    ensure_cohort_meta_table()

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT cd.cohort_name,
               COALESCE(cm.owner_user_id, 'unknown') AS owner
        FROM cohort_docs cd
        LEFT JOIN cohort_meta cm
          ON cd.cohort_name = cm.cohort_name
        GROUP BY cd.cohort_name
        ORDER BY cd.cohort_name
        """
    )
    rows = cur.fetchall()
    conn.close()

    return [[name, owner] for (name, owner) in rows]


In [None]:
# ============================================================
# CELL 5.0 / STEP 5.0 – Model Registry Core (v15, with schema migration)
# ============================================================

import sqlite3
from typing import List, Dict, Any

def ensure_model_registry_table():
    """
    Ensure the model_registry table exists with the v15 schema.
    If an older schema is detected (missing model_id or other key fields),
    we DROP and recreate the table.

    This is safe for the MVP since we don't rely on persistent custom models yet.
    """
    conn = get_db_conn()
    cur = conn.cursor()

    # Does the table exist?
    cur.execute(
        "SELECT name FROM sqlite_master WHERE type='table' AND name='model_registry'"
    )
    row = cur.fetchone()

    if row:
        # Table exists – inspect its columns
        cur.execute("PRAGMA table_info(model_registry)")
        cols_info = cur.fetchall()
        existing_cols = {c[1] for c in cols_info}  # c[1] is the column name

        required_cols = {
            "provider",
            "model_id",
            "display_name",
            "model_type",
            "enabled",
            "is_default",
            "cost_score",
            "latency_score",
            "max_context_tokens",
            "api_base",
            "notes",
        }

        # If the key v15 columns are missing, drop and recreate
        if not required_cols.issubset(existing_cols):
            print("DEBUG: model_registry schema mismatch detected. Dropping old table.")
            cur.execute("DROP TABLE IF EXISTS model_registry")
            conn.commit()

    # (Re)create the table with the correct v15 schema
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS model_registry (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            provider TEXT,
            model_id TEXT,
            display_name TEXT,
            model_type TEXT,
            enabled INTEGER DEFAULT 1,
            is_default INTEGER DEFAULT 0,
            cost_score INTEGER DEFAULT 2,
            latency_score INTEGER DEFAULT 2,
            max_context_tokens INTEGER,
            api_base TEXT,
            notes TEXT
        )
        """
    )
    conn.commit()
    conn.close()


def seed_default_models():
    """
    Seed the registry with standard defaults if empty.
    """
    ensure_model_registry_table()
    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM model_registry")
    count = cur.fetchone()[0]

    if count == 0:
        defaults = [
            # Default Chat Model
            {
                "provider": "openai",
                "model_id": "gpt-4.1-mini",
                "display_name": "GPT-4.1 Mini",
                "model_type": "chat",
                "enabled": 1,
                "is_default": 1,
                "cost_score": 1,
                "latency_score": 1,
                "max_context_tokens": 128_000,
                "api_base": None,
                "notes": "Primary chat model",
            },
            # Default Embedding Model
            {
                "provider": "openai",
                "model_id": "text-embedding-3-small",
                "display_name": "text-embedding-3-small",
                "model_type": "embed",
                "enabled": 1,
                "is_default": 1,
                "cost_score": 1,
                "latency_score": 1,
                "max_context_tokens": None,
                "api_base": None,
                "notes": "Primary embedding model",
            },
        ]

        for row in defaults:
            cur.execute(
                """
                INSERT INTO model_registry (
                    provider, model_id, display_name, model_type,
                    enabled, is_default, cost_score, latency_score,
                    max_context_tokens, api_base, notes
                )
                VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """,
                (
                    row["provider"],
                    row["model_id"],
                    row["display_name"],
                    row["model_type"],
                    row["enabled"],
                    row["is_default"],
                    row["cost_score"],
                    row["latency_score"],
                    row["max_context_tokens"],
                    row["api_base"],
                    row["notes"],
                ),
            )

        conn.commit()

    conn.close()


def list_models(model_type: str | None = None, only_enabled: bool = True) -> List[Dict[str, Any]]:
    """
    Return raw dict rows from the model_registry table.

    Args:
        model_type: 'chat', 'embed', etc. (optional)
        only_enabled: filter to enabled == 1
    """
    ensure_model_registry_table()
    conn = get_db_conn()
    cur = conn.cursor()

    sql = """
        SELECT provider, model_id, display_name, model_type,
               enabled, is_default, cost_score, latency_score,
               max_context_tokens, api_base, notes
        FROM model_registry
        WHERE 1=1
    """
    params: list[Any] = []

    if model_type:
        sql += " AND model_type = ?"
        params.append(model_type)

    if only_enabled:
        sql += " AND enabled = 1"

    sql += " ORDER BY model_type, is_default DESC, model_id"

    cur.execute(sql, params)
    rows = cur.fetchall()
    conn.close()

    keys = [
        "provider", "model_id", "display_name", "model_type", "enabled",
        "is_default", "cost_score", "latency_score", "max_context_tokens",
        "api_base", "notes",
    ]

    return [dict(zip(keys, r)) for r in rows]
def load_model_registry() -> List[Dict[str, Any]]:
    """
    Loads all models from the registry and ensures defaults exist.
    """
    ensure_model_registry_table()
    seed_default_models()  # <-- CRITICAL LINE (missing previously)

    conn = get_db_conn()
    conn.row_factory = sqlite3.Row
    cur = conn.cursor()

    cur.execute(
        """
        SELECT
            id,
            provider,
            model_id,
            model_id AS model,          -- backward compatibility
            display_name,
            model_type AS type,         -- backward compatibility
            enabled,
            is_default,
            cost_score,
            latency_score,
            max_context_tokens,
            api_base,
            notes
        FROM model_registry
        ORDER BY model_type, is_default DESC, model_id
        """
    )

    rows = [dict(r) for r in cur.fetchall()]
    conn.close()
    return rows






In [None]:
# ============================================================
# CELL 5.1 / STEP 5.1 – ModelConfig & Lookup Helpers (v15)
# ============================================================

from dataclasses import dataclass
from typing import Optional, List, Dict, Any

@dataclass
class ModelConfig:
    model_id: str
    provider: str = "openai"
    display_name: str | None = None
    model_type: str = "chat"   # 'chat', 'embed', 'rerank', etc.
    is_default: bool = False
    enabled: bool = True
    cost_score: int = 2
    latency_score: int = 2
    max_context_tokens: int | None = None
    api_base: str | None = None
    notes: str | None = None

    def as_kwargs(self) -> dict:
        """
        Common kwargs we might pass into a client, e.g. OpenAI.
        For now this is mostly for future extensibility.
        """
        kw = {"model": self.model_id}
        if self.api_base:
            kw["api_base"] = self.api_base
        return kw


def load_model_configs(model_type: str, only_enabled: bool = True) -> list[ModelConfig]:
    """
    Load models from the registry as ModelConfig objects.
    Relies on list_models(...) from STEP 5.0.
    """
    rows = list_models(model_type=model_type, only_enabled=only_enabled)
    configs: list[ModelConfig] = []
    for row in rows:
        configs.append(
            ModelConfig(
                model_id=row["model_id"],
                provider=row["provider"],
                display_name=row["display_name"],
                model_type=row["model_type"],
                is_default=row["is_default"],
                enabled=row["enabled"],
                cost_score=row["cost_score"],
                latency_score=row["latency_score"],
                max_context_tokens=row["max_context_tokens"],
                api_base=row["api_base"],
                notes=row["notes"],
            )
        )
    return configs


def get_default_model(model_type: str) -> Optional[Dict[str, Any]]:
    """
    Low-level helper used by the *Config() helpers below.

    Looks in the model_registry for the default model of a given type.
    Strategy:
      1) Use list_models(model_type, only_enabled=True)
      2) Return the one where is_default = True, if present
      3) Otherwise return the first enabled model
      4) If none exist, return None
    """
    rows = list_models(model_type=model_type, only_enabled=True)
    if not rows:
        return None

    for row in rows:
        if row.get("is_default"):
            return row

    # Fallback: first enabled of that type
    return rows[0]


def get_default_chat_model_config() -> ModelConfig:
    """
    Returns a ModelConfig for the default chat model.
    If none is explicitly set, falls back to the first enabled chat model.
    """
    default_row = get_default_model("chat")
    if default_row:
        return ModelConfig(
            model_id=default_row["model_id"],
            provider=default_row["provider"],
            display_name=default_row["display_name"],
            model_type=default_row["model_type"],
            is_default=default_row["is_default"],
            enabled=default_row["enabled"],
            cost_score=default_row["cost_score"],
            latency_score=default_row["latency_score"],
            max_context_tokens=default_row["max_context_tokens"],
            api_base=default_row["api_base"],
            notes=default_row["notes"],
        )

    # Fallback: first enabled chat model
    configs = load_model_configs("chat", only_enabled=True)
    if configs:
        return configs[0]

    # Last resort hard-coded default (should not happen due to seeding)
    return ModelConfig(
        model_id="gpt-4.1-mini",
        provider="openai",
        display_name="GPT-4.1 Mini (fallback)",
        model_type="chat",
        is_default=True,
        enabled=True,
        cost_score=1,
        latency_score=1,
        max_context_tokens=128_000,
    )


def get_default_embed_model_config() -> ModelConfig:
    """
    Returns a ModelConfig for the default embedding model.
    """
    default_row = get_default_model("embed")
    if default_row:
        return ModelConfig(
            model_id=default_row["model_id"],
            provider=default_row["provider"],
            display_name=default_row["display_name"],
            model_type=default_row["model_type"],
            is_default=default_row["is_default"],
            enabled=default_row["enabled"],
            cost_score=default_row["cost_score"],
            latency_score=default_row["latency_score"],
            max_context_tokens=default_row["max_context_tokens"],
            api_base=default_row["api_base"],
            notes=default_row["notes"],
        )

    # Fallback: first enabled embed model
    configs = load_model_configs("embed", only_enabled=True)
    if configs:
        return configs[0]

    # Last resort hard-coded default
    return ModelConfig(
        model_id="text-embedding-3-small",
        provider="openai",
        display_name="text-embedding-3-small (fallback)",
        model_type="embed",
        is_default=True,
        enabled=True,
        cost_score=1,
        latency_score=1,
    )
def list_chat_models() -> List[str]:
    models = load_model_registry()
    return [m["model_id"] for m in models if m["type"] == "chat" and m["enabled"]]



In [None]:
# ============================================================
# CELL 5.2 / STEP 5.2 – RAG Retrieval Over a Cohort (v15)
# ============================================================

def build_context_from_index(
    api_key: str,
    chat_model: str,
    embed_model: str,
    cohort_name: str,
    query: str,
    top_k: int = 5,
):
    """
    Given a cohort and query:
    - Load all documents for that cohort
    - For each doc's index, perform similarity search
    - Aggregate top_k results across docs
    Returns:
      - concatenated context string
      - list of (doc_name, rank, score) for citations
    """

    ensure_docs_table()
    resolved_chat, resolved_embed = resolve_models(chat_model, embed_model)
    client = build_openai_client(api_key)

    # 1. Load all docs & their index references
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT doc_name, index_id
        FROM documents
        WHERE cohort_name = ?
        """,
        (cohort_name,),
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "", []

    # 2. Embed query once
    q_embed_resp = client.embeddings.create(model=resolved_embed, input=[query])
    q_vec = np.array(q_embed_resp.data[0].embedding, dtype="float32")
    q_vec = q_vec / (np.linalg.norm(q_vec) + 1e-10)

    all_hits = []  # (doc_name, idx, score, text)

    # 3. Perform similarity search in each doc index
    for doc_name, index_id in rows:
        try:
            index = load_index(index_id)
            meta = load_metadata(index_id)  # {"chunks": [...]}
        except Exception:
            continue

        D, I = index.search(q_vec[np.newaxis, :], top_k)
        scores = D[0]
        idxs = I[0]

        for score, idx in zip(scores, idxs):
            if idx < 0:
                continue
            chunks = meta.get("chunks", [])
            if idx >= len(chunks):
                continue
            text_chunk = chunks[idx]
            all_hits.append((doc_name, idx, float(score), text_chunk))

    if not all_hits:
        return "", []

    # 4. Sort & select top results
    all_hits.sort(key=lambda x: x[2], reverse=True)
    top_hits = all_hits[:top_k]

    context_parts = []
    citations = []

    for rank, (doc_name, idx, score, text) in enumerate(top_hits, start=1):
        header = f"[{rank}] From {doc_name} (chunk #{idx}, score={score:.3f})"
        context_parts.append(header + "\n" + text)
        citations.append((doc_name, rank, score))

    context = "\n\n".join(context_parts)
    return context, citations


# ============================================================
# UPDATED answer_with_rag() — Now uses Routing Brain (v15)
# ============================================================

def answer_with_rag(
    api_key: str,
    chat_model: str,          # kept for backward compatibility
    embed_model: str,         # still used for embedding queries
    cohort_name: str,
    query: str,
    system_prompt: str = "",
    user_pref: str | None = None,  # NEW: allows override from UI
):
    """
    Perform RAG retrieval and return:
        - answer_markdown
        - raw_answer_text
        - model_used (for history & UI display)
    """

    # 1. Build RAG context
    context, citations = build_context_from_index(
        api_key, chat_model, embed_model, cohort_name, query
    )

    if not context:
        return (
            "I could not find any context for this query in the selected cohort.",
            "",
            "N/A",
        )

    # 2. Default system prompt
    if not system_prompt:
        system_prompt = (
            "You are a helpful assistant answering questions based on the provided context.\n"
            "If the answer cannot be found in the context, say you do not know."
        )

    # 3. Construct chat messages
    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                "Use ONLY the context below to answer the question.\n\n"
                "=== CONTEXT START ===\n"
                f"{context}\n"
                "=== CONTEXT END ===\n\n"
                f"QUESTION: {query}"
            ),
        },
    ]

    # 4. Call the routing brain — NEW for v15
    answer_text, raw_answer_text, model_used = call_chat_model(
        api_key=api_key,
        messages=messages,
        task_type="rag_answer",
        user_pref=user_pref,          # user-selected override (optional)
        context_size=len(context),    # helps routing choose large/small models
    )

    # 5. Build answer markdown with citations + model info
    md = answer_text + "\n\n---\n\n**Cited sources:**\n"
    for doc_name, rank, score in citations:
        md += f"- [{rank}] `{doc_name}` (score={score:.3f})\n"

    md += f"\n\n**Model Used:** `{model_used}`\n"

    return md, raw_answer_text, model_used
def answer_question_over_cohort(api_key, username, cohort_name, question, model_id):
    """
    Wrapper for the RAG retrieval + LLM step.
    Must exist BEFORE STEP 10.
    """
    trace_log(
        f"answer_question_over_cohort called user={username}, cohort={cohort_name}"
    )

    embed_cfg = get_default_embed_model_config()
    embed_model = embed_cfg.model_id

    answer_md, raw_answer, used_model = answer_with_rag(
        api_key=api_key,
        chat_model=model_id,
        embed_model=embed_model,
        cohort_name=cohort_name,
        query=question,
        system_prompt="",
        user_pref=model_id,  # routed through routing brain
    )

    return answer_md, raw_answer, used_model





In [None]:
# ============================================================
# CELL 6 / STEP 6 – Build Cohort from Uploaded Docs (v15)
# ============================================================
def build_cohort_from_files(
    api_key: str,
    embed_model: str,
    cohort_name: str,
    files: List[Any],
) -> str:
    """Ingest a list of uploaded files into a *single* cohort.

    For each file we:
      - load the text
      - chunk it
      - embed with the given embedding model
      - build & save a FAISS index
      - save metadata (including the chunks)
      - register the document in the SQLite `documents` table

    Returns a human-readable summary string for the UI.
    """
    ensure_docs_table()
    ensure_cohort_table()
    os.makedirs(INDEX_DIR, exist_ok=True)

    success_count = 0
    total_chunks = 0
    messages: List[str] = []

    for file_obj in files:
        try:
            # 1) Load text & derive a stable doc_name
            text, doc_name = load_file_to_text(file_obj)
            if not text or not text.strip():
                messages.append(f"⚠️ {doc_name}: no extractable text, skipped.")
                continue

            # 2) Chunk
            chunks = chunk_text(text)
            if not chunks:
                messages.append(f"⚠️ {doc_name}: produced 0 chunks, skipped.")
                continue

            # 3) Embed
            vectors = embed_texts(
                api_key=api_key,
                embed_model=embed_model,
                texts=chunks,
            )

            # 4) Build FAISS index
            index = build_faiss_index(vectors)

            # 5) Save index & metadata
            index_id = str(uuid4())
            save_index(index, index_id)

            meta = {
                "cohort_name": cohort_name,
                "doc_name": doc_name,
                "chunks": chunks,
                "embed_model": embed_model,
            }
            save_metadata(index_id, meta)

            # 6) Register in SQLite
            register_document(
                doc_name=doc_name,
                cohort_name=cohort_name,
                index_id=index_id,
                n_chunks=len(chunks),
                embed_model=embed_model,
            )

            success_count += 1
            total_chunks += len(chunks)
            messages.append(f"✅ {doc_name}: {len(chunks)} chunks embedded.")

        except Exception as e:
            # Best-effort error capture per-file
            name = getattr(file_obj, "name", str(file_obj))
            messages.append(f"❌ {name}: {e}")

    if success_count == 0:
        detail = "\n".join(messages) if messages else ""
        return "❌ No documents were successfully processed." + (f"\n{detail}" if detail else "")

    summary = (
        f"✅ Built cohort '{cohort_name}' with {success_count} document(s) "
        f"and {total_chunks} total chunks."
    )
    if messages:
        summary += "\n" + "\n".join(messages)
    return summary


def build_cohort_index(
    api_key: str,
    cohort_name: str,
    files: List[Any],
    owner: Optional[str] = None,
) -> str:
    """
    High-level helper used by the Gradio UI (v15):

      - Selects default embedding model (from model registry)
      - Builds the cohort
      - Stores cohort ownership
    """
    if not api_key or not api_key.strip():
        return "❌ OpenAI API key is required."

    if not cohort_name or not cohort_name.strip():
        return "❌ Cohort name is required."

    if not files:
        return "❌ Please upload at least one file."

    # ✔️ NEW — correct default embedding model call
    embed_cfg = get_default_embed_model_config()
    embed_model = embed_cfg.model_id

    # Build FAISS index + metadata
    result_msg = build_cohort_from_files(
        api_key=api_key,
        embed_model=embed_model,
        cohort_name=cohort_name.strip(),
        files=files,
    )

    # Save cohort owner metadata
    try:
        if owner:
            role = USERS.get(owner, {}).get("role", "user")
            user_obj = SessionUser(username=owner, role=role)
        else:
            user_obj = None

        set_cohort_owner(cohort_name.strip(), user_obj)
    except Exception as e:
        print("DEBUG set_cohort_owner error:", e)

    return result_msg


In [None]:
# ============================================================
# CELL 6.1 / STEP 6.1 – Routing Brain (v15)
# ============================================================
#
# Centralized model selection + wrapper for ALL chat LLM calls.
# Uses the model_registry table (STEP 5.0) and supports:
#   - task_type hints ("question_improve", "rag_answer", "summary", "admin")
#   - optional user_pref override (a specific model_id)
#   - dry_run flag for automated self-tests (no API call made)
#
# Functions:
#   - select_chat_model(task_type, context_size, user_pref)
#   - call_chat_model(api_key, messages, task_type, user_pref, context_size, dry_run)

from typing import List, Dict, Any, Tuple


def select_chat_model(
    task_type: str,
    context_size: int = 0,
    user_pref: str | None = None,
) -> str:
    """
    Decide which chat model_id to use based on:
      - user_pref: explicit override from UI
      - registry defaults (is_default)
      - task_type and context_size (reserved for future heuristics)

    Returns:
        model_id (e.g., "gpt-4.1-mini")
    """
    # Load enabled chat models as ModelConfig objects
    configs = load_model_configs("chat", only_enabled=True)

    # Fallback if registry is empty or misconfigured
    if not configs:
        return "gpt-4.1-mini"

    # 1. If user_pref matches a known enabled model_id, honor it
    if user_pref:
        for cfg in configs:
            if cfg.model_id == user_pref:
                return cfg.model_id

    # 2. Prefer the model marked as default
    for cfg in configs:
        if cfg.is_default:
            return cfg.model_id

    # 3. Simple heuristic placeholder:
    #    For now we ignore task_type/context_size and just use the first enabled.
    return configs[0].model_id



def call_chat_model(
    api_key: str,
    messages: List[Dict[str, Any]],
    task_type: str,
    user_pref: str | None = None,
    context_size: int = 0,
    dry_run: bool = False,
) -> Tuple[str, str, str]:
    """
    Wrapper for ALL chat LLM calls in the app.

    Args:
        api_key:      OpenAI API key
        messages:     Chat completion messages
        task_type:    Semantic label for routing (e.g. 'rag_answer', 'question_improve')
        user_pref:    Optional explicit model_id override
        context_size: Approx size of context (chars) to inform routing
        dry_run:      If True, DO NOT call the API (used by self-tests).

    Returns:
        (answer_text, raw_answer_text, model_used)
    """
    model_id = select_chat_model(
        task_type=task_type,
        context_size=context_size,
        user_pref=user_pref,
    )

    # For automated self-tests: don't hit the API
    if dry_run:
        dummy = f"[DRY RUN] task_type={task_type}, model_id={model_id}"
        return dummy, dummy, model_id

    client = build_openai_client(api_key)

    resp = client.chat.completions.create(
        model=model_id,
        messages=messages,
        temperature=0.2,
        max_tokens=900,
    )

    answer_text = resp.choices[0].message.content.strip()
    # In this MVP, raw_answer_text == answer_text, but we keep both for future transforms
    raw_answer_text = answer_text

    return answer_text, raw_answer_text, model_id


In [None]:
# ============================================================
# CELL 7 / STEP 7 – Admin Helpers (Stats & Maintenance)
# ============================================================

def get_db_stats() -> str:
    ensure_docs_table()
    ensure_cohort_table()
    ensure_user_table()
    ensure_chat_history_table()

    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM documents")
    n_docs = cur.fetchone()[0]

    cur.execute("SELECT COUNT(DISTINCT cohort_name) FROM cohort_docs")
    n_cohorts = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM users")
    n_users = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM chat_history")
    n_chats = cur.fetchone()[0]

    conn.close()

    return (
        f"**DB Stats**\n\n"
        f"- Documents: {n_docs}\n"
        f"- Cohorts: {n_cohorts}\n"
        f"- Users: {n_users}\n"
        f"- Chat records (last 7 days enforced on write): {n_chats}\n"
    )

def describe_users() -> str:
    rows = list_users()
    if not rows:
        return "No users have been recorded yet."
    lines = ["**Known Users**\n"]
    for user_id, display_name, role in rows:
        disp = display_name or "(no display name)"
        r = role or "(no role)"
        lines.append(f"- `{user_id}` – {disp} – role: `{r}`")
    return "\n".join(lines)

def describe_cohorts() -> str:
    """
    Returns a markdown summary of cohorts, including:
      - name
      - owner
      - visibility (private/shared)
      - document count
    """
    ensure_docs_table()
    ensure_cohort_table()
    ensure_cohort_meta_table()

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT DISTINCT
            cd.cohort_name,
            COALESCE(cm.owner_user_id, '(none)') AS owner,
            COALESCE(cm.is_shared, 0)            AS is_shared,
            COUNT(DISTINCT cd.doc_name)          AS num_docs
        FROM cohort_docs cd
        LEFT JOIN cohort_meta cm
          ON cd.cohort_name = cm.cohort_name
        GROUP BY cd.cohort_name, owner, is_shared
        ORDER BY cd.cohort_name
        """
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "No cohorts found."

    lines = ["**Cohorts**", ""]
    for name, owner, is_shared, num_docs in rows:
        share_label = "shared" if is_shared else "private"
        lines.append(
            f"- **{name}** — owner: `{owner}`, visibility: {share_label}, docs: {num_docs}"
        )

    return "\n".join(lines)




In [None]:
# ============================================================
# CELL 8 / STEP 8 – Prompt Coach (Optional Query Improvement) – v15
# ============================================================
#
# Uses the v15 Routing Brain (call_chat_model) instead of calling OpenAI directly.
# Signature is kept the same so STEP 10's on_improve_query(...) still works:
#     improve_query(api_key, chat_model, original_query)
#
# In a future step, we can optionally add a user-selected model override and
# pass it into call_chat_model(user_pref=...).

def improve_query(
    api_key: str,
    chat_model: str,       # kept for backward compatibility; routing ignores it
    original_query: str,
) -> str:
    """
    Prompt coach to re-write the user's query for better RAG retrieval.

    Behavior:
    - If the original query is empty/whitespace, returns "".
    - Otherwise, uses the Routing Brain (task_type='question_improve') to pick
      an appropriate chat model and rewrite the question to be clearer, more
      explicit, and RAG-friendly.
    """
    if not original_query.strip():
        return ""

    system_prompt = (
        "You are a prompt coach helping the user improve questions for a RAG system. "
        "Rewrite the query to be explicit, concise, and focused on key details. "
        "Preserve the user's intent but remove ambiguity, vague pronouns, and "
        "unnecessary filler. Return ONLY the improved query text."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": original_query},
    ]

    # Use the v15 Routing Brain instead of calling OpenAI directly
    improved, _, model_used = call_chat_model(
        api_key=api_key,
        messages=messages,
        task_type="question_improve",
        user_pref=None,                     # (optional override will come from UI later)
        context_size=len(original_query),   # small, but available for routing heuristics
    )

    # For now we just return the improved text. If desired later, we can:
    # - log model_used to audit_log
    # - display which model did the improvement in the UI.
    return improved.strip()


In [None]:
# ============================================================
# CELL 9 / STEP 9 – Chat History Viewer (User-Facing)
# ============================================================

def format_history_markdown(
    user_id: Optional[str],
    cohort_name: Optional[str],
    limit: int = 50,
) -> str:
    """
    Turn recent history into markdown for display.
    """
    hist = get_recent_history(user_id=user_id, cohort_name=cohort_name, limit=limit)
    if not hist:
        return "No chat history found for the given filters (within retention window)."

    lines = []
    lines.append(
        f"**Showing up to {limit} most recent interactions** "
        f"{'(filtered)' if user_id or cohort_name else ''}\n"
    )

    for h in hist:
        ts = h["created_at"]
        u = h["user_id"] or "(anonymous)"
        r = h["role"] or "(none)"
        c = h["cohort_name"] or "(none)"
        which = h["which_prompt"] or "(unknown)"

        lines.append(f"---\n**User:** `{u}`  |  **Role:** `{r}`  |  **Cohort:** `{c}`  |  **When:** {ts}")
        lines.append(f"**Prompt used:** `{which}`")
        lines.append(f"**Original query:**\n{h['original_query']}\n")
        if h["improved_query"]:
            lines.append(f"**Improved query:**\n{h['improved_query']}\n")
        lines.append("**Answer:**")
        lines.append(h["answer"])
        lines.append("")

    return "\n".join(lines)


In [None]:
# ============================================================
# CELL 9.5 / STEP 9.5 – Identity & Admin Ops (v14)
# ============================================================

def authenticate(username: str, password: str) -> SessionUser | None:
    """
    MVP auth: checks against local USERS dict.
    Returns SessionUser or None if invalid.
    """
    record = USERS.get(username)
    if not record:
        return None
    if password != record["password"]:
        return None
    return SessionUser(username=username, role=record["role"])
def authenticate(username: str, password: str) -> SessionUser | None:
    """
    MVP auth: checks against local USERS dict.
    Returns SessionUser or None if invalid.
    """
    record = USERS.get(username)
    if not record:
        return None
    if password != record["password"]:
        return None
    return SessionUser(username=username, role=record["role"])


def authenticate_credentials(username: str, password: str):
    """
    Wrapper used by the Gradio login logic in STEP 10.

    It calls authenticate(...) which returns a SessionUser, then:
      - Upserts the user into the `users` table (for admin/history views)
      - Returns a simple dict {username, role} that the UI expects.
    """
    # Use the existing MVP auth
    user = authenticate(username, password)
    if not user:
        return None

    # Make sure the user exists in the DB's `users` table
    try:
        upsert_user(
            user_id=user.username,
            role=user.role,
            display_name=user.username,
        )
    except Exception as e:
        # Don't break login if DB write fails; just log it
        trace_log(f"authenticate_credentials upsert_user ERROR: {e}")

    # UI login code in STEP 10 expects a dict-like object
    return {"username": user.username, "role": user.role}

def require_admin(user: SessionUser):
    """
    Helper for admin-only actions. Raises PermissionError if not admin.
    """
    if not user or not user.is_admin:
        raise PermissionError("Admin privileges required for this action.")


def admin_delete_cohort(user: SessionUser, cohort_name: str) -> str:
    """
    Admin-only wrapper around delete_cohort().
    Uses the existing delete_cohort function from STEP 4.
    """
    try:
        require_admin(user)
    except PermissionError as e:
        return f"❌ Not authorized: {e}"

    if not cohort_name:
        return "❌ Please select a cohort to delete."

    try:
        # Use existing v13 delete_cohort logic (no reassignment in this MVP).
        msg = delete_cohort(cohort_name, reassign_to=None)
        log_audit(user.username, user.role, "delete_cohort", f"cohort={cohort_name}")
        return msg
    except Exception as e:
        return f"❌ Error deleting cohort: {e}"

def admin_view_audit_log(user: SessionUser, limit: int = 50) -> str:
    """
    Admin-only view of recent audit log entries.
    """
    try:
        require_admin(user)
    except PermissionError as e:
        return f"❌ Not authorized: {e}"

    ensure_audit_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT ts, username, role, action, details
        FROM audit_log
        ORDER BY id DESC
        LIMIT ?
        """,
        (limit,),
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "No audit log entries."

    lines = ["**Recent Audit Log Entries**\n"]
    for ts, username, role, action, details in rows:
        u = username or "-"
        r = role or "-"
        d = details or ""
        lines.append(f"- {ts} | user=`{u}` | role=`{r}` | action=`{action}` | {d}")

    return "\n".join(lines)


In [None]:
# ============================================================
# STEP 10 — Gradio App, Tabs & Startup (v15 Full Version)
# ============================================================

# Global state handles (will be initialized in build_interface)
current_user_state = None
current_role_state = None
current_api_key_state = None


# -----------------------------
# SETUP & COHORTS TAB
# -----------------------------
def build_setup_tab():
    """
    Build the 'Setup & Cohorts' tab for the Gradio UI.

    Modes:
      1) Create new cohort
      2) Add files to existing cohort

    Features:
      - Prevent duplicate cohort names when creating new
      - Clear fields on successful build
      - Clear status when Refresh is clicked
      - Refresh list of cohorts visible to the current user
      - Show an overview list of cohorts
    """
    gr.Markdown("### Build or Manage a Cohort")

    # Choose what we are doing
    action_radio = gr.Radio(
        label="Action",
        choices=["Create new cohort", "Add files to existing cohort"],
        value="Create new cohort",
    )

    # For creating a new cohort
    cohort_name = gr.Textbox(
        label="New Cohort Name",
        placeholder="e.g., USDA_WIC_Guidance",
    )

    # For adding to an existing cohort
    existing_cohort_dropdown = gr.Dropdown(
        label="Existing Cohort (for Add mode)",
        choices=[],
        value=None,
        interactive=True,
    )

    file_uploader = gr.File(
        label="Upload Documents",
        file_count="multiple",
        type="filepath",  # filepaths go into build_cohort_index / load_file_to_text
    )

    build_btn = gr.Button("Build / Update Cohort Index")
    build_status = gr.Markdown()

    refresh_btn = gr.Button("Refresh Your Cohorts")

    # Overview list of cohorts (now interactive for visibility, but logically read-only)
    cohort_list = gr.Dropdown(
        label="Your Cohorts (overview)",
        choices=[],
        value=None,
        interactive=True,  # still just for viewing; no actions tied to it yet
    )

    # ---------- BUILD COHORT CALLBACK ----------

    def _build(new_name, files, existing_name, action, username, api_key):
        """
        Build or update a cohort index from uploaded files.

        Returns:
          (status_markdown,
           new_cohort_name_update,
           file_uploader_update)
        """
        keep_name = gr.update()
        keep_files = gr.update()

        trace_log(
            f"SETUP _build called user={username!r}, action={action!r}, "
            f"new_name={new_name!r}, existing_name={existing_name!r}, "
            f"files={len(files) if files else 0}"
        )

        # --- guard rails (do NOT clear fields on error) ---

        if not username:
            return "❌ Not logged in.", keep_name, keep_files

        if not api_key or not api_key.strip():
            return "❌ OpenAI API key is required.", keep_name, keep_files

        if not files:
            return "❌ Please upload at least one file.", keep_name, keep_files

        action = action or "Create new cohort"

        # Decide target cohort name based on mode
        if action == "Create new cohort":
            if not new_name or not new_name.strip():
                return "❌ New cohort name is required.", keep_name, keep_files
            target = new_name.strip()
        else:  # "Add files to existing cohort"
            if not existing_name:
                return "❌ Please select an existing cohort to add files to.", keep_name, keep_files
            target = existing_name

        # Check for duplicate when creating new
        if action == "Create new cohort":
            try:
                existing = list_cohorts()
            except Exception as e:
                trace_log(f"SETUP _build list_cohorts ERROR: {e}")
                existing = []

            if target in existing:
                msg = (
                    f"❌ A cohort named **{target}** already exists. "
                    "Please choose a different name or use 'Add files to existing cohort'."
                )
                trace_log(f"SETUP _build DUPLICATE new cohort={target!r}")
                # Keep name & files so user can adjust
                return msg, gr.update(value=target), keep_files

        # --- Happy path: build / update the cohort ---

        try:
            msg = build_cohort_index(
                api_key=api_key,
                cohort_name=target,
                files=files,
                owner=username,
            )
            trace_log(
                f"SETUP _build SUCCESS action={action!r}, target={target!r}: {msg}"
            )

            # On success: clear both the new_name field and the uploader
            clear_name = gr.update(value="")
            clear_files = gr.update(value=None)

            return msg, clear_name, clear_files

        except Exception as e:
            trace_log(f"SETUP _build ERROR action={action!r}, target={target!r}: {e}")
            return f"❌ Error: {e}", keep_name, keep_files

    # ---------- REFRESH COHORTS CALLBACK ----------

    def _refresh(username):
        """
        Refresh the list of cohorts visible to this user and
        clear the build status message.
        """
        trace_log(f"SETUP _refresh called with username={username!r}")
        try:
            if not username:
                empty_choices = gr.update(choices=[], value=None)
                return (
                    empty_choices,  # existing_cohort_dropdown
                    empty_choices,  # cohort_list (overview)
                    gr.update(value=""),  # build_status
                )

            names = get_cohorts_for_user(username)
            trace_log(f"SETUP _refresh cohorts={names}")

            # For existing-cohort dropdown in Add mode
            existing_update = gr.update(choices=names, value=None)

            # For overview dropdown, set first cohort as selected (if any)
            if names:
                overview_update = gr.update(choices=names, value=names[0])
            else:
                overview_update = gr.update(choices=[], value=None)

            return (
                existing_update,   # existing_cohort_dropdown
                overview_update,   # cohort_list (overview)
                gr.update(value=""),  # build_status
            )
        except Exception as e:
            trace_log(f"SETUP _refresh ERROR for username={username!r}: {e}")
            empty_choices = gr.update(choices=[], value=None)
            return (
                empty_choices,              # existing
                empty_choices,              # overview
                gr.update(value=""),        # status
            )

    # ---------- WIRE BUTTONS ----------

    # _build returns 3 outputs
    build_btn.click(
        _build,
        inputs=[
            cohort_name,
            file_uploader,
            existing_cohort_dropdown,
            action_radio,
            current_user_state,
            current_api_key_state,
        ],
        outputs=[build_status, cohort_name, file_uploader],
    )

    # _refresh returns 3 outputs:
    # existing_cohort_dropdown + cohort_list + build_status
    refresh_btn.click(
        _refresh,
        inputs=[current_user_state],
        outputs=[existing_cohort_dropdown, cohort_list, build_status],
    )

# -----------------------------
# ASK TAB (with prompt improve)
# -----------------------------
def build_ask_tab():
    """
    Build the 'Ask a Question' tab for the Gradio UI.

    Features:
      - Choose chat model (from model registry)
      - Improve Prompt
      - Original vs Improved prompt selector
      - Full RAG pipeline via answer_question_over_cohort()
      - Cohort refresh
      - Chat history saving (per user)
    """
    with gr.Tab("Ask a Question"):
        gr.Markdown("### Ask a Question Against a Cohort")

        # --- UI controls ---

        cohort_dropdown = gr.Dropdown(
            label="Select Cohort",
            choices=[],
            value=None,
            interactive=True,
        )

        model_dropdown = gr.Dropdown(
            label="Choose Chat Model",
            choices=list_chat_models(),
            value=None,
            interactive=True,
        )

        question_box = gr.Textbox(
            label="Your Original Prompt / Question",
            placeholder="Ask something using the selected cohort...",
            lines=3,
        )

        improve_btn = gr.Button("✨ Improve Prompt")

        improved_box = gr.Textbox(
            label="Improved Prompt",
            placeholder="Improved version of your question will appear here...",
            lines=3,
        )

        prompt_choice = gr.Radio(
            label="Which prompt should be used for the answer?",
            choices=["Use original prompt", "Use improved prompt"],
            value="Use original prompt",
        )

        ask_btn = gr.Button("Ask")
        ask_output = gr.Markdown()

        ask_refresh_btn = gr.Button("Refresh Cohorts")

        # ---------- PROMPT IMPROVER ----------

        def _improve_prompt(username, question, model_id, api_key):
            trace_log(
                f"_improve_prompt called user={username!r}, model_id={model_id!r}"
            )

            if not username:
                return ("⚠️ You must be logged in.", gr.update())

            if not question or not question.strip():
                return ("⚠️ Enter a question to improve.", gr.update())

            if not api_key or not api_key.strip():
                return ("⚠️ OpenAI API key required.", gr.update())

            system_msg = (
                "You are a professional prompt engineer. "
                "Rewrite the user's question into a clearer, more actionable "
                "prompt suitable for retrieval-augmented generation. "
                "Do NOT invent new facts. Return only the improved prompt."
            )

            try:
                improved_text, raw, used_model = call_chat_model(
                    api_key=api_key,
                    messages=[
                        {"role": "system", "content": system_msg},
                        {"role": "user", "content": question},
                    ],
                    task_type="prompt_improve",
                    user_pref=model_id,
                    context_size=len(question),
                )
                trace_log(f"_improve_prompt succeeded using model={used_model!r}")

                # Switch radio to "Use improved prompt"
                return improved_text, gr.update(value="Use improved prompt")

            except Exception as e:
                trace_log(f"_improve_prompt ERROR: {e}")
                return f"❌ Error improving prompt: {e}", gr.update()

        # ---------- ASK FUNCTION (RAG + history) ----------

        def _ask(
            username,
            cohort_name,
            model_id,
            original_prompt,
            improved_prompt,
            which_prompt,
            api_key,
        ):
            trace_log(
                f"_ask called user={username!r}, cohort={cohort_name!r}, "
                f"model_id={model_id!r}, which_prompt={which_prompt!r}"
            )

            if not username:
                return "⚠️ You must be logged in."

            if not api_key or not api_key.strip():
                return "⚠️ OpenAI API key required."

            if not cohort_name:
                return "⚠️ Please select a cohort."

            if not original_prompt or not original_prompt.strip():
                return "⚠️ Enter a question."

            # Decide which prompt to use
            if (
                which_prompt == "Use improved prompt"
                and improved_prompt
                and improved_prompt.strip()
            ):
                final_query = improved_prompt.strip()
                trace_log("_ask using improved prompt")
            else:
                final_query = original_prompt.strip()
                trace_log("_ask using original prompt")

            trace_log(f"_ask effective_query={final_query!r}")

            # Run RAG pipeline
            try:
                answer_md, raw_answer, used_model = answer_question_over_cohort(
                    api_key=api_key,
                    username=username,
                    cohort_name=cohort_name,
                    question=final_query,
                    model_id=model_id,
                )
            except Exception as e:
                trace_log(f"_ask ERROR answer_question_over_cohort: {e}")
                return f"❌ Error while generating answer: {e}"

            # Save chat history (non-fatal on error)
            try:
                # IMPORTANT: these kwarg names must match save_chat_history()
                save_chat_history(
                    user_id=username,
                    cohort_name=cohort_name,
                    question=final_query,
                    answer=answer_md,
                    model_used=used_model,
                )
            except Exception as e:
                trace_log(f"_ask ERROR saving chat history: {e}")

            return answer_md

        # ---------- REFRESH COHORTS ----------

        def _ask_refresh(username):
            trace_log(f"ASK _refresh called username={username!r}")

            try:
                if not username:
                    return gr.update(choices=[], value=None)

                names = get_cohorts_for_user(username)
                trace_log(f"ASK _refresh found cohorts={names}")

                if names:
                    return gr.update(choices=names, value=names[0])
                else:
                    return gr.update(choices=[], value=None)

            except Exception as e:
                trace_log(f"ASK _refresh ERROR: {e}")
                return gr.update(choices=[], value=None)

        # ---------- WIRE CALLBACKS (these lines must stay AFTER the defs) ----------

        improve_btn.click(
            _improve_prompt,
            inputs=[current_user_state, question_box, model_dropdown, current_api_key_state],
            outputs=[improved_box, prompt_choice],
        )

        ask_btn.click(
            _ask,
            inputs=[
                current_user_state,
                cohort_dropdown,
                model_dropdown,
                question_box,
                improved_box,
                prompt_choice,
                current_api_key_state,
            ],
            outputs=[ask_output],
        )

        ask_refresh_btn.click(
            _ask_refresh,
            inputs=[current_user_state],
            outputs=[cohort_dropdown],
        )

# -----------------------------
# HISTORY TAB (simple view)
# -----------------------------
# -----------------------------
# HISTORY TAB (simple view)
# -----------------------------
def build_history_tab():
    with gr.Tab("History"):
        gr.Markdown("### Recent Q&A History")

        # Scope selector so admin can see org-wide history
        scope_choice = gr.Radio(
            choices=["My history only", "Org-wide (admin only)"],
            value="My history only",
            label="View scope",
        )

        history_box = gr.Dataframe(
            headers=["Time (UTC)", "Cohort", "Question", "Model"],
            datatype=["str", "str", "str", "str"],
            interactive=False,
        )

        def _load_history(username: str, role: str, scope: str):
            """
            Load recent history using get_recent_history().

            - For normal users: always scoped to their own username.
            - For admin: can choose "My history only" or "Org-wide".
            """
            try:
                trace_log(
                    f"HISTORY _load_history called username={username!r}, "
                    f"role={role!r}, scope={scope!r}"
                )

                if not username:
                    return []

                # Decide how to filter based on role + scope
                if role == "admin" and scope == "Org-wide (admin only)":
                    # Admin, org-wide view: don't filter by user_id
                    rows = get_recent_history(
                        user_id=None,
                        cohort_name=None,
                        limit=100,
                    )
                else:
                    # Everyone else (or admin in "My history only" mode)
                    rows = get_recent_history(
                        user_id=username,
                        cohort_name=None,
                        limit=50,
                    )

                table = []
                for r in rows:
                    ts = r.get("created_at", "") or ""
                    cohort = r.get("cohort_name", "") or ""
                    # Prefer original_query, fall back to improved_query
                    q = r.get("original_query") or r.get("improved_query") or ""
                    model = r.get("chat_model", "") or ""

                    table.append([ts, cohort, q, model])

                return table

            except Exception as e:
                trace_log(f"HISTORY _load_history ERROR: {e}")
                return []

        # Load on button click to avoid auto-refresh issues
        load_btn = gr.Button("Refresh History")
        load_btn.click(
            _load_history,
            inputs=[current_user_state, current_role_state, scope_choice],
            outputs=[history_box],
        )


# -----------------------------
# ADMIN TAB (read-only)
# -----------------------------
def build_admin_tab():
    with gr.Tab("Admin"):
        gr.Markdown("### Admin View")

        gr.Markdown("#### Registered Users")
        users_box = gr.Dataframe(
            headers=["Username", "Role"],
            interactive=False,
        )

        gr.Markdown("#### Cohorts")
        cohorts_box = gr.Dataframe(
            headers=["Cohort Name"],
            interactive=False,
        )

        gr.Markdown("#### Model Registry")
        models_box = gr.Dataframe(
            headers=["Provider", "Model ID", "Type", "Enabled", "Default"],
            interactive=False,
        )

        def _load_admin(username, role):
            trace_log(f"ADMIN _load_admin called username={username!r}, role={role!r}")
            if role != "admin":
                return [], [], []

            # Users
            try:
                conn = get_db_conn()
                conn.row_factory = sqlite3.Row
                cur = conn.cursor()
                cur.execute("SELECT username, role FROM users ORDER BY username")
                user_rows = cur.fetchall()
                conn.close()
                users_table = [[u["username"], u["role"]] for u in user_rows]
            except Exception as e:
                trace_log(f"ADMIN load users ERROR: {e}")
                users_table = []

            # Cohorts
            try:
                cohort_names = list_cohorts()
                cohorts_table = [[name] for name in cohort_names]
            except Exception as e:
                trace_log(f"ADMIN load cohorts ERROR: {e}")
                cohorts_table = []

            # Models
            try:
                models = load_model_registry()
                models_table = [
                    [
                        m["provider"],
                        m["model_id"],
                        m["type"],
                        "Yes" if m["enabled"] else "No",
                        "Yes" if m["is_default"] else "No",
                    ]
                    for m in models
                ]
            except Exception as e:
                trace_log(f"ADMIN load models ERROR: {e}")
                models_table = []

            return users_table, cohorts_table, models_table

        load_btn = gr.Button("Refresh Admin Data")
        load_btn.click(
            _load_admin,
            inputs=[current_user_state, current_role_state],
            outputs=[users_box, cohorts_box, models_box],
        )


# -----------------------------
# MAIN INTERFACE & STARTUP
# -----------------------------
def build_interface():
    """
    Build the full Gradio Blocks app with:
      - Login screen
      - Tabs: Setup, Ask, History, Admin (admin only)
      - Logout button to return to login
    """
    with gr.Blocks(title="RAG MVP v15") as demo:
        global current_user_state, current_role_state, current_api_key_state
        current_user_state = gr.State("")
        current_role_state = gr.State("")
        current_api_key_state = gr.State("")

        # -------- Login View --------
        with gr.Column(visible=True) as login_view:
            gr.Markdown("## RAG MVP v15 — Login")

            username_in = gr.Textbox(label="Username")
            password_in = gr.Textbox(label="Password", type="password")
            api_key_in = gr.Textbox(
                label="OpenAI API Key",
                placeholder="sk-...",
                type="password",
            )
            login_btn = gr.Button("Login")
            login_status = gr.Markdown()

        # -------- Main App View --------
        with gr.Column(visible=False) as app_view:
            # Top row: title + Logout button
            with gr.Row():
                header = gr.Markdown("## RAG MVP v15")
                logout_btn = gr.Button("Logout")

            # Second row: user label
            with gr.Row():
                gr.Markdown("Logged in as:")
                user_label = gr.Markdown()

            # Tabs
            with gr.Tabs():
                with gr.Tab("Setup & Cohorts"):
                    build_setup_tab()

                build_ask_tab()
                build_history_tab()
                build_admin_tab()

        # -------- Login Logic --------
        def _login(username, password, api_key):
            """
            Authenticate user and toggle views.
            """
            try:
                user = authenticate_credentials(username, password)
            except Exception as e:
                trace_log(f"LOGIN ERROR auth: {e}")
                return (
                    f"❌ Login failed: {e}",
                    gr.update(visible=True),
                    gr.update(visible=False),
                    "",
                    "",
                )

            if not user:
                trace_log(f"LOGIN failed for username={username!r}")
                return (
                    "❌ Invalid username or password.",
                    gr.update(visible=True),
                    gr.update(visible=False),
                    "",
                    "",
                )

            role = user.get("role", "user")
            trace_log(f"LOGIN success username={username!r}, role={role!r}")

            # Show app, hide login, set state
            return (
                f"✅ Logged in as **{username}** ({role})",
                gr.update(visible=False),   # hide login view
                gr.update(visible=True),    # show app view
                username,                   # current_user_state
                role,                       # current_role_state
            )

        # -------- Logout Logic --------
        def _logout():
            """
            Clear user-related state and return to login screen.
            """
            trace_log("LOGOUT invoked")

            return (
                "",                         # login_status cleared
                gr.update(visible=True),    # show login view
                gr.update(visible=False),   # hide app view
                "",                         # current_user_state cleared
                "",                         # current_role_state cleared
                "",                         # current_api_key_state cleared
                "",                         # user_label cleared
            )

        # Wire login button
        login_btn.click(
            _login,
            inputs=[username_in, password_in, api_key_in],
            outputs=[
                login_status,
                login_view,
                app_view,
                current_user_state,
                current_role_state,
            ],
        )

        # Keep API key in state separately (not validated here)
        def _store_api_key(api_key):
            trace_log("API key updated in state")
            return api_key

        api_key_in.change(
            _store_api_key,
            inputs=[api_key_in],
            outputs=[current_api_key_state],
        )

        # Show username in header when state changes
        def _update_user_label(username, role):
            if not username:
                return ""
            return f"**User:** {username} — **Role:** {role}"

        current_user_state.change(
            _update_user_label,
            inputs=[current_user_state, current_role_state],
            outputs=[user_label],
        )

        # Wire logout button
        logout_btn.click(
            _logout,
            inputs=[],
            outputs=[
                login_status,
                login_view,
                app_view,
                current_user_state,
                current_role_state,
                current_api_key_state,
                user_label,
            ],
        )

    return demo


# ---- Create & Launch the App ----
demo = build_interface()
# In Colab: share=False is usually fine; set to True if you want a public link.
demo.launch(share=False)
