In [None]:
# Following the fine-tuning tutorial at https://ai.google.dev/gemma/docs/core/huggingface_text_finetune_qlora

In [1]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install "transformers>=4.51.3"

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.21.0" \
  "peft==0.14.0" \
  protobuf \
  sentencepiece

# COMMENT IN: if you are running on a GPU that supports BF16 data type and flash attn, such as NVIDIA L4 or NVIDIA A100
#% pip install flash-attn

Collecting datasets==3.3.2
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting accelerate==1.4.0
  Downloading accelerate-1.4.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate==0.4.3
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting bitsandbytes==0.45.3
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting trl==0.21.0
  Downloading trl-0.21.0-py3-none-any.whl.metadata (11 kB)
Collecting peft==0.14.0
  Downloading peft-0.14.0-py3-none-any.whl.metadata (13 kB)
Collecting protobuf
  Downloading protobuf-6.32.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00

In [1]:
from google.colab import userdata
from huggingface_hub import login

# Login into Hugging Face Hub - if you don't have a token, sign up. It will be used to download the Gemma-3 model and also for uploading your trained adapter.
hf_token = userdata.get('HF_TOKEN') # If you are running inside a Google Colab
login(hf_token)

In [2]:
from datasets import load_dataset

# System message for the assistant
system_message = """You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA."""

# User prompt that combines the user query and the schema
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>
"""


In [3]:
system_message

'You are a text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.'

In [4]:
user_prompt

"Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\n{context}\n</SCHEMA>\n\n<USER_QUERY>\n{question}\n</USER_QUERY>\n"

In [5]:
def create_conversation(sample):
  return {
    "messages": [
      # {"role": "system", "content": system_message},
      {"role": "user", "content": user_prompt.format(question=sample["sql_prompt"], context=sample["sql_context"])},
      {"role": "assistant", "content": sample["sql"]}
    ]
  }

# Load dataset from the hub
dataset = load_dataset("philschmid/gretel-synthetic-text-to-sql", split="train")
dataset = dataset.shuffle().select(range(12500))

In [6]:
# Convert dataset to OAI messages
dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)
# split dataset into 10,000 training samples and 2,500 test samples
dataset = dataset.train_test_split(test_size=2500/12500)

# Print formatted user prompt
print(dataset["train"][345]["messages"][1]["content"])

Map:   0%|          | 0/12500 [00:00<?, ? examples/s]

SELECT s.service_type, COUNT(*) OVER (PARTITION BY s.service_type) AS count_of_subscribers_with_service_type FROM services s JOIN subscribers sub ON s.service_type = sub.service_type;


In [7]:
dataset['test'][0]

{'messages': [{'content': "Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.\n\n<SCHEMA>\nCREATE TABLE refugee_support (id INT, organization VARCHAR(50), sector VARCHAR(50)); INSERT INTO refugee_support (id, organization, sector) VALUES (1, 'UNHCR', 'Refugee Agency'), (2, 'WFP', 'Food Assistance');\n</SCHEMA>\n\n<USER_QUERY>\nWhich organizations provided support in the 'refugee_support' table?\n</USER_QUERY>\n",
   'role': 'user'},
  {'content': 'SELECT DISTINCT organization FROM refugee_support;',
   'role': 'assistant'}]}

# Let's try to generate SQL with the original model and see what it does.

If there's a memory error, you might need to restart the kernel, but leave the output of the previous cell so you can compare.

In [None]:
"""
Ignore the following error thrown by this block!
AttributeError                            Traceback (most recent call last)
AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig, Gemma3ForCausalLM


In [None]:
model_id = "google/gemma-3-1b-pt"
print(model_id)
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  device_map="auto",
  # torch_dtype=torch_dtype,
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

google/gemma-3-1b-pt


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [None]:
# Load the model and tokenizer into the pipeline
print(model_id)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
# print(f"Generated Answer:\n{outputs[0]['generated_text']}")

Device set to use cuda:0


google/gemma-3-1b-pt
Context:
 CREATE TABLE IF NOT EXISTS cargo (id INT PRIMARY KEY, vessel_name VARCHAR(255), average_speed DECIMAL(5,2)); CREATE TABLE IF NOT EXISTS vessel_safety (id INT PRIMARY KEY, vessel_name VARCHAR(255), safety_inspection_date DATE);
Query:
 Create a view named 'vessel_summary' that contains the vessel_name, average_speed and safety_inspection_date
Original Answer:
CREATE VIEW vessel_summary AS SELECT cargo.vessel_name, cargo.average_speed, vessel_safety.safety_inspection_date FROM cargo INNER JOIN vessel_safety ON cargo.vessel_name = vessel_safety.vessel_name;
Generated Answer:
<?php
$sql = "SELECT vessel_name, average_speed, safety_inspection_date FROM vessel_summary";
$result = mysqli_query($conn, $sql);
$row = mysqli_fetch_array($result);
echo $row['vessel_name'];
echo $row['average_speed'];
echo $row['safety_inspection_date'];
?>
<?php
$sql = "SELECT vessel_name, average_speed, safety_inspection_date FROM vessel_summary WHERE vessel_name = 'A' AND safety_in

The model generates PHP code with embedded SQL queries instead of pure SQL. So let's train a model to teach it to generate SQL instead.

# Let's get read to train our model then!

