In [1]:
import sqlparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TRANSFORMERS_CACHE

In [2]:
print(torch.cuda.is_available())
torch.cuda.empty_cache()

available_memory = torch.cuda.get_device_properties(0).total_memory
print(available_memory)

print(TRANSFORMERS_CACHE)

True
17170956288
C:\Users\zly20\.cache\huggingface\hub


In [3]:
model_name = "Qwen/Qwen2.5-7B-Instruct"
cache_dir = "E:/Data File/transformers.cache"  # model cache directory
tokenizer = AutoTokenizer.from_pretrained(model_name,cache_dir=cache_dir)


model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    load_in_8bit=True,
    device_map="auto",
    use_cache=True,
    cache_dir=cache_dir
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- You do not need to generate your thought process but just the answer

For example:
  question: 'List all the businesses with more than 4.5 stars',
  answer: 'SELECT name FROM business WHERE rating > 4.5'

### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE business (
  bid INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each business
  business_id VARCHAR(255) NOT NULL UNIQUE, -- External business identifier
  name VARCHAR(255) NOT NULL, -- Name of the business
  full_address VARCHAR(255) NOT NULL, -- Full street address
  city VARCHAR(255) NOT NULL, -- City of the business
  latitude VARCHAR(255) NOT NULL, -- Latitude coordinate
  longitude VARCHAR(255) NOT NULL, -- Longitude coordinate
  review_count BIGINT NOT NULL, -- Number of reviews the business has received
  is_open TINYINT(1) NOT NULL, -- Indicates if the business is open (1) or closed (0)
  rating FLOAT DEFAULT NULL, -- Average rating of the business
  state VARCHAR(255) DEFAULT NULL -- State or province
);

CREATE TABLE category (
  id INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each category record
  business_id VARCHAR(255) NOT NULL, -- ID of the business this category belongs to
  category_name VARCHAR(255) NOT NULL, -- Name of the category
  FOREIGN KEY (business_id) REFERENCES business(business_id)
);

CREATE TABLE checkin (
  cid INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each check-in record
  business_id VARCHAR(255) NOT NULL, -- ID of the business where the check-in occurred
  count INTEGER DEFAULT NULL, -- Number of check-ins
  day VARCHAR(12) DEFAULT NULL, -- Day of the week the check-in occurred
  FOREIGN KEY (business_id) REFERENCES business(business_id)
);

CREATE TABLE neighborhood (
  id INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each neighborhood record
  business_id VARCHAR(255) NOT NULL, -- ID of the business in the neighborhood
  neighborhood_name VARCHAR(255) NOT NULL, -- Name of the neighborhood
  FOREIGN KEY (business_id) REFERENCES business(business_id)
);

CREATE TABLE review (
  rid INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each review
  business_id VARCHAR(255) NOT NULL, -- ID of the business being reviewed
  user_id VARCHAR(255) NOT NULL, -- ID of the user who wrote the review
  rating FLOAT DEFAULT NULL, -- Rating given in the review
  text LONGTEXT NOT NULL, -- Content of the review
  year INTEGER DEFAULT NULL, -- Year the review was posted
  month VARCHAR(10) DEFAULT NULL, -- Month the review was posted
  FOREIGN KEY (business_id) REFERENCES business(business_id),
  FOREIGN KEY (user_id) REFERENCES user(user_id)
);

CREATE TABLE tip (
  tip_id INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each tip
  business_id VARCHAR(255) NOT NULL, -- ID of the business the tip is about
  text LONGTEXT NOT NULL, -- Content of the tip
  user_id VARCHAR(255) NOT NULL, -- ID of the user who left the tip
  likes INTEGER NOT NULL, -- Number of likes the tip received
  year INTEGER DEFAULT NULL, -- Year the tip was posted
  month VARCHAR(10) DEFAULT NULL, -- Month the tip was posted
  FOREIGN KEY (business_id) REFERENCES business(business_id),
  FOREIGN KEY (user_id) REFERENCES user(user_id)
);

CREATE TABLE user (
  uid INTEGER PRIMARY KEY AUTO_INCREMENT, -- Unique ID for each user record
  user_id VARCHAR(255) NOT NULL UNIQUE, -- External user identifier
  name VARCHAR(255) NOT NULL -- Name of the user
);

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [5]:
def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=512,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[1].split("[/SQL]")[0], reindent=True)

In [6]:
question = "List all the reviews which rated a business less than 1"
generated_sql = generate_query(question)
print(generated_sql)




SELECT *
FROM review
WHERE rating < 1;

### Answer
