In [None]:
import os
import re
import string
import logging
import difflib
from typing import TypedDict, Optional, Tuple, Dict, Any, List
from collections import Counter

import pandas as pd
import psycopg2
import gradio as gr
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document

# ---------------- Gmail imports (NEW) ----------------
import base64
from email.mime.text import MIMEText
from typing import Optional as _Opt
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build
# ----------------------------------------------------

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# ========== ENV ==========
load_dotenv()
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
SOURCE_TABLE = os.getenv("SOURCE_TABLE", "bi_dwh.prediction_agent_data")

GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GROQ_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct"

# ---- Gmail env (NEW) ----
GMAIL_CREDENTIALS_FILE = os.getenv("GMAIL_CREDENTIALS_FILE", "credentials.json")
GMAIL_TOKEN_FILE = os.getenv("GMAIL_TOKEN_FILE", "token.json")
SENDER_EMAIL = os.getenv("SENDER_EMAIL", "")
DEFAULT_TO_EMAIL = os.getenv("DEFAULT_TO_EMAIL", "")
# ------------------------

CUSTOMER_ID_COL = "customerid"
POLICY_NO_COL = "policy_no"

# ========== DB ==========
def fetch_customer_data() -> pd.DataFrame:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        df = pd.read_sql(f"SELECT * FROM {SOURCE_TABLE}", conn)
    finally:
        conn.close()
    df.columns = df.columns.str.lower()
    return df

# LLM
llm = ChatGroq(model=GROQ_MODEL, api_key=GROQ_API_KEY, temperature=0)

# Load table once (columns + sample fallback)
df_all_customers = fetch_customer_data()
all_cols = list(df_all_customers.columns)

# ========== Tiny schema RAG (embed only column names + one sample value) ==========
def build_schema_docs(df: pd.DataFrame) -> FAISS:
    texts = []
    for col in df.columns:
        try:
            sample_val = df[col].dropna().iloc[0]
        except Exception:
            sample_val = None
        desc = f"{col}: column in {SOURCE_TABLE}"
        if sample_val is not None:
            desc += f" | sample_value={str(sample_val)[:120]}"
        texts.append(desc)
    docs = [Document(page_content=t) for t in texts]
    embeds = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
    return FAISS.from_documents(docs, embeds)

schema_vs = build_schema_docs(df_all_customers)
schema_retriever = schema_vs.as_retriever(search_kwargs={"k": 8})

def get_schema_snippets(question: str) -> str:
    docs = schema_retriever.get_relevant_documents(question)
    return "\n".join(d.page_content for d in docs)

# ========== Lightweight in-process memory for follow-ups ==========
MEMORY: Dict[str, Any] = {
    "last_kind": None,        # "CUSTOMER" | "POLICY" | None
    "last_ident": None,       # the actual id string
    "last_context": None,     # flattened row context string
    "last_status": None,      # "renewed" | "not_renewed" | "unknown"
}

FOLLOWUP_HINTS = {
    "for this", "for that", "for above", "same", "this policy", "that policy",
    "this customer", "that customer", "for this one",
    "use the same", "continue with this"
}

def looks_like_followup(text: str) -> bool:
    q = text.lower()
    return any(h in q for h in FOLLOWUP_HINTS)

# ========== SQL helpers ==========
def normalize_policy_for_compare(p: str) -> str:
    s = str(p).strip()
    if s.startswith(("'", '"')):
        s = s[1:]
    s = s.replace(" ", "")
    return s.upper()

def extract_identifier(question: str) -> Tuple[Optional[str], Optional[str]]:
    q = question.strip()

    # explicit "policy …"
    m_pol_explicit = re.search(
        r"(?:policy(?:\s*no)?\.?\s*#?:?\s*[\"']?)([A-Za-z0-9\-]{6,40})",
        q, flags=re.IGNORECASE
    )
    if m_pol_explicit:
        return "POLICY", m_pol_explicit.group(1)

    # customer id (≥5 consecutive digits) – checked BEFORE the fallback
    m_cust = re.search(r"(?:customer\s*)?(\d{5,})", q, flags=re.IGNORECASE)
    if m_cust:
        return "CUSTOMER", m_cust.group(1)

    # long-token fallback IF it contains at least one digit (looks number-ish)
    m_long = re.search(
        r"[\"']?([A-Za-z0-9\-]{10,40}\d[A-Za-z0-9\-]*)",
        q,
    )
    if m_long:
        return "POLICY", m_long.group(1)

    return None, None

def fuzzy_in(text: str, targets, cutoff=0.8) -> bool:
    words = [w for w in re.findall(r"\w+", text) if len(w) > 2]
    for w in words:
        for t in targets:
            if len(t) > 2 and difflib.SequenceMatcher(None, w, t).ratio() >= cutoff:
                return True
    return False

def is_whole_word(word: str, text: str) -> bool:
    return re.search(rf"\b{re.escape(word)}\b", text) is not None

def classify_intent(question: str) -> str:
    q = question.strip().lower().rstrip(string.punctuation)

    greeting_phrases = ["hi", "hello", "hey", "what can you do", "who are you", "help"]
    if (fuzzy_in(q, greeting_phrases, 0.8) or any(is_whole_word(p, q) for p in greeting_phrases)):
        return "GREETING"

    email_phrases = [
        "email", "e-mail", "mail", "sms", "whatsapp", "draft",
        "compose", "send mail", "write an email", "write to",
        "message template", "text message"
    ]
    if (fuzzy_in(q, email_phrases, 0.75) or any(is_whole_word(p, q) for p in email_phrases)):
        return "EMAIL"

    reco_phrases = [
        "recommendation", "recommend", "strategy", "suggest",
        "advice", "plan", "retain", "how to retain",
        "next best action", "improve renewal", "reduce churn"
    ]
    if (fuzzy_in(q, reco_phrases, 0.75) or any(is_whole_word(p, q) for p in reco_phrases)):
        return "RECOMMENDATION"

    return "UNKNOWN"

def generate_sql_with_llm(question: str, kind: Optional[str], ident: Optional[str]) -> str:
    schema_snippets = get_schema_snippets(question)
    ident_str = ident or ""
    kind_str = kind or "UNKNOWN"

    norm_sql_expr = f"UPPER(REPLACE(REGEXP_REPLACE({POLICY_NO_COL}::text, '^''', ''), ' ', ''))"

    prompt = f"""
You write a single valid PostgreSQL SELECT for table {SOURCE_TABLE}. No markdown.

Schema hints:
{schema_snippets}

Rules:
- Use only table {SOURCE_TABLE}.
- If the question is customer-specific -> filter as: WHERE {CUSTOMER_ID_COL} = '<id>'.
- If the question is policy-specific -> compare normalized policy numbers:
  Use: {norm_sql_expr} = '<NORMALIZED_ID>'  -- where NORMALIZED_ID strips leading quote and spaces then UPPER.
- Never do JOINs or subqueries. One SELECT only.
- Return *only* the final SQL (end with a semicolon).

Question: {question}
Kind detected: {kind_str}
Identifier: {ident_str}

Examples:
-- customer
SELECT * FROM {SOURCE_TABLE} WHERE {CUSTOMER_ID_COL} = '1467520';

-- policy (normalize compare)
SELECT * FROM {SOURCE_TABLE}
WHERE {norm_sql_expr} = '201140020123100158101000';
"""
    sql = llm.invoke(prompt).content.strip()
    m = re.search(r"(SELECT\s.+?;)", sql, flags=re.IGNORECASE | re.DOTALL)
    return m.group(1).strip() if m else sql

