In [1]:
from langchain_community.utilities import SQLDatabase
from langchain_community.llms import Ollama
from langchain.chains import create_sql_query_chain
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
import re

In [2]:
db = SQLDatabase.from_uri("postgresql://postgres:postgres@localhost:5432/dvdrental")
print(db.get_usable_table_names())  # sanity-check


['actor', 'address', 'category', 'city', 'country', 'customer', 'film', 'film_actor', 'film_category', 'inventory', 'language', 'payment', 'rental', 'staff', 'store']


In [3]:
schema_descr = db.get_table_info()  # full ER-style description
print(schema_descr)


CREATE TABLE actor (
	actor_id SERIAL NOT NULL, 
	first_name VARCHAR(45) NOT NULL, 
	last_name VARCHAR(45) NOT NULL, 
	last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, 
	CONSTRAINT actor_pkey PRIMARY KEY (actor_id)
)

/*
3 rows from actor table:
actor_id	first_name	last_name	last_update
1	Penelope	Guiness	2013-05-26 14:47:57.620000
2	Nick	Wahlberg	2013-05-26 14:47:57.620000
3	Ed	Chase	2013-05-26 14:47:57.620000
*/


CREATE TABLE address (
	address_id SERIAL NOT NULL, 
	address VARCHAR(50) NOT NULL, 
	address2 VARCHAR(50), 
	district VARCHAR(20) NOT NULL, 
	city_id SMALLINT NOT NULL, 
	postal_code VARCHAR(10), 
	phone VARCHAR(20) NOT NULL, 
	last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now() NOT NULL, 
	CONSTRAINT address_pkey PRIMARY KEY (address_id), 
	CONSTRAINT fk_address_city FOREIGN KEY(city_id) REFERENCES city (city_id)
)

