<a href="https://colab.research.google.com/github/Renlim61/MVP_Product001_2025_Tier120pbc/blob/version-history/Phase1_RAG_MVP_v14.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<details>
<summary><strong>üìò RAG MVP ‚Äì Version 13 Overview (Click to Expand)</strong></summary>

### üîë Key Enhancements in v13

#### **1. Persistent Chat History (7-Day Retention)**
All user interactions are stored in SQLite with timestamps, user IDs, roles, original/improved queries, selected prompt type, RAG answer, and chat model.  
Automatic pruning removes entries older than 7 days to prevent growth.

#### **2. ICAM-Ready User & Role Scaffolding**
Users can now specify a **User ID** and **Role** (user/admin).  
This information is persisted, forming the basis for future role-based access control, audit logging, and permission policies.

#### **3. Robust Cohort-Based Document Ingestion**
- Reliable `filepath` uploads  
- PDF, DOCX, TXT support  
- Chunking, embedding, and FAISS indexing per document  
- Metadata and SQLite document tracking

Each document is independently searchable while belonging to a cohort.

#### **4. Key & Model Validation**
A dedicated button validates:
- API key correctness  
- Chat model availability  
- Embedding model availability  

Prevents misconfiguration before indexing or retrieval.

#### **5. Prompt Coach (Query Improvement)**
Automatically rewrites user queries to improve RAG retrieval quality.  
Users may choose **Original** or **Improved** versions before running RAG.

#### **6. Expanded Admin Dashboard**
Provides visibility into:
- Total docs  
- Total cohorts  
- Known users  
- Recent chat record count  
- User profiles and cohort inventory  

#### **7. Enhanced Chat History Viewer**
Filter by:
- User ID  
- Cohort  
- Max records  

Displays complete history with timestamps, queries, answers, and roles.

---

### üéØ Summary
v13 elevates the MVP into a **persistent**, **auditable**, **multi-user**, and **cohort-aware** RAG system.  
It establishes the technical foundation for future ICAM controls, analytics, multi-model routing, and scalable enterprise features.

</details>



In [None]:
# ============================================================
# CELL 0 / STEP 0 ‚Äì Install & Imports
# ============================================================
# Run this once at the top of the notebook (Colab style).

%pip install -q faiss-cpu openai gradio PyPDF2 python-docx

import os
import io
import time
import json
import shutil
import pickle
import sqlite3
from uuid import uuid4
from datetime import datetime, timedelta, timezone
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import faiss
import gradio as gr

from PyPDF2 import PdfReader
from docx import Document as DocxDocument

# Colab detection
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import drive as colab_drive

# OpenAI client (v1 library)
from openai import OpenAI




In [None]:

# ============================================================
# CELL 1 / STEP 1 ‚Äì Paths, Defaults, and OpenAI Client
# ============================================================

# Base directory for everything (indexes, DB, etc.)
BASE_DIR = "/content/rag_mvp" if IN_COLAB else os.path.join(os.getcwd(), "rag_mvp")
os.makedirs(BASE_DIR, exist_ok=True)

# SQLite DB path (for saved documents, cohorts, users, chat history)
DB_PATH = os.path.join(BASE_DIR, "rag_documents.db")

# Directory to store FAISS indexes and metadata
INDEX_DIR = os.path.join(BASE_DIR, "indexes")
os.makedirs(INDEX_DIR, exist_ok=True)

# Chunking parameters
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200

# Defaults ‚Äì you can change these if you like
EMBED_MODEL_DEFAULT = "text-embedding-3-small"
CHAT_MODEL_DEFAULT = "gpt-4.1-mini"

def build_openai_client(api_key: str) -> OpenAI:
    """
    Build a new OpenAI client from an API key.
    """
    if not api_key:
        raise ValueError("OpenAI API key is required.")
    return OpenAI(api_key=api_key)

def resolve_models(chat_model: str, embed_model: str) -> Tuple[str, str]:
    """
    Resolve user selections or fall back to sensible defaults.
    """
    resolved_chat = chat_model.strip() or CHAT_MODEL_DEFAULT
    resolved_embed = embed_model.strip() or EMBED_MODEL_DEFAULT
    return resolved_chat, resolved_embed

def ensure_base_dirs():
    os.makedirs(BASE_DIR, exist_ok=True)
    os.makedirs(INDEX_DIR, exist_ok=True)




In [None]:


# ============================================================
# CELL 1.5 / STEP 1.5 ‚Äì Validate OpenAI Key & Models
# ============================================================

def validate_openai_key_and_models(api_key: str, chat_model: str, embed_model: str) -> str:
    """
    Lightweight validation:
    - Instantiate client
    - Do a tiny chat completion
    - Do a small embedding call
    Returns a human-readable status string.
    """
    if not api_key:
        return "‚ùå Please provide an OpenAI API key."

    try:
        client = build_openai_client(api_key)
        resolved_chat, resolved_embed = resolve_models(chat_model, embed_model)

        # Tiny chat test
        _ = client.chat.completions.create(
            model=resolved_chat,
            messages=[
                {"role": "system", "content": "Model availability test."},
                {"role": "user", "content": "Respond with 'OK' only."},
            ],
            max_tokens=2,
            temperature=0.0,
        )

        # Tiny embedding test
        _ = client.embeddings.create(
            model=resolved_embed,
            input=["test"],
        )

        return f"‚úÖ OpenAI key valid. Chat model: `{resolved_chat}`, Embed model: `{resolved_embed}`"

    except Exception as e:
        return f"‚ùå Error validating key/models: {e}"



In [None]:

# ============================================================
# CELL 2 / STEP 2 ‚Äì Document Loading & Chunking
# ============================================================

def load_pdf(file_bytes: bytes) -> str:
    reader = PdfReader(io.BytesIO(file_bytes))
    texts = []
    for page in reader.pages:
        try:
            txt = page.extract_text() or ""
        except Exception:
            txt = ""
        texts.append(txt)
    return "\n".join(texts)

def load_docx(file_bytes: bytes) -> str:
    f = io.BytesIO(file_bytes)
    doc = DocxDocument(f)
    return "\n".join(p.text for p in doc.paragraphs)

def load_txt(file_bytes: bytes, encoding: str = "utf-8") -> str:
    return file_bytes.decode(encoding, errors="ignore")

def load_file_to_text(file_obj) -> Tuple[str, str]:
    """
    Accepts either:
    - a string filepath (when gr.File(type="filepath") is used), or
    - a file-like object with a .name attribute (older behavior).

    Returns:
      (text_content, original_filename)
    """
    # Case 1: gr.File(type="filepath") -> we get a string path
    if isinstance(file_obj, str):
        path = file_obj
        name = os.path.basename(path)
    else:
        # Case 2: some object with a .name attribute
        path = getattr(file_obj, "name", None)
        if path is None:
            raise ValueError("Unsupported file object from uploader.")
        name = os.path.basename(path)

    with open(path, "rb") as f:
        data = f.read()

    lower = name.lower()
    if lower.endswith(".pdf"):
        text = load_pdf(data)
    elif lower.endswith(".docx"):
        text = load_docx(data)
    elif lower.endswith(".txt"):
        text = load_txt(data)
    else:
        raise ValueError(f"Unsupported file type for {name}")

    return text, name


