# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import datasets
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import sqlite3
import sqlparse
from tqdm.auto import tqdm
import time
import pickle

import psutil
import GPUtil

# Utils

In [2]:
def get_vm_usage_metrics():
    # CPU usage
    cpu_load = psutil.cpu_percent(interval=1, percpu=True)
    for id, load in enumerate(cpu_load):
        print(f"CPU {id} load: {load:.2f}")
    # RAM usage
    ram = psutil.virtual_memory()
    print(f"RAM Total: {ram.total/(1024**3):.2f} GB, Used: {(ram.used)/(1024**3):.2f} GB")
    # GPU
    if torch.cuda.is_available():
        gpus = GPUtil.getGPUs()
        for gpu in gpus:
            print(f"GPU {gpu.id} ({gpu.name}) load: {gpu.load*100}%")
            print(f"GPU {gpu.id} ({gpu.name}) VRAM Total: {gpu.memoryTotal} MB, Used {gpu.memoryUsed} MB")
    # Disk 
    disk = psutil.disk_usage('/')
    print(f"Disk Total: {disk.total/(1024**3):.2f} GB, Used: {(disk.used)/(1024**3):.2f} GB")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')
get_vm_usage_metrics()

Device: cuda
CPU 0 load: 4.00
CPU 1 load: 2.00
CPU 2 load: 0.00
CPU 3 load: 1.00
RAM Total: 27.41 GB, Used: 1.62 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 3.0 MB
Disk Total: 60.95 GB, Used: 35.58 GB


# Data

In [3]:
ds = datasets.load_dataset('gretelai/synthetic_text_to_sql', streaming=False)
ds_train, ds_test = ds['train'], ds['test']

ds_train

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100000
})

In [4]:
ds_train[0]

{'id': 5097,
 'domain': 'forestry',
 'domain_description': 'Comprehensive data on sustainable forest management, timber production, wildlife habitat, and carbon sequestration in forestry.',
 'sql_complexity': 'single join',
 'sql_complexity_description': 'only one join (specify inner, outer, cross)',
 'sql_task_type': 'analytics and reporting',
 'sql_task_type_description': 'generating reports, dashboards, and analytical insights',
 'sql_prompt': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
 'sql_context': "CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",
 'sql'

# Model

In [None]:
checkpoint = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype=torch.float16).to(device)
model.eval()

In [5]:
def construct_message(prompt, context):
    return [
        {"role": "system", "content": f"The user asks a question. Your task is to generate the SQL query to answer that question. Return SQL query only. The context of the question is the following: '{context}'"},
        {"role": "user", "content": prompt}
    ]

In [6]:
def generate_model_response_batch(messages_list, enable_thinking=True, max_new_tokens=512):
    texts = [
        tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
        for messages in messages_list
    ]

    model_inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        padding_side='left'
    ).to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )

    responses = []
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # Slice to get only generated part
        output_only_ids = output_ids[len(input_ids):].tolist()

        # Try to find `</think>` (id 151668)
        try:
            index = len(output_only_ids) - output_only_ids[::-1].index(151668)
        except ValueError:
            index = 0

        if enable_thinking:
            thinking_content = tokenizer.decode(
                output_only_ids[:index],
                skip_special_tokens=True
            ).strip("\n")
            content = tokenizer.decode(
                output_only_ids[index:],
                skip_special_tokens=True
            ).strip("\n")
        else:
            thinking_content = None
            content = tokenizer.decode(
                output_only_ids,
                skip_special_tokens=True
            ).strip("\n")

        responses.append({
            'thinking_content': thinking_content,
            'content': content
        })

    return responses

# Evaluate

In [7]:
rouge = evaluate.load("rouge")

def normalize_sql(sql):
    return sqlparse.format(sql, reindent=True, keyword_case='upper').strip()

def compute_rouge(reference, prediction):
    result = rouge.compute(predictions=[prediction], references=[reference])
    return result['rougeL']

def evaluate_sql_response(reference, prediction, sql_context):
    # ROUGE-L
    rouge_score = compute_rouge(reference, prediction)
    
    # execution check
    try:
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()
        
        cursor.executescript(sql_context)
        cursor.execute(reference)
        ref_result = cursor.fetchall()
        
        cursor.execute(prediction)
        model_result = cursor.fetchall()
        
        execution_match = ref_result == model_result
    except Exception:
        execution_match = False
    finally:
        conn.close()
    
    # final score
    if execution_match:
        final_score = 1.0
    else:
        final_score = 0.7 * rouge_score

    return {
        "rougeL": round(rouge_score, 4),
        "execution_match": execution_match,
        "final_score": final_score
    }

