# 06 — Text-to-SQL (Aligned Pipeline)

> **Course:** Introduction to LLM-Based Agents — SBBD 2025  
> **Notebook goal:** Provide a didactic, step-by-step Text-to-SQL pipeline consistent with the course notes:  
Intent Parsing → Schema Linking → Value Grounding → SQL Generation → Execution & Correction → Answer & Interactive Refinement.

**Important:** The notebook is structured to be read and *then* run block-by-block. Some cells include placeholders (e.g., connection string, `get_llm()`), which must be adapted before execution.

## Pipeline overview & checklist

1. **Tool registration (SQL) & schema snapshot** — register the SQL connection and serialize a compact view of the schema for prompting.  
2. **Intent Parsing** — extract a structured intent (task, entities, filters, time range, grouping, sorting, limit).  
3. **Schema Linking** — retrieve candidate tables/columns and validate a minimal subgraph (tables, columns, join path).  
4. **Value Grounding** — normalize temporal ranges and entity values to match the database contents.  
5. **SQL Generation** — produce a single, safe `SELECT` query for the target dialect (with `LIMIT` fallback).  
6. **Execution & Correction** — run the query; on error, feed the error back to the LLM and minimally correct the SQL.  
7. **Answer & Interactive Refinement** — convert rows into a concise natural-language answer; if empty, propose a clarification.

> The code is written with LangChain's `Runnable` API and a `get_llm()` factory to align with the course preferences.

In [None]:
# (Optional) Dependencies — keep commented unless the environment needs installs.
# !pip install langchain langchain-core langchain-community langchain-openai
# !pip install sqlalchemy
# !pip install faiss-cpu  # or 'chromadb' if preferred as retriever backend

In [None]:
from __future__ import annotations

# Core / typing
from typing import List, Dict, Any, Optional, TypedDict
import json
from dataclasses import dataclass
from datetime import date, datetime, timedelta

# LangChain core
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough

# LangChain community
from langchain_community.utilities import SQLDatabase
# If using FAISS for catalog retrieval:
# from langchain_community.vectorstores import FAISS
# from langchain_community.embeddings import HuggingFaceEmbeddings

# NOTE: The function get_llm() is expected to be provided by the course utilities.
# It should return a chat model compatible with LangChain Runnable (e.g., ChatOpenAI or ChatOllama).

In [4]:
# %load get_llm.py
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama

# Load environment variables from .env
load_dotenv()

def get_llm(provider: str = "openai"):
    """
    Return a language model instance configured for either OpenAI or Ollama.

    This function centralizes the initialization of chat-based LLMs so that 
    notebooks and applications can switch seamlessly between cloud-based models 
    (OpenAI) and local models (Ollama).

    Parameters
    ----------
    provider : str, optional
        The backend provider to use. Options are:
        - "openai": returns a ChatOpenAI instance (requires OPENAI_API_KEY in .env).
        - "ollama": returns a ChatOllama instance (requires Ollama installed locally).
        Default is "openai".

    Returns
    -------
    langchain.chat_models.base.BaseChatModel
        A chat model instance that can be invoked with messages.

    Examples
    --------
    Initialize an OpenAI model (requires API key):

    >>> llm = get_llm("openai")
    >>> llm.invoke("Hello, how are you?")

    Initialize a local Ollama model (e.g., Gemma2 2B):

    >>> llm = get_llm("ollama")
    >>> llm.invoke("Summarize the benefits of reinforcement learning.")
    """
    if provider == "openai":
        return ChatOpenAI(
            model="gpt-4o-mini",  # can also be "gpt-4.1" or "gpt-4o"
            temperature=0
        )
    elif provider == "ollama":
        return ChatOllama(
            model="gemma2:2b",   # replace with any local model installed in Ollama
            temperature=0
        )
    else:
        raise ValueError("Unsupported provider. Use 'openai' or 'ollama'.")


## 0) Tool registration (SQL) & schema snapshot

Provide a DB connection (`DB_URI`) and create a compact, text-based schema snapshot for prompting. The snapshot should include:
- table names, column names and types,
- primary/foreign keys and obvious relationships,
- a few example values for categorical columns (optional but recommended).

In [5]:
# --- Configuration ---
DB_URI = "sqlite:///./tickets.sqlite"

# --- Connect (no execution at import time) ---
db = SQLDatabase.from_uri(DB_URI)

