In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl (137.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.5/137.5 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.3


In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
torch.cuda.is_available()

True

In [4]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [5]:
print(available_memory)

15835660288


#Download the model

In [6]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 16e9:
  model = AutoModelForCausalLM.from_pretrained(
      model_name,
      trust_remote_code = True,
      torch_dtype = torch.float16,
      device_map = "auto",
      use_cache = True,
  )
else:
  model = AutoModelForCausalLM.from_pretrained(
      model_name,
      trust_remote_code = True,
      load_in_8bit = True,
      device_map = "auto",
      use_cache = True,
  )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

#Set the Question & Prompt and Tokenize

In [7]:
prompt = """###task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

###Instructions
- If you cannot answer the question with the available database schema, return 'I DONT KNOW'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantitiy

###Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER -- Current quantity in stock
);

CREATE TABLE customers (
  customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
  name VARCHAR(50), -- Name of the customer
  address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of the product sold
  customer_id INTEGER, -- ID of the customer who purchased the product
  salesperson_id INTEGER, -- ID of the salesperson who made the sale
  quantity INTEGER, -- Quantity of the product sold
  sale_date DATE, -- Date of the sale occured
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id

###Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""


In [10]:
import sqlparse

def generate_query(question):
  updated_prompt = prompt.format(question = question)
  input = tokenizer(updated_prompt, return_tensors = "pt").to("cuda")
  generated_ids = model.generate(
      **input,
      num_return_sequences = 1,
      eos_token_id = tokenizer.eos_token_id,
      pad_token_id = tokenizer.eos_token_id,
      max_new_tokens = 400,
      do_sample = False,
      num_beams = 1,
  )
  outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)
  torch.cuda.empty_cache()
  torch.cuda.synchronize()
  return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent = True)

In [11]:
question = "What was the highest quantity sold last month?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT MAX(s.quantity) AS max_quantity
FROM sales s
WHERE s.sale_date >= (CURRENT_DATE - INTERVAL '1 month');


In [12]:
question = "Which salesperson sold the least amount of products last month?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT s.salesperson_id,
       SUM(s.quantity) AS total_quantity
FROM sales s
WHERE s.sale_date >= (CURRENT_DATE - INTERVAL '1 month')
GROUP BY s.salesperson_id
ORDER BY total_quantity ASC
LIMIT 1


In [13]:
question = "What was our total revenue last month?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT SUM(p.price * s.quantity) AS total_revenue
FROM sales s
JOIN products p ON s.product_id = p.product_id
WHERE s.sale_date >= (CURRENT_DATE - INTERVAL '1 month');
