<a href="https://colab.research.google.com/github/MercyMoparthy/lab-sql-generation-with-transformer-api/blob/main/lab-sql-generation-with-transformer-api.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQL Generation with Transformer API

In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

  Attempting uninstall: nvidia-cusolver-cu12
    Found existing installation: nvidia-cusolver-cu12 11.6.3.83
    Uninstalling nvidia-cusolver-cu12-11.6.3.83:
      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83
Successfully installed bitsandbytes-0.46.0 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
torch.cuda.is_available()

True

In [None]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [None]:
print(available_memory)

15828320256


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [None]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [None]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [None]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


# Exercise
 - Complete the prompts similar to what we did in class.
     - Try at least 3 versions
     - Be creative
 - Write a one page report summarizing your findings.
     - Were there variations that didn't work well? i.e., where GPT either hallucinated or wrong
 - What did you learn?

## Version 1
Enhanced Prompt Engineering

Steps included in below cell:
* Setting up SQLite Database
    * Creating Employees table
    * Inserting Sample Data

In [None]:
import sqlite3
import requests
import json

conn = sqlite3.connect('employees.db')
cursor = conn.cursor()

cursor.execute('''
    CREATE TABLE IF NOT EXISTS employees (
        employee_id INTEGER PRIMARY KEY,
        first_name TEXT,
        last_name TEXT,
        department TEXT,
        salary INTEGER
    )
''')

cursor.executemany('''
    INSERT OR REPLACE INTO employees (employee_id, first_name, last_name, department, salary)
    VALUES (?, ?, ?, ?, ?)
''', [
    (1, 'John', 'Doe', 'Sales', 60000),
    (2, 'Jane', 'Smith', 'Engineering', 75000),
    (3, 'Bob', 'Johnson', 'Sales', 55000),
    (4, 'Alice', 'Brown', 'Marketing', 65000)
])
conn.commit()

Steps included in below cell:
* Rule based SQL Query Generator
  * Show all employees in department
  * List employees with salary greater than 60000
  * sorting employees by salary
  * show all employees

In [None]:
# Rule-based SQL query generator
def rule_based_sql_query(prompt):
    prompt = prompt.lower().strip()
    # Pattern: "Show all employees in [department]"
    if "show all employees in" in prompt:
        match = re.search(r"in\s+(\w+)", prompt, re.IGNORECASE)
        if match:
            department = match.group(1).capitalize()
            return f"SELECT * FROM employees WHERE department = '{department}';"
    # Pattern: "List employees with salary greater than [number]"
    elif "salary greater than" in prompt:
        match = re.search(r"salary greater than\s+(\d+)", prompt)
        if match:
            salary = match.group(1)
            return f"SELECT * FROM employees WHERE salary > {salary};"
    # Pattern: "Sort employees by salary [ascending/descending]"
    elif "sort by salary" in prompt:
        order = "DESC" if "descending" in prompt else "ASC"
        return f"SELECT * FROM employees ORDER BY salary {order};"
    # Pattern: "Show all employees"
    elif "show all employees" in prompt:
        return "SELECT * FROM employees;"
    else:
        return "Error: Unsupported prompt"

Steps included in below cell:
* Creating Function to execute SQL Query in below cell
* Testing example prompts


In [None]:
def execute_sql_query(query):
    try:
        cursor.execute(query)
        results = cursor.fetchall()
        columns = [description[0] for description in cursor.description]
        return {"columns": columns, "results": results}
    except sqlite3.Error as e:
        return f"SQL Error: {e}"

