In [1]:
!pip install torch transformers bitsandbytes accelerate sqlparse

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

Collecting bitsandbytes
  Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.44.1


In [2]:
torch.cuda.is_available()
available_memory = torch.cuda.get_device_properties(0).total_memory

print(available_memory)

15835660288


In [3]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
    # if you have atleast 15GB of GPU memory, run load the model in float16
model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [None]:
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that profit is revenue minus cost
- Remember that revenue is sale_price multiplied by quantity_sold
- Remember that cost is purchase_price multiplied by quantity_sold
### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE accounts (
    account_id INTEGER PRIMARY KEY, -- Unique ID for each account
    account_name VARCHAR(50), -- Name of the account
    account_type VARCHAR(50) -- Type of the account (e.g., Asset, Liability, Equity)
);

CREATE TABLE transactions (
    transaction_id INTEGER PRIMARY KEY, -- Unique ID for each transaction
    account_id INTEGER, -- ID of the account involved in the transaction
    transaction_date DATE, -- Date the transaction occurred
    amount DECIMAL(10,2), -- Amount of the transaction
    type VARCHAR(50) -- Type of the transaction (e.g., Debit, Credit)
);

CREATE TABLE products (
    product_id INTEGER PRIMARY KEY, -- Unique ID for each product
    name VARCHAR(50), -- Name of the product
    sale_price DECIMAL(10,2), -- Sale price of the product
    purchase_price DECIMAL(10,2), -- Purchase price of the product
);

CREATE TABLE sales (
    sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
    product_id INTEGER, -- ID of the product sold
    sale_date DATE, -- Date the sale occurred
    quantity_sold INTEGER, -- Quantity of product sold
);

CREATE TABLE suppliers (
    supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
    name VARCHAR(50), -- Name of the supplier
    contact_info VARCHAR(100) -- Contact information of the supplier
);

CREATE TABLE product_suppliers (
    product_id INTEGER, -- Product ID supplied
    supplier_id INTEGER, -- Supplier ID who supplied the product
    purchase_price DECIMAL(10,2), -- Purchase price per unit charged by the supplier
    PRIMARY KEY (product_id, supplier_id),
    FOREIGN KEY (product_id) REFERENCES products(product_id),
    FOREIGN KEY (supplier_id) REFERENCES suppliers(supplier_id)
);
### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""

In [None]:
import sqlparse
#Ideally, you should use num_beams=4 for best results. But because of memory constraints, we will stick to just 2 for now.
def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=2,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
question = "Which supplier provided the highest number of products?"
generated_sql = generate_query(question)

In [None]:
print(generated_sql)


SELECT s.name,
       COUNT(ps.product_id) AS product_count
FROM suppliers s
JOIN product_suppliers ps ON s.supplier_id = ps.supplier_id
GROUP BY s.name
ORDER BY product_count DESC
LIMIT 1


In [None]:
question = "What is the balance of each account as of 2023-12-31?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT a.account_id,
       a.account_name,
       a.account_type,
       SUM(t.amount) AS balance
FROM accounts a
JOIN transactions t ON a.account_id = t.account_id
WHERE t.transaction_date <= '2023-12-31'
GROUP BY a.account_id,
         a.account_name,
         a.account_type
ORDER BY a.account_type,
         a.account_id NULLS LAST;


In [None]:
question = "What is total amount of money invested in the purchase of products in the year 2024?"
generated_sql = generate_query(question)
print(generated_sql)


SELECT SUM(p.purchase_price * s.quantity_sold) AS total_amount_spent
FROM products p
JOIN sales s ON p.product_id = s.product_id
WHERE EXTRACT(YEAR
              FROM s.sale_date) = 2024;


## Medical Query

