# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import datasets
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import evaluate
import wandb
from datetime import datetime
import time
from tqdm.auto import tqdm
import sqlite3
import sqlparse
import _config

import os
import psutil
import GPUtil
import gc


# Set the verbosity to WARNING to suppress INFO messages
evaluate.logging.set_verbosity_warning()

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["WANDB_API_KEY"] = _config.WANDB_API_KEY
os.environ["WANDB_PROJECT"] = _config.WANDB_PROJECT

ENABLE_THINKING = False

2026-02-15 08:02:07.711658: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[2026-02-15 08:02:14] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[2026-02-15 08:02:14] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[2026-02-15 08:02:14] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[2026-02-15 08:02:14] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[2026-02-15 08:02:15] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[2026-02-15 08:02:15] INFO _client.py:1025: HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK

# Utils

In [2]:
def get_vm_usage_metrics():
    # CPU usage
    cpu_load = psutil.cpu_percent(interval=1, percpu=True)
    for id, load in enumerate(cpu_load):
        print(f"CPU {id} load: {load:.2f}")
    # RAM usage
    ram = psutil.virtual_memory()
    print(f"RAM Total: {ram.total/(1024**3):.2f} GB, Used: {(ram.used)/(1024**3):.2f} GB")
    # GPU
    if torch.cuda.is_available():
        gpus = GPUtil.getGPUs()
        for gpu in gpus:
            print(f"GPU {gpu.id} ({gpu.name}) load: {gpu.load*100}%")
            print(f"GPU {gpu.id} ({gpu.name}) VRAM Total: {gpu.memoryTotal} MB, Used {gpu.memoryUsed} MB")
    # Disk 
    disk = psutil.disk_usage('/')
    print(f"Disk Total: {disk.total/(1024**3):.2f} GB, Used: {(disk.used)/(1024**3):.2f} GB")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Device: {device}')
get_vm_usage_metrics()

Device: cuda
CPU 0 load: 2.00
CPU 1 load: 3.00
CPU 2 load: 2.00
CPU 3 load: 0.00
RAM Total: 27.40 GB, Used: 2.29 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 3.0 MB
Disk Total: 60.95 GB, Used: 51.43 GB


In [3]:
def generate_model_response_batch(model, tokenizer, messages_list, enable_thinking=False, max_new_tokens=512):
    texts = [
        tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
        for messages in messages_list
    ]

    model_inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        padding_side='left'
    ).to(model.device)

    model.eval()
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )

    responses = []
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # Slice to get only generated part
        output_only_ids = output_ids[len(input_ids):].tolist()

        # Try to find `</think>` (id 151668)
        try:
            index = len(output_only_ids) - output_only_ids[::-1].index(151668)
        except ValueError:
            index = 0

        if enable_thinking:
            thinking_content = tokenizer.decode(
                output_only_ids[:index],
                skip_special_tokens=True
            ).strip("\n")
            content = tokenizer.decode(
                output_only_ids[index:],
                skip_special_tokens=True
            ).strip("\n")
        else:
            thinking_content = None
            content = tokenizer.decode(
                output_only_ids,
                skip_special_tokens=True
            ).strip("\n")

        responses.append({
            'thinking_content': thinking_content,
            'content': content
        })

    return responses

# Data

In [4]:
ds = datasets.load_dataset('gretelai/synthetic_text_to_sql', streaming=False)
ds_train, ds_test = ds['train'], ds['test']

ds_train_subset = ds_train.train_test_split(test_size=0.1, seed=42)['test']
split = ds_train_subset.train_test_split(test_size=0.025, seed=42)
ds_train = split['train']
ds_valid = split['test']

ds_train

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 9750
})

In [5]:
# datasets for GRPO must include a column "prompt"
def construct_message(example):
    return {"prompt": [
        {"role": "system", "content": f"The user asks a question. Your task is to generate the SQL query to answer that question. Return SQL query only. The context of the question is the following: '{example['sql_context']}'"},
        {"role": "user", "content": example['sql_prompt']}
    ]}
ds_train = ds_train.map(construct_message)
ds_valid = ds_valid.map(construct_message)

# rename the ground_truth column
ds_train = ds_train.rename_column("sql", "ground_truth")
ds_valid = ds_valid.rename_column("sql", "ground_truth")

ds_train

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'ground_truth', 'sql_explanation', 'prompt'],
    num_rows: 9750
})

# Reward function

