# Text-to-SQL Pipeline

## Learning Goals
- Understand the components of a **Text-to-SQL** pipeline.
- Use an LLM to generate SQL queries from natural language questions.
- Execute queries on a sample SQLite database.
- Implement a correction loop when SQL fails.
- Return results in natural language.

This notebook corresponds to Section *1.7 Text-to-SQL* in the lecture notes.

In [7]:
# %load get_llm.py
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama

# Load environment variables from .env
load_dotenv()

def get_llm(provider: str = "openai"):
    """
    Return a language model instance configured for either OpenAI or Ollama.

    This function centralizes the initialization of chat-based LLMs so that 
    notebooks and applications can switch seamlessly between cloud-based models 
    (OpenAI) and local models (Ollama).

    Parameters
    ----------
    provider : str, optional
        The backend provider to use. Options are:
        - "openai": returns a ChatOpenAI instance (requires OPENAI_API_KEY in .env).
        - "ollama": returns a ChatOllama instance (requires Ollama installed locally).
        Default is "openai".

    Returns
    -------
    langchain.chat_models.base.BaseChatModel
        A chat model instance that can be invoked with messages.

    Examples
    --------
    Initialize an OpenAI model (requires API key):

    >>> llm = get_llm("openai")
    >>> llm.invoke("Hello, how are you?")

    Initialize a local Ollama model (e.g., Gemma2 2B):

    >>> llm = get_llm("ollama")
    >>> llm.invoke("Summarize the benefits of reinforcement learning.")
    """
    if provider == "openai":
        return ChatOpenAI(
            model="gpt-4o-mini",  # can also be "gpt-4.1" or "gpt-4o"
            temperature=0
        )
    elif provider == "ollama":
        return ChatOllama(
            model="gemma2:2b",   # replace with any local model installed in Ollama
            temperature=0
        )
    else:
        raise ValueError("Unsupported provider. Use 'openai' or 'ollama'.")


In [8]:
import sqlite3
import pandas as pd

# Create a small in-memory database
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()

# Example schema: customers and orders
cursor.execute("""
CREATE TABLE customers (
    id INTEGER PRIMARY KEY,
    name TEXT,
    country TEXT
)
""")

cursor.execute("""
CREATE TABLE orders (
    id INTEGER PRIMARY KEY,
    customer_id INTEGER,
    amount REAL,
    FOREIGN KEY(customer_id) REFERENCES customers(id)
)
""")

# Insert sample data
customers = [(1, "Alice", "Brazil"), (2, "Bob", "USA"), (3, "Charlie", "Brazil")]
orders = [(1, 1, 100.0), (2, 2, 200.0), (3, 1, 150.0)]
cursor.executemany("INSERT INTO customers VALUES (?, ?, ?)", customers)
cursor.executemany("INSERT INTO orders VALUES (?, ?, ?)", orders)
conn.commit()

## Step 1 — Inspect schema
The schema is retrieved with `PRAGMA table_info` so the LLM can be informed about available tables and columns.

In [9]:
def get_schema(cursor):
    schema = {}
    for table in ["customers", "orders"]:
        cursor.execute(f"PRAGMA table_info({table});")
        schema[table] = [row[1] for row in cursor.fetchall()]
    return schema

schema = get_schema(cursor)
print("Database schema:", schema)

Database schema: {'customers': ['id', 'name', 'country'], 'orders': ['id', 'customer_id', 'amount']}


## Step 2 — Prompt LLM to generate SQL
The LLM receives the schema and a question, and is instructed to output only SQL.

In [10]:
from langchain_core.prompts import ChatPromptTemplate

# Assume get_llm() was defined in Notebook 2
llm = get_llm("openai")  # or get_llm("ollama")

sql_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are an expert in SQL. Given a schema and a question, output only a SQL query."),
    ("human", "Schema: {schema}\nQuestion: {question}")
])

def generate_sql(question: str):
    chain = sql_prompt | llm
    return chain.invoke({"schema": schema, "question": question}).content

## Step 3 — Execute SQL with correction loop
The query is executed with pandas. If it fails, the error message is passed back to the LLM for correction.

In [11]:
def run_query(sql: str):
    try:
        return pd.read_sql_query(sql, conn)
    except Exception as e:
        return f"Execution error: {e}"

def text_to_sql(question: str):
    sql = generate_sql(question)
    print("Generated SQL:", sql)
    result = run_query(sql)
    if isinstance(result, str):  # error string
        correction_prompt = f"The SQL failed with error: {result}. Please suggest a corrected SQL query."
        sql = llm.invoke([("human", correction_prompt)]).content
        print("Corrected SQL:", sql)
        result = run_query(sql)
    return result

## Step 4 — Try examples
The pipeline is tested with natural language questions.

In [12]:
print(text_to_sql("List the names of all customers from Brazil."))
print(text_to_sql("What is the total amount of orders per customer?"))

Generated SQL: ```sql
SELECT name FROM customers WHERE country = 'Brazil';
```
Corrected SQL: The error you're encountering is due to the presence of the code block formatting (```sql) in your SQL query. SQL queries should not include these formatting markers. Here’s the corrected SQL query:

```sql
SELECT name FROM customers WHERE country = 'Brazil';
```

Make sure to run the query without the backticks and any additional formatting. Just use the SQL statement as it is shown above.
Execution error: Execution failed on sql 'The error you're encountering is due to the presence of the code block formatting (```sql) in your SQL query. SQL queries should not include these formatting markers. Here’s the corrected SQL query:

```sql
SELECT name FROM customers WHERE country = 'Brazil';
```

Make sure to run the query without the backticks and any additional formatting. Just use the SQL statement as it is shown above.': near "The": syntax error
Generated SQL: ```sql
SELECT c.id, c.name, SUM(o.

### Reflection
- Text-to-SQL pipelines require schema awareness, grounding values, and iterative correction.
- The LLM handles translation from natural language → SQL, but verification is needed.
- In real-world applications, more sophisticated schema linking and safety checks are necessary.

## Exercises
1. Add a new table `products` and extend the schema. Ask multi-table questions.
2. Force the LLM to output results in JSON instead of plain SQL.
3. Replace `get_llm(\"openai\")` with `get_llm(\"ollama\")` to run with a local model.