def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
    """
    Simple sliding-window chunking.
    """
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    tokens = text.split()
    chunks = []
    start = 0
    while start < len(tokens):
        end = start + chunk_size
        chunk_tokens = tokens[start:end]
        chunk = " ".join(chunk_tokens)
        chunks.append(chunk)
        start += chunk_size - overlap
    return chunks




In [None]:

# ============================================================
# CELL 3 / STEP 3 ‚Äì Embedding & FAISS Index Helpers
# ============================================================

def embed_texts(
    api_key: str,
    embed_model: str,
    texts: List[str],
    batch_size: int = 32,
) -> np.ndarray:
    """
    Embed a list of texts using OpenAI embeddings.
    Returns an ndarray of shape (N, D).
    """
    client = build_openai_client(api_key)
    resolved_chat, resolved_embed = resolve_models("", embed_model)

    vectors: List[List[float]] = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i : i + batch_size]
        resp = client.embeddings.create(model=resolved_embed, input=batch)
        for d in resp.data:
            vectors.append(d.embedding)

    arr = np.array(vectors, dtype="float32")
    return arr

def build_faiss_index(vectors: np.ndarray) -> faiss.IndexFlatIP:
    """
    Build a simple inner-product FAISS index from vectors.
    """
    norm = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-10
    normed = vectors / norm
    dim = normed.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(normed)
    return index

def save_index(index: faiss.IndexFlatIP, index_id: str):
    path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
    faiss.write_index(index, path)

def load_index(index_id: str) -> faiss.IndexFlatIP:
    path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Index file not found: {path}")
    return faiss.read_index(path)

def save_metadata(index_id: str, meta: Dict[str, Any]):
    path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
    with open(path, "wb") as f:
        pickle.dump(meta, f)

def load_metadata(index_id: str) -> Dict[str, Any]:
    path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
    if not os.path.exists(path):
        raise FileNotFoundError(f"Metadata file not found: {path}")
    with open(path, "rb") as f:
        return pickle.load(f)



In [None]:
# ============================================================
# CELL 4 / STEP 4 ‚Äì SQLite Persistence for Documents & Cohorts
# ============================================================

def get_db_conn():
    ensure_base_dirs()
    os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
    return sqlite3.connect(DB_PATH)

def ensure_docs_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS documents (
            id              TEXT PRIMARY KEY,
            doc_name        TEXT NOT NULL,
            cohort_name     TEXT NOT NULL,
            index_id        TEXT NOT NULL,
            n_chunks        INTEGER NOT NULL,
            embed_model     TEXT NOT NULL,
            created_at      TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()

def ensure_cohort_table():
    """
    Mapping of cohort_name -> doc_name (for easier listing).
    """
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS cohort_docs (
            cohort_name TEXT NOT NULL,
            doc_name    TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()

def list_cohorts() -> List[str]:
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute("SELECT DISTINCT cohort_name FROM cohort_docs ORDER BY cohort_name ASC")
    rows = cur.fetchall()
    conn.close()
    return [r[0] for r in rows]

def list_docs_in_cohort(cohort_name: str) -> List[str]:
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT doc_name FROM cohort_docs WHERE cohort_name = ? ORDER BY doc_name ASC",
        (cohort_name,),
    )
    rows = cur.fetchall()
    conn.close()
    return [r[0] for r in rows]

def add_docs_to_cohort(cohort_name: str, doc_names: List[str]):
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    for dn in doc_names:
        cur.execute(
            "INSERT INTO cohort_docs (cohort_name, doc_name) VALUES (?, ?)",
            (cohort_name, dn),
        )
    conn.commit()
    conn.close()

def rename_cohort(old_name: str, new_name: str):
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "UPDATE cohort_docs SET cohort_name = ? WHERE cohort_name = ?",
        (new_name, old_name),
    )
    cur.execute(
        "UPDATE documents SET cohort_name = ? WHERE cohort_name = ?",
        (new_name, old_name),
    )
    conn.commit()
    conn.close()

def delete_cohort(cohort_name: str, reassign_to: Optional[str] = None) -> str:
    """
    Delete a cohort. If reassign_to is provided, documents move there.
    Otherwise, documents are deleted (and their indexes removed).
    NOTE: This is intentionally explicit to avoid orphaned docs.
    """
    ensure_docs_table()
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute(
        "SELECT id, index_id FROM documents WHERE cohort_name = ?",
        (cohort_name,),
    )
    docs = cur.fetchall()

    if reassign_to:
        # Just move docs
        cur.execute(
            "UPDATE documents SET cohort_name = ? WHERE cohort_name = ?",
            (reassign_to, cohort_name),
        )
        cur.execute(
            "UPDATE cohort_docs SET cohort_name = ? WHERE cohort_name = ?",
            (reassign_to, cohort_name),
        )
        msg = f"‚úÖ Cohort '{cohort_name}' renamed/reassigned to '{reassign_to}'. No indexes deleted."
    else:
        # Delete docs and indexes
        for doc_id, index_id in docs:
            # Remove index & metadata
            faiss_path = os.path.join(INDEX_DIR, f"{index_id}.faiss")
            pkl_path = os.path.join(INDEX_DIR, f"{index_id}.pkl")
            for p in [faiss_path, pkl_path]:
                if os.path.exists(p):
                    os.remove(p)
            cur.execute("DELETE FROM documents WHERE id = ?", (doc_id,))

        cur.execute("DELETE FROM cohort_docs WHERE cohort_name = ?", (cohort_name,))
        msg = f"‚úÖ Cohort '{cohort_name}' and its documents/indexes were deleted."

    conn.commit()
    conn.close()
    return msg

def register_document(
    doc_name: str,
    cohort_name: str,
    index_id: str,
    n_chunks: int,
    embed_model: str,
):
    ensure_docs_table()
    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    doc_id = str(uuid4())
    created_at = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO documents (id, doc_name, cohort_name, index_id, n_chunks,
                               embed_model, created_at)
        VALUES (?, ?, ?, ?, ?, ?, ?)
        """,
        (doc_id, doc_name, cohort_name, index_id, n_chunks, embed_model, created_at),
    )
    cur.execute(
        "INSERT INTO cohort_docs (cohort_name, doc_name) VALUES (?, ?)",
        (cohort_name, doc_name),
    )
    conn.commit()
    conn.close()
    return doc_id

def get_doc_index_id(doc_name: str, cohort_name: str) -> Optional[str]:
    ensure_docs_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT index_id
        FROM documents
        WHERE doc_name = ? AND cohort_name = ?
        """,
        (doc_name, cohort_name),
    )
    row = cur.fetchone()
    conn.close()
    if row:
        return row[0]
    return None