In [11]:
reward_model = AutoModelForSequenceClassification.from_pretrained(
    "rm-output/best_model",
    dtype=torch.bfloat16,
    device_map="auto" if torch.cuda.is_available() else None
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", padding_side='left')


def get_reward_model_scores(completions, batch_size=8):
    scores = []
    for i in range(0, len(completions), batch_size):
        batch = completions[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = reward_model(**inputs)
            logits = outputs.logits
            if logits.shape[-1] == 1:
                probs = torch.sigmoid(logits.squeeze(-1))
            else:
                probs = torch.softmax(logits, dim=-1)[:, 1]
            batch_scores = probs.cpu().float().tolist()
        scores.extend(batch_scores)
    return scores

[2026-02-15 08:03:25] INFO modeling.py:1004: We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


In [12]:
rouge = evaluate.load("rouge")

def normalize_sql(sql):
    return sqlparse.format(sql, reindent=True, keyword_case='upper').strip()

def compute_rouge(reference, prediction):
    result = rouge.compute(predictions=[prediction], references=[reference])
    return result['rougeL']


def evaluate_sql_response(reference, prediction, sql_context):
    # ROUGE-L
    rouge_score = compute_rouge(reference, prediction)

    for ref, pred,  in zip(reference, prediction):
        # execution check
        try:
            conn = sqlite3.connect(":memory:")
            cursor = conn.cursor()
            
            cursor.executescript(sql_context)
            cursor.execute(reference)
            ref_result = cursor.fetchall()
            
            cursor.execute(prediction)
            model_result = cursor.fetchall()
            
            execution_match = ref_result == model_result
        except Exception:
            execution_match = False
        finally:
            conn.close()
        
        # final score
        if execution_match:
            final_score = 1.0
        else:
            final_score = 0.7 * rouge_score

    return {
        "rougeL": round(rouge_score, 4),
        "execution_match": execution_match,
        "final_score": final_score
    }

In [13]:
def reward_func(completions, ground_truth, sql_context, **kwargs):
    scores = []
    rm_scores = []

    for reference, prediction, context in zip(ground_truth, completions, sql_context):
        result = evaluate_sql_response(
            reference=reference,
            prediction=prediction[0]["content"],
            sql_context=context
        )
        scores.append(result["final_score"])

    completion_texts = [c[0]["content"] for c in completions]
    rm_scores = get_reward_model_scores(completion_texts)

    combined_scores = [
        0.2 * rm + 0.8 * score
        for rm, score in zip(rm_scores, scores)
    ]
    return combined_scores

# Model

In [14]:
checkpoint = "Qwen/Qwen3-0.6B"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    # attn_implementation="sdpa",
    device_map="auto",
    dtype=torch.float16,
    # quantization_config=bnb_config
)

model.config.use_cache = False
model.gradient_checkpointing_enable()
# model = prepare_model_for_kbit_training(model)
model.enable_input_require_grads()

get_vm_usage_metrics()

[2026-02-15 08:03:32] INFO modeling.py:1004: We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


CPU 0 load: 0.00
CPU 1 load: 2.00
CPU 2 load: 1.00
CPU 3 load: 0.00
RAM Total: 27.40 GB, Used: 2.61 GB
GPU 0 (Tesla T4) load: 0.0%
GPU 0 (Tesla T4) VRAM Total: 16384.0 MB, Used 4007.0 MB
Disk Total: 60.95 GB, Used: 51.43 GB


# GRPO

In [15]:
torch.cuda.empty_cache()

timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
# timestamp = '2026-01-25_09-24-49'
RUN_NAME = f'grpo-rm-lr1e8-epochs1-{timestamp}'
OUTPUT_DIR = './grpo-rm-lr1e8-output'
RESUME_TRAINING = True

PER_DEVICE_BATCH_SIZE = 8
effective_batch_size = 16
epochs=1
learning_rate = 1e-8 # changed from 1e-5
warmup_ratio = 0.1
lora_r = 16*4
lora_alpha = 64*4
lora_dropout = 0.01

gradient_accumulation_steps = int(effective_batch_size / PER_DEVICE_BATCH_SIZE)

wandb.init(
    project=os.environ["WANDB_PROJECT"],
    name=RUN_NAME,
    # id='yrm8qwl9' ,         # resume previous run if available
    # resume="allow",    # allows resuming crashed run
)



training_args = GRPOConfig(
    output_dir=OUTPUT_DIR,
    
    chat_template_kwargs = {"enable_thinking": False},
    num_train_epochs=epochs,
    num_generations=8,
    # use_liger_kernel=True,
    
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    lr_scheduler_type="cosine",
    warmup_ratio=warmup_ratio,
    save_strategy="steps",
    save_steps=30,
    save_total_limit=2,
    eval_strategy="steps",
    eval_steps=30,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE*4,
    eval_accumulation_steps=1,
    # eval_kwargs={"num_generations": 1},
    num_generations_eval=1,
    logging_strategy="steps",
    logging_steps=30,
    report_to=['wandb'],
    run_name=RUN_NAME,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    max_grad_norm=1,
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    # generate_during_eval=True
)

peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules='all-linear'
)
# model.requires_grad_(False)                     # freeze base weights (precautionary)
model_peft = get_peft_model(model, peft_config) # inject a LoRA adapter