def execute_sql_fetch_one(sql: str) -> Optional[Dict[str, Any]]:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        with conn.cursor() as cur:
            cur.execute(sql)
            row = cur.fetchone()
            if row is None:
                return None
            cols = [desc[0].lower() for desc in cur.description]
            return dict(zip(cols, row))
    finally:
        conn.close()

# NEW: fetch all rows
def execute_sql_fetch_all(sql: str) -> List[Dict[str, Any]]:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        with conn.cursor() as cur:
            cur.execute(sql)
            rows = cur.fetchall()
            if not rows:
                return []
            cols = [desc[0].lower() for desc in cur.description]
            return [dict(zip(cols, r)) for r in rows]
    finally:
        conn.close()

# NEW: aggregate churn reasons across rows
def aggregate_reasons(rows: List[Dict[str, Any]]) -> tuple[list[str], Dict[str, int]]:
    reasons_counter = Counter()
    reason_cols = ["churn_main_reason", "churn_top_3_reasons"]
    for r in rows:
        for c in reason_cols:
            if c in r and r[c]:
                val = str(r[c])
                parts = [p.strip() for p in val.replace(";", ",").split(",") if p.strip()]
                for p in parts:
                    reasons_counter[p] += 1
    top3 = [reason for reason, _ in reasons_counter.most_common(3)]
    return top3, dict(reasons_counter)

def infer_status_from_values(value: object) -> Optional[str]:
    v = str(value).strip().lower()
    if v in {"renewed", "yes", "true"}:
        return "renewed"
    if v in {"not renewed", "no", "false", "churn", "not_renewed"}:
        return "not_renewed"
    return None

def infer_status_from_context(ctx: str) -> str:
    c = ctx.lower()
    if "is_churn: renewed" in c or "predicted status: renewed" in c or "status: renewed" in c or ("renewed" in c and "not renewed" not in c):
        return "renewed"
    if "is_churn: not renewed" in c or "predicted status: not renewed" in c or "status: not renewed" in c or "churn" in c:
        return "not_renewed"
    return "unknown"

# ========== State ==========
class AgentState(TypedDict, total=False):
    question: str
    context: str
    route: str
    status: str
    customer_id: Optional[str]
    policy_no: Optional[str]
    response: str

def unknown_node(state: AgentState) -> AgentState:
    state["response"] = "Sorry, I’m a retention-strategy agent and can’t answer that question."
    return state

# ========== Build context (with memory support) ==========
def build_context_and_status(state: AgentState) -> AgentState:
    q = state["question"]

    # Prefer identifiers set by router_node (which may have reused memory)
    if state.get("customer_id"):
        kind, ident = "CUSTOMER", state["customer_id"]
    elif state.get("policy_no"):
        kind, ident = "POLICY", state["policy_no"]
    else:
        kind, ident = extract_identifier(q)

    state["customer_id"] = ident if kind == "CUSTOMER" else None
    state["policy_no"] = ident if kind == "POLICY" else None

    # No identifier → sample fallback; do not overwrite memory
    if kind is None or ident is None:
        sample = df_all_customers.head(3)
        context = "Sample records: " + "; ".join(
            ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
            for i in range(min(3, len(sample)))
        )
        state["context"] = context
        state["status"] = infer_status_from_context(context)
        return state

    # Generate SQL
    sql = generate_sql_with_llm(q, kind=kind, ident=ident)
    logging.info("Generated SQL:\n%s", sql)

    # -------- POLICY branch (single row as before) --------
    if kind == "POLICY":
        norm_val = normalize_policy_for_compare(ident)
        guard_norm = f"UPPER(REPLACE(REGEXP_REPLACE({POLICY_NO_COL}::text, '^''', ''), ' ', ''))"
        if guard_norm not in sql.upper():
            sql = f"""
SELECT * FROM {SOURCE_TABLE}
WHERE {guard_norm} = '{norm_val}';
""".strip()
            logging.info("Corrected SQL:\n%s", sql)

        row = execute_sql_fetch_one(sql)
        if not row:
            sample = df_all_customers.head(3)
            context = f"No exact row found for {kind}:{ident}. Sample records: " + "; ".join(
                ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
                for i in range(min(3, len(sample)))
            )
            state["context"] = context
            state["status"] = infer_status_from_context(context)

            MEMORY["last_kind"] = kind
            MEMORY["last_ident"] = ident
            MEMORY["last_context"] = context
            MEMORY["last_status"] = state["status"]
            return state

        context = ", ".join(f"{k}: {row.get(k)}" for k in all_cols if k in row)
        state["context"] = context

        status = "unknown"
        for col in ["is_churn", "predicted status", "predicted_status", "status"]:
            key = col.lower()
            if key in row:
                s = infer_status_from_values(row[key])
                if s:
                    status = s
                    break
        if status == "unknown":
            status = infer_status_from_context(context)
        state["status"] = status

        # Persist memory
        MEMORY["last_kind"] = kind
        MEMORY["last_ident"] = ident
        MEMORY["last_context"] = context
        MEMORY["last_status"] = status
        return state

    # -------- CUSTOMER branch (NEW: fetch ALL rows + aggregate) --------
    if kind == "CUSTOMER":
        rows = execute_sql_fetch_all(sql)
        if not rows:
            sample = df_all_customers.head(3)
            context = f"No rows found for {kind}:{ident}. Sample records: " + "; ".join(
                ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
                for i in range(min(3, len(sample)))
            )
            state["context"] = context
            state["status"] = infer_status_from_context(context)

            MEMORY["last_kind"] = kind
            MEMORY["last_ident"] = ident
            MEMORY["last_context"] = context
            MEMORY["last_status"] = state["status"]
            return state

        # Aggregate reasons across rows
        top3, counts = aggregate_reasons(rows)

        # Customer-level status
        found_statuses: List[str] = []
        for r in rows:
            for k in ["is_churn", "predicted status", "predicted_status", "status"]:
                if k in r:
                    s = infer_status_from_values(r.get(k))
                    if s:
                        found_statuses.append(s)

        if any(s == "not_renewed" for s in found_statuses):
            status = "not_renewed"
        elif found_statuses and all(s == "renewed" for s in found_statuses):
            status = "renewed"
        else:
            status = "unknown"

        # Compact context: aggregate + a few sample rows
        samples_str = "; ".join(
            ", ".join(f"{c}: {sr.get(c)}" for c in all_cols if c in sr)
            for sr in rows[:3]
        )
        context = (
            f"customer_id: {ident}; policies_found: {len(rows)}; "
            f"aggregated_top_3_reasons: {top3}; reason_counts: {counts}; "
            f"samples: {samples_str}"
        )

        state["context"] = context
        state["status"] = status

        # Persist memory
        MEMORY["last_kind"] = kind
        MEMORY["last_ident"] = ident
        MEMORY["last_context"] = context
        MEMORY["last_status"] = status
        return state

    return state