In [4]:
prompt = '''### Task
Generate an SQL query to answer the following question: [QUESTION]{question}[/QUESTION]

### Domain-Specific Notes (Healthcare)
- This query will interact with a healthcare database.
- Patient information is sensitive; ensure queries follow privacy guidelines and avoid returning personally identifiable information (PII) unnecessarily.
- Common metrics include patient counts, appointment statistics, treatment outcomes, and medication prescriptions.
- Health records often use dates (e.g., admission and discharge) and need accurate filtering.
- Clinical data may include attributes such as diagnosis codes (ICD-10), lab results, and prescribed medications.

### Database Schema
The healthcare database contains the following tables:

CREATE TABLE patients (
    patient_id INTEGER PRIMARY KEY,  -- Unique identifier for each patient
    first_name VARCHAR(50),  -- Patient's first name
    last_name VARCHAR(50),  -- Patient's last name
    date_of_birth DATE,  -- Patient's birth date
    gender VARCHAR(10),  -- Patient's gender
    address VARCHAR(100),  -- Patient's address (Avoid using in queries to protect privacy)
    phone_number VARCHAR(15),  -- Contact number (Avoid using in queries to protect privacy)
);

CREATE TABLE appointments (
    appointment_id INTEGER PRIMARY KEY,  -- Unique ID for each appointment
    patient_id INTEGER,  -- ID of the patient who booked the appointment
    doctor_id INTEGER,  -- ID of the doctor handling the appointment
    appointment_date DATE,  -- Date of the appointment
    purpose VARCHAR(100),  -- Reason for appointment (e.g., routine check-up, follow-up)
    status VARCHAR(20)  -- Status of the appointment (e.g., completed, cancelled, pending)
);

CREATE TABLE doctors (
    doctor_id INTEGER PRIMARY KEY,  -- Unique identifier for each doctor
    first_name VARCHAR(50),  -- Doctor's first name
    last_name VARCHAR(50),  -- Doctor's last name
    specialty VARCHAR(50),  -- Doctor's area of specialization (e.g., cardiology, pediatrics)
);

CREATE TABLE prescriptions (
    prescription_id INTEGER PRIMARY KEY,  -- Unique ID for each prescription
    patient_id INTEGER,  -- ID of the patient receiving the prescription
    doctor_id INTEGER,  -- ID of the doctor issuing the prescription
    medication_name VARCHAR(100),  -- Name of the medication prescribed
    dosage VARCHAR(50),  -- Dosage instructions
    prescription_date DATE,  -- Date the prescription was issued
);

CREATE TABLE lab_results (
    result_id INTEGER PRIMARY KEY,  -- Unique ID for each lab result
    patient_id INTEGER,  -- ID of the patient for the lab test
    test_name VARCHAR(50),  -- Name of the lab test (e.g., blood test, MRI)
    test_date DATE,  -- Date the test was conducted
    result_value VARCHAR(50),  -- Result value (e.g., numerical values, positive/negative)
);

### Answer
Considering the above healthcare database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION].
[SQL]
'''

In [5]:
import sqlparse
#Ideally, you should use num_beams=4 for best results. But because of memory constraints, we will stick to just 2 for now.
def generate_query(question):
    updated_prompt = prompt.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=2,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [7]:
question1 = "How many appointments did Dr. Mohan have in the last month?"
generated_sql = generate_query(question1)
print(generated_sql)


SELECT COUNT(a.appointment_id) AS number_of_appointments
FROM appointments a
JOIN doctors d ON a.doctor_id = d.doctor_id
WHERE d.first_name = 'Dr'
  AND d.last_name = 'Mohan'
  AND a.appointment_date >= (CURRENT_DATE - interval '1 month');


In [8]:
question2 = "Show a list of patients who were prescribed medications after a lab test result indicated abnormal levels."
generated_sql = generate_query(question2)
print(generated_sql)


SELECT p.first_name,
       p.last_name
FROM patients p
JOIN prescriptions pr ON p.patient_id = pr.patient_id
JOIN lab_results lr ON p.patient_id = lr.patient_id
WHERE lr.result_value = 'abnormal'
ORDER BY p.first_name,
         p.last_name NULLS LAST;


In [11]:
question3 = "Get a count of prescriptions written by each doctor in the cardiology department."
generated_sql = generate_query(question3)
print(generated_sql)


SELECT d.doctor_id,
       COUNT(p.prescription_id) AS number_of_prescriptions
FROM doctors d
JOIN prescriptions p ON d.doctor_id = p.doctor_id
WHERE d.specialty = 'cardiology'
GROUP BY d.doctor_id
ORDER BY number_of_prescriptions DESC NULLS LAST;


In [13]:
question4 = "List all upcoming appointments for patients above the age of 60."
generated_sql = generate_query(question4)
print(generated_sql)


SELECT a.appointment_id,
       a.patient_id,
       p.first_name,
       p.last_name,
       a.appointment_date
FROM appointments a
JOIN patients p ON a.patient_id = p.patient_id
WHERE EXTRACT(YEAR
              FROM age(CURRENT_DATE, p.date_of_birth)) > 60
  AND a.appointment_date >= CURRENT_DATE
ORDER BY a.appointment_date NULLS LAST;


In [14]:
question5 = "What is the average age of patients who had appointments with doctors specializing in dermatology?"
generated_sql = generate_query(question5)
print(generated_sql)


SELECT AVG(EXTRACT(YEAR
                   FROM AGE(CURRENT_DATE, p.date_of_birth))) AS average_age
FROM patients p
JOIN appointments a ON p.patient_id = a.patient_id
JOIN doctors d ON a.doctor_id = d.doctor_id
WHERE d.specialty = 'dermatology';
