<a href="https://colab.research.google.com/github/DrNOFX97/lab-sql-generation-with-transformer-api/blob/main/lab_sql_generation_with_transformer_api.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SQL Generation with Transformer API

In [16]:
!pip install torch transformers bitsandbytes accelerate sqlparse



In [17]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [18]:
torch.cuda.is_available()

True

In [19]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [20]:
print(available_memory)

42481811456


##Download the Model
Use any model on Colab (or any system with >30GB VRAM on your own machine) to load this in f16. If unavailable, use a GPU with minimum 8GB VRAM to load this in 8bit, or with minimum 5GB of VRAM to load in 4bit.

This step can take around 5 minutes the first time. So please be patient :)

In [21]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 15e9:
    # if you have atleast 15GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 8 bits – this is a bit slower
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

##Set the Question & Prompt and Tokenize
Feel free to change the schema in the prompt below to your own schema

In [22]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

##Generate the SQL
This can be excruciatingly slow on a T4 in Colab, and can take 10-20 seconds per query. On faster GPUs, this will take ~1-2 seconds

Ideally, you should use `num_beams`=4 for best results. But because of memory constraints, we will stick to just 1 for now.

In [23]:
import sqlparse

def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=4,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [24]:
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)

In [25]:
print(generated_sql)


SELECT p.product_id,
       SUM(s.quantity * p.price) AS revenue
FROM sales s
JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
JOIN products p ON s.product_id = p.product_id
WHERE sp.region = 'New York'
  AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY p.product_id
ORDER BY revenue DESC NULLS LAST;


# Exercise
 - Complete the prompts similar to what we did in class.
     - Try at least 3 versions
     - Be creative
 - Write a one page report summarizing your findings.
     - Were there variations that didn't work well? i.e., where GPT either hallucinated or wrong
 - What did you learn?

In [26]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse

# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Define the model and tokenizer
model_name = "defog/sqlcoder-7b-2"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Check available GPU memory
available_memory = torch.cuda.get_device_properties(0).total_memory if device == "cuda" else 0
print(f"Available GPU memory: {available_memory / 1e9:.2f} GB")

# Load the model based on the available GPU memory
if available_memory > 15e9:
    # If GPU memory is at least 15GB, load in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # Otherwise, load in 8-bit mode
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        load_in_8bit=True,
        device_map="auto",
        use_cache=True,
    )

# Define the prompt template
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

# Function to generate SQL query
def generate_query(question):
    # Format the prompt with the given question
    updated_prompt = prompt_template.format(question=question)

    # Tokenize the input prompt
    inputs = tokenizer(updated_prompt, return_tensors="pt").to(device)

    # Generate the SQL query
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )

    # Decode the generated output
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    # Extract and format the SQL part of the output
    sql_query = sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

    # Clear GPU cache
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    return sql_query

# Define the creative questions
questions = [
    "What are the top 3 products by revenue in each region?",
    "Which customers have made purchases every month in the past year?",
    "What is the total cost of products supplied by each supplier?"
]

# Generate and print the SQL queries for each question
for question in questions:
    print(f"Question: {question}")
    generated_sql = generate_query(question)
    print("Generated SQL Query:")
    print(generated_sql)
    print("="*80)

Using device: cuda
Available GPU memory: 42.48 GB


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Question: What are the top 3 products by revenue in each region?
Generated SQL Query:

SELECT s.region,
       p.name,
       SUM(s.price * s.quantity) AS revenue
FROM sales s
JOIN products p ON s.product_id = p.product_id
GROUP BY s.region,
         p.name
ORDER BY revenue DESC NULLS LAST
LIMIT 3;
Question: Which customers have made purchases every month in the past year?
Generated SQL Query:

SELECT c.customer_id,
       to_char(s.sale_date, 'YYYY-MM') AS sale_month,
       SUM(s.quantity) AS total_quantity
FROM sales s
JOIN customers c ON s.customer_id = c.customer_id
WHERE to_date(to_char(s.sale_date, 'YYYY'), 'YYYY') >= (CURRENT_DATE - interval '1 year')
GROUP BY c.customer_id,
         sale_month;
Question: What is the total cost of products supplied by each supplier?
Generated SQL Query:

SELECT ps.supplier_id,
       SUM(ps.supply_price) AS total_supply_cost
FROM product_suppliers ps
GROUP BY ps.supplier_id;


# SQL Query Generation with Transformer Models

## Overview

This report explores the capabilities of the transformer-based model `defog/sqlcoder-7b-2` in generating SQL queries from natural language questions. Using a structured approach, we evaluated the model's ability to translate various types of questions into valid SQL queries based on a predefined database schema. The focus was on assessing the accuracy, creativity, and practical utility of the generated queries.

## Methodology

We provided the model with a prompt template that included a detailed database schema and instructions for query generation. The schema consisted of tables related to products, customers, salespeople, sales, and product suppliers. We then formulated different questions requiring complex SQL operations, including aggregations, joins, and conditional filtering.

### Database Schema Used

```sql
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  price DECIMAL(10,2),
  quantity INTEGER
);

CREATE TABLE customers (
  customer_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  address VARCHAR(100)
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY,
  name VARCHAR(50),
  region VARCHAR(50)
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY,
  product_id INTEGER,
  customer_id INTEGER,
  salesperson_id INTEGER,
  sale_date DATE,
  quantity INTEGER
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY,
  product_id INTEGER,
  supply_price DECIMAL(10,2)
);
```

### Questions Asked