SCHEMA_MAX_COLS_PER_TABLE = 12
EXAMPLE_ROWS_PER_COLUMN = 3

def introspect_schema(db: SQLDatabase,
                      max_cols_per_table: int = SCHEMA_MAX_COLS_PER_TABLE,
                      examples_per_col: int = EXAMPLE_ROWS_PER_COLUMN) -> str:
    """Return a compact text representation of the schema for prompting.
    Implementation may query sqlite_master / information_schema under the hood through SQLDatabase.
    Keep it short and readable.
    """
    # Pseudocode sketch — adjust to your dialect if needed:
    # tables = db.get_usable_table_names()
    # lines = []
    # for t in tables:
    #     cols = db.run(f"PRAGMA table_info({t})")  # SQLite example; use proper INFO schema for other DBs
    #     lines.append(f"TABLE {t}:")
    #     for c in cols[:max_cols_per_table]:
    #         lines.append(f"  - {c['name']} ({c['type']})")
    #     # Optionally: example values (head) for categorical columns
    # return "\n".join(lines)
    return ""  # Placeholder to keep the notebook runnable until DB is configured.

# Example (leave commented until DB is available):
# schema_text = introspect_schema(db)
# print(schema_text)

NameError: name 'SQLDatabase' is not defined

## 1) Intent Parsing

Extract a structured intent from the user question. The model must return **only** the requested JSON schema.

In [None]:
class IntentDict(TypedDict):
    task: str             # e.g., 'select', 'count', 'aggregate', 'topk'
    entities: List[str]   # conceptual entities (e.g., 'customers', 'tickets')
    measures: List[str]   # numeric fields or aggregations (e.g., 'sum(total)')
    filters: Dict[str, Any]
    time_range: Dict[str, Any]  # e.g., {'relative': 'last_month'} or {'start': '2025-06-01', 'end': '2025-06-30'}
    group_by: List[str]
    order_by: List[str]
    limit: int

intent_prompt = ChatPromptTemplate.from_messages([
    ("system",
     """Extract a database intent as strict JSON with keys:
     task, entities, measures, filters, time_range, group_by, order_by, limit.
     Return **ONLY** the JSON object, with null for unknown fields."""),
    ("user", "{question}")
])

# parse_intent = intent_prompt | get_llm().with_structured_output(IntentDict)  # To be enabled after get_llm()

# Example usage (keep commented):
# user_question = "How many tickets were closed last month?"
# intent: IntentDict = parse_intent.invoke({"question": user_question})
# print(intent)

## 2) Schema Linking

Create a textual catalog of `table.column – type – notes` entries, use embeddings to retrieve candidates, then validate a **minimal subgraph** (tables, columns, join path) via LLM.

In [None]:
def build_catalog_documents(schema_text: str) -> List[str]:
    """Split the schema snapshot into small, retrievable chunks (1 per column or section).
    Return a list of short strings.
    """
    if not schema_text:
        return []
    # Simple splitter (improve as needed):
    docs = [line.strip() for line in schema_text.splitlines() if line.strip()]
    return docs

def summarize_candidates(candidates: List[Any], max_len: int = 1200) -> str:
    """Concatenate top-k candidate snippets into a single string for the validator LLM."""
    text = "\n".join(getattr(c, "page_content", str(c)) for c in candidates)
    return text[:max_len]

link_prompt = ChatPromptTemplate.from_messages([
    ("system",
     """You are validating schema usage for Text-to-SQL.
Given the candidates below and the question, list ONLY:
- relevant tables,
- relevant columns per table,
- the minimal join path (as pairs table.col -> table.col),
in a compact JSON with keys {tables, columns, joins}. Return only JSON."""),
    ("user", "Question: {question}\n\nCandidates:\n{cands}")
])

# Example (commented until retriever is configured):
# embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# vs = FAISS.from_texts(build_catalog_documents(schema_text), embedder)
# schema_retriever = vs.as_retriever(search_kwargs={"k": 12})
# candidates = schema_retriever.get_relevant_documents(user_question)
# linked_schema = (link_prompt | get_llm()).invoke({"question": user_question, "cands": summarize_candidates(candidates)})
# print(linked_schema)

## 3) Value Grounding

Normalize temporal ranges and map entity names to existing values before generating SQL.

