In [1]:
import os, re
import numpy as np
from dotenv import load_dotenv, find_dotenv
from tqdm import tqdm

assert load_dotenv(find_dotenv())

from random import randint
from huggingface_hub import login
import sqlite3

# Login into Hugging Face Hub
hf_token = os.environ['HF_TOKEN'] 
login(hf_token)

from datasets import load_dataset


import torch
from transformers import (pipeline,
    AutoTokenizer, AutoModelForCausalLM,
    AutoModelForImageTextToText, BitsAndBytesConfig
)
from peft import LoraConfig, PeftModel
from trl import SFTConfig, SFTTrainer

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [2]:
# System message for the assistant 
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }  

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted expected answer
idx = randint(0, 10000)
print(dataset["train"][idx]["messages"][1]["content"])


Map:   0%|          | 0/12500 [00:00<?, ? examples/s]

SELECT species_name, depth, habitat FROM (SELECT species_name, depth, habitat, MAX(depth) OVER (PARTITION BY ocean) AS max_depth FROM southern_ocean_depths WHERE ocean = 'Southern Ocean') t WHERE depth = max_depth;


In [3]:
# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2", # Use "eager" when NOT running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template


In [4]:
ds_list = dataset["train"].to_list()

tokens = [
    len(
        tokenizer(
            training_sample["messages"][0]['content']
        )["input_ids"]
    )
    for training_sample in ds_list
]

print(f"Maximum Number of tokens in the training dataset: {max(tokens)}")
print(f"Mean Number of tokens in the training dataset: {np.mean(tokens)}")

Maximum Number of tokens in the training dataset: 560
Mean Number of tokens in the training dataset: 171.2988


In [5]:
for training_sample in ds_list[:2]:
    print(f"Query:\n{training_sample["messages"][0]['content']}")

Query:
Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
CREATE TABLE UnionInfo (UnionName VARCHAR(50), HeadquarterCountry VARCHAR(50), Members VARINT); INSERT INTO UnionInfo (UnionName, HeadquarterCountry, Members) VALUES ('UnionA', 'USA', 70000), ('UnionB', 'Canada', 45000), ('UnionC', 'Mexico', 60000);
</SCHEMA>

<USER_QUERY>
What are the names of unions and their respective headquarters' countries for unions with more than 50,000 members?
</USER_QUERY>

Query:
Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
CREATE TABLE waste_generation (state VARCHAR(20), year INT, quantity INT); CREATE TABLE recycling_rates (state VARCHAR(20), year INT, recycling_rate DECIMAL(5,2)); INSERT INTO waste_generation VALUES ('Texas

In [6]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

In [7]:
args = SFTConfig(
    output_dir="gemma-text-to-sql",         # directory to save and repository id
    max_seq_length=1024,                     # max input sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=3,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

In [8]:
# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)


Converting train dataset to ChatML:   0%|          | 0/10000 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [9]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()


It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `flash_attention_2`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,1.2843
20,0.7353
30,0.6008
40,0.5932
50,0.5542
60,0.535
70,0.5232
80,0.516
90,0.5211
100,0.5059


No files have been modified since last commit. Skipping to prevent empty commit.


In [10]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

In [11]:
model_id_ft = "gemma-text-to-sql"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id_ft,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id_ft)

# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
else:
    model_class = AutoModelForImageTextToText

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="flash_attention_2", # Use "eager" when NOT running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

# Load model and tokenizer
model_og = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer_og = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
pipe_og = pipeline("text-generation", model=model_og, tokenizer=tokenizer_og)

Device set to use cuda:0
Device set to use cuda:0


In [55]:
# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"]))
test_sample = dataset["test"][rand_idx]

user_request = test_sample["messages"][:1]
# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(user_request, tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Convert as test example into a prompt with the Gemma template
stop_token_ids_og = [tokenizer_og.eos_token_id, tokenizer_og.convert_tokens_to_ids("<end_of_turn>")]
prompt_og = pipe_og.tokenizer.apply_chat_template(user_request, tokenize=False, add_generation_prompt=True)
# Generate our SQL query.
outputs_og = pipe_og(prompt_og, max_new_tokens=256, do_sample=True, temperature=0.1, top_k=50, top_p=0.1, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print("\n###################################################################################\n""")
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
print(f"Original Model:\n{outputs_og[0]['generated_text'][len(prompt_og):].strip()}")

Context:
 CREATE TABLE urban_agriculture (country VARCHAR(255), crop VARCHAR(255)); INSERT INTO urban_agriculture (country, crop) VALUES ('Canada', 'Tomatoes'), ('Canada', 'Lettuce'), ('Canada', 'Cucumbers');
Query:
 List all the distinct crops grown in urban agriculture systems in Canada.

###################################################################################

Original Answer:
SELECT DISTINCT crop FROM urban_agriculture WHERE country = 'Canada'
Generated Answer:
SELECT DISTINCT crop FROM urban_agriculture WHERE country = 'Canada';
Original Model:
Given the <USER_QUERY>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<USER_QUERY>
List all the distinct crops grown in urban agriculture systems in Canada.
</USER_QUERY>

<SCHEMA>
CREATE TABLE urban_agriculture (country VARCHAR(255), crop VARCHAR(255)); INSERT INTO urban_agriculture (country, crop) VALUES ('Canada', 'Tomatoes'), ('Canada',

In [14]:
def validate_sql_syntax(cursor, query):
    try:
        cursor.execute(f"EXPLAIN {query}")  # Check if SQL is valid
        return True
    except sqlite3.Error as e:
        return f"❌ Invalid SQL: {e}"

In [56]:
query = "INSERT INTO audience (age, gender) VALUES + 1(35, 'Female');"
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
validate_sql_syntax(cursor, query)

'❌ Invalid SQL: near "+": syntax error'

In [15]:
ds_test_list = dataset["test"].to_list()
errors = 0
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()

for test_sample in ds_test_list:
    context = re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip()
    # Create an in-memory SQLite database
    sql_check = validate_sql_syntax(cursor, context)
    if not isinstance(sql_check, bool):
        errors+=1

In [16]:
print(f"Number of non executable Contexts: {errors} out of {len(ds_test_list)}.")

Number of non executable Contexts: 2086 out of 2500.


In [59]:
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
responses = {
    "model_response": [],
    "syntax_error": []
}
for test_sample in tqdm(ds_test_list):
    user_request = test_sample['messages'][:1]
    user_request[0]['content']
    # Convert as test example into a prompt with the Gemma template
    stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    prompt = pipe.tokenizer.apply_chat_template(user_request, tokenize=False, add_generation_prompt=True)

    # Generate our SQL query.
    outputs = pipe(
        prompt, max_new_tokens=256, do_sample=True, 
        temperature=0.1, top_k=50, top_p=0.1, 
        eos_token_id=stop_token_ids, 
        disable_compile=False
    )

    query = outputs[0]['generated_text'][len(prompt):].strip()
    responses['model_response'].append(query)
    valid_query = validate_sql_syntax(cursor, query)
    if isinstance(valid_query, str) and "syntax error" in valid_query:
        responses['syntax_error'].append(1)
    else:
        responses['syntax_error'].append(0)

100%|██████████| 2500/2500 [3:05:57<00:00,  4.46s/it]  


In [62]:
sum(responses["syntax_error"])

232