# ========== Agents ==========
def recommendation_agent(state: AgentState) -> str:
    if state.get("status") == "renewed" and (state.get("customer_id") or state.get("policy_no")):
        return "No action is needed. This customer is predicted to renew their policy."
    prompt = f"""
You are a professional customer retention strategist for a car insurance company.

Your job is to analyze the situation based on the given context and provide a clear, helpful recommendation to improve customer retention.

Use the following internal guidelines while crafting your answer (DO NOT mention these in your response):

- Users may ask general or customer-specific questions.

- If the question is customer-specific or policy-specific:
  - Use is_churn to determine the customer’s or policy’s status.
  - If the status is "Not Renewed", provide a smart, personalized recommendation to retain the customer.
  - If the status is "Renewed", simply acknowledge that no action is needed.

- For customer or policy specific queries:
  - Carefully analyze the churn reasons provided, including both `churn_main_reason` and all values listed in `churn_top_3_reasons`.
  - Your recommendation must address all churn reasons mentioned, not just one. For each reason, include a relevant counter-strategy or benefit.
  - Additionally, consider the customer’s or policy’s purchase history, premium values, vehicle type, and policy type to enrich your recommendation.

- If the question is based on churn reasons:
  - Analyze the `churn_main_reason` and `churn_top_3_reasons` - give a relevant, reason-based retention strategy

- If mentioning discounts related, stay in practical ranges (10–30%) unless strong justification; never exceed 30%.

- Don't mention any `churned_customer_segment` related.

Style rules (apply to every answer):
- Deliver a concise, third-person strategy note—never an email.
- Do not address the customer directly (“you”, “we”, “please”).
- Do not ask for missing fields; work with what you have.
- Keep the tone professional, helpful, and concise.
- Format the recommendation in clear bullet points. Each point should focus on a distinct strategy or insight.

Context:
{state['context']}

Question:
{state['question']}

Return only the recommendation.
""".strip()
    return llm.invoke(prompt).content.strip()

def email_agent(state: AgentState) -> str:
    if state.get("status") == "renewed" and (state.get("customer_id") or state.get("policy_no")):
        return "No email is needed because this customer is predicted to renew."
    prompt = f"""
You are an expert in customer retention communication for car insurance.

Write a concise, friendly, proactive email or SMS message.

Follow these internal rules (DO NOT mention them in the output):
- Do NOT mention churn predictions like "you may not renew" or "we noticed you won't renew".

- If the question is about a specific customer wise or policy wise:
    - Carefully review both `churn_main_reason` and all values in `churn_top_3_reasons`, and incorporate them into the message. If multiple reasons are given, the message should reflect a proactive solution or benefit for each one.
    - Use their `churn_main_reason`, `churn_top_3_reasons`, `premium_amount`, `policy_type`, `vehicle_type`, etc. to personalize the message.
    - Offer a relevant benefit like discounts, loyalty reward, etc.

- If the request is about churn_reason (`churn_main_reason`, `churn_top_3_reasons`) related (e.g., "Low Vehicle IDV", "Low discount with NCB", etc.):
    - Write a general reusable template for that reason.
    - Do NOT include any specific customer details (name, policy number, tenure, etc.).
    - Do NOT mention any specific vehicle make/model.
    - Keep the message applicable to all customers with that reason.
    - Provide a persuasive, benefit-focused message addressing the reason and offering a compelling renewal incentive.

- Just assume retention is needed and offer clear value (e.g., discounts, benefits).

- If mentioning discounts related, stay in practical ranges (10–30%) unless strong justification; never exceed 30%.

- Don't mention any `churned_customer_segment` related.

- Keep the tone friendly, persuasive, and action-oriented.

- End with a clear next step: e.g., renew link or support contact.

- Be brief and do not repeat context data unnecessarily.

Context:
{state['context']}

Query:
{state['question']}

Return only the finished message.
""".strip()
    return llm.invoke(prompt).content.strip()

def greeting_agent() -> str:
    prompt = """
You are a helpful assistant for a car insurance retention team.
Explain briefly what you can do: analyze customer contexts, provide renewal/retention recommendations, and draft concise email/SMS messages.
Keep it short and professional.
"""
    return llm.invoke(prompt).content.strip()

# ========== Router with follow-up memory ==========
def router_node(state: AgentState) -> AgentState:
    state["route"] = classify_intent(state["question"])

    if state["route"] not in {"GREETING", "UNKNOWN"}:
        q = state["question"]
        kind, ident = extract_identifier(q)

        # Reuse last identifier for follow-ups without an explicit id
        if (kind is None or ident is None) and looks_like_followup(q):
            if MEMORY.get("last_kind") and MEMORY.get("last_ident"):
                kind = MEMORY["last_kind"]
                ident = MEMORY["last_ident"]

        if kind == "CUSTOMER":
            state["customer_id"] = ident
            state["policy_no"] = None
        elif kind == "POLICY":
            state["policy_no"] = ident
            state["customer_id"] = None
        else:
            state["customer_id"] = None
            state["policy_no"] = None

        state = build_context_and_status(state)
    return state

def reco_node(state: AgentState) -> AgentState:
    state["response"] = recommendation_agent(state)
    return state

def email_node(state: AgentState) -> AgentState:
    state["response"] = email_agent(state)
    return state

def greet_node(state: AgentState) -> AgentState:
    state["response"] = greeting_agent()
    return state

# ========== Graph ==========
graph = StateGraph(AgentState)
graph.add_node("router", RunnableLambda(router_node))
graph.add_node("recommendation_agent", RunnableLambda(reco_node))
graph.add_node("email_agent", RunnableLambda(email_node))
graph.add_node("greeting_agent", RunnableLambda(greet_node))
graph.add_node("unknown_agent", RunnableLambda(unknown_node))

graph.set_entry_point("router")
graph.add_conditional_edges(
    "router",
    lambda s: s["route"],
    {
        "RECOMMENDATION": "recommendation_agent",
        "EMAIL": "email_agent",
        "GREETING": "greeting_agent",
        "UNKNOWN": "unknown_agent",
    },
)
graph.add_edge("recommendation_agent", END)
graph.add_edge("email_agent", END)
graph.add_edge("greeting_agent", END)
graph.add_edge("unknown_agent", END)