def list_all_documents() -> List[Tuple[str, str, str]]:
    """
    Return list of (doc_name, cohort_name, created_at).
    """
    ensure_docs_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT doc_name, cohort_name, created_at FROM documents ORDER BY created_at DESC"
    )
    rows = cur.fetchall()
    conn.close()
    return rows



In [None]:
# ============================================================
# CELL 4.5 / STEP 4.5 ‚Äì Users & Chat History (7-Day Retention)
# ============================================================

from dataclasses import dataclass

@dataclass

# ============================================================
# USER IDENTITY MODEL (used by audit, cohorts, admin, etc.)
# ============================================================
class SessionUser:
    username: str | None = None
    role: str = "anonymous"   # "anonymous", "user", "admin"

    def __init__(self, username=None, role="anonymous"):
        self.username = username
        self.role = role

    @property
    def is_authenticated(self) -> bool:
        return self.username is not None

    @property
    def is_admin(self) -> bool:
        return self.role == "admin"


# Simple in-memory user store for MVP (replace with ICAM/SSO later)
USERS = {
    "admin": {"password": "admin123", "role": "admin"},
    "demo":  {"password": "demo123",  "role": "user"},
}


def ensure_user_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS users (
            user_id     TEXT PRIMARY KEY,
            display_name TEXT,
            role        TEXT,
            created_at  TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()

def ensure_chat_history_table():
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS chat_history (
            id              INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id         TEXT,
            role            TEXT,
            cohort_name     TEXT,
            original_query  TEXT,
            improved_query  TEXT,
            which_prompt    TEXT,
            answer          TEXT,
            chat_model      TEXT,
            created_at      TEXT NOT NULL
        )
        """
    )
    conn.commit()
    conn.close()

def upsert_user(user_id: str, role: str, display_name: Optional[str] = None):
    """
    Basic user scaffolding for future ICAM:
    - Inserts new user or updates role/display_name.
    """
    if not user_id:
        return
    ensure_user_table()
    conn = get_db_conn()
    cur = conn.cursor()
    now = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO users (user_id, display_name, role, created_at)
        VALUES (?, ?, ?, ?)
        ON CONFLICT(user_id) DO UPDATE SET
            display_name = COALESCE(?, users.display_name),
            role = COALESCE(?, users.role)
        """,
        (user_id, display_name, role, now, display_name, role),
    )
    conn.commit()
    conn.close()

def prune_chat_history(days: int = 7):
    """
    Delete chat entries older than `days` days.
    """
    ensure_chat_history_table()
    cutoff = datetime.now(timezone.utc) - timedelta(days=days)
    cutoff_iso = cutoff.isoformat()

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "DELETE FROM chat_history WHERE created_at < ?",
        (cutoff_iso,),
    )
    conn.commit()
    conn.close()

def save_chat_interaction(
    user_id: str,
    role: str,
    cohort_name: str,
    original_query: str,
    improved_query: str,
    which_prompt: str,
    answer: str,
    chat_model: str,
):
    """
    Save a single Q&A to chat_history and enforce 7-day retention.
    """
    ensure_chat_history_table()
    prune_chat_history(days=7)

    conn = get_db_conn()
    cur = conn.cursor()
    now = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO chat_history (
            user_id, role, cohort_name, original_query, improved_query,
            which_prompt, answer, chat_model, created_at
        )
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
        """,
        (
            user_id or None,
            role or None,
            cohort_name or None,
            original_query,
            improved_query,
            which_prompt,
            answer,
            chat_model,
            now,
        ),
    )
    conn.commit()
    conn.close()

def get_recent_history(
    user_id: Optional[str] = None,
    cohort_name: Optional[str] = None,
    limit: int = 50,
) -> List[Dict[str, Any]]:
    """
    Return recent chat history, optionally filtered by user_id / cohort_name.
    """
    ensure_chat_history_table()
    conn = get_db_conn()
    cur = conn.cursor()

    query = "SELECT user_id, role, cohort_name, original_query, improved_query, which_prompt, answer, chat_model, created_at FROM chat_history"
    params: List[Any] = []
    conditions: List[str] = []

    if user_id:
        conditions.append("user_id = ?")
        params.append(user_id)
    if cohort_name:
        conditions.append("cohort_name = ?")
        params.append(cohort_name)

    if conditions:
        query += " WHERE " + " AND ".join(conditions)
    query += " ORDER BY created_at DESC LIMIT ?"
    params.append(limit)

    cur.execute(query, tuple(params))
    rows = cur.fetchall()
    conn.close()

    history = []
    for r in rows:
        history.append(
            {
                "user_id": r[0],
                "role": r[1],
                "cohort_name": r[2],
                "original_query": r[3],
                "improved_query": r[4],
                "which_prompt": r[5],
                "answer": r[6],
                "chat_model": r[7],
                "created_at": r[8],
            }
        )
    return history

def list_users() -> List[Tuple[str, str, str]]:
    """
    Simple user listing for admin view.
    Returns list of (user_id, display_name, role).
    """
    ensure_user_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT user_id, display_name, role FROM users ORDER BY created_at DESC"
    )
    rows = cur.fetchall()
    conn.close()
    return rows



In [None]:
# ============================================================
# CELL 4.6 / STEP 4.6 ‚Äì Audit Log (v14)
# ============================================================
from datetime import datetime, timezone

def ensure_audit_table():
    """
    Create an audit_log table if it doesn't already exist.
    """
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS audit_log (
            id        INTEGER PRIMARY KEY AUTOINCREMENT,
            ts        TEXT NOT NULL,
            username  TEXT,
            role      TEXT,
            action    TEXT NOT NULL,
            details   TEXT
        )
        """
    )
    conn.commit()
    conn.close()

def log_audit(username: str, role: str, action: str, details: str = ""):
    """
    Insert a row into the audit_log table.
    - ts: UTC ISO timestamp
    - username / role: may be None/empty for anonymous
    - action: short code, e.g. 'login', 'ask', 'admin_refresh', 'delete_cohort'
    - details: freeform string with context
    """
    ensure_audit_table()
    conn = get_db_conn()
    cur = conn.cursor()

    # Use timezone-aware UTC
    now = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO audit_log (ts, username, role, action, details)
        VALUES (?, ?, ?, ?, ?)
        """,
        (now, username or "", role or "", action, details or ""),
    )
    conn.commit()
    conn.close()



In [None]:
# ============================================================
# CELL 4.7 / STEP 4.7 ‚Äì Cohort Ownership & Sharing (v14)
# ============================================================

def ensure_cohort_meta_table():
    """
    Metadata for cohorts:
      - owner_user_id: who created/owns the cohort
      - is_shared: 0 = private to owner, 1 = shared to all users
      - created_ts: UTC timestamp
    """
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        CREATE TABLE IF NOT EXISTS cohort_meta (
            cohort_name    TEXT PRIMARY KEY,
            owner_user_id  TEXT,
            is_shared      INTEGER DEFAULT 0,
            created_ts     TEXT
        )
        """
    )
    conn.commit()
    conn.close()