# Test

In [9]:
BATCH_SIZE = 32
ENABLE_THINKING = False
MAX_NEW_TOKENS = 512


prompts = [ds_test[id]['sql_prompt'] for id in range(len(ds_test))]
contexts = [ds_test[id]['sql_context'] for id in range(len(ds_test))]

responses = []
print(f"Start time: {time.ctime(time.time())}")
for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
    batch_prompts = prompts[i : i + BATCH_SIZE]
    batch_contexts = contexts[i : i + BATCH_SIZE]

    messages_list = [
        construct_message(prompt=p, context=c)
        for p, c in zip(batch_prompts, batch_contexts)
    ]

    batch_responses = generate_model_response_batch(messages_list, enable_thinking=ENABLE_THINKING, max_new_tokens=MAX_NEW_TOKENS)

    responses.extend(batch_responses)

print(f"End time: {time.ctime(time.time())}")

Start time: Sun Oct 19 19:40:19 2025


  0%|          | 0/183 [00:00<?, ?it/s]

End time: Sun Oct 19 20:03:14 2025


In [10]:
references = [ds_test[id]['sql'] for id in range(len(ds_test))]
predictions = [responses[id]['content'] for id in range(len(ds_test))]

scores = [
    evaluate_sql_response(
        reference=reference,
        prediction=prediction,
        sql_context=context
    )
    for reference, prediction, context in tqdm(zip(references, predictions, contexts), total=len(ds_test))
]

print(f"Mean test set score: {np.mean([score['final_score'] for score in scores]):.3f}")

  0%|          | 0/5851 [00:00<?, ?it/s]

Mean test set score: 0.679


In [11]:
with open('test_eval_predictions.pkl', 'wb') as f:
    pickle.dump(predictions, f)

In [12]:
# with open('test_eval_predictions.pkl', 'rb') as f:
#     pred_test = pickle.load(f)

# pred_test[:10]

# 4-bit quantization

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

checkpoint = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint, 
    quantization_config=quantization_config,
    device_map='auto'
)

model.eval()

Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): Linear4bit(in_features=1024, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=1024, bias=False)
          (q_norm): Qwen3RMSNorm((128,), eps=1e-06)
          (k_norm): Qwen3RMSNorm((128,), eps=1e-06)
        )
        (mlp): Qwen3MLP(
          (gate_proj): Linear4bit(in_features=1024, out_features=3072, bias=False)
          (up_proj): Linear4bit(in_features=1024, out_features=3072, bias=False)
          (down_proj): Linear4bit(in_features=3072, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen3RMSNorm((1024,), eps=1e-06)
        (po

In [9]:
BATCH_SIZE = 32
ENABLE_THINKING = False
MAX_NEW_TOKENS = 512


prompts = [ds_test[id]['sql_prompt'] for id in range(len(ds_test))]
contexts = [ds_test[id]['sql_context'] for id in range(len(ds_test))]

responses = []
print(f"Start time: {time.ctime(time.time())}")
for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
    batch_prompts = prompts[i : i + BATCH_SIZE]
    batch_contexts = contexts[i : i + BATCH_SIZE]

    messages_list = [
        construct_message(prompt=p, context=c)
        for p, c in zip(batch_prompts, batch_contexts)
    ]

    batch_responses = generate_model_response_batch(messages_list, enable_thinking=ENABLE_THINKING, max_new_tokens=MAX_NEW_TOKENS)

    responses.extend(batch_responses)

print(f"End time: {time.ctime(time.time())}")

Start time: Sun Oct 19 20:18:17 2025


  0%|          | 0/183 [00:00<?, ?it/s]

End time: Sun Oct 19 20:58:24 2025


In [10]:
references = [ds_test[id]['sql'] for id in range(len(ds_test))]
predictions = [responses[id]['content'] for id in range(len(ds_test))]

scores = [
    evaluate_sql_response(
        reference=reference,
        prediction=prediction,
        sql_context=context
    )
    for reference, prediction, context in tqdm(zip(references, predictions, contexts), total=len(ds_test))
]

print(f"Mean test set score: {np.mean([score['final_score'] for score in scores]):.3f}")

  0%|          | 0/5851 [00:00<?, ?it/s]

Mean test set score: 0.648