flow = graph.compile()

# ========== NEW: last email store + Gmail sender ==========
_LAST_EMAIL_SUBJ = ""
_LAST_EMAIL_BODY = ""
_LAST_ROUTE = "UNKNOWN"

def _guess_subject_from_body(body: str) -> str:
    # Minimal heuristic: first non-empty line, truncated
    for line in (body or "").splitlines():
        s = line.strip()
        if s:
            return s[:80]
    return "Policy Renewal Options"

# ========== UI glue ==========
SUGGESTION_TEMPLATES = [
    ["Give a retention recommendation for policy \"Policy_no\"."],
    ["Give a retention recommendation for customer \"Customer_id\"."],
    ["Write an email for policy \"Policy_no\" focusing on renewal and value."],
    ["Write an email for customer \"Customer_id\" focusing on renewal and value."],
    ["Give a retention recommendation for the customers whose having \"Reason\" as a reason."],
    ["Write a mail draft for the customers whose having \"Reason\" as a reason."],
    ["For this policy, write a retention email draft."],
    ["For this customer, write a retention email draft."]
]

def chat_respond(message: str, history: list) -> str:
    global _LAST_EMAIL_SUBJ, _LAST_EMAIL_BODY, _LAST_ROUTE
    try:
        out = flow.invoke({"question": message})
        # Preserve last draft if this is an email route
        _LAST_ROUTE = out.get("route", "UNKNOWN")
        if _LAST_ROUTE == "EMAIL":
            body = out.get("response", "")
            _LAST_EMAIL_BODY = body
            _LAST_EMAIL_SUBJ = _guess_subject_from_body(body)
        return out.get("response", "No response.")
    except Exception as e:
        _LAST_ROUTE = "UNKNOWN"
        return f"Error: {e}"

# ========== Keep your ChatInterface exactly, but embed it in a Blocks wrapper (to show Send button) ==========
with gr.Blocks(title="Retention Assistant") as app:
    # Your original ChatInterface (unchanged)
    demo = gr.ChatInterface(
        fn=chat_respond,
        title="Retention Assistant",
        textbox=gr.Textbox(placeholder="Ask for a recommendation or request an email/SMS draft…"),
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear",
        examples=SUGGESTION_TEMPLATES,
        cache_examples=False
    )

    # Small send panel (NEW)
    with gr.Row():
        to_box = gr.Textbox(label="To", value=DEFAULT_TO_EMAIL, scale=4)
        send_btn = gr.Button("Send Mail", visible=True, scale=1)  # stays visible; guarded in handler
    send_status = gr.Markdown("")

    # Click handler (NEW)
    def _on_send(to_addr: str):
        global _LAST_EMAIL_SUBJ, _LAST_EMAIL_BODY, _LAST_ROUTE
        if _LAST_ROUTE != "EMAIL":
            return "No email draft detected. Ask for an email draft first."
        if not to_addr:
            return "Please provide a 'To' address."
        if not _LAST_EMAIL_BODY:
            return "No email body available."

        # Build Gmail service & send
        try:
            # Build service (first time triggers OAuth → creates token.json)
            scopes = ["https://www.googleapis.com/auth/gmail.send"]
            creds = None
            if os.path.exists(GMAIL_TOKEN_FILE):
                creds = Credentials.from_authorized_user_file(GMAIL_TOKEN_FILE, scopes)
            if not creds or not creds.valid:
                if creds and creds.expired and creds.refresh_token:
                    creds.refresh(Request())
                else:
                    flow = InstalledAppFlow.from_client_secrets_file(GMAIL_CREDENTIALS_FILE, scopes)
                    try:
                        creds = flow.run_local_server(port=0)
                    except Exception:
                        print("Falling back to console OAuth flow...")
                        creds = flow.run_console()
                with open(GMAIL_TOKEN_FILE, "w") as f:
                    f.write(creds.to_json())
            service = build("gmail", "v1", credentials=creds)

            sender = SENDER_EMAIL or ""
            if not sender:
                return "SENDER_EMAIL not set in environment."

            msg = MIMEText(_LAST_EMAIL_BODY, _charset="utf-8")
            msg["to"] = to_addr
            msg["from"] = sender
            msg["subject"] = _LAST_EMAIL_SUBJ or "Policy Renewal Options"

            raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
            sent = service.users().messages().send(userId="me", body={"raw": raw}).execute()
            return f"Sent. Gmail message id: {sent.get('id','')}"
        except Exception as e:
            return f"Send failed: {e}"

    send_btn.click(_on_send, inputs=[to_box], outputs=[send_status])

if __name__ == "__main__":
    app.launch()

  df = pd.read_sql(f"SELECT * FROM {SOURCE_TABLE}", conn)
  embeds = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
2025-08-18 10:41:51,069 | INFO | Use pytorch device_name: cuda:0
2025-08-18 10:41:51,071 | INFO | Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5
2025-08-18 10:42:00,977 | INFO | Loading faiss with AVX2 support.
2025-08-18 10:42:01,408 | INFO | Successfully loaded faiss with AVX2 support.
2025-08-18 10:42:01,426 | INFO | Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss.


Running on local URL:  http://127.0.0.1:7860


2025-08-18 10:42:02,440 | INFO | HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
2025-08-18 10:42:02,472 | INFO | HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
2025-08-18 10:42:02,573 | INFO | HTTP Request: GET http://127.0.0.1:7860/startup-events "HTTP/1.1 200 OK"
2025-08-18 10:42:02,941 | INFO | HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"



To create a public link, set `share=True` in `launch()`.


2025-08-18 10:42:03,301 | INFO | HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-08-18 10:42:03,324 | INFO | HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
  docs = schema_retriever.get_relevant_documents(question)
2025-08-18 10:45:05,973 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-18 10:45:06,026 | INFO | Generated SQL:
SELECT churn_top_3_reasons FROM bi_dwh.prediction_agent_data 
WHERE UPPER(REPLACE(REGEXP_REPLACE(policy_no::text, '^''', ''), ' ', '')) = '201130140823100001101000';
2025-08-18 10:45:06,026 | INFO | Corrected SQL:
SELECT * FROM bi_dwh.prediction_agent_data
WHERE UPPER(REPLACE(REGEXP_REPLACE(policy_no::text, '^''', ''), ' ', '')) = '201130140823100001101000';
2025-08-18 10:45:07,639 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-18 10:45:57,318 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/

In [None]:
import os
import re
import string
import logging
import difflib
from typing import TypedDict, Optional, Tuple, Dict, Any, List
from collections import Counter

import pandas as pd
import psycopg2
import gradio as gr
from dotenv import load_dotenv
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph, END
from langchain_core.runnables import RunnableLambda
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document

# Gmail imports
import base64
from email.mime.text import MIMEText
from typing import Optional as _Opt
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from googleapiclient.discovery import build

logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")

# ENV 
load_dotenv()
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
SOURCE_TABLE = os.getenv("SOURCE_TABLE", "bi_dwh.prediction_agent_data")

GROQ_API_KEY = os.getenv("GROQ_API_KEY")
GROQ_MODEL = "meta-llama/llama-4-maverick-17b-128e-instruct"

# Gmail env
GMAIL_CREDENTIALS_FILE = os.getenv("GMAIL_CREDENTIALS_FILE", "credentials.json")
GMAIL_TOKEN_FILE = os.getenv("GMAIL_TOKEN_FILE", "token.json")
SENDER_EMAIL = os.getenv("SENDER_EMAIL", "")
DEFAULT_TO_EMAIL = os.getenv("DEFAULT_TO_EMAIL", "")

CUSTOMER_ID_COL = "customerid"
POLICY_NO_COL = "policy_no"

# DB
def fetch_customer_data() -> pd.DataFrame:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        df = pd.read_sql(f"SELECT * FROM {SOURCE_TABLE}", conn)
    finally:
        conn.close()
    df.columns = df.columns.str.lower()
    return df

# LLM
llm = ChatGroq(model=GROQ_MODEL, api_key=GROQ_API_KEY, temperature=0)

# Load table once (columns + sample fallback)
df_all_customers = fetch_customer_data()
all_cols = list(df_all_customers.columns)

# Tiny schema RAG (embed only column names + one sample value)
def build_schema_docs(df: pd.DataFrame) -> FAISS:
    texts = []
    for col in df.columns:
        try:
            sample_val = df[col].dropna().iloc[0]
        except Exception:
            sample_val = None
        desc = f"{col}: column in {SOURCE_TABLE}"
        if sample_val is not None:
            desc += f" | sample_value={str(sample_val)[:120]}"
        texts.append(desc)
    docs = [Document(page_content=t) for t in texts]
    embeds = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
    return FAISS.from_documents(docs, embeds)

schema_vs = build_schema_docs(df_all_customers)
schema_retriever = schema_vs.as_retriever(search_kwargs={"k": 8})

def get_schema_snippets(question: str) -> str:
    docs = schema_retriever.get_relevant_documents(question)
    return "\n".join(d.page_content for d in docs)

# Lightweight in-process memory for follow-ups
MEMORY: Dict[str, Any] = {
    "last_kind": None,        # "CUSTOMER" | "POLICY" | None
    "last_ident": None,       # the actual id string
    "last_context": None,     # flattened row context string
    "last_status": None,      # "renewed" | "not_renewed" | "unknown"
}

FOLLOWUP_HINTS = {
    "for this", "for that", "for above", "same", "this policy", "that policy",
    "this customer", "that customer", "for this one",
    "use the same", "continue with this"
}

def looks_like_followup(text: str) -> bool:
    q = text.lower()
    return any(h in q for h in FOLLOWUP_HINTS)

# SQL helpers
def normalize_policy_for_compare(p: str) -> str:
    s = str(p).strip()
    if s.startswith(("'", '"')):
        s = s[1:]
    s = s.replace(" ", "")
    return s.upper()

def extract_identifier(question: str) -> Tuple[Optional[str], Optional[str]]:
    q = question.strip()

    # explicit "policy …"
    m_pol_explicit = re.search(
        r"(?:policy(?:\s*no)?\.?\s*#?:?\s*[\"']?)([A-Za-z0-9\-]{6,40})",
        q, flags=re.IGNORECASE
    )
    if m_pol_explicit:
        return "POLICY", m_pol_explicit.group(1)

    # customer id (≥5 consecutive digits)
    m_cust = re.search(r"(?:customer\s*)?(\d{5,})", q, flags=re.IGNORECASE)
    if m_cust:
        return "CUSTOMER", m_cust.group(1)

    # long-token fallback if contains at least one digit
    m_long = re.search(
        r"[\"']?([A-Za-z0-9\-]{10,40}\d[A-Za-z0-9\-]*)",
        q,
    )
    if m_long:
        return "POLICY", m_long.group(1)

    return None, None

def fuzzy_in(text: str, targets, cutoff=0.8) -> bool:
    words = [w for w in re.findall(r"\w+", text) if len(w) > 2]
    for w in words:
        for t in targets:
            if len(t) > 2 and difflib.SequenceMatcher(None, w, t).ratio() >= cutoff:
                return True
    return False

def is_whole_word(word: str, text: str) -> bool:
    return re.search(rf"\b{re.escape(word)}\b", text) is not None

def classify_intent(question: str) -> str:
    q = question.strip().lower().rstrip(string.punctuation)

    greeting_phrases = ["hi", "hello", "hey", "what can you do", "who are you", "help"]
    if (fuzzy_in(q, greeting_phrases, 0.8) or any(is_whole_word(p, q) for p in greeting_phrases)):
        return "GREETING"

    email_phrases = [
        "email", "e-mail", "mail", "sms", "whatsapp", "draft",
        "compose", "send mail", "write an email", "write to",
        "message template", "text message"
    ]
    if (fuzzy_in(q, email_phrases, 0.75) or any(is_whole_word(p, q) for p in email_phrases)):
        return "EMAIL"

    reco_phrases = [
        "recommendation", "recommend", "strategy", "suggest",
        "advice", "plan", "retain", "how to retain",
        "next best action", "improve renewal", "reduce churn"
    ]
    if (fuzzy_in(q, reco_phrases, 0.75) or any(is_whole_word(p, q) for p in reco_phrases)):
        return "RECOMMENDATION"

    return "UNKNOWN"

def generate_sql_with_llm(question: str, kind: Optional[str], ident: Optional[str]) -> str:
    schema_snippets = get_schema_snippets(question)
    ident_str = ident or ""
    kind_str = kind or "UNKNOWN"

    norm_sql_expr = f"UPPER(REPLACE(REGEXP_REPLACE({POLICY_NO_COL}::text, '^''', ''), ' ', ''))"

    prompt = f"""
You write a single valid PostgreSQL SELECT for table {SOURCE_TABLE}. No markdown.

Schema hints:
{schema_snippets}

Rules:
- Use only table {SOURCE_TABLE}.
- If the question is customer-specific -> filter as: WHERE {CUSTOMER_ID_COL} = '<id>'.
- If the question is policy-specific -> compare normalized policy numbers:
  Use: {norm_sql_expr} = '<NORMALIZED_ID>'  -- where NORMALIZED_ID strips leading quote and spaces then UPPER.
- Never do JOINs or subqueries. One SELECT only.
- Return *only* the final SQL (end with a semicolon).

Question: {question}
Kind detected: {kind_str}
Identifier: {ident_str}

Examples:
-- customer
SELECT * FROM {SOURCE_TABLE} WHERE {CUSTOMER_ID_COL} = '1467520';

-- policy (normalize compare)
SELECT * FROM {SOURCE_TABLE}
WHERE {norm_sql_expr} = '201140020123100158101000';
"""
    sql = llm.invoke(prompt).content.strip()
    m = re.search(r"(SELECT\s.+?;)", sql, flags=re.IGNORECASE | re.DOTALL)
    return m.group(1).strip() if m else sql