def set_cohort_owner(cohort_name: str, user: SessionUser | None):
    """
    Register or update the owner of a cohort.
    For now:
      - owner_user_id = user.username (or 'anonymous' if not logged in)
      - is_shared = 0 by default (private)
    Later we can extend this to support an 'allow share' flag.
    """
    if not cohort_name:
        return

    ensure_cohort_meta_table()
    conn = get_db_conn()
    cur = conn.cursor()

    owner = "anonymous"
    if user and user.username:
        owner = user.username

    # If a row already exists, we keep the existing is_shared but update owner if needed.
    now = datetime.now(timezone.utc).isoformat()

    cur.execute(
        """
        INSERT INTO cohort_meta (cohort_name, owner_user_id, is_shared, created_ts)
        VALUES (?, ?, 0, ?)
        ON CONFLICT(cohort_name) DO UPDATE SET
            owner_user_id = excluded.owner_user_id
        """,
        (cohort_name, owner, now),
    )

    conn.commit()
    conn.close()

def cohort_exists(cohort_name: str) -> bool:
    """
    Returns True if a cohort with this name already exists in cohort_docs.
    This is global (not per-user) to avoid confusing duplicate names.
    """
    if not cohort_name:
        return False

    ensure_cohort_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        "SELECT 1 FROM cohort_docs WHERE cohort_name = ? LIMIT 1",
        (cohort_name,),
    )
    row = cur.fetchone()
    conn.close()
    return row is not None


def list_cohorts_for_user(user: SessionUser | None) -> list[str]:
    """
    Return list of cohort names visible to the given user.

    Rules (v14):
      - Admins: all cohorts (global)
      - Non-admin:
          * Cohorts where they are owner (cohort_meta.owner_user_id)
          * Cohorts marked is_shared = 1 (future 'allow share' flag)
      - Anonymous: only shared cohorts (is_shared = 1)

    If anything goes wrong with the metadata logic, fall back to global list_cohorts().
    """
    ensure_cohort_table()
    ensure_cohort_meta_table()

    conn = get_db_conn()
    cur = conn.cursor()

    try:
        if user and getattr(user, "is_admin", False):
            # Admin: all distinct cohorts
            cur.execute(
                """
                SELECT DISTINCT cohort_name
                FROM cohort_docs
                ORDER BY cohort_name
                """
            )
            rows = cur.fetchall()
            conn.close()
            return [r[0] for r in rows]

        username = user.username if (user and user.username) else None

        if username:
            # Non-admin logged-in
            cur.execute(
                """
                SELECT DISTINCT cd.cohort_name
                FROM cohort_docs cd
                LEFT JOIN cohort_meta cm
                  ON cd.cohort_name = cm.cohort_name
                WHERE cm.owner_user_id = ?
                   OR cm.is_shared = 1
                   OR cm.cohort_name IS NULL   -- safety: cohorts without meta still appear
                ORDER BY cd.cohort_name
                """,
                (username,),
            )
        else:
            # Anonymous: only explicitly shared cohorts (or cohorts without meta as fallback)
            cur.execute(
                """
                SELECT DISTINCT cd.cohort_name
                FROM cohort_docs cd
                LEFT JOIN cohort_meta cm
                  ON cd.cohort_name = cm.cohort_name
                WHERE cm.is_shared = 1
                   OR cm.cohort_name IS NULL
                ORDER BY cd.cohort_name
                """
            )

        rows = cur.fetchall()
        conn.close()
        names = [r[0] for r in rows]

        # Final safety net: if nothing, fall back to global.
        if not names:
            return list_cohorts()
        return names

    except Exception as e:
        conn.close()
        print("DEBUG list_cohorts_for_user error:", e)
        return list_cohorts()



In [None]:
# ============================================================
# CELL 5 / STEP 5 ‚Äì RAG Retrieval Over a Cohort
# ============================================================