In [None]:
prompts = [
    "Show all employees in Sales",
    "List employees with salary greater than 60000",
    "Sort employees by salary in descending order",
    "Show all employees"
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    query = rule_based_sql_query(prompt)
    print(f"Generated SQL Query: {query}")
    result = execute_sql_query(query)
    print(f"Result: {result}")

conn.close()


Prompt: Show all employees in Sales
Generated SQL Query: SELECT * FROM employees WHERE department = 'Sales';
Result: {'columns': ['employee_id', 'first_name', 'last_name', 'department', 'salary'], 'results': [(1, 'John', 'Doe', 'Sales', 60000), (3, 'Bob', 'Johnson', 'Sales', 55000)]}

Prompt: List employees with salary greater than 60000
Generated SQL Query: SELECT * FROM employees WHERE salary > 60000;
Result: {'columns': ['employee_id', 'first_name', 'last_name', 'department', 'salary'], 'results': [(2, 'Jane', 'Smith', 'Engineering', 75000), (4, 'Alice', 'Brown', 'Marketing', 65000)]}

Prompt: Sort employees by salary in descending order
Generated SQL Query: Error: Unsupported prompt
Result: SQL Error: near "Error": syntax error

Prompt: Show all employees
Generated SQL Query: SELECT * FROM employees;
Result: {'columns': ['employee_id', 'first_name', 'last_name', 'department', 'salary'], 'results': [(1, 'John', 'Doe', 'Sales', 60000), (2, 'Jane', 'Smith', 'Engineering', 75000), (3,

## Version 2

SQLite database with a products table

Steps included in belwo cell:
* Setting up SQLite Database
  * Creating products table
  * Inserting Sample data
  

In [None]:
conn = sqlite3.connect('products.db')
cursor = conn.cursor()

cursor.execute('''
    CREATE TABLE IF NOT EXISTS products (
        product_id INTEGER PRIMARY KEY,
        product_name TEXT,
        category TEXT,
        price INTEGER,
        stock INTEGER
    )
''')

cursor.executemany('''
    INSERT OR REPLACE INTO products (product_id, product_name, category, price, stock)
    VALUES (?, ?, ?, ?, ?)
''', [
    (1, 'Laptop', 'Electronics', 1200, 50),
    (2, 'Smartphone', 'Electronics', 800, 100),
    (3, 'Desk Chair', 'Furniture', 150, 30),
    (4, 'Coffee Table', 'Furniture', 250, 20),
    (5, 'Headphones', 'Electronics', 100, 200)
])
conn.commit()

Steps included in below cell:
* Enhancing Rule based SQL Query Generator
  * Show all products
  * List producets with price greaterthan 500
  * Sort products
  * show all products
  * show column wise products in furniture
  * show average price in electronics

In [None]:
def rule_based_sql_query(prompt):
    prompt = prompt.lower().strip()

    valid_columns = ['product_id', 'product_name', 'category', 'price', 'stock']

    if "show all products in" in prompt:
        match = re.search(r"in\s+(\w+)", prompt, re.IGNORECASE)
        if match:
            category = match.group(1).capitalize()
            return f"SELECT * FROM products WHERE category = '{category}';"

    elif "price greater than" in prompt:
        match = re.search(r"price greater than\s+(\d+)", prompt)
        if match:
            price = match.group(1)
            return f"SELECT * FROM products WHERE price > {price};"

    elif "sort by" in prompt:
        match = re.search(r"sort by\s+(\w+)\s*(ascending|descending)?", prompt)
        if match:
            column = match.group(1)
            order = "DESC" if match.group(2) and match.group(2).lower() == "descending" else "ASC"
            if column in valid_columns:
                return f"SELECT * FROM products ORDER BY {column} {order};"
            else:
                return f"Error: Invalid column '{column}'"

    elif "show all products" in prompt:
        return "SELECT * FROM products;"

    elif "show" in prompt and "for products in" in prompt:
        match_columns = re.search(r"show\s+([\w\s,]+)\s+for products in\s+(\w+)", prompt, re.IGNORECASE)
        if match_columns:
            columns = [col.strip() for col in match_columns.group(1).split(',')]
            category = match_columns.group(2).capitalize()
            if all(col in valid_columns for col in columns):
                return f"SELECT {', '.join(columns)} FROM products WHERE category = '{category}';"
            else:
                return f"Error: Invalid column(s) in {columns}"

    elif "average price in" in prompt:
        match = re.search(r"average price in\s+(\w+)", prompt, re.IGNORECASE)
        if match:
            category = match.group(1).capitalize()
            return f"SELECT AVG(price) AS average_price FROM products WHERE category = '{category}';"

    else:
        return "Error: Unsupported prompt"

Steps included in below cell:
* Function to execute SQL Query
* Testing sample prompts

In [None]:
def execute_sql_query(query):
    try:
        cursor.execute(query)
        results = cursor.fetchall()
        columns = [description[0] for description in cursor.description]
        return {"columns": columns, "results": results}
    except sqlite3.Error as e:
        return f"SQL Error: {e}"

prompts = [
    "Show all products in Electronics",
    "List products with price greater than 500",
    "Sort products by price in descending order",
    "Show all products",
    "Show product_name, price for products in Furniture",
    "Average price in Electronics"
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    query = rule_based_sql_query(prompt)
    print(f"Generated SQL Query: {query}")
    result = execute_sql_query(query)
    print(f"Result: {result}")

conn.close()


Prompt: Show all products in Electronics
Generated SQL Query: SELECT * FROM products WHERE category = 'Electronics';
Result: {'columns': ['product_id', 'product_name', 'category', 'price', 'stock'], 'results': [(1, 'Laptop', 'Electronics', 1200, 50), (2, 'Smartphone', 'Electronics', 800, 100), (5, 'Headphones', 'Electronics', 100, 200)]}

Prompt: List products with price greater than 500
Generated SQL Query: SELECT * FROM products WHERE price > 500;
Result: {'columns': ['product_id', 'product_name', 'category', 'price', 'stock'], 'results': [(1, 'Laptop', 'Electronics', 1200, 50), (2, 'Smartphone', 'Electronics', 800, 100)]}

Prompt: Sort products by price in descending order
Generated SQL Query: Error: Unsupported prompt
Result: SQL Error: near "Error": syntax error

Prompt: Show all products
Generated SQL Query: SELECT * FROM products;
Result: {'columns': ['product_id', 'product_name', 'category', 'price', 'stock'], 'results': [(1, 'Laptop', 'Electronics', 1200, 50), (2, 'Smartphone

## Version 3
SQLite database with two tables: Orders and Customers

Steps included in below cell:
* Setting up SQLite database with 2 tables
  * Create order table
  * Create customer table
  * Insert Sample data
  

In [None]:
conn = sqlite3.connect('orders.db')
cursor = conn.cursor()

cursor.execute('''
    CREATE TABLE IF NOT EXISTS orders (
        order_id INTEGER PRIMARY KEY,
        customer_id INTEGER,
        product_name TEXT,
        amount INTEGER,
        order_date TEXT
    )
''')

cursor.execute('''
    CREATE TABLE IF NOT EXISTS customers (
        customer_id INTEGER PRIMARY KEY,
        customer_name TEXT,
        city TEXT
    )
''')

cursor.executemany('''
    INSERT OR REPLACE INTO orders (order_id, customer_id, product_name, amount, order_date)
    VALUES (?, ?, ?, ?, ?)
''', [
    (1, 1, 'Laptop', 1200, '2025-01-15'),
    (2, 1, 'Mouse', 25, '2025-02-10'),
    (3, 2, 'Smartphone', 800, '2025-03-05'),
    (4, 3, 'Headphones', 100, '2025-04-20'),
    (5, 2, 'Charger', 50, '2025-05-01')
])

cursor.executemany('''
    INSERT OR REPLACE INTO customers (customer_id, customer_name, city)
    VALUES (?, ?, ?)
''', [
    (1, 'Alice Smith', 'New York'),
    (2, 'Bob Johnson', 'Chicago'),
    (3, 'Carol Brown', 'San Francisco')
])
conn.commit()

Steps included in below cell:
* Enhancing rule-based SQL Query Generator
  * validate Columns
  * Show all orders in NewYork
  * List orders with amount greater than 500
  * Sort orders by amount
  * show all orders
  * show product name , amount for orders in NewYork
  * Average Order amount in San Francisco
  * Show orders with customer names


In [None]:
def rule_based_sql_query(prompt):
    prompt = prompt.lower().strip()

    valid_orders_columns = ['order_id', 'customer_id', 'product_name', 'amount', 'order_date']
    valid_customers_columns = ['customer_id', 'customer_name', 'city']

    if "show all orders in" in prompt:
        match = re.search(r"in\s+(\w+\s*\w*)", prompt, re.IGNORECASE)
        if match:
            city = match.group(1).title()
            return f"""
                SELECT orders.*
                FROM orders
                JOIN customers ON orders.customer_id = customers.customer_id
                WHERE customers.city = '{city}';
            """

    elif "amount greater than" in prompt:
        match = re.search(r"amount greater than\s+(\d+)", prompt)
        if match:
            amount = match.group(1)
            return f"SELECT * FROM orders WHERE amount > {amount};"

    elif "sort orders by" in prompt:
        match = re.search(r"sort orders by\s+(\w+)\s*(ascending|descending)?", prompt)
        if match:
            column = match.group(1)
            order = "DESC" if match.group(2) and match.group(2).lower() == "descending" else "ASC"
            if column in valid_orders_columns:
                return f"SELECT * FROM orders ORDER BY {column} {order};"
            else:
                return f"Error: Invalid column '{column}'"

    elif "show all orders" in prompt:
        return "SELECT * FROM orders;"

    elif "show" in prompt and "for orders in" in prompt:
        match_columns = re.search(r"show\s+([\w\s,]+)\s+for orders in\s+(\w+\s*\w*)", prompt, re.IGNORECASE)
        if match_columns:
            columns = [col.strip() for col in match_columns.group(1).split(',')]
            city = match_columns.group(2).title()

            if all(col in valid_orders_columns for col in columns):
                return f"""
                    SELECT {', '.join(columns)}
                    FROM orders
                    JOIN customers ON orders.customer_id = customers.customer_id
                    WHERE customers.city = '{city}';
                """
            else:
                return f"Error: Invalid column(s) in {columns}"


    elif "average order amount in" in prompt:
        match = re.search(r"average order amount in\s+(\w+\s*\w*)", prompt, re.IGNORECASE)
        if match:
            city = match.group(1).title()
            return f"""
                SELECT AVG(orders.amount) AS average_amount
                FROM orders
                JOIN customers ON orders.customer_id = customers.customer_id
                WHERE customers.city = '{city}';
            """

    elif "show orders with customer names" in prompt:
        return """
            SELECT orders.order_id, orders.product_name, orders.amount, orders.order_date, customers.customer_name
            FROM orders
            JOIN customers ON orders.customer_id = customers.customer_id;
        """

    else:
        return "Error: Unsupported prompt"

Steps included in below cell:
* Function to execute SQL Query
* Testing Sample Prompts

In [None]:
def execute_sql_query(query):
    try:
        cursor.execute(query)
        results = cursor.fetchall()
        columns = [description[0] for description in cursor.description]
        return {"columns": columns, "results": results}
    except sqlite3.Error as e:
        return f"SQL Error: {e}"

In [None]:
prompts = [
    "Show all orders in Chicago",
    "List orders with amount greater than 500",
    "Sort orders by amount in descending order",
    "Show all orders",
    "Show product_name, amount for orders in New York",
    "Average order amount in San Francisco",
    "Show orders with customer names"
]

for prompt in prompts:
    print(f"\nPrompt: {prompt}")
    query = rule_based_sql_query(prompt)
    print(f"Generated SQL Query: {query}")
    result = execute_sql_query(query)
    print(f"Result: {result}")

conn.close()


Prompt: Show all orders in Chicago
Generated SQL Query: 
                SELECT orders.* 
                FROM orders 
                JOIN customers ON orders.customer_id = customers.customer_id 
                WHERE customers.city = 'Chicago';
            
Result: {'columns': ['order_id', 'customer_id', 'product_name', 'amount', 'order_date'], 'results': [(3, 2, 'Smartphone', 800, '2025-03-05'), (5, 2, 'Charger', 50, '2025-05-01')]}

Prompt: List orders with amount greater than 500
Generated SQL Query: SELECT * FROM orders WHERE amount > 500;
Result: {'columns': ['order_id', 'customer_id', 'product_name', 'amount', 'order_date'], 'results': [(1, 1, 'Laptop', 1200, '2025-01-15'), (3, 2, 'Smartphone', 800, '2025-03-05')]}

Prompt: Sort orders by amount in descending order
Generated SQL Query: SELECT * FROM orders ORDER BY amount ASC;
Result: {'columns': ['order_id', 'customer_id', 'product_name', 'amount', 'order_date'], 'results': [(2, 1, 'Mouse', 25, '2025-02-10'), (5, 2, 'Charger'