# Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import datasets
import evaluate
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, EarlyStoppingCallback
from transformers.trainer_utils import get_last_checkpoint
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import cpu_offload
import sqlite3
import sqlparse
from tqdm.auto import tqdm
from datetime import datetime
import pickle
import wandb
import psutil
import GPUtil
import os
import gc
import math

import _config

In [2]:
ENABLE_THINKING = False

In [3]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.environ["WANDB_API_KEY"] = _config.WANDB_API_KEY
os.environ["WANDB_PROJECT"] = _config.WANDB_PROJECT

# Utils

## General

In [4]:
def get_vm_usage_metrics():
    # CPU usage
    cpu_load = psutil.cpu_percent(interval=1, percpu=True)
    for id, load in enumerate(cpu_load):
        print(f"CPU {id} load: {load:.2f}")
    # RAM usage
    ram = psutil.virtual_memory()
    print(f"RAM Total: {ram.total/(1024**3):.2f} GB, Used: {(ram.used)/(1024**3):.2f} GB")
    # GPU
    if torch.cuda.is_available():
        gpus = GPUtil.getGPUs()
        for gpu in gpus:
            print(f"GPU {gpu.id} ({gpu.name}) load: {gpu.load*100}%")
            print(f"GPU {gpu.id} ({gpu.name}) VRAM Total: {gpu.memoryTotal} MB, Used {gpu.memoryUsed} MB")
    # Disk 
    disk = psutil.disk_usage('/')
    print(f"Disk Total: {disk.total/(1024**3):.2f} GB, Used: {(disk.used)/(1024**3):.2f} GB")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')
get_vm_usage_metrics()

Device: cuda
CPU 0 load: 0.00
CPU 1 load: 2.00
CPU 2 load: 0.00
CPU 3 load: 0.00
RAM Total: 27.41 GB, Used: 1.48 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 3.0 MB
Disk Total: 60.95 GB, Used: 41.14 GB


In [5]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable params: {trainable_params} || All params: {all_param} || Trainable %: {100 * trainable_params / all_param:.2f}"
    )

## Model

In [6]:
def construct_message(prompt, context):
    return [
        {"role": "system", "content": f"The user asks a question. Your task is to generate the SQL query to answer that question. Return SQL query only. The context of the question is the following: '{context}'"},
        {"role": "user", "content": prompt}
    ]

In [7]:
def generate_model_response_batch(model, tokenizer, messages_list, enable_thinking=True, max_new_tokens=512):
    texts = [
        tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
        for messages in messages_list
    ]

    model_inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        padding_side='left'
    ).to(model.device)

    model.eval()
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )

    responses = []
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # Slice to get only generated part
        output_only_ids = output_ids[len(input_ids):].tolist()

        # Try to find `</think>` (id 151668)
        try:
            index = len(output_only_ids) - output_only_ids[::-1].index(151668)
        except ValueError:
            index = 0

        if enable_thinking:
            thinking_content = tokenizer.decode(
                output_only_ids[:index],
                skip_special_tokens=True
            ).strip("\n")
            content = tokenizer.decode(
                output_only_ids[index:],
                skip_special_tokens=True
            ).strip("\n")
        else:
            thinking_content = None
            content = tokenizer.decode(
                output_only_ids,
                skip_special_tokens=True
            ).strip("\n")

        responses.append({
            'thinking_content': thinking_content,
            'content': content
        })

    return responses

## Formatting functions

In [8]:
# used for training
def construct_message_with_assistant_content(example):
    messages = construct_message(example['sql_prompt'], example['sql_context'])
    messages.append({
        'role': 'assistant',
        'content': example['sql']
    })
    return {'messages': messages}

In [9]:
def formatting_func(example, enable_thinking=ENABLE_THINKING):
    return tokenizer.apply_chat_template(
        example["messages"],
        tokenize=False,
        add_generation_prompt=False, # no generation prompt during training
        enable_thinking=ENABLE_THINKING 
    )

## Evaluate

In [10]:
rouge = evaluate.load("rouge")

def normalize_sql(sql):
    return sqlparse.format(sql, reindent=True, keyword_case='upper').strip()

def compute_rouge(reference, prediction):
    result = rouge.compute(predictions=[prediction], references=[reference])
    return result['rougeL']