In [None]:
def month_bounds(ref: date) -> tuple[date, date]:
    start = ref.replace(day=1)
    if start.month == 12:
        end = start.replace(year=start.year+1, month=1, day=1) - timedelta(days=1)
    else:
        end = start.replace(month=start.month+1, day=1) - timedelta(days=1)
    return start, end

def normalize_time(time_range: Optional[Dict[str, Any]], ref_date: Optional[date] = None) -> Dict[str, str]:
    """Map relative time ranges to explicit YYYY-MM-DD bounds."""
    if ref_date is None:
        ref_date = date.today()
    if not time_range:
        return {}
    rel = str(time_range.get("relative", "")).lower()
    if rel == "last_month":
        first_of_this = ref_date.replace(day=1)
        last_month_end = first_of_this - timedelta(days=1)
        start, end = last_month_end.replace(day=1), last_month_end
        return {"start": start.isoformat(), "end": end.isoformat()}
    if {"start", "end"}.issubset(time_range.keys()):
        return {"start": str(time_range["start"]), "end": str(time_range["end"])}
    return {}

def lookup_entities(entities: List[str], db: Optional[SQLDatabase], linked_schema_json: Any) -> Dict[str, Any]:
    """Attempt to ground entity strings to actual DB values.
    Implementation here is a placeholder; typical approach:
      - For each entity-like field found in linked_schema_json, sample distinct values and fuzzy-match.
    """
    grounded = {}
    # Example stub — extend with real queries over `db`:
    for e in entities or []:
        grounded[e] = {"matched": e, "confidence": 0.5}
    return grounded

## 4) SQL Generation (guardrails)

Generate a single `SELECT` query for the chosen dialect, using only the linked schema and grounded values. Enforce safe defaults:
- **No** DDL/DML or multiple statements,
- add `LIMIT` if missing,
- include a short SQL comment with the intent summary.

In [None]:
def extract_sql_block(text: str) -> str:
    """Extract the SQL code block from an LLM response; return raw SQL string."""
    if text is None:
        return ""
    fence = "```"
    if fence in text:
        inner = text.split(fence)
        # Try to pick the first fenced block
        for i in range(len(inner)-1):
            candidate = inner[i+1]
            # strip leading sql identifier (e.g., 'sql\n')
            candidate = candidate.split("\n", 1)[-1]
            if ";" in candidate or "SELECT" in candidate.upper():
                return candidate.strip()
    return text.strip()

sql_prompt = ChatPromptTemplate.from_messages([
    ("system",
     """You are a SQL generator. Follow STRICT rules:
- Use ONLY the provided linked schema: {linked}
- Respect grounded values: {grounded}
- Task intent (JSON): {intent}
- Dialect: {dialect}
- Generate exactly one SELECT statement (no DDL/DML, no multiple statements).
- If LIMIT is unspecified, add LIMIT {default_limit}.
- Add a leading SQL comment with a one-line intent summary."""),
    ("user", "{question}")
])

# generate_sql = sql_prompt | get_llm() | RunnableLambda(extract_sql_block)  # Enable after get_llm()

## 5) Execution & Correction loop

Run the SQL. On error, provide the error message to the LLM and request a **minimal correction** using the same linked schema and grounded values.

In [None]:
def run_sql(db: SQLDatabase, sql: str) -> tuple[list[dict], Optional[str]]:
    try:
        rows = db.run(sql)
        # `db.run` often returns a list of tuples; convert to dicts if column names are available.
        # This is dialect/driver-dependent; adapt as needed.
        return rows, None
    except Exception as e:
        return [], str(e)

fix_prompt = ChatPromptTemplate.from_messages([
    ("system",
     """The previous SQL failed with the error below. Return a minimal corrected SQL.
- Keep the SAME intent and grounded values.
- Use ONLY the same linked schema.
- Return a single SELECT statement.
Error: {error}"""),
    ("user", "Original SQL:\n{sql}")
])