def build_context_from_index(
    api_key: str,
    chat_model: str,
    embed_model: str,
    cohort_name: str,
    query: str,
    top_k: int = 5,
) -> Tuple[str, List[Tuple[str, int, float]]]:
    """
    Given a cohort and query:
    - Load all documents for that cohort
    - For each doc's index, perform similarity
    - Aggregate top_k results across docs
    Returns:
      - concatenated context string
      - list of (doc_name, rank, score) for citations
    """
    ensure_docs_table()
    resolved_chat, resolved_embed = resolve_models(chat_model, embed_model)
    client = build_openai_client(api_key)

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT doc_name, index_id
        FROM documents
        WHERE cohort_name = ?
        """,
        (cohort_name,),
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "", []

    # Embed query once
    q_embed_resp = client.embeddings.create(model=resolved_embed, input=[query])
    q_vec = np.array(q_embed_resp.data[0].embedding, dtype="float32")
    q_vec = q_vec / (np.linalg.norm(q_vec) + 1e-10)

    all_hits: List[Tuple[str, int, float, str]] = []  # (doc_name, idx, score, text)

    for doc_name, index_id in rows:
        try:
            index = load_index(index_id)
            meta = load_metadata(index_id)  # {"chunks": [...], "doc_name": ..., ...}
        except Exception:
            continue

        D, I = index.search(q_vec[np.newaxis, :], top_k)
        scores = D[0]
        idxs = I[0]

        for rank, (score, idx) in enumerate(zip(scores, idxs), start=1):
            if idx < 0:
                continue
            chunks = meta.get("chunks", [])
            if idx >= len(chunks):
                continue
            text_chunk = chunks[idx]
            all_hits.append((doc_name, idx, float(score), text_chunk))

    if not all_hits:
        return "", []

    # Sort across docs by score desc
    all_hits.sort(key=lambda x: x[2], reverse=True)
    top_hits = all_hits[:top_k]

    context_parts = []
    citations = []
    for rank, (doc_name, idx, score, text) in enumerate(top_hits, start=1):
        header = f"[{rank}] From {doc_name} (chunk #{idx}, score={score:.3f})"
        context_parts.append(header + "\n" + text)
        citations.append((doc_name, rank, score))

    context = "\n\n".join(context_parts)
    return context, citations

def answer_with_rag(
    api_key: str,
    chat_model: str,
    embed_model: str,
    cohort_name: str,
    query: str,
    system_prompt: str = "",
) -> Tuple[str, str]:
    """
    Perform RAG retrieval and return (answer_markdown, raw_answer_text).
    """
    resolved_chat, resolved_embed = resolve_models(chat_model, embed_model)
    client = build_openai_client(api_key)

    context, citations = build_context_from_index(
        api_key, resolved_chat, resolved_embed, cohort_name, query
    )

    if not context:
        return (
            "I could not find any context for this query in the selected cohort.",
            "",
        )

    if not system_prompt:
        system_prompt = (
            "You are a helpful assistant answering questions based on the provided context. "
            "If the answer cannot be found in the context, say you do not know."
        )

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                "Use ONLY the context below to answer the question.\n\n"
                "=== CONTEXT START ===\n"
                f"{context}\n"
                "=== CONTEXT END ===\n\n"
                f"QUESTION: {query}"
            ),
        },
    ]

    resp = client.chat.completions.create(
        model=resolved_chat,
        messages=messages,
        temperature=0.2,
        max_tokens=800,
    )
    answer_text = resp.choices[0].message.content.strip()

    # Build markdown with citations
    md = answer_text + "\n\n---\n\n**Cited sources:**\n"
    for doc_name, rank, score in citations:
        md += f"- [{rank}] `{doc_name}` (score={score:.3f})\n"

    return md, answer_text



In [None]:
# ============================================================
# CELL 6 / STEP 6 ‚Äì Build Cohort from Uploaded Docs
# ============================================================

def build_cohort_from_files(
    api_key: str,
    embed_model: str,
    cohort_name: str,
    files: List[Any],
) -> str:
    """
    Upload and index one or more files into a cohort.
    """
    if not files:
        return "Please upload at least one file."

    if not cohort_name:
        return "Please provide a cohort name."

    ensure_docs_table()
    ensure_cohort_table()
    ensure_user_table()
    ensure_chat_history_table()

    status_lines = []
    for f in files:
        try:
            text, original_name = load_file_to_text(f)
        except Exception as e:
            status_lines.append(f"‚ùå {f.name}: error loading file ‚Äì {e}")
            continue

        chunks = chunk_text(text)
        if not chunks:
            status_lines.append(f"‚ö†Ô∏è {original_name}: no text found.")
            continue

        vectors = embed_texts(api_key, embed_model, chunks)
        index = build_faiss_index(vectors)
        index_id = str(uuid4())
        save_index(index, index_id)
        save_metadata(
            index_id,
            {
                "chunks": chunks,
                "doc_name": original_name,
                "cohort_name": cohort_name,
                "embed_model": embed_model,
                "created_at": datetime.now(timezone.utc).isoformat(),
            },
        )

        register_document(
            doc_name=original_name,
            cohort_name=cohort_name,
            index_id=index_id,
            n_chunks=len(chunks),
            embed_model=embed_model,
        )

        status_lines.append(
            f"‚úÖ Indexed `{original_name}` into cohort `{cohort_name}` with {len(chunks)} chunks."
        )

    if not status_lines:
        return "No files were successfully processed."
    return "\n".join(status_lines)



In [None]:
# ============================================================
# CELL 7 / STEP 7 ‚Äì Admin Helpers (Stats & Maintenance)
# ============================================================

def get_db_stats() -> str:
    ensure_docs_table()
    ensure_cohort_table()
    ensure_user_table()
    ensure_chat_history_table()

    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM documents")
    n_docs = cur.fetchone()[0]

    cur.execute("SELECT COUNT(DISTINCT cohort_name) FROM cohort_docs")
    n_cohorts = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM users")
    n_users = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM chat_history")
    n_chats = cur.fetchone()[0]

    conn.close()

    return (
        f"**DB Stats**\n\n"
        f"- Documents: {n_docs}\n"
        f"- Cohorts: {n_cohorts}\n"
        f"- Users: {n_users}\n"
        f"- Chat records (last 7 days enforced on write): {n_chats}\n"
    )

def describe_users() -> str:
    rows = list_users()
    if not rows:
        return "No users have been recorded yet."
    lines = ["**Known Users**\n"]
    for user_id, display_name, role in rows:
        disp = display_name or "(no display name)"
        r = role or "(no role)"
        lines.append(f"- `{user_id}` ‚Äì {disp} ‚Äì role: `{r}`")
    return "\n".join(lines)

def describe_cohorts() -> str:
    """
    Returns a markdown summary of cohorts, including:
      - name
      - owner
      - visibility (private/shared)
      - document count
    """
    ensure_docs_table()
    ensure_cohort_table()
    ensure_cohort_meta_table()

    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT DISTINCT
            cd.cohort_name,
            COALESCE(cm.owner_user_id, '(none)') AS owner,
            COALESCE(cm.is_shared, 0)            AS is_shared,
            COUNT(DISTINCT cd.doc_name)          AS num_docs
        FROM cohort_docs cd
        LEFT JOIN cohort_meta cm
          ON cd.cohort_name = cm.cohort_name
        GROUP BY cd.cohort_name, owner, is_shared
        ORDER BY cd.cohort_name
        """
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "No cohorts found."

    lines = ["**Cohorts**", ""]
    for name, owner, is_shared, num_docs in rows:
        share_label = "shared" if is_shared else "private"
        lines.append(
            f"- **{name}** ‚Äî owner: `{owner}`, visibility: {share_label}, docs: {num_docs}"
        )

    return "\n".join(lines)




In [None]:
# ============================================================
# CELL 8 / STEP 8 ‚Äì Prompt Coach (Optional Query Improvement)
# ============================================================

def improve_query(
    api_key: str,
    chat_model: str,
    original_query: str,
) -> str:
    """
    Prompt coach to re-write the user's query for better RAG retrieval.
    """
    if not original_query.strip():
        return ""

    resolved_chat, _ = resolve_models(chat_model, EMBED_MODEL_DEFAULT)
    client = build_openai_client(api_key)

    system_prompt = (
        "You are a prompt coach helping the user improve questions for a RAG system. "
        "Rewrite the query to be explicit, concise, and focused on key details. "
        "Return ONLY the improved query text."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": original_query},
    ]

    resp = client.chat.completions.create(
        model=resolved_chat,
        messages=messages,
        temperature=0.1,
        max_tokens=200,
    )
    improved = resp.choices[0].message.content.strip()
    return improved



In [None]:
# ============================================================
# CELL 9 / STEP 9 ‚Äì Chat History Viewer (User-Facing)
# ============================================================

