<a href="https://colab.research.google.com/github/adityashinde0/SQLGenerator/blob/main/SQLGenerator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
torch.cuda.is_available()

In [None]:
availabel_memory=torch.cuda.get_device_properties(0).total_memory

In [None]:
print(availabel_memory)

In [None]:
model_name='defog/sqlcoder-7b-2'
tokenizer=AutoTokenizer.from_pretrained(model_name)
if availabel_memory>16e9:
  model=AutoModelForCausalLM.from_pretrained(
      model_name,
      trust_remote_code=True,
      torch_dtype=torch.float16,
      device_map='auto',
      use_cache=True,
  )
else:
  model=AutoModelForCausalLM.from_pretrained(
      model_name,
      trust_remote_code=True,
      load_in_8bit=True,
      device_map='auto',
      use_cache=True,
  )

In [None]:
prompt="""### Task
Generate a SQL query to answer [question]{question}[/question]

### Instructions
- If you cannot answer the question with the available database schema, return "I do not know".
- Remember that revenue is price multiplied by quantity.
- Remember that cost is supply price multiplied by quantity.

### Database Schema
This query will run on a database whose schema is represented below:

CREATE TABLE products (
  Product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  Name VARCHAR(50),               -- Name of the product
  Price DECIMAL(10,2),            -- Price of each unit of the product
  Quantity INTEGER                -- Current quantity in stock
);

CREATE TABLE customers (
  Customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
  Name VARCHAR(50),                -- Name of the customer
  Address VARCHAR(50)              -- Mailing address of the customer
);

CREATE TABLE salespeople (
  Salespeople_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  Name VARCHAR(50),                   -- Name of the salesperson
  Region VARCHAR(50)                  -- Geographic sales region
);

CREATE TABLE sales (
  Sale_id INTEGER PRIMARY KEY,     -- Unique ID for each sale
  Product_id INTEGER,              -- ID of product sold
  Customer_id INTEGER,             -- ID of customer who made the purchase
  Salesperson_id INTEGER,          -- ID of salesperson who made the sale
  Quantity INTEGER                 -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  Supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  Product_id INTEGER,              -- Product ID supplied
  Supply_Price DECIMAL(10,2)       -- Unit price charged by supplier
);

-- Relationships:
-- sales.product_id joins with products.product_id
-- sales.customer_id joins with customers.customer_id
-- sales.salesperson_id joins with salespeople.salespeople_id
-- product_suppliers.product_id joins with products.product_id

### Answer
Given the database schema, here is the SQL query that answers [question]{question}[/question]
[SQL]
"""


In [None]:
import sqlparse

def generate_query(question):
  update_prompt=prompt.format(question=question)
  inputs=tokenizer(update_prompt,return_tensors="pt").to("cuda")
  generate_ids=model.generate(
      **inputs,
      num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
      max_new_tokens=400,
      do_sample=False,
      num_beams=1,
  )
  outputs=tokenizer.batch_decode(generate_ids,skip_special_tokens=True)

  torch.cuda.empty_cache()
  torch.cuda.synchronize()

  return sqlparse.format(outputs[0].split("[sql]")[-1],reindent=True)

In [None]:
question="what was the highest quantitiy sold last month?"
generate_sql=generate_query(question)
print(generate_sql)

In [None]:
question="which salesperson sold large amount of products last month?"
generate_sql=generate_query(question)
print(generate_sql)

In [None]:
question="what was our revenue by the product in the newyork?"
generate_sql=generate_query(question)
print(generate_sql)