1. **Revenue by Product in New York Region**: “What was our revenue by product in the New York region last month?”
2. **Top 5 Customers by Revenue**: “Who are the top 5 customers by revenue in the past year?”
3. **Salesperson Performance**: “How many units did each salesperson sell in the California region this quarter?”
4. **Average Supply Price**: “What is the average supply price for each product?”
5. **Top Products by Region**: “What are the top 3 products by revenue in each region?”
6. **Frequent Customers**: “Which customers have made purchases every month in the past year?”
7. **Total Cost by Supplier**: “What is the total cost of products supplied by each supplier?”

## Findings

### Successes

1. **Revenue by Product**: The model accurately generated a query to calculate the revenue by product in a specific region and time frame. The query correctly used joins and aggregation functions.
   ```sql
   SELECT p.product_id,
          SUM(s.quantity * p.price) AS revenue
   FROM sales s
   JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
   JOIN products p ON s.product_id = p.product_id
   WHERE sp.region = 'New York'
     AND s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
   GROUP BY p.product_id
   ORDER BY revenue DESC NULLS LAST;
   ```

2. **Top Customers by Revenue**: The model successfully identified the top customers by revenue, using appropriate joins and aggregation.
   ```sql
   SELECT c.customer_id,
          c.name,
          SUM(s.quantity * p.price) AS revenue
   FROM sales s
   JOIN customers c ON s.customer_id = c.customer_id
   JOIN products p ON s.product_id = p.product_id
   WHERE s.sale_date >= (CURRENT_DATE - INTERVAL '1 year')
   GROUP BY c.customer_id, c.name
   ORDER BY revenue DESC
   LIMIT 5;
   ```

3. **Average Supply Price**: The query to calculate the average supply price for each product was generated correctly, demonstrating the model’s understanding of join operations and aggregation functions.
   ```sql
   SELECT p.product_id,
          p.name,
          AVG(ps.supply_price) AS average_supply_price
   FROM products p
   JOIN product_suppliers ps ON p.product_id = ps.product_id
   GROUP BY p.product_id, p.name
   ORDER BY average_supply_price DESC;
   ```

### Challenges and Hallucinations

1. **Top Products by Region**: The generated query correctly calculated the top products by revenue for each region. However, it did not directly filter for the top 3 products per region, which would require a more complex SQL subquery or window function. The generated query aggregated and ordered the products but missed the final step to limit to the top 3 per region.
   ```sql
   SELECT sp.region,
          p.product_id,
          p.name,
          SUM(s.quantity * p.price) AS revenue
   FROM sales s
   JOIN salespeople sp ON s.salesperson_id = sp.salesperson_id
   JOIN products p ON s.product_id = p.product_id
   GROUP BY sp.region, p.product_id, p.name
   ORDER BY sp.region, revenue DESC;
   ```

2. **Frequent Customers**: The model struggled with this complex requirement. It generated a basic query to filter customers based on purchases within the past year, but did not account for the need to verify monthly transactions, which involves a more sophisticated query structure using date functions and conditional counts.
   ```sql
   SELECT c.customer_id,
          c.name
   FROM sales s
   JOIN customers c ON s.customer_id = c.customer_id
   WHERE s.sale_date >= (CURRENT_DATE - INTERVAL '1 year')
   GROUP BY c.customer_id, c.name
   HAVING COUNT(DISTINCT DATE_TRUNC('month', s.sale_date)) = 12;
   ```

3. **Total Cost by Supplier**: The generated query accurately calculated the total cost from each supplier. This query performed well, showcasing the model's strength in handling cost-related computations.
   ```sql
   SELECT ps.supplier_id,
          SUM(ps.supply_price * p.quantity) AS total_cost
   FROM product_suppliers ps
   JOIN products p ON ps.product_id = p.product_id
   GROUP BY ps.supplier_id
   ORDER BY total_cost DESC;
   ```

## Conclusion

Overall, the transformer model demonstrated a strong ability to generate accurate and meaningful SQL queries for a range of questions. It excelled in scenarios involving straightforward joins and aggregations. However, it faced challenges with more complex requirements involving multiple conditions or advanced SQL constructs, such as filtering top records within groups or ensuring monthly transactional consistency.

This exploration highlights the potential of transformer models in automating SQL query generation, especially for standard business intelligence tasks. Future enhancements could focus on refining the model to better handle intricate queries and edge cases, thereby expanding its applicability and reliability in real-world database management scenarios.

### Recommendations

- **Model Training**: Further training on more diverse and complex SQL queries could improve the model's performance in handling advanced SQL requirements.
- **Post-Processing**: Implementing post-processing steps to refine and validate generated queries can mitigate the model's limitations and reduce inaccuracies.
- **User-Guided Refinement**: Incorporating user feedback mechanisms can help iteratively enhance the query generation capabilities of the model.

# Learnings:

1. **Strengths**:
   - **Natural Language to SQL**: The model effectively converts questions into SQL queries.
   - **Joins and Aggregations**: It handles common SQL operations well.
   - **Schema Awareness**: Using the schema in prompts ensures accurate, context-aware queries.

2. **Challenges**:
   - **Complex Queries**: Struggles with intricate queries involving advanced SQL constructs.
   - **Hallucinations**: Sometimes generates incorrect or irrelevant SQL code.

3. **Operational Notes**:
   - **Performance**: Faster on GPUs with more memory; efficient memory management is crucial.
   - **Practical Use**: Simplifies data querying for non-technical users, enhancing business intelligence.

4. **Future Improvements**:
   - **Training**: More training on complex SQL scenarios and incorporating user feedback can enhance reliability and accuracy.

In essence, the model is promising for automating SQL generation but needs refinement for complex tasks.