def format_history_markdown(
    user_id: Optional[str],
    cohort_name: Optional[str],
    limit: int = 50,
) -> str:
    """
    Turn recent history into markdown for display.
    """
    hist = get_recent_history(user_id=user_id, cohort_name=cohort_name, limit=limit)
    if not hist:
        return "No chat history found for the given filters (within retention window)."

    lines = []
    lines.append(
        f"**Showing up to {limit} most recent interactions** "
        f"{'(filtered)' if user_id or cohort_name else ''}\n"
    )

    for h in hist:
        ts = h["created_at"]
        u = h["user_id"] or "(anonymous)"
        r = h["role"] or "(none)"
        c = h["cohort_name"] or "(none)"
        which = h["which_prompt"] or "(unknown)"

        lines.append(f"---\n**User:** `{u}`  |  **Role:** `{r}`  |  **Cohort:** `{c}`  |  **When:** {ts}")
        lines.append(f"**Prompt used:** `{which}`")
        lines.append(f"**Original query:**\n{h['original_query']}\n")
        if h["improved_query"]:
            lines.append(f"**Improved query:**\n{h['improved_query']}\n")
        lines.append("**Answer:**")
        lines.append(h["answer"])
        lines.append("")

    return "\n".join(lines)


In [None]:
# ============================================================
# CELL 9.5 / STEP 9.5 ‚Äì Identity & Admin Ops (v14)
# ============================================================

def authenticate(username: str, password: str) -> SessionUser | None:
    """
    MVP auth: checks against local USERS dict.
    Returns SessionUser or None if invalid.
    """
    record = USERS.get(username)
    if not record:
        return None
    if password != record["password"]:
        return None
    return SessionUser(username=username, role=record["role"])

def require_admin(user: SessionUser):
    """
    Helper for admin-only actions. Raises PermissionError if not admin.
    """
    if not user or not user.is_admin:
        raise PermissionError("Admin privileges required for this action.")


def admin_delete_cohort(user: SessionUser, cohort_name: str) -> str:
    """
    Admin-only wrapper around delete_cohort().
    Uses the existing delete_cohort function from STEP 4.
    """
    try:
        require_admin(user)
    except PermissionError as e:
        return f"‚ùå Not authorized: {e}"

    if not cohort_name:
        return "‚ùå Please select a cohort to delete."

    try:
        # Use existing v13 delete_cohort logic (no reassignment in this MVP).
        msg = delete_cohort(cohort_name, reassign_to=None)
        log_audit(user.username, user.role, "delete_cohort", f"cohort={cohort_name}")
        return msg
    except Exception as e:
        return f"‚ùå Error deleting cohort: {e}"

def admin_view_audit_log(user: SessionUser, limit: int = 50) -> str:
    """
    Admin-only view of recent audit log entries.
    """
    try:
        require_admin(user)
    except PermissionError as e:
        return f"‚ùå Not authorized: {e}"

    ensure_audit_table()
    conn = get_db_conn()
    cur = conn.cursor()
    cur.execute(
        """
        SELECT ts, username, role, action, details
        FROM audit_log
        ORDER BY id DESC
        LIMIT ?
        """,
        (limit,),
    )
    rows = cur.fetchall()
    conn.close()

    if not rows:
        return "No audit log entries."

    lines = ["**Recent Audit Log Entries**\n"]
    for ts, username, role, action, details in rows:
        u = username or "-"
        r = role or "-"
        d = details or ""
        lines.append(f"- {ts} | user=`{u}` | role=`{r}` | action=`{action}` | {d}")

    return "\n".join(lines)


In [None]:
# ============================================================
# CELL 10 / STEP 10 ‚Äì Gradio App (Cohorts, Admin, History, Q&A) ‚Äì v14
# ============================================================

def on_build_cohort(
    api_key,
    embed_model,
    cohort_name,
    files,
    current_user: SessionUser,
):
    # Normalize name
    cohort_name = (cohort_name or "").strip()

    # Basic validation
    if not api_key:
        return "‚ùå Please provide your OpenAI API key.", gr.update(), gr.update()
    if not files:
        return "‚ùå Please upload one or more files.", gr.update(), gr.update()
    if not cohort_name:
        return "‚ùå Please provide a cohort name.", gr.update(), gr.update()

    # NEW: enforce global uniqueness of cohort name
    if cohort_exists(cohort_name):
        # Don‚Äôt clear uploads; user may just want to rename
        return (
            f"‚ùå A cohort named '{cohort_name}' already exists. "
            "Please choose a different name.",
            gr.update(),           # keep file_uploader contents
            gr.update(),           # keep cohort_name_box as-is so they can tweak it
        )

    # Build / update cohort
    msg = build_cohort_from_files(api_key, embed_model, cohort_name, files)

    # Register cohort owner
    try:
        set_cohort_owner(cohort_name, current_user)
    except Exception as e:
        print("DEBUG set_cohort_owner error:", e)

    # Audit entry
    log_audit(
        username=current_user.username if current_user and current_user.username else "",
        role=current_user.role if current_user else "",
        action="build_cohort",
        details=f"cohort={cohort_name}",
    )

    # On success: clear file uploads and cohort name textbox
    return msg, gr.update(value=None), gr.update(value="")




def on_refresh_cohorts_for_user(current_user=None):
    """
    Refresh cohort dropdown based on the current user's visibility:
      - admin: all cohorts
      - user: only own (and shared) cohorts

    This version is defensive:
      - If Gradio passes no args or something unexpected, we fall back to anonymous.
      - If filtering fails for any reason, we fall back to global list_cohorts().
    """
    # Defensive handling: Gradio sometimes passes None or unexpected types
    if isinstance(current_user, SessionUser):
        user = current_user
    else:
        # Treat as anonymous if we don't have a valid SessionUser
        user = SessionUser()

    try:
        names = list_cohorts_for_user(user)
        if not names:
            # Fallback: show global list (old v13 behavior)
            names = list_cohorts()
    except Exception as e:
        print("DEBUG on_refresh_cohorts_for_user error:", e)
        names = list_cohorts()

    if not names:
        return gr.update(choices=[], value=None)
    return gr.update(choices=names, value=names[0])


def on_validate_key(api_key, chat_model, embed_model):
    return validate_openai_key_and_models(api_key, chat_model, embed_model)

def on_improve_query(api_key, chat_model, original_query):
    if not api_key:
        return "Please provide an API key first."
    return improve_query(api_key, chat_model, original_query)