def evaluate_sql_response(reference, prediction, sql_context):
    # ROUGE-L
    rouge_score = compute_rouge(reference, prediction)
    
    # execution check
    try:
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()
        
        cursor.executescript(sql_context)
        cursor.execute(reference)
        ref_result = cursor.fetchall()
        
        cursor.execute(prediction)
        model_result = cursor.fetchall()
        
        execution_match = ref_result == model_result
    except Exception:
        execution_match = False
    finally:
        conn.close()
    
    # final score
    if execution_match:
        final_score = 1.0
    else:
        final_score = 0.7 * rouge_score

    return {
        "rougeL": round(rouge_score, 4),
        "execution_match": execution_match,
        "final_score": final_score
    }

# Data

In [11]:
ds = datasets.load_dataset('gretelai/synthetic_text_to_sql', streaming=False)
ds_train, ds_test = ds['train'], ds['test']

split = ds_train.train_test_split(test_size=0.025, seed=42)
ds_train = split['train']
ds_valid = split['test']

ds_train

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 97500
})

# QLoRA

In [None]:
print(len(ds_train), len(ds_valid), len(ds_test))

ds_train_with_assistant_content = ds_train.map(construct_message_with_assistant_content)
ds_valid_with_assistant_content = ds_valid.map(construct_message_with_assistant_content)

get_vm_usage_metrics()


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

checkpoint = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    device_map="auto",
    quantization_config=bnb_config
)

model.config.use_cache = False
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()




timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

# resuming the prev run
timestamp = '2025-10-23_05-59-42'
RUN_NAME = f'qlora-final-model-all-linear-r64-{timestamp}'
run_id = 'cj9mbl11'
wandb.init(
    project=os.environ["WANDB_PROJECT"],
    name=RUN_NAME,
    id=run_id,         # resume previous run if available
    resume="allow",    # allows resuming crashed run
)


RESUME_TRAINING = True
OUTPUT_DIR = "./qlora-final_model_all_linear_r64-output"
PER_DEVICE_BATCH_SIZE = 2  # higher values --> OOM

optimizer = 'paged_adamw_8bit'#'nadam'
effective_batch_size = 16
learning_rate = 1e-5
weight_decay = 0.0
betas = (0.9, 0.9999)
warmup_ratio = 0.2
epochs = 1
gradient_accumulation_steps = int(effective_batch_size / PER_DEVICE_BATCH_SIZE)
lora_r = 16*4
lora_alpha = 64*4
lora_dropout = 0.01


training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    optim=optimizer,   # better for 4-bit models
    num_train_epochs=epochs,
    weight_decay=weight_decay,
    lr_scheduler_type="cosine",
    warmup_ratio=warmup_ratio,
    save_strategy="steps",
    save_steps=gradient_accumulation_steps*5,
    save_total_limit=2,
    eval_strategy="steps",
    eval_steps=gradient_accumulation_steps*5,
    logging_strategy="steps",
    logging_steps=gradient_accumulation_steps*5,
    report_to=['wandb'],
    run_name=RUN_NAME,
    # bf16=True,
    fp16=True,
    # fp16_full_eval=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    max_grad_norm=1,
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)


peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules='all-linear'
)
# model.requires_grad_(False)                     # freeze base weights (precautionary)
model_peft = get_peft_model(model, peft_config) # inject a LoRA adapter
print_trainable_parameters(model_peft)

trainer = SFTTrainer(
    model=model_peft,
    train_dataset=ds_train_with_assistant_content,
    eval_dataset=ds_valid_with_assistant_content,
    formatting_func=formatting_func,
    args=training_args,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=25)]
)


# Training setup summary
dataset_size = len(ds_train_with_assistant_content)
steps_per_epoch = dataset_size // (PER_DEVICE_BATCH_SIZE * gradient_accumulation_steps)
total_steps = steps_per_epoch * epochs
warmup_steps = int(total_steps * warmup_ratio)

print("===== Training Setup Summary =====")
print(f"Num epochs:            {epochs}")
print(f"Effective batch size:  {effective_batch_size}")
print(f"Per-device batch size: {PER_DEVICE_BATCH_SIZE}")
print(f"Gradient accumulation: {gradient_accumulation_steps}")
print(f"Dataset size:          {dataset_size}")
print(f"Steps per epoch:       {steps_per_epoch}")
print(f"Total training steps:  {total_steps}")
print(f"Warmup steps:          {warmup_steps}")
print(f"Logging steps:         {training_args.logging_steps}")
print("===================================")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")


# Training
last_checkpoint = None
if RESUME_TRAINING and os.path.isdir(OUTPUT_DIR):
    last_checkpoint = get_last_checkpoint(OUTPUT_DIR)

if last_checkpoint is not None:
    print(f"Resuming training from checkpoint: {last_checkpoint}")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print("Starting fresh training run")
    trainer.train()