def execute_sql_fetch_one(sql: str) -> Optional[Dict[str, Any]]:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        with conn.cursor() as cur:
            cur.execute(sql)
            row = cur.fetchone()
            if row is None:
                return None
            cols = [desc[0].lower() for desc in cur.description]
            return dict(zip(cols, row))
    finally:
        conn.close()

def execute_sql_fetch_all(sql: str) -> List[Dict[str, Any]]:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT, dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD
    )
    try:
        with conn.cursor() as cur:
            cur.execute(sql)
            rows = cur.fetchall()
            if not rows:
                return []
            cols = [desc[0].lower() for desc in cur.description]
            return [dict(zip(cols, r)) for r in rows]
    finally:
        conn.close()

def aggregate_reasons(rows: List[Dict[str, Any]]) -> tuple[list[str], Dict[str, int]]:
    reasons_counter = Counter()
    reason_cols = ["churn_main_reason", "churn_top_3_reasons"]
    for r in rows:
        for c in reason_cols:
            if c in r and r[c]:
                val = str(r[c])
                parts = [p.strip() for p in val.replace(";", ",").split(",") if p.strip()]
                for p in parts:
                    reasons_counter[p] += 1
    top3 = [reason for reason, _ in reasons_counter.most_common(3)]
    return top3, dict(reasons_counter)

def infer_status_from_values(value: object) -> Optional[str]:
    v = str(value).strip().lower()
    if v in {"renewed", "yes", "true"}:
        return "renewed"
    if v in {"not renewed", "no", "false", "churn", "not_renewed"}:
        return "not_renewed"
    return None

def infer_status_from_context(ctx: str) -> str:
    c = ctx.lower()
    if "is_churn: renewed" in c or "predicted status: renewed" in c or "status: renewed" in c or ("renewed" in c and "not renewed" not in c):
        return "renewed"
    if "is_churn: not renewed" in c or "predicted status: not renewed" in c or "status: not renewed" in c or "churn" in c:
        return "not_renewed"
    return "unknown"

# State
class AgentState(TypedDict, total=False):
    question: str
    context: str
    route: str
    status: str
    customer_id: Optional[str]
    policy_no: Optional[str]
    response: str

def unknown_node(state: AgentState) -> AgentState:
    state["response"] = "Sorry, I’m a retention-strategy agent and can’t answer that question."
    return state

# Build context (with memory support)
def build_context_and_status(state: AgentState) -> AgentState:
    q = state["question"]

    if state.get("customer_id"):
        kind, ident = "CUSTOMER", state["customer_id"]
    elif state.get("policy_no"):
        kind, ident = "POLICY", state["policy_no"]
    else:
        kind, ident = extract_identifier(q)

    state["customer_id"] = ident if kind == "CUSTOMER" else None
    state["policy_no"] = ident if kind == "POLICY" else None

    if kind is None or ident is None:
        sample = df_all_customers.head(3)
        context = "Sample records: " + "; ".join(
            ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
            for i in range(min(3, len(sample)))
        )
        state["context"] = context
        state["status"] = infer_status_from_context(context)
        return state

    sql = generate_sql_with_llm(q, kind=kind, ident=ident)
    logging.info("Generated SQL:\n%s", sql)

    if kind == "POLICY":
        norm_val = normalize_policy_for_compare(ident)
        guard_norm = f"UPPER(REPLACE(REGEXP_REPLACE({POLICY_NO_COL}::text, '^''', ''), ' ', ''))"
        if guard_norm not in sql.upper():
            sql = f"""
SELECT * FROM {SOURCE_TABLE}
WHERE {guard_norm} = '{norm_val}';
""".strip()
            logging.info("Corrected SQL:\n%s", sql)

        row = execute_sql_fetch_one(sql)
        if not row:
            sample = df_all_customers.head(3)
            context = f"No exact row found for {kind}:{ident}. Sample records: " + "; ".join(
                ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
                for i in range(min(3, len(sample)))
            )
            state["context"] = context
            state["status"] = infer_status_from_context(context)

            MEMORY["last_kind"] = kind
            MEMORY["last_ident"] = ident
            MEMORY["last_context"] = context
            MEMORY["last_status"] = state["status"]
            return state

        context = ", ".join(f"{k}: {row.get(k)}" for k in all_cols if k in row)
        state["context"] = context

        status = "unknown"
        for col in ["is_churn", "predicted status", "predicted_status", "status"]:
            key = col.lower()
            if key in row:
                s = infer_status_from_values(row[key])
                if s:
                    status = s
                    break
        if status == "unknown":
            status = infer_status_from_context(context)
        state["status"] = status

        MEMORY["last_kind"] = kind
        MEMORY["last_ident"] = ident
        MEMORY["last_context"] = context
        MEMORY["last_status"] = status
        return state

    if kind == "CUSTOMER":
        rows = execute_sql_fetch_all(sql)
        if not rows:
            sample = df_all_customers.head(3)
            context = f"No rows found for {kind}:{ident}. Sample records: " + "; ".join(
                ", ".join(f"{c}: {sample.iloc[i][c]}" for c in all_cols)
                for i in range(min(3, len(sample)))
            )
            state["context"] = context
            state["status"] = infer_status_from_context(context)

            MEMORY["last_kind"] = kind
            MEMORY["last_ident"] = ident
            MEMORY["last_context"] = context
            MEMORY["last_status"] = state["status"]
            return state

        top3, counts = aggregate_reasons(rows)

        found_statuses: List[str] = []
        for r in rows:
            for k in ["is_churn", "predicted status", "predicted_status", "status"]:
                if k in r:
                    s = infer_status_from_values(r.get(k))
                    if s:
                        found_statuses.append(s)

        if any(s == "not_renewed" for s in found_statuses):
            status = "not_renewed"
        elif found_statuses and all(s == "renewed" for s in found_statuses):
            status = "renewed"
        else:
            status = "unknown"

        samples_str = "; ".join(
            ", ".join(f"{c}: {sr.get(c)}" for c in all_cols if c in sr)
            for sr in rows[:3]
        )
        context = (
            f"customer_id: {ident}; policies_found: {len(rows)}; "
            f"aggregated_top_3_reasons: {top3}; reason_counts: {counts}; "
            f"samples: {samples_str}"
        )

        state["context"] = context
        state["status"] = status

        MEMORY["last_kind"] = kind
        MEMORY["last_ident"] = ident
        MEMORY["last_context"] = context
        MEMORY["last_status"] = status
        return state

    return state

