In [1]:
!pip install -q -U transformers bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
from huggingface_hub import login
from google.colab import userdata

login(token = userdata.get("HF_TOKEN"))

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("sai-santhosh/text-2-sql-Llama-3.2-3B")
model = AutoModelForCausalLM.from_pretrained("sai-santhosh/text-2-sql-Llama-3.2-3B",load_in_4bit=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

In [15]:
from transformers import pipeline
import re
import textwrap

def get_sql_query(model,tokenizer,question,context):

    pipe = pipeline(
    "text-generation", model=model, tokenizer=tokenizer)

    system_prompt = """Provide the SQL query to the question based on the context in below format
                                                    1. SQL Query: start with ```sql
                                                    2. Explanation of the query: start with ```explanation

            make the explanation clear and detail in less than 75 words"""

    prompt = pipe.tokenizer.apply_chat_template(
        [   {"role": "system", "content" : system_prompt},
            {"role": "user", "content": f"Question: {question} Context: {context}"}],
        tokenize=False,
        add_generation_prompt=True,
    )


    outputs = pipe(
        prompt,
        max_length=350,
        clean_up_tokenization_spaces=True
    )

    generated_text = outputs[0]['generated_text']

    generated_text = generated_text.replace(system_prompt,"")
    sql_query_pattern = r"```sql\n(.*?)\n```"
    sql_query_match = re.search(sql_query_pattern, generated_text, re.DOTALL)
    sql_query = sql_query_match.group(1).strip() if sql_query_match else None

    explanation_pattern = r"\*\*Explanation\:\*\*(.*)"
    explanation_match = re.search(explanation_pattern, generated_text, re.DOTALL)
    explanation = explanation_match.group(1).strip() if explanation_match else None


    print("\n\nQuery:\n-----------\n")
    print("\n".join(textwrap.wrap(sql_query, width=100)))

    print("\n\nExplanation:\n-----------\n")
    print("\n".join(textwrap.wrap(explanation, width=100))) # Wrap text to limit width

In [16]:
# SELECT name FROM Aircraft WHERE distance > (SELECT AVG(distance) FROM Aircraft)  --->  right answer
question = "Show names for all aircrafts with distances more than the average."
context  = "CREATE TABLE Aircraft (name VARCHAR, distance INTEGER)"

In [17]:
get_sql_query(model,tokenizer,question,context)

Device set to use cuda:0




Query:
-----------

SELECT name FROM Aircraft WHERE distance > (SELECT AVG(distance) FROM Aircraft);


Explanation:
-----------

This SQL query calculates the average distance of all aircraft and then selects the names of
aircraft with distances greater than that average. It uses a subquery to calculate the average
distance, and the main query selects the `name` column from the `Aircraft` table where the
`distance` is greater than the calculated average.


In [18]:
question_1 = "Get the names of customers who spent more than the average amount."
context_1 = "CREATE TABLE Customers (name VARCHAR, spent INTEGER)"

get_sql_query(model,tokenizer,question_1,context_1)

Device set to use cuda:0




Query:
-----------

SELECT name  FROM Customers  WHERE spent > (SELECT AVG(spent) FROM Customers);


Explanation:
-----------

This SQL query selects the names of customers who have spent more than the average amount spent by
all customers. The subquery `(SELECT AVG(spent) FROM Customers)` calculates the average amount spent
by all customers, and the main query selects the names of customers whose spent amount is greater
than this average.


In [19]:
question_2 = "Find the highest salary from the Employees table."
context_2 = "CREATE TABLE Employees (name VARCHAR, salary INTEGER)"

get_sql_query(model,tokenizer,question_2,context_2)

Device set to use cuda:0




Query:
-----------

SELECT MAX(salary) AS highest_salary FROM Employees;


Explanation:
-----------

This SQL query finds the maximum salary from the `Employees` table. It uses the `MAX` function to
select the highest value from the `salary` column. The result is a single row with a single column,
`highest_salary`, which contains the highest salary from the table.


In [20]:
question_3 = "Show all orders that were placed by customers from New York."
context_3 = "CREATE TABLE Orders (order_id INTEGER, customer_id INTEGER, city VARCHAR)"

get_sql_query(model,tokenizer,question_3,context_3)

Device set to use cuda:0




Query:
-----------

SELECT * FROM Orders WHERE city = 'New York';


Explanation:
-----------

This SQL query selects all columns (`*`) from the `Orders` table where the `city` column is 'New
York'. This query is designed to retrieve all orders associated with customers from New York, based
on the specified table structure.


In [21]:
question_4 = "Find the names of customers who have placed orders with a total amount greater than the average order amount."
context_4 = "CREATE TABLE Customers (customer_id INTEGER, name VARCHAR); CREATE TABLE Orders (order_id INTEGER, customer_id INTEGER, amount INTEGER);"

get_sql_query(model,tokenizer,question_4,context_4)

Device set to use cuda:0




Query:
-----------

SELECT C.name FROM Customers C JOIN (   SELECT customer_id, SUM(amount) AS total_amount   FROM
Orders   GROUP BY customer_id ) AS T ON C.customer_id = T.customer_id WHERE T.total_amount > (SELECT
AVG(amount) FROM Orders);


Explanation:
-----------

This SQL query joins the `Customers` table with a subquery that calculates the total order amount
for each customer. It then selects the names of customers whose total order amount exceeds the
average order amount.


In [23]:
question_5 = "Get the names of products that were ordered by customers in New York who spent more than the average amount."
context_5 = "CREATE TABLE Customers (customer_id INTEGER, name VARCHAR, city VARCHAR); \
             CREATE TABLE Orders (order_id INTEGER, customer_id INTEGER, amount INTEGER); \
             CREATE TABLE Products (product_id INTEGER, name VARCHAR); CREATE TABLE Order_Items (order_id INTEGER, product_id INTEGER);"

get_sql_query(model,tokenizer,question_5,context_5)

Device set to use cuda:0




Query:
-----------

SELECT P.name FROM Products P JOIN Orders O ON P.product_id = ANY (SELECT product_id FROM
Order_Items WHERE order_id IN (SELECT order_id FROM Orders WHERE customer_id IN (SELECT customer_id
FROM Customers WHERE city = 'New York')));


Explanation:
-----------

This query joins the `Products` table with the `Orders` table on the condition that the `product_id`
in `Products` is present in the `product_id` of any order made by customers in New York. It then
filters these orders to include only those where the total amount spent is greater than the average
amount spent by customers in New York.