In [None]:
# Hugging Face model id
model_id = "google/gemma-3-1b-pt" # or `google/gemma-3-4b-pt`, `google/gemma-3-12b-pt`, `google/gemma-3-27b-pt`

# Select model class based on id
if model_id == "google/gemma-3-1b-pt":
    model_class = AutoModelForCausalLM
    # model_class = Gemma3ForCausalLM
else:
    model_class = AutoModelForImageTextToText


In [10]:
# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

In [11]:
# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

In [12]:
# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

In [13]:
from transformers import AutoTokenizer, Gemma3ForCausalLM
# Load model and tokenizer
model = model_class.from_pretrained(model_id, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") # Load the Instruction Tokenizer to use the official Gemma template

config.json:   0%|          | 0.00/880 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [14]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

In [15]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="gemma-text-to-sql-2",         # directory to save and repository id
    max_length=512,                         # max sequence length for model and packing of the dataset
    packing=False,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=1,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=4,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=10,                       # log every 10 steps
    save_strategy="steps",                  # save checkpoint every 500 steps.
    save_steps=500,
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="cosine",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

In [16]:
model

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear4bit(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear4bit(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear4bit(in_features=6912, out_features=1152, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_l

In [17]:
peft_config

LoraConfig(task_type='CAUSAL_LM', peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, inference_mode=False, r=16, target_modules='all-linear', exclude_modules=None, lora_alpha=16, lora_dropout=0.05, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=['lm_head', 'embed_tokens'], init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, eva_config=None, use_dora=False, layer_replication=None, runtime_config=LoraRuntimeConfig(ephemeral_gpu_offload=False), lora_bias=False)

In [18]:
from trl import SFTTrainer

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    peft_config=peft_config,
    processing_class=tokenizer
)

Tokenizing train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/10000 [00:00<?, ? examples/s]

# Start training

Keep an eye on the loss. I observe that the loss stops reducing after less than 1 epoch. The model is saved every 500 steps so it is safe to abort the training as long as you have a checkpoint saved and the loss has stopped reducing.

You could always automate this with early stopping.

In [None]:
# # Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()


# # Save the final model again to the Hugging Face Hub
trainer.save_model()

In [20]:
4# free the memory again
del model
# del trainer
torch.cuda.empty_cache()

The part below helps to save the merged model. But it is very intensive so I skip.

In [None]:
# from peft import PeftModel

# # Load Model base model
# model = model_class.from_pretrained(model_id, low_cpu_mem_usage=True)

# # Merge LoRA and base model and save
# peft_model = PeftModel.from_pretrained(model, args.output_dir)
# merged_model = peft_model.merge_and_unload()
# merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

# processor = AutoTokenizer.from_pretrained(args.output_dir)
# # processor.save_pretrained("merged_model")

# Let's test out our model!

If memory hasn't fully freed up even after deleting the model, you might need to restart the notebook and run a few of the top cells again (HF_hub token, loading the dataset)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, BitsAndBytesConfig, Gemma3ForCausalLM

model_id = "asthagarg/gemma-text-to-sql-2" # replace "asthagarg" with your own id to access your model. If you leave it as-is, you'll access my model instead.
# model_id = "google/gemma-3-1b-pt"
print(model_id)
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  device_map="auto",
  # torch_dtype=torch_dtype,
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [10]:
import torch
from transformers import pipeline

In [12]:
from random import randint
import re

# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(dataset["test"])-1)
test_sample = dataset["test"][rand_idx]

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:1], tokenize=False, add_generation_prompt=True)
# prompt = test_sample["messages"][0]["content"]
# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(model_id)
print(f"Context:\n", re.search(r'<SCHEMA>\n(.*?)\n</SCHEMA>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Query:\n", re.search(r'<USER_QUERY>\n(.*?)\n</USER_QUERY>', test_sample['messages'][0]['content'], re.DOTALL).group(1).strip())
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")
# print(f"Generated Answer:\n{outputs[0]['generated_text']}")

Device set to use cuda:0


asthagarg/gemma-text-to-sql-2
Context:
 CREATE TABLE IF NOT EXISTS cargo (id INT PRIMARY KEY, vessel_name VARCHAR(255), average_speed DECIMAL(5,2)); CREATE TABLE IF NOT EXISTS vessel_safety (id INT PRIMARY KEY, vessel_name VARCHAR(255), safety_inspection_date DATE);
Query:
 Create a view named 'vessel_summary' that contains the vessel_name, average_speed and safety_inspection_date
Original Answer:
CREATE VIEW vessel_summary AS SELECT cargo.vessel_name, cargo.average_speed, vessel_safety.safety_inspection_date FROM cargo INNER JOIN vessel_safety ON cargo.vessel_name = vessel_safety.vessel_name;
Generated Answer:
CREATE VIEW vessel_summary AS SELECT vessel_name, AVG(average_speed) as avg_speed, DATEADD(month, DATEDIFF(month, 0, DATEADD(year, DATEDIFF(year, 0, GETDATE()), 0), DATEDIFF(month, 0, GETDATE())) AS safety_inspection_date FROM cargo GROUP BY vessel_name;


Voila! Our model now generates valid SQL.

As Chip Huyen argues in her book - use RAG for knowledge, finetuning for formatting.

Finetuning with just a small number of samples has helped adapt our model to generate valid SQL.

What's next? Learn to use VLLM so that you can self-host your model for fast, private inference.