trainer = GRPOTrainer(
    processing_class=tokenizer,
    model=model_peft,
    args=training_args,
    reward_funcs=[reward_func],
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)


# Training setup summary
dataset_size = len(ds_train)
steps_per_epoch = dataset_size // (PER_DEVICE_BATCH_SIZE * gradient_accumulation_steps)
total_steps = steps_per_epoch * epochs
warmup_steps = int(total_steps * warmup_ratio)

print("===== Training Setup Summary =====")
print(f"Num epochs:            {epochs}")
print(f"Effective batch size:  {effective_batch_size}")
print(f"Per-device batch size: {PER_DEVICE_BATCH_SIZE}")
print(f"Gradient accumulation: {gradient_accumulation_steps}")
print(f"Dataset size:          {dataset_size}")
print(f"Steps per epoch:       {steps_per_epoch}")
print(f"Total training steps:  {total_steps}")
print(f"Warmup steps:          {warmup_steps}")
print(f"Logging steps:         {training_args.logging_steps}")
print("===================================")
print(f"Start time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")


# Training
last_checkpoint = None
if RESUME_TRAINING and os.path.isdir(OUTPUT_DIR):
    last_checkpoint = get_last_checkpoint(OUTPUT_DIR)

if last_checkpoint is not None:
    print(f"Resuming training from checkpoint: {last_checkpoint}")
    trainer.train(resume_from_checkpoint=last_checkpoint)
else:
    print("Starting fresh training run")
    trainer.train()

print(f"End time: {datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")

[2026-02-15 08:03:35] INFO spawn.py:77: x86_64-linux-gnu-gcc -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O2 -Wall -fPIC -c /tmp/tmp4djfpbrg/test.c -o /tmp/tmp4djfpbrg/test.o
[2026-02-15 08:03:35] INFO spawn.py:77: x86_64-linux-gnu-gcc /tmp/tmp4djfpbrg/test.o -laio -o /tmp/tmp4djfpbrg/a.out
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
[2026-02-15 08:03:35] INFO spawn.py:77: x86_64-linux-gnu-gcc -fno-strict-overflow -Wsign-compare -DNDEBUG -g -O2 -Wall -fPIC -c /tmp/tmp8x8q3h_8/test.c -o /tmp/tmp8x8q3h_8/test.o
[2026-02-15 08:03:35] INFO spawn.py:77: x86_64-linux-gnu-gcc /tmp/tmp8x8q3h_8/test.o -L/usr -L/usr/lib64 -lcufile -o /tmp/tmp8x8q3h_8/a.out
/usr/bin/ld: cannot find -lcufile: No such file or directory
collect2: error: ld returned 1 exit status
The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation

===== Training Setup Summary =====
Num epochs:            1
Effective batch size:  16
Per-device batch size: 8
Gradient accumulation: 2
Dataset size:          9750
Steps per epoch:       609
Total training steps:  609
Warmup steps:          60
Logging steps:         30
Start time: 2026-02-15_08-03-36
Starting fresh training run


[2026-02-15 08:03:40] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:40] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:40] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:40] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:40] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:41] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:42] INFO rouge_scorer.py:83: Using default tok

Step,Training Loss,Validation Loss
30,-0.0145,0.0
60,0.0034,0.0
90,0.0098,0.0
120,0.0494,0.0
150,0.0041,0.0
180,-0.0041,0.0
210,0.0123,0.0
240,0.0196,0.0
270,0.0108,0.0
300,-0.0165,0.0