def correction_loop(db: SQLDatabase, sql: str, linked: str, intent: Dict[str, Any],
                    grounded: Dict[str, Any], dialect: str = "PostgreSQL",
                    max_retries: int = 2, default_limit: int = 100) -> tuple[str, list[dict], Optional[str]]:
    rows, err = run_sql(db, sql)
    if not err:
        return sql, rows, None
    for _ in range(max_retries):
        # repaired = (fix_prompt | get_llm() | RunnableLambda(extract_sql_block)).invoke(
        #     {"error": err, "sql": sql, "linked": linked,
        #      "intent": json.dumps(intent), "grounded": json.dumps(grounded),
        #      "dialect": dialect, "default_limit": default_limit}
        # )
        repaired = sql  # placeholder to keep structure; replace with the line above.
        rows, err = run_sql(db, repaired)
        sql = repaired
        if not err:
            break
    return sql, rows, err

## 6) Answer & Interactive Refinement

Convert result rows into a concise natural-language answer; when the result is empty, propose a clarifying question.

In [None]:
answer_prompt = ChatPromptTemplate.from_messages([
    ("system",
     """Summarize the SQL result into a concise, neutral answer.
Include the time window and key filters if available.
If there are no rows, propose a short clarifying question."""),
    ("user", "Rows (JSON): {rows}\nIntent: {intent}\nGrounded: {grounded}")
])

# final_answer = (answer_prompt | get_llm())  # Enable after get_llm()

### (Optional) Create a tiny SQLite demo

This cell creates a miniature dataset for quick local tests. Keep it commented in production.

In [3]:
# import sqlite3
# conn = sqlite3.connect("./tickets.sqlite")
# cur = conn.cursor()
# cur.execute("CREATE TABLE IF NOT EXISTS tickets (id INTEGER PRIMARY KEY, status TEXT, closed_date TEXT);")
# cur.execute("DELETE FROM tickets;")
# cur.executemany("INSERT INTO tickets (id, status, closed_date) VALUES (?, ?, ?);", [
#     (1, "closed", "2025-08-05"),
#     (2, "open",   None),
#     (3, "closed", "2025-08-22"),
# ])
# conn.commit(); conn.close()

## End-to-end driver (commented)

The following cell shows the end-to-end sequence. Uncomment after wiring `get_llm()` and setting `DB_URI`.

In [None]:
# user_question = "How many tickets were closed last month?"
#
# # 0) Tool registration
# db = SQLDatabase.from_uri(DB_URI)
# schema_text = introspect_schema(db)
#
# # 1) Intent
# parse_intent = intent_prompt | get_llm().with_structured_output(IntentDict)
# intent: IntentDict = parse_intent.invoke({"question": user_question})
#
# # 2) Schema Linking
# # (build retriever)
# # embedder = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# # vs = FAISS.from_texts(build_catalog_documents(schema_text), embedder)
# # schema_retriever = vs.as_retriever(search_kwargs={"k": 12})
# # candidates = schema_retriever.get_relevant_documents(user_question)
# # linked_schema = (link_prompt | get_llm()).invoke({
# #     "question": user_question, "cands": summarize_candidates(candidates)
# # })
#
# linked_schema = "{}"  # temporary if retriever is not configured
#
# # 3) Value Grounding
# grounded = {
#     "time_range": normalize_time(intent.get("time_range"), ref_date=date(2025, 9, 1)),
#     "entities": lookup_entities(intent.get("entities"), db if 'db' in locals() else None, linked_schema)
# }
#
# # 4) SQL Generation
# generate_sql = sql_prompt | get_llm() | RunnableLambda(extract_sql_block)
# sql = generate_sql.invoke({
#     "linked": linked_schema, "grounded": json.dumps(grounded),
#     "intent": json.dumps(intent), "dialect": "SQLite", "default_limit": 100,
#     "question": user_question
# })
#
# # 5) Execute & correct
# final_sql, rows, err = correction_loop(db, sql, linked_schema, intent, grounded, dialect="SQLite")
#
# # 6) Answer
# # final_answer_chain = answer_prompt | get_llm()
# # answer = final_answer_chain.invoke({
# #     "rows": json.dumps(rows), "intent": json.dumps(intent), "grounded": json.dumps(grounded)
# # })
# # print(answer)
#
# print("SQL:\n", final_sql)
# print("Rows:\n", rows)
# print("Error:\n", err)

## Notes & tips

- Keep the schema snapshot **short**; prune aggressively to minimize prompt tokens.  
- Persist intermediate artifacts (intent JSON, linked schema JSON, grounded values) for debugging.  
- In academic demos, add a short `-- intent:` comment line at the top of the generated SQL to aid students' understanding.  
- Cap retries in the correction loop and log `(SQL → error → fix)` history for transparency.