# NL→SQL Conversational Agent with Snowflake Cortex + Snowpark + LangGraph

This notebook demonstrates how to build an agentic conversational platform that:
- Converts natural language (NL) queries into SQL using Snowflake Cortex LLM
- Validates & executes queries on Snowflake (mocked locally here)
- Summarizes results for insights
- Uses LangGraph for state & memory management

In [1]:
import os, re, json, requests, sqlparse, pandas as pd
from typing import Dict, Optional
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.checkpoint.memory import MemorySaver

# -------------------------
# CONFIG
# -------------------------
MOCK_MODE = True
SNOWFLAKE_ACCOUNT = os.getenv("SNOWFLAKE_ACCOUNT", "<YOUR_ACCOUNT>")
CORTEX_BEARER_TOKEN = os.getenv("CORTEX_BEARER_TOKEN", "<YOUR_CORTEX_BEARER_TOKEN>")

# -------------------------
# MOCK DATA
# -------------------------
mock_orders = pd.DataFrame({
    "order_id": [1, 2, 3, 4, 5],
    "customer_id": [10, 11, 10, 12, 13],
    "order_date": pd.to_datetime(["2024-09-01","2024-08-15","2024-07-20","2024-06-05","2024-05-30"]),
    "amount": [120.5, 45.0, 78.9, 150.0, 23.4]
})

In [2]:
# Cortex / Mock LLM
def cortex_infer_rest(prompt: str, model: str = "mistral-large2") -> dict:
    url = f"https://{SNOWFLAKE_ACCOUNT}.snowflakecomputing.com/api/v2/cortex/inference:complete"
    headers = {"Authorization": f"Bearer {CORTEX_BEARER_TOKEN}", "Content-Type": "application/json"}
    payload = {"model": model, "messages": [{"role": "user", "content": prompt}]}
    resp = requests.post(url, headers=headers, json=payload, timeout=60)
    resp.raise_for_status()
    return resp.json()

def mock_llm_generate(prompt: str) -> dict:
    sql = """
    SELECT DATE_TRUNC('month', order_date) AS year_month,
           SUM(amount) AS total_revenue
    FROM ORDERS
    WHERE order_date >= DATEADD(month, -12, CURRENT_DATE())
    GROUP BY 1
    ORDER BY 1 DESC
    LIMIT 1000
    """
    return {"choices": [{"message": {"content": json.dumps({"sql": sql})}}]}

In [3]:
# SQL Validator
def validate_sql(sql_text: str) -> str:
    if ";" in sql_text:
        raise ValueError("Semicolons not allowed")
    if not sql_text.lower().startswith("select"):
        raise ValueError("Only SELECT allowed")
    if "limit" not in sql_text.lower():
        sql_text += " LIMIT 1000"
    return sqlparse.format(sql_text, reindent=True, keyword_case="upper")

In [4]:
# SQL Executor
def execute_sql(sql_text: str) -> pd.DataFrame:
    if MOCK_MODE:
        df = mock_orders.copy()
        df["year_month"] = df["order_date"].dt.to_period("M").astype(str)
        agg = df.groupby("year_month", as_index=False)["amount"].sum()
        agg.rename(columns={"amount": "total_revenue"}, inplace=True)
        return agg.sort_values("year_month", ascending=False)
    else:
        # session.sql(sql_text).to_pandas()
        raise NotImplementedError("Snowflake execution not wired here.")

In [5]:
# Summarizer
def summarize(df: pd.DataFrame) -> str:
    if "total_revenue" in df.columns:
        return f"Rows: {len(df)} | Max Revenue: {df['total_revenue'].max()} | Avg Revenue: {df['total_revenue'].mean():.2f}"
    return f"Rows returned: {len(df)}"

In [6]:
# Graph State
class AgentState(dict):
    query: str
    sql: Optional[str]
    result: Optional[pd.DataFrame]
    summary: Optional[str]
    messages: list

In [7]:
# LangGraph Nodes
def llm_node(state: AgentState) -> AgentState:
    schema = "ORDERS(order_id, customer_id, order_date, amount)"
    prompt = f"You are a Snowflake SQL generator. Only return JSON: {{\"sql\":\"<SQL>\"}}. Use {schema}. Task: {state['query']}"
    resp = mock_llm_generate(prompt) if MOCK_MODE else cortex_infer_rest(prompt)
    content = resp["choices"][0]["message"]["content"]
    try:
        sql = json.loads(content)["sql"]
    except Exception:
        sql = content
    state["sql"] = sql
    state["messages"].append(AIMessage(content=f"Generated SQL: {sql}"))
    return state

def validate_node(state: AgentState) -> AgentState:
    state["sql"] = validate_sql(state["sql"])
    state["messages"].append(AIMessage(content=f"Validated SQL: {state['sql']}"))
    return state

def execute_node(state: AgentState) -> AgentState:
    df = execute_sql(state["sql"])
    state["result"] = df
    state["messages"].append(AIMessage(content=f"Executed SQL, got {len(df)} rows"))
    return state

def summarize_node(state: AgentState) -> AgentState:
    state["summary"] = summarize(state["result"])
    state["messages"].append(AIMessage(content=f"Summary: {state['summary']}"))
    return state

In [8]:
# Build Graph
workflow = StateGraph(AgentState)
workflow.add_node("llm", llm_node)
workflow.add_node("validate", validate_node)
workflow.add_node("execute", execute_node)
workflow.add_node("summarize", summarize_node)

workflow.set_entry_point("llm")
workflow.add_edge("llm", "validate")
workflow.add_edge("validate", "execute")
workflow.add_edge("execute", "summarize")
workflow.add_edge("summarize", END)

memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

In [9]:
# Run Conversation
init_state = AgentState(query="Show monthly revenue for last 12 months", sql=None, result=None, summary=None, messages=[HumanMessage(content="Show monthly revenue")])
result = app.invoke(init_state)

print("---- Conversation ----")
for msg in result["messages"]:
    role = "USER" if isinstance(msg, HumanMessage) else "AGENT"
    print(f"{role}: {msg.content}")
print("\nFinal Summary:", result["summary"])