[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:51] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tokenizer.
[2026-02-15 08:03:52] INFO rouge_scorer.py:83: Using default tok

KeyboardInterrupt: 

In [None]:
# model.save_pretrained(f"{OUTPUT_DIR}/best_model")

# Test

In [None]:
OUTPUT_DIR = './sft-grpo-lr1e6-ngen4-output'
checkpoint = f"{OUTPUT_DIR}/checkpoint-4875/"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, dtype=torch.float16).to(device)
model.eval()

ds = datasets.load_dataset('gretelai/synthetic_text_to_sql', streaming=False)
ds_train, ds_test = ds['train'], ds['test']


def construct_message(prompt, context):
    return [
        {"role": "system", "content": f"The user asks a question. Your task is to generate the SQL query to answer that question. Return SQL query only. The context of the question is the following: '{context}'"},
        {"role": "user", "content": prompt}
    ]

def generate_model_response_batch(messages_list, enable_thinking=True, max_new_tokens=512):
    texts = [
        tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=enable_thinking
        )
        for messages in messages_list
    ]

    model_inputs = tokenizer(
        texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        padding_side='left'
    ).to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_tokens
    )

    responses = []
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):
        # Slice to get only generated part
        output_only_ids = output_ids[len(input_ids):].tolist()

        # Try to find `</think>` (id 151668)
        try:
            index = len(output_only_ids) - output_only_ids[::-1].index(151668)
        except ValueError:
            index = 0

        if enable_thinking:
            thinking_content = tokenizer.decode(
                output_only_ids[:index],
                skip_special_tokens=True
            ).strip("\n")
            content = tokenizer.decode(
                output_only_ids[index:],
                skip_special_tokens=True
            ).strip("\n")
        else:
            thinking_content = None
            content = tokenizer.decode(
                output_only_ids,
                skip_special_tokens=True
            ).strip("\n")

        responses.append({
            'thinking_content': thinking_content,
            'content': content
        })

    return responses


rouge = evaluate.load("rouge")

def normalize_sql(sql):
    return sqlparse.format(sql, reindent=True, keyword_case='upper').strip()

def compute_rouge(reference, prediction):
    result = rouge.compute(predictions=[prediction], references=[reference])
    return result['rougeL']

def evaluate_sql_response(reference, prediction, sql_context):
    # ROUGE-L
    rouge_score = compute_rouge(reference, prediction)
    
    # execution check
    try:
        conn = sqlite3.connect(":memory:")
        cursor = conn.cursor()
        
        cursor.executescript(sql_context)
        cursor.execute(reference)
        ref_result = cursor.fetchall()
        
        cursor.execute(prediction)
        model_result = cursor.fetchall()
        
        execution_match = ref_result == model_result
    except Exception:
        execution_match = False
    finally:
        conn.close()
    
    # final score
    if execution_match:
        final_score = 1.0
    else:
        final_score = 0.7 * rouge_score

    return {
        "rougeL": round(rouge_score, 4),
        "execution_match": execution_match,
        "final_score": final_score
    }

In [None]:
BATCH_SIZE = 32
ENABLE_THINKING = False
MAX_NEW_TOKENS = 512


prompts = [ds_test[id]['sql_prompt'] for id in range(len(ds_test))]
contexts = [ds_test[id]['sql_context'] for id in range(len(ds_test))]

responses = []
print(f"Start time: {time.ctime(time.time())}")
for i in tqdm(range(0, len(prompts), BATCH_SIZE)):
    batch_prompts = prompts[i : i + BATCH_SIZE]
    batch_contexts = contexts[i : i + BATCH_SIZE]

    messages_list = [
        construct_message(prompt=p, context=c)
        for p, c in zip(batch_prompts, batch_contexts)
    ]

    batch_responses = generate_model_response_batch(messages_list, enable_thinking=ENABLE_THINKING, max_new_tokens=MAX_NEW_TOKENS)

    responses.extend(batch_responses)

print(f"End time: {time.ctime(time.time())}")

In [None]:
references = [ds_test[id]['sql'] for id in range(len(ds_test))]
predictions = [responses[id]['content'] for id in range(len(ds_test))]

scores = [
    evaluate_sql_response(
        reference=reference,
        prediction=prediction,
        sql_context=context
    )
    for reference, prediction, context in tqdm(zip(references, predictions, contexts), total=len(ds_test))
]

In [None]:
print(f"Mean test set score: {np.mean([score['final_score'] for score in scores]):.3f}")