# Agents
def recommendation_agent(state: AgentState) -> str:
    if state.get("status") == "renewed" and (state.get("customer_id") or state.get("policy_no")):
        return "No action is needed. This customer is predicted to renew their policy."
    prompt = f"""
You are a professional customer retention strategist for a car insurance company.

Your job is to analyze the situation based on the given context and provide a clear, helpful recommendation to improve customer retention.

Use the following internal guidelines while crafting your answer (DO NOT mention these in your response):

- Users may ask general or customer-specific questions.

- If the question is customer-specific or policy-specific:
  - Use is_churn to determine the customer’s or policy’s status.
  - If the status is "Not Renewed", provide a smart, personalized recommendation to retain the customer.
  - If the status is "Renewed", simply acknowledge that no action is needed.

- For customer or policy specific queries:
  - Carefully analyze the churn reasons provided, including both `churn_main_reason` and all values listed in `churn_top_3_reasons`.
  - Your recommendation must address all churn reasons mentioned, not just one. For each reason, include a relevant counter-strategy or benefit.
  - Additionally, consider the customer’s or policy’s purchase history, premium values, vehicle type, and policy type to enrich your recommendation.

- If the question is based on churn reasons:
  - Analyze the `churn_main_reason` and `churn_top_3_reasons` - give a relevant, reason-based retention strategy

- If mentioning discounts related, stay in practical ranges (10–30%) unless strong justification; never exceed 30%.

- Don't mention any `churned_customer_segment` related.

Style rules (apply to every answer):
- Deliver a concise, third-person strategy note—never an email.
- Do not address the customer directly (“you”, “we”, “please”).
- Do not ask for missing fields; work with what you have.
- Keep the tone professional, helpful, and concise.
- Format the recommendation in clear bullet points. Each point should focus on a distinct strategy or insight.

Context:
{state['context']}

Question:
{state['question']}

Return only the recommendation.
""".strip()
    return llm.invoke(prompt).content.strip()

def email_agent(state: AgentState) -> str:
    if state.get("status") == "renewed" and (state.get("customer_id") or state.get("policy_no")):
        return "No email is needed because this customer is predicted to renew."
    prompt = f"""
You are an expert in customer retention communication for car insurance.

Write a concise, friendly, proactive email or SMS message.

Follow these internal rules (DO NOT mention them in the output):
- Do NOT mention churn predictions like "you may not renew" or "we noticed you won't renew".

- If the question is about a specific customer wise or policy wise:
    - Carefully review both `churn_main_reason` and all values in `churn_top_3_reasons`, and incorporate them into the message. If multiple reasons are given, the message should reflect a proactive solution or benefit for each one.
    - Use their `churn_main_reason`, `churn_top_3_reasons`, `premium_amount`, `policy_type`, `vehicle_type`, etc. to personalize the message.
    - Offer a relevant benefit like discounts, loyalty reward, etc.

- If the request is about churn_reason (`churn_main_reason`, `churn_top_3_reasons`) related (e.g., "Low Vehicle IDV", "Low discount with NCB", etc.):
    - Write a general reusable template for that reason.
    - Do NOT include any specific customer details (name, policy number, tenure, etc.).
    - Do NOT mention any specific vehicle make/model.
    - Keep the message applicable to all customers with that reason.
    - Provide a persuasive, benefit-focused message addressing the reason and offering a compelling renewal incentive.

- Just assume retention is needed and offer clear value (e.g., discounts, benefits).

- If mentioning discounts related, stay in practical ranges (10–30%) unless strong justification; never exceed 30%.

- Don't mention any `churned_customer_segment` related.

- Keep the tone friendly, persuasive, and action-oriented.

- End with a clear next step: e.g., renew link or support contact.

- Be brief and do not repeat context data unnecessarily.

Context:
{state['context']}

Query:
{state['question']}

Return only the finished message.
""".strip()
    return llm.invoke(prompt).content.strip()

def greeting_agent() -> str:
    prompt = """
You are a helpful assistant for a car insurance retention team.
Explain briefly what you can do: analyze customer contexts, provide renewal/retention recommendations, and draft concise email/SMS messages.
Keep it short and professional.
"""
    return llm.invoke(prompt).content.strip()

# Router with follow-up memory
def router_node(state: AgentState) -> AgentState:
    state["route"] = classify_intent(state["question"])

    if state["route"] not in {"GREETING", "UNKNOWN"}:
        q = state["question"]
        kind, ident = extract_identifier(q)

        # Reuse last identifier for follow-ups without an explicit id
        if (kind is None or ident is None) and looks_like_followup(q):
            if MEMORY.get("last_kind") and MEMORY.get("last_ident"):
                kind = MEMORY["last_kind"]
                ident = MEMORY["last_ident"]

        if kind == "CUSTOMER":
            state["customer_id"] = ident
            state["policy_no"] = None
        elif kind == "POLICY":
            state["policy_no"] = ident
            state["customer_id"] = None
        else:
            state["customer_id"] = None
            state["policy_no"] = None

        state = build_context_and_status(state)
    return state

def reco_node(state: AgentState) -> AgentState:
    state["response"] = recommendation_agent(state)
    return state

def email_node(state: AgentState) -> AgentState:
    state["response"] = email_agent(state)
    return state

def greet_node(state: AgentState) -> AgentState:
    state["response"] = greeting_agent()
    return state

# Graph
graph = StateGraph(AgentState)
graph.add_node("router", RunnableLambda(router_node))
graph.add_node("recommendation_agent", RunnableLambda(reco_node))
graph.add_node("email_agent", RunnableLambda(email_node))
graph.add_node("greeting_agent", RunnableLambda(greet_node))
graph.add_node("unknown_agent", RunnableLambda(unknown_node))

graph.set_entry_point("router")
graph.add_conditional_edges(
    "router",
    lambda s: s["route"],
    {
        "RECOMMENDATION": "recommendation_agent",
        "EMAIL": "email_agent",
        "GREETING": "greeting_agent",
        "UNKNOWN": "unknown_agent",
    },
)
graph.add_edge("recommendation_agent", END)
graph.add_edge("email_agent", END)
graph.add_edge("greeting_agent", END)
graph.add_edge("unknown_agent", END)

flow = graph.compile()

# Email subject/body extraction (NEW)
SUBJECT_FALLBACK = "Policy Renewal Options"
SUBJECT_RE = re.compile(r"^\s*subject\s*:\s*(.+)$", re.IGNORECASE)