/*
3 rows from address table:
address_id	address	address2	district	city_id	postal_code	phone	last_update
1	47 MySakila Drive	None	Alberta	30

In [4]:
#  Prompt → produce a mini brief prompt that provides only neccessary tables and columns
from langchain_core.prompts import PromptTemplate

focused_schema_prompt = PromptTemplate.from_template(
    """
    You are a senior data engineer helping an LLM write SQL.

    INPUTS
    • QUESTION: {question}
    • RAW_SCHEMA: {schema}

    TASK
    1. Read the QUESTION first. List columns required by QUESTION. Decide which tables are **absolutely required**.
       • Pick the neccessary tables needed to answer the QUESTION.
       • Do NOT include tables whose columns will not appear in SELECT, WHERE, or JOIN.
    2. For each required table, output:
       - One-sentence purpose.
       - Bullet list of relevant columns:
         PK: primary key,FK: foreign key (mention referenced table), human-readable/business/description fields
    3. Then add:
       Necessary columns for query: <comma-separated list>
       Key relationships: <one short sentence per join path>

    STYLE
    • Use markdown bullets.
    • Keep the whole description under 500 words.
    """
)



In [5]:
from langchain.schema.runnable import RunnablePassthrough
from langchain_community.llms import Ollama

llm = Ollama(model="llama3.1", temperature =0) 

focused_schema_chain = (
    {"question": RunnablePassthrough(), "schema": RunnablePassthrough()}
    | focused_schema_prompt
    | llm
)


  llm = Ollama(model="llama3.1", temperature =0)


In [None]:
# focused schema chain working example
user_question = "Which three film categories brought in the highest total rental revenue across all stores in 2007?"
schema_text = db.get_table_info()           # your long DDL

mini_description = focused_schema_chain.invoke(
    {"question": user_question, "schema": schema_text}
)

print(mini_description)


**TASK**

## Step 1: Read the QUESTION and identify required tables

The QUESTION is:
"Which tables are absolutely required to answer this question?"

After reading the QUESTION, we can list the columns required by the QUESTION:

* `film_id`
* `category_id`

Based on these columns, we can decide which tables are **absolutely required**.

## Step 2: Identify necessary tables and their purposes

The necessary tables needed to answer the QUESTION are:

### film_category table
Purpose: Stores relationships between films and categories.
Relevant columns:
* `film_id` (PK): primary key referencing the film table, FK referencing the film table
* `category_id` (PK): primary key referencing the category table, FK referencing the category table

### film table
Purpose: Stores information about individual films.
Relevant columns:
* `film_id` (PK): primary key

### category table
Purpose: Stores categories of films.
Relevant columns:
* `category_id` (PK): primary key

## Step 3: Add necessary colum

In [7]:
# natural language → SQL (no execution)
write_query_chain  = create_sql_query_chain(llm=llm, db=db)



In [8]:
response = write_query_chain.invoke({"question": schema_descr + " How many films are in the catalogue?"})
response

'SELECT COUNT(film_id) FROM film'

In [21]:
# Helper that keeps the original question and mini-description
def build_sql(inputs):
    mini_desc = focused_schema_chain.invoke(
        {"question": inputs["question"], "schema": schema_descr}
    )
    prompt = mini_desc + "\n\n### Question: " + inputs["question"]
    sql_text = write_query_chain.invoke({"question": prompt})
    return {"question": inputs["question"], "sql_response": sql_text}

sql_builder = RunnableLambda(build_sql)


In [22]:
# Extract SQL, clean it, execute, return rows
def exec_sql(inputs):
    raw = inputs["sql_response"]

    # Priority 1 ─ fenced block ```sql … ```
    block = re.search(r"```sql\s*(.*?)\s*```", raw, re.I | re.S)
    if block:
        sql = block.group(1).strip()
    else:
        # Priority 2 ─ text after 'SQLQuery:' (may span lines)
        after = re.search(r"SQLQuery:\s*(.*)", raw, re.S)
        if after:
            sql = after.group(1).strip()
        else:
            #  last-chance fallback – grab the first SELECT/WITH until the first semicolon
            bare = re.search(r"(?i)\b(SELECT|WITH).*?;", raw, re.S)
            sql = bare.group(0).strip() if bare else ""

    # Remove any stray fences the regex missed
    sql = sql.replace("```", "").strip()

    # Run and handle errors gracefully
    try:
        rows = db.run(sql) if sql else []
    except Exception as e:
        rows = f"Execution error: {e}"

    return {
        "question": inputs["question"],
        "query": sql,
        "result": rows
    }

sql_executor = RunnableLambda(exec_sql)


In [24]:
sql = """
SELECT 
    c.name AS category_name, 
    SUM(p.amount) AS total_revenue
FROM 
    payment p
JOIN 
    rental r ON p.rental_id = r.rental_id
JOIN 
    inventory i ON r.inventory_id = i.inventory_id
JOIN 
    film f ON i.film_id = f.film_id
JOIN 
    film_category fc ON f.film_id = fc.film_id
JOIN 
    category c ON fc.category_id = c.category_id
WHERE 
    EXTRACT(YEAR FROM p.payment_date) = 2007
GROUP BY 
    c.name
ORDER BY 
    total_revenue DESC
LIMIT 3;
"""
rows = db.run(sql)
rows

"[('Sports', Decimal('4892.19')), ('Sci-Fi', Decimal('4336.01')), ('Animation', Decimal('4245.31'))]"

In [28]:
# # Prompt to return BOTH query & narrative answer
answer_prompt = PromptTemplate.from_template(
    """You are a helpful data analyst that accurately interprets the result without adding additional information.

Return your reply in **two parts**:

1. `Answer:` – a plain-English explanation of the result.
2. `SQL Query:` – show the exact query inside a ```sql fenced block.

Use the inputs below.

Question: {question}
SQL Query (raw string): {query}
SQL Result: {result}

Respond now."""
)



In [25]:
# 1️⃣1️⃣  Assemble full pipeline (question → answer)
full_chain = (
    sql_builder
    | sql_executor
    | answer_prompt
    | llm
    | StrOutputParser()
)


In [32]:
#  Ask something
user_q = "What are the names of the three film categories that brought in the highest total rental revenue across all stores in 2007?"
print(full_chain.invoke({"question": user_q}))


**Answer:** The three film categories that brought in the highest total rental revenue across all stores in 2007 are Sports, Sci-Fi, and Animation.

**SQL Query:**
```sql
SELECT 
    fc.category_id, 
    c.name AS category_name, 
    SUM(p.amount) AS total_revenue
FROM 
    payment p
JOIN 
    rental r ON p.rental_id = r.rental_id
JOIN 
    inventory i ON r.inventory_id = i.inventory_id
JOIN 
    film f ON i.film_id = f.film_id
JOIN 
    film_category fc ON f.film_id = fc.film_id
JOIN 
    category c ON fc.category_id = c.category_id
WHERE 
    EXTRACT(YEAR FROM p.payment_date) = 2007
GROUP BY 
    fc.category_id, c.name
ORDER BY 
    total_revenue DESC
LIMIT 3;
```