def on_ask_with_choice(
    api_key,
    user_id,
    user_role,
    chat_model,
    embed_model,
    cohort_name,
    original_query,
    improved_query,
    which_prompt,
):
    """
    Main Q&A handler:
    - Respects which_prompt (original vs improved)
    - Performs RAG
    - Saves chat history with 7-day retention
    - Upserts user with role (future ICAM)
    - NEW (v14): Writes an audit_log 'ask' entry
    """
    if not api_key:
        return "‚ùå Please provide your OpenAI API key.", ""

    if not cohort_name:
        return "‚ùå Please select a cohort.", ""

    if not original_query.strip():
        return "‚ùå Please enter a question.", ""

    # Choose which query to send
    query = original_query
    if which_prompt == "Improved" and improved_query.strip():
        query = improved_query

    # Upsert user (scaffolding for ICAM)
    user_id_clean = (user_id or "").strip() or "anonymous"
    user_role_clean = (user_role or "").strip() or "user"
    upsert_user(user_id_clean, user_role_clean)

    # Call RAG
    answer_md, answer_text = answer_with_rag(
        api_key=api_key,
        chat_model=chat_model,
        embed_model=embed_model,
        cohort_name=cohort_name,
        query=query,
    )

    # Save chat history
    save_chat_interaction(
        user_id=user_id_clean,
        role=user_role_clean,
        cohort_name=cohort_name,
        original_query=original_query,
        improved_query=improved_query,
        which_prompt=which_prompt,
        answer=answer_md,
        chat_model=chat_model or CHAT_MODEL_DEFAULT,
    )

    # NEW: audit log entry for the question
    log_audit(
        username=user_id_clean,
        role=user_role_clean,
        action="ask",
        details=f"cohort={cohort_name}; which={which_prompt}; q={query[:200]}",
    )

    return answer_md, answer_text

def on_view_history(user_id, cohort_name, limit):
    uid = user_id.strip() if user_id else None
    coh = cohort_name.strip() if cohort_name else None
    try:
        limit_int = int(limit)
    except Exception:
        limit_int = 50

    md = format_history_markdown(uid, coh, limit_int)

    # Audit the view (doesn't depend on SessionUser yet, just the filter)
    log_audit(
        username=uid or "",
        role="",  # we aren't tying this to SessionUser role here
        action="view_history",
        details=f"user_filter={uid or ''}; cohort_filter={coh or ''}; limit={limit_int}",
    )

    return md


def on_admin_refresh(current_user: SessionUser):
    """
    Admin overview (stats & inventory).
    Now enforces admin role and logs an audit event.
    """
    try:
        require_admin(current_user)
    except PermissionError as e:
        return f"‚ùå Not authorized: {e}", "", ""

    ensure_docs_table()
    ensure_cohort_table()
    ensure_user_table()
    ensure_chat_history_table()
    ensure_audit_table()

    conn = get_db_conn()
    cur = conn.cursor()

    cur.execute("SELECT COUNT(*) FROM documents")
    n_docs = cur.fetchone()[0]

    cur.execute("SELECT COUNT(DISTINCT cohort_name) FROM cohort_docs")
    n_cohorts = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM users")
    n_users = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM chat_history")
    n_chats = cur.fetchone()[0]

    cur.execute("SELECT COUNT(*) FROM audit_log")
    n_audit = cur.fetchone()[0]

    conn.close()

    stats = (
        f"**DB Stats**\n\n"
        f"- Documents: {n_docs}\n"
        f"- Cohorts: {n_cohorts}\n"
        f"- Users: {n_users}\n"
        f"- Chat records (last 7 days enforced on write): {n_chats}\n"
        f"- Audit log entries: {n_audit}\n"
    )

    cohorts_md = describe_cohorts()
    users_md = describe_users()

    log_audit(current_user.username, current_user.role, "admin_refresh", "Refreshed admin overview")

    return stats, cohorts_md, users_md

def login_fn(username, password, current_user: SessionUser):
    """
    Login handler for v14:
      - Authenticates against USERS
      - Logs success/failure to audit_log
      - Updates admin panel visibility
      - Updates Ask tab user_id / role
      - Clears file uploads on login
    """
    username = (username or "").strip()
    password = (password or "").strip()

    user = authenticate(username, password)

    if not user:
        # Audit failed login
        log_audit(
            username=username,
            role="",
            action="login_failed",
            details="Invalid credentials",
        )
        return (
            SessionUser(),                             # user_session
            "‚ùå Invalid credentials. Try 'admin'/'admin123' or 'demo'/'demo123'.",
            gr.update(visible=False),                  # admin_panel
            gr.update(visible=True, value="‚ö†Ô∏è You do not have permission to access the admin dashboard. Log in as an admin user to view admin tools."),
            gr.update(value=""),                       # user_id_box
            gr.update(value="user"),                   # user_role_dd default
            gr.update(value=None),                     # file_uploader clear
            gr.update(value=""),                       # cohort_name_box clear
        )


    # Successful login
    log_audit(user.username, user.role, "login", "User logged in")
    status = f"‚úÖ Logged in as {user.username} (role: {user.role})"

    if user.is_admin:
        # Admin: can see admin panel, no warning
        return (
            user,                                     # user_session
            status,                                   # login_status
            gr.update(visible=True),                  # admin_panel
            gr.update(visible=False),                 # admin_denied_md
            gr.update(value=user.username),           # user_id_box
            gr.update(value="admin"),                 # user_role_dd
            gr.update(value=None),                    # file_uploader clear
            gr.update(value=""),                      # cohort_name_box clear
        )


    # Non-admin (demo): hide admin panel, show warning
    return (
        user,
        status,
        gr.update(visible=False),                 # admin_panel
        gr.update(
            visible=True,
            value="‚ö†Ô∏è You do not have permission to access the admin dashboard. "
                  "Log in as an admin user to view admin tools.",
        ),
        gr.update(value=user.username),           # user_id_box
        gr.update(value="user"),                  # user_role_dd
        gr.update(value=None),                    # file_uploader clear
        gr.update(value=""),                      # cohort_name_box clear
    )


# ---------------------- Build Gradio UI ----------------------

