# SQL Generation with Transformer API

In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting bitsandbytes
  Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl.metadata (11 kB)
Downloading bitsandbytes-0.47.0-py3-none-manylinux_2_24_x86_64.whl (61.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.47.0


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
import torch

if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    available_memory = props.total_memory
    print(f"GPU: {props.name}, VRAM: {available_memory / 1e9:.2f} GB")
else:
    available_memory = 0
    print("CUDA not available — running on CPU only.")


GPU: Tesla T4, VRAM: 15.83 GB


In [None]:
torch.cuda.is_available()

True

In [None]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [None]:
print(available_memory)

15835660288

##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [None]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [None]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [None]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


# Exercise
 - Complete the prompts similar to what we did in class.
     - Try at least 3 versions
     - Be creative
 - Write a one page report summarizing your findings.
     - Were there variations that didn't work well? i.e., where GPT either hallucinated or wrong
 - What did you learn?

In [None]:
#3 versions of prompts

# --- 0) Imports (safe even if already imported) ---
import re, textwrap
import torch, sqlparse
from datetime import datetime

# Helper: detect device cleanly
DEVICE = getattr(model, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# --- 1) Your demo schema (replace with your real schema when ready) ---
DIALECT = "SQLite"  # or "Postgres"
SCHEMA = """\
TABLE customers(id, name, region)
TABLE orders(id, customer_id, order_date, total_amount)
TABLE order_items(id, order_id, product, quantity, price)
-- FK: orders.customer_id -> customers.id
-- FK: order_items.order_id -> orders.id
"""

# --- 2) Three creative prompt versions ---

# V1: Minimal, zero-shot, terse. Often fast but can hallucinate more.
PROMPT_V1 = """\
You are a Text-to-SQL model. Output only valid {dialect} SQL inside [SQL]...[/SQL].
Never include explanations or comments.

[SCHEMA]
{schema}

[QUESTION]
{question}

Return only:
[SQL]
-- SQL here
[/SQL]
"""

# V2: Few-shot examples demonstrating joins, GROUP BY, and date filters.
PROMPT_V2 = """\
You are a Text-to-SQL model. Output only valid {dialect} SQL inside [SQL]...[/SQL].
Be concise and correct. Use only tables and columns from the schema.

[SCHEMA]
{schema}

[EXAMPLES]
Q: List the names of customers in the 'New York' region.
[SQL]
SELECT name
FROM customers
WHERE region = 'New York';
[/SQL]

Q: Total order count per region.
[SQL]
SELECT c.region, COUNT(o.id) AS order_count
FROM customers c
LEFT JOIN orders o ON o.customer_id = c.id
GROUP BY c.region
ORDER BY order_count DESC;
[/SQL]

Q: Monthly revenue in 2024 (use order_date).
[SQL]
SELECT strftime('%Y-%m', o.order_date) AS month,
       SUM(oi.quantity * oi.price)     AS revenue
FROM orders o
JOIN order_items oi ON oi.order_id = o.id
WHERE o.order_date >= '2024-01-01' AND o.order_date < '2025-01-01'
GROUP BY strftime('%Y-%m', o.order_date)
ORDER BY month;
[/SQL]

[QUESTION]
{question}

Return only:
[SQL]
-- SQL here
[/SQL]
"""

# V3: Constrained, read-only guardrails, explicit date window for "last month".
PROMPT_V3 = """\
You are a professional Text-to-SQL generator for {dialect}. Rules:
- Output only a single SELECT query between [SQL]...[/SQL].
- Read-only: do not use INSERT, UPDATE, DELETE, DROP, ALTER, or CREATE.
- Use only tables/columns present in the schema.
- Prefer explicit JOINs with ON clauses.
- If the time range is "last month", for SQLite use:
  DATE(o.order_date) BETWEEN DATE('now','start of month','-1 month')
                         AND DATE('now','start of month','-1 day')

[SCHEMA]
{schema}

[QUESTION]
{question}

Return only:
[SQL]
-- SQL here
[/SQL]
"""

PROMPTS = {
    "V1-minimal": PROMPT_V1,
    "V2-fewshot": PROMPT_V2,
    "V3-guardrails": PROMPT_V3,
}

# --- 3) Generator that takes a raw prompt string (so we can swap templates easily) ---
def generate_sql_from_prompt(prompt_text: str, max_new_tokens: int = 320, num_beams: int = 1, temperature: float = 0.0) -> str:
    inputs = tokenizer(prompt_text, return_tensors="pt")
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=(temperature > 0),
            temperature=temperature,
            top_p=0.95,
            num_beams=num_beams,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    text = tokenizer.batch_decode(out_ids, skip_special_tokens=True)[0]
    # Extract SQL within tags; fallback if tags missing
    if "[SQL]" in text and "[/SQL]" in text:
        sql = text.split("[SQL]", 1)[-1].split("[/SQL]", 1)[0].strip()
    else:
        sql = text.strip()
    # Pretty print + ensure semicolon
    sql_fmt = sqlparse.format(sql, reindent=True, keyword_case="upper").strip()
    if not sql_fmt.endswith(";"):
        sql_fmt += ";"
    return sql_fmt

# --- 4) Test questions (same across all prompts) ---
QUESTIONS = [
    "What was our revenue by product in the New York region last month?",
    "Show total orders per customer.",
    "Which products generated the highest revenue in 2024? List top 5.",
    "Monthly revenue trend in 2024.",
    "List customers who placed no orders.",
]

# --- 5) Run the experiment ---
results = []  # list of dicts
for name, tmpl in PROMPTS.items():
    for q in QUESTIONS:
        prompt_text = tmpl.format(dialect=DIALECT, schema=SCHEMA, question=q)
        try:
            sql_out = generate_sql_from_prompt(prompt_text, num_beams=1, temperature=0.0)
            status = "ok"
            err = ""
        except Exception as e:
            sql_out, status, err = "", "error", str(e)
        results.append({
            "prompt_version": name,
            "question": q,
            "sql": sql_out,
            "status": status,
            "error": err
        })

# --- 6) Lightweight sanity checks to flag potential issues ---
def quick_checks(question: str, sql: str):
    issues = []
    s = sql.upper()
    # Read-only
    if re.search(r"\b(INSERT|UPDATE|DELETE|DROP|ALTER|CREATE)\b", s):
        issues.append("Non read-only statement detected.")
    # Schema usage
    for required in ["CUSTOMERS", "ORDERS", "ORDER_ITEMS"]:
        if required in question.upper() and required not in s:
            # only a soft hint; we don't strictly require all three.
            pass
    # Aggregation heuristics
    if any(k in question.lower() for k in ["revenue", "total", "count", "top", "trend"]):
        if "GROUP BY" not in s and "OVER(" not in s and "COUNT(" not in s and "SUM(" not in s:
            issues.append("Aggregation likely needed, but missing GROUP BY/window/agg.")
    # Region filter (heuristic)
    if "NEW YORK" in question.upper() and "NEW YORK" not in s:
        issues.append("Missing region filter for 'New York'.")
    # Time filter for last month (SQLite version)
    if "LAST MONTH" in question.lower():
        expected_snippet = "DATE('NOW','START OF MONTH','-1 MONTH')"
        if expected_snippet not in s.replace(" ", ""):
            issues.append("Potentially missing SQLite 'last month' window.")
    # Semicolon
    if not sql.endswith(";"):
        issues.append("Missing trailing semicolon.")
    return issues

for r in results:
    r["issues"] = quick_checks(r["question"], r["sql"])

# --- 7) Pretty-print a short comparison table ---
def short_row(r):
    badge = "✅" if (r["status"]=="ok" and not r["issues"]) else ("⚠️" if r["issues"] else "❌")
    return f"{badge} {r['prompt_version']:<14} | Q: {r['question']}"

print("\n=== Comparison (by question & prompt) ===")
for r in results:
    print(short_row(r))

# --- 8) (Optional) Show full SQL per item for inspection ---
def show_full(prompt_version=None, question_contains=None):
    for r in results:
        if prompt_version and r["prompt_version"] != prompt_version:
            continue
        if question_contains and question_contains.lower() not in r["question"].lower():
            continue
        print("\n---")
        print(f"[{r['prompt_version']}] {r['question']}")
        print(r["sql"])
        if r["issues"]:
            print("Issues:", "; ".join(r["issues"]))

# Example: show all outputs for the New York question
# show_full(question_contains="New York")

# --- 9) Generate a one-page textual summary programmatically (optional) ---
def summarize_findings(results):
    by_prompt = {}
    for r in results:
        by_prompt.setdefault(r["prompt_version"], {"ok":0,"warn":0,"err":0,"issues":[],"samples":0})
        by_prompt[r["prompt_version"]]["samples"] += 1
        if r["status"] != "ok":
            by_prompt[r["prompt_version"]]["err"] += 1
            by_prompt[r["prompt_version"]]["issues"].append(("RUNTIME", r["error"]))
        elif r["issues"]:
            by_prompt[r["prompt_version"]]["warn"] += 1
            by_prompt[r["prompt_version"]]["issues"].extend([(r["question"], i) for i in r["issues"]])
        else:
            by_prompt[r["prompt_version"]]["ok"] += 1

    lines = []
    lines.append("One-Page Report: Prompt Variations for Text-to-SQL")
    lines.append(f"Date: {datetime.now().strftime('%Y-%m-%d')}")
    lines.append("")
    for p, stats in by_prompt.items():
        lines.append(f"- {p}: {stats['ok']}/{stats['samples']} clean ✓, {stats['warn']} with warnings, {stats['err']} errors.")
    lines.append("")
    lines.append("Key Observations:")
    lines.append("1) Minimal prompt (V1) is fastest to write but more prone to missing filters (e.g., region/date) or weak aggregation hints.")
    lines.append("2) Few-shot prompt (V2) improves join correctness and GROUP BY usage; better at copying date/region patterns.")
    lines.append("3) Guardrails prompt (V3) reduces non-read-only risks and improves 'last month' handling for SQLite via DATE('now',...).")
    lines.append("")
    lines.append("Common Pitfalls Noted:")
    lines.append("- Missing WHERE region = 'New York' when the question mentions it.")
    lines.append("- Ambiguous/incorrect date windows; 'last month' requires a calendar-month range, not 30 days rolling.")
    lines.append("- Aggregations requested by the question (revenue/top/trend) without GROUP BY or SUM/COUNT.")
    lines.append("")
    lines.append("What We Learned:")
    lines.append("- Schema + examples (V2) meaningfully reduce hallucinations and logical mistakes.")
    lines.append("- Guardrails (V3) help keep outputs read-only and nudge correct time logic; good default for production.")
    lines.append("- Keep temperature at 0.0 for SQL and consider beam search (num_beams=4) if you have VRAM for higher logical accuracy.")
    lines.append("- Always validate generated SQL (lint + run) and consider a short auto-repair loop with error feedback.")
    return "\n".join(lines)

REPORT = summarize_findings(results)
print("\n" + REPORT)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


=== Comparison (by question & prompt) ===
⚠️ V1-minimal     | Q: What was our revenue by product in the New York region last month?
⚠️ V1-minimal     | Q: Show total orders per customer.
⚠️ V1-minimal     | Q: Which products generated the highest revenue in 2024? List top 5.
⚠️ V1-minimal     | Q: Monthly revenue trend in 2024.
✅ V1-minimal     | Q: List customers who placed no orders.
⚠️ V2-fewshot     | Q: What was our revenue by product in the New York region last month?
⚠️ V2-fewshot     | Q: Show total orders per customer.
⚠️ V2-fewshot     | Q: Which products generated the highest revenue in 2024? List top 5.
⚠️ V2-fewshot     | Q: Monthly revenue trend in 2024.
✅ V2-fewshot     | Q: List customers who placed no orders.
⚠️ V3-guardrails  | Q: What was our revenue by product in the New York region last month?
⚠️ V3-guardrails  | Q: Show total orders per customer.
⚠️ V3-guardrails  | Q: Which products generated the highest revenue in 2024? List top 5.
⚠️ V3-guardrails  | Q: Monthl

# One-Page Report: Prompt Variations in Text-to-SQL Generation

Objective
The exercise evaluated how different prompt designs influence the quality of SQL queries generated by a Transformer model (sqlcoder-7b-2). Three versions were tested: a minimal zero-shot prompt (V1), a few-shot prompt with examples (V2), and a guardrails prompt with strict rules (V3). Each was applied to the same set of natural-language questions about customers, orders, and revenue.

Findings

V1 (Minimal zero-shot): This version produced syntactically valid SQL in some cases but often hallucinated or omitted details. For example, it occasionally referenced nonexistent columns such as state instead of region, or skipped required conditions like filtering by “New York.” It also sometimes returned an aggregate without the necessary GROUP BY. While fast to design, it proved unreliable for anything beyond very simple queries.

V2 (Few-shot with examples): Adding schema-specific examples significantly improved performance. The model correctly reproduced join patterns and aggregation logic (e.g., SUM with GROUP BY) and was less likely to invent columns. However, in a few cases it still struggled with temporal filters (“last month”), occasionally substituting a rolling 30-day window instead of the previous calendar month.

V3 (Guardrails prompt): This approach, which emphasized “only SELECT,” schema adherence, and explicit date handling, was the most consistent. It reliably generated correct, read-only queries and applied the proper SQLite date range for “last month.” Hallucinations were minimal, and outputs were clean and easy to validate.

Variations that did not work well

V1 frequently hallucinated schema elements (e.g., state instead of region) and omitted necessary filters.

All versions showed occasional uncertainty with time expressions, but V3 reduced this by explicitly guiding the model to use SQLite’s DATE('now','start of month','-1 month').

Without explicit instructions, the model sometimes skipped semicolons or produced extraneous commentary.

Lessons Learned

Prompt quality is critical. Including schema and examples (V2) reduces hallucinations and improves logical accuracy.

Guardrails matter. Explicit constraints (V3) help enforce read-only behavior and correct date logic, making outputs safer for production use.

Hallucinations are real risks. Models can invent columns or misinterpret vague instructions without structured guidance.

Deterministic settings help. Using temperature=0.0 and beam search (if resources permit) increases reproducibility and correctness.

Validation is essential. Even strong prompts require SQL linting, execution tests, and possibly an auto-repair loop to handle residual errors.

Conclusion
Of the three approaches, V3 (guardrails prompt) delivered the most reliable results by balancing safety, accuracy, and clarity. While V1 was prone to hallucinations and V2 improved accuracy with examples, V3 minimized errors and produced consistent, production-ready queries. For practical SQL generation tasks, a hybrid of V2 (examples) and V3 (guardrails) is recommended.