def extract_subject_and_clean_body(text: str) -> tuple[str, str]:
    """
    - If a line starts with 'Subject:' (any case), use everything after ':' as the subject
      and remove that line from the body.
    - Else: use the first non-empty line as subject and remove it from the body.
    - Trim quotes/spaces and cap length to ~78 chars.
    """
    if not text:
        return SUBJECT_FALLBACK, ""

    lines = text.splitlines()
    subject: Optional[str] = None
    body_lines: list[str] = []

    for line in lines:
        m = SUBJECT_RE.match(line)
        if m and subject is None:
            subject = m.group(1).strip()
            continue  # skip this line from body
        body_lines.append(line)

    if subject is None:
        # choose first non-empty line as subject, then remove it
        for idx, line in enumerate(body_lines):
            if line.strip():
                subject = line.strip().strip("'").strip('"')
                body_lines = body_lines[idx+1:]
                break

    if not subject:
        subject = SUBJECT_FALLBACK

    if subject.lower().startswith("subject:"):
        subject = subject.split(":", 1)[1].strip()

    subject = subject[:78]  # safe header length

    # strip leading blank lines from body
    while body_lines and not body_lines[0].strip():
        body_lines.pop(0)

    clean_body = "\n".join(body_lines).rstrip()
    return subject, clean_body

# NEW: last email store
_LAST_EMAIL_SUBJ = ""
_LAST_EMAIL_BODY = ""
_LAST_ROUTE = "UNKNOWN"

# UI glue
SUGGESTION_TEMPLATES = [
    ["Give a retention recommendation for policy \"Policy_no\"."],
    ["Give a retention recommendation for customer \"Customer_id\"."],
    ["Write an email for policy \"Policy_no\" focusing on renewal and value."],
    ["Write an email for customer \"Customer_id\" focusing on renewal and value."],
    ["Give a retention recommendation for the customers whose having \"Reason\" as a reason."],
    ["Write a mail draft for the customers whose having \"Reason\" as a reason."],
    ["For this policy, write a retention email draft."],
    ["For this customer, write a retention email draft."]
]

def chat_respond(message: str, history: list) -> str:
    global _LAST_EMAIL_SUBJ, _LAST_EMAIL_BODY, _LAST_ROUTE
    try:
        out = flow.invoke({"question": message})
        _LAST_ROUTE = out.get("route", "UNKNOWN")
        if _LAST_ROUTE == "EMAIL":
            raw_body = out.get("response", "") or ""
            subj, clean = extract_subject_and_clean_body(raw_body)
            _LAST_EMAIL_SUBJ = subj or SUBJECT_FALLBACK
            _LAST_EMAIL_BODY = clean
        return out.get("response", "No response.")
    except Exception as e:
        _LAST_ROUTE = "UNKNOWN"
        return f"Error: {e}"

with gr.Blocks(title="Retention Assistant") as app:
    demo = gr.ChatInterface(
        fn=chat_respond,
        title="Retention Assistant",
        textbox=gr.Textbox(placeholder="Ask for a recommendation or request an email/SMS draft…"),
        retry_btn="Retry",
        undo_btn="Undo",
        clear_btn="Clear",
        examples=SUGGESTION_TEMPLATES,
        cache_examples=False
    )

    with gr.Row():
        to_box = gr.Textbox(label="To", value=DEFAULT_TO_EMAIL, scale=4)
        send_btn = gr.Button("Send Mail", visible=True, scale=1)
    send_status = gr.Markdown("")

    def _on_send(to_addr: str):
        global _LAST_EMAIL_SUBJ, _LAST_EMAIL_BODY, _LAST_ROUTE
        if _LAST_ROUTE != "EMAIL":
            return "No email draft detected. Ask for an email draft first."
        if not to_addr:
            return "Please provide a 'To' address."
        if not _LAST_EMAIL_BODY:
            return "No email body available."

        try:
            scopes = ["https://www.googleapis.com/auth/gmail.send"]
            creds = None
            if os.path.exists(GMAIL_TOKEN_FILE):
                creds = Credentials.from_authorized_user_file(GMAIL_TOKEN_FILE, scopes)
            if not creds or not creds.valid:
                if creds and creds.expired and creds.refresh_token:
                    creds.refresh(Request())
                else:
                    flow = InstalledAppFlow.from_client_secrets_file(GMAIL_CREDENTIALS_FILE, scopes)
                    try:
                        creds = flow.run_local_server(port=0)
                    except Exception:
                        print("Falling back to console OAuth flow...")
                        creds = flow.run_console()
                with open(GMAIL_TOKEN_FILE, "w") as f:
                    f.write(creds.to_json())
            service = build("gmail", "v1", credentials=creds)

            sender = SENDER_EMAIL or ""
            if not sender:
                return "SENDER_EMAIL not set in environment."

            # Use cleaned subject and body
            msg = MIMEText(_LAST_EMAIL_BODY, _subtype="plain", _charset="utf-8")
            msg["to"] = to_addr
            msg["from"] = sender
            msg["subject"] = _LAST_EMAIL_SUBJ or SUBJECT_FALLBACK

            raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
            sent = service.users().messages().send(userId="me", body={"raw": raw}).execute()
            return f"Sent. Gmail message id: {sent.get('id','')}"
        except Exception as e:
            return f"Send failed: {e}"

    send_btn.click(_on_send, inputs=[to_box], outputs=[send_status])

if __name__ == "__main__":
    app.launch()

  df = pd.read_sql(f"SELECT * FROM {SOURCE_TABLE}", conn)
  embeds = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
2025-08-20 14:20:44,933 | INFO | Use pytorch device_name: cuda:0
2025-08-20 14:20:44,934 | INFO | Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5
2025-08-20 14:20:50,865 | INFO | Loading faiss with AVX2 support.
2025-08-20 14:20:51,102 | INFO | Successfully loaded faiss with AVX2 support.
2025-08-20 14:20:51,115 | INFO | Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. This is only an error if you're trying to use GPU Faiss.


Running on local URL:  http://127.0.0.1:7860


2025-08-20 14:20:51,981 | INFO | HTTP Request: GET http://127.0.0.1:7860/startup-events "HTTP/1.1 200 OK"
2025-08-20 14:20:51,981 | INFO | HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
2025-08-20 14:20:52,015 | INFO | HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
2025-08-20 14:20:52,124 | INFO | HTTP Request: HEAD http://127.0.0.1:7860/ "HTTP/1.1 200 OK"



To create a public link, set `share=True` in `launch()`.


2025-08-20 14:20:52,945 | INFO | HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-08-20 14:20:52,974 | INFO | HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-08-20 14:21:15,767 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
  docs = schema_retriever.get_relevant_documents(question)
2025-08-20 14:21:36,344 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-20 14:21:36,363 | INFO | Generated SQL:
SELECT claim_approval_rate FROM bi_dwh.prediction_agent_data 
WHERE UPPER(REPLACE(REGEXP_REPLACE(policy_no::text, '^''', ''), ' ', '')) = '201140020123704866505000';
2025-08-20 14:21:36,366 | INFO | Corrected SQL:
SELECT * FROM bi_dwh.prediction_agent_data
WHERE UPPER(REPLACE(REGEXP_REPLACE(policy_no::text, '^''', ''), ' ', '')) = '201140020123704866505000';
2025-08-20 14:21:39,208 | INFO | HTTP Request: POST https://api.groq.com/openai/v1/chat/