print(f"End time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")


# WandB logging of eval metrics
for log in trainer.state.log_history:
    if 'eval_loss' in log:
        wandb.log({
            "eval_loss": log['eval_loss'],
            "eval_perplexity": math.exp(log['eval_loss']),
            "step": log['step'],
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "betas": betas,
            "warmup_ratio": warmup_ratio,
            "effective_batch_size": effective_batch_size,
            "optimizer": optimizer
        })

wandb.finish()  # finish the run 

97500 2500 5851
CPU 0 load: 0.00
CPU 1 load: 0.00
CPU 2 load: 0.00
CPU 3 load: 0.00
RAM Total: 27.41 GB, Used: 1.49 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 3.0 MB
Disk Total: 60.95 GB, Used: 40.97 GB


[34m[1mwandb[0m: Currently logged in as: [33molialeshka[0m ([33molialeshka-none[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Trainable params: 40370176 || All params: 416219136 || Trainable %: 9.70


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


===== Training Setup Summary =====
Num epochs:            1
Effective batch size:  16
Per-device batch size: 2
Gradient accumulation: 8
Dataset size:          97500
Steps per epoch:       6093
Total training steps:  6093
Warmup steps:          1218
Logging steps:         40
Start time: 2025-10-23_12-48-13
Resuming training from checkpoint: ./qlora-final_model_all_linear_r64-output/checkpoint-2680


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
2720,0.4267,0.432508,0.426978,123655.0,0.884865
2760,0.4345,0.431839,0.438531,248878.0,0.884937
2800,0.4217,0.431473,0.430893,376696.0,0.884978
2840,0.4227,0.431305,0.42644,502510.0,0.88496
2880,0.4262,0.431643,0.428791,625112.0,0.884645
2920,0.429,0.430185,0.425335,751874.0,0.88484
2960,0.4227,0.429823,0.430036,878979.0,0.885062
3000,0.4265,0.429208,0.429833,1001192.0,0.885091
3040,0.4285,0.428892,0.429349,1126073.0,0.884924
3080,0.4249,0.428403,0.432246,1244931.0,0.885061


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [13]:
model_path = os.path.join(OUTPUT_DIR, 'final')
trainer.save_model(model_path)

## Test

In [12]:
torch.cuda.empty_cache()
get_vm_usage_metrics()

CPU 0 load: 0.00
CPU 1 load: 1.00
CPU 2 load: 1.00
CPU 3 load: 0.00
RAM Total: 27.41 GB, Used: 1.51 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 3.0 MB
Disk Total: 60.95 GB, Used: 41.14 GB


In [13]:
model_path = './qlora-final_model_all_linear_r64-output/final'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)

print_trainable_parameters(model)
model.eval()

Trainable params: 0 || All params: 636420096 || Trainable %: 0.00


Qwen3ForCausalLM(
  (model): Qwen3Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-27): 28 x Qwen3DecoderLayer(
        (self_attn): Qwen3Attention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=1024, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.01, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=1024, out_features=64, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=64, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=1024, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(
              (default

In [None]:
BATCH_SIZE = 32
ENABLE_THINKING = False
MAX_NEW_TOKENS = 512


prompts = [ds_test[id]['sql_prompt'] for id in range(len(ds_test))]
contexts = [ds_test[id]['sql_context'] for id in range(len(ds_test))]

responses = []
print(f"Start time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
    batch_prompts = prompts[i : i + BATCH_SIZE]
    batch_contexts = contexts[i : i + BATCH_SIZE]

    messages_list = [
        construct_message(prompt=p, context=c)
        for p, c in zip(batch_prompts, batch_contexts)
    ]

    batch_responses = generate_model_response_batch(model, tokenizer, messages_list, enable_thinking=ENABLE_THINKING, max_new_tokens=MAX_NEW_TOKENS)

    responses.extend(batch_responses)

print(f"End time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")

Start time: 2025-10-24_10-46-59


  0%|          | 0/183 [00:00<?, ?it/s]

In [None]:
references = [ds_test[id]['sql'] for id in range(len(ds_test))]
predictions = [responses[id]['content'] for id in range(len(ds_test))]

scores = [
    evaluate_sql_response(
        reference=reference,
        prediction=prediction,
        sql_context=context
    )
    for reference, prediction, context in tqdm(zip(references, predictions, contexts), total=len(ds_test))
]

In [16]:
print(f"Mean test set score: {np.mean([score['final_score'] for score in scores]):.3f}")

Mean test set score: 0.755