with gr.Blocks(title="Phase 1 ‚Äì RAG MVP (v14)") as demo:
    user_session = gr.State(SessionUser())  # holds current SessionUser

    gr.Markdown(
        """
        # üìò Phase 1 ‚Äì RAG MVP (v14)

        **New in v14:**
        - ‚úÖ Audit logging to SQLite (`audit_log` table)
        - ‚úÖ Simple login with roles (`admin`, `demo`) using SessionUser
        - ‚úÖ Admin-only actions (admin tab is hidden unless logged in as admin)
        """
    )

    # ---- Login Tab ----
    with gr.Tab("0Ô∏è‚É£ Login"):
        gr.Markdown(
            "Use `admin/admin123` for admin, `demo/demo123` for user (MVP only)."
        )
        login_user = gr.Textbox(label="Username")
        login_pass = gr.Textbox(label="Password", type="password")
        login_btn = gr.Button("Login")
        login_status = gr.Markdown("Not logged in.")

    # ---- Setup & Cohorts ----
    with gr.Tab("1Ô∏è‚É£ Setup & Cohorts"):
        gr.Markdown("### OpenAI Setup & Build a Cohort")

        api_key_box = gr.Textbox(
            label="OpenAI API Key",
            type="password",
            placeholder="sk-...",
        )
        chat_model_dd = gr.Textbox(
            label="Chat Model",
            value=CHAT_MODEL_DEFAULT,
            info="e.g., gpt-4.1, gpt-4.1-mini, gpt-4o, etc.",
        )
        embed_model_dd = gr.Textbox(
            label="Embedding Model",
            value=EMBED_MODEL_DEFAULT,
            info="e.g., text-embedding-3-small",
        )

        validate_btn = gr.Button("Validate Key & Models")
        validate_out = gr.Markdown()

        validate_btn.click(
            on_validate_key,
            inputs=[api_key_box, chat_model_dd, embed_model_dd],
            outputs=[validate_out],
        )

        gr.Markdown("---")

        cohort_name_box = gr.Textbox(
            label="Cohort Name",
            placeholder="e.g., WIC Policy Docs",
        )
        file_uploader = gr.File(
            label="Upload Files (PDF, DOCX, TXT)",
            file_count="multiple",
            type="filepath",
        )

        build_btn = gr.Button("Build / Update Cohort")
        build_status = gr.Markdown()

        build_btn.click(
            on_build_cohort,
            inputs=[api_key_box, embed_model_dd, cohort_name_box, file_uploader, user_session],
            outputs=[build_status, file_uploader, cohort_name_box],
        )




    # ---- Ask Your Documents ----
    with gr.Tab("2Ô∏è‚É£ Ask Your Documents"):
        gr.Markdown("### Ask Questions Against a Cohort")

        with gr.Row():
            user_id_box = gr.Textbox(
                label="User ID",
                placeholder="your-username-or-email",
                value="user1",
            )
            user_role_dd = gr.Dropdown(
                label="Role (scaffolding for ICAM)",
                choices=["user", "admin"],
                value="user",
            )

            ask_cohort_dd = gr.Dropdown(
                label="Select Cohort",
                choices=[],
            )


        refresh_cohorts_btn = gr.Button("üîÑ Refresh Cohort List")
        refresh_cohorts_btn.click(
            on_refresh_cohorts_for_user,
            inputs=[user_session],
            outputs=[ask_cohort_dd],
        )

        query_box = gr.Textbox(
            label="Original Question",
            lines=4,
            placeholder="Ask a question about your documents...",
        )
        improved_query_box = gr.Textbox(
            label="Improved Question (Prompt Coach Output)",
            lines=4,
            placeholder="Click 'Improve Question' or edit manually.",
        )
        prompt_choice = gr.Radio(
            label="Which query should be used for RAG?",
            choices=["Original", "Improved"],
            value="Original",
        )

        improve_btn = gr.Button("‚ú® Improve Question")
        improve_btn.click(
            on_improve_query,
            inputs=[api_key_box, chat_model_dd, query_box],
            outputs=[improved_query_box],
        )

        ask_btn = gr.Button("üí¨ Ask")
        answer_md = gr.Markdown()
        raw_answer_box = gr.Textbox(
            label="Raw Answer (for copy/print/export)",
            lines=6,
        )

        ask_btn.click(
            on_ask_with_choice,
            inputs=[
                api_key_box,
                user_id_box,
                user_role_dd,
                chat_model_dd,
                embed_model_dd,
                ask_cohort_dd,
                query_box,
                improved_query_box,
                prompt_choice,
            ],
            outputs=[answer_md, raw_answer_box],
        )

    # ---- Chat History ----
    with gr.Tab("3Ô∏è‚É£ Chat History"):
        gr.Markdown(
            "### View Chat History (7-Day Window)\n"
            "History is stored in SQLite with automatic removal of records older than 7 days."
        )

        hist_user_id = gr.Textbox(
            label="Filter by User ID (optional)",
            placeholder="Leave blank for all users",
        )
        hist_cohort = gr.Textbox(
            label="Filter by Cohort Name (optional)",
            placeholder="Leave blank for all cohorts",
        )
        hist_limit = gr.Textbox(
            label="Max Records",
            value="50",
        )
        hist_btn = gr.Button("üìú Show History")
        hist_md = gr.Markdown()

        hist_btn.click(
            on_view_history,
            inputs=[hist_user_id, hist_cohort, hist_limit],
            outputs=[hist_md],
        )

    # ---- Admin ----
    with gr.Tab("4Ô∏è‚É£ Admin"):
    # Panel hidden unless logged in as admin
      admin_panel = gr.Group(visible=False)
      with admin_panel:
        gr.Markdown(
            "### Admin Overview (Stats, Inventory, & Audit Log)\n"
            "This tab is only visible when logged in as an admin."
        )

        admin_refresh_btn = gr.Button("üîÑ Refresh Admin Info")
        admin_stats_md = gr.Markdown()
        admin_cohorts_md = gr.Markdown()
        admin_users_md = gr.Markdown()

        admin_refresh_btn.click(
            on_admin_refresh,
            inputs=[user_session],
            outputs=[admin_stats_md, admin_cohorts_md, admin_users_md],
        )

        gr.Markdown("---")

        # Admin-only: delete cohort
        admin_refresh_cohorts_btn = gr.Button("üîÑ Refresh Cohort List (Admin)")
        admin_del_cohort_dd = gr.Dropdown(
            label="Cohort to delete (admin only)",
            choices=[],
        )
        admin_delete_cohort_btn = gr.Button("üóëÔ∏è Delete Selected Cohort")
        admin_delete_status = gr.Markdown()

        admin_refresh_cohorts_btn.click(
            on_refresh_cohorts_for_user,
            inputs=[user_session],
            outputs=[admin_del_cohort_dd],
        )

        admin_delete_cohort_btn.click(
            fn=lambda cohort, user: admin_delete_cohort(user, cohort),
            inputs=[admin_del_cohort_dd, user_session],
            outputs=[admin_delete_status],
        )

        gr.Markdown("---")

        # Admin-only: view audit log
        audit_btn = gr.Button("üìã View Audit Log (last 50)")
        audit_md = gr.Markdown()

        audit_btn.click(
            fn=lambda user: admin_view_audit_log(user, limit=50),
            inputs=[user_session],
            outputs=[audit_md],
        )

    # NEW: message shown when user is not admin
    admin_denied_md = gr.Markdown(
        "‚ö†Ô∏è You do not have permission to access the admin dashboard. "
        "Log in as an admin user to view admin tools.",
        visible=True,
    )


    gr.Markdown("Built for fast iteration and future ICAM integration. ‚ö°")

    # Wire login button
    login_btn.click(
          fn=login_fn,
          inputs=[login_user, login_pass, user_session],
          outputs=[
              user_session,
              login_status,
              admin_panel,
              admin_denied_md,
              user_id_box,
              user_role_dd,
              file_uploader,
              cohort_name_box,
          ],
    )


demo.launch(debug=True)
