In [3]:
from datasets import Dataset, load_dataset
import pandas as pd
import torch
import json
import os


In [4]:
from unsloth import FastLanguageModel

# Sequence length for SQL queries (longer for complex queries with context)
max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    max_seq_length=max_seq_length,
    dtype=None,
    load_in_4bit=True,
)

ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
#### Unsloth: `hf_xet==1.1.10` and `ipykernel>6.30.1` breaks progress bars. Disabling for now in XET.
#### Unsloth: To re-enable progress bars, please downgrade to `ipykernel==6.30.1` or wait for a fix to
https://github.com/huggingface/xet-core/issues/526
INFO 11-03 02:12:20 [__init__.py:216] Automatically detected platform cuda.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.10.7: Fast Gemma3 patching. Transformers: 4.56.2. vLLM: 0.11.0.
   \\   /|    NVIDIA GeForce RTX 4060. Num GPUs = 1. Max memory: 7.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu128. CUDA: 8.9. CUDA Toolkit: 12.8. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA -

In [5]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 42,
    use_rslora = False,
    loftq_config = None,
)


Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients


In [6]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma",
)

In [7]:
# Load text-to-SQL dataset
dataset_raw = load_dataset("gretelai/synthetic_text_to_sql", split="train")

print(f"Dataset: {dataset_raw}")
print(f"Number of examples: {len(dataset_raw)}")
print(f"Dataset features: {dataset_raw.features}")
print("\nFirst example:")
print(dataset_raw[0])

# Use first 1000 samples for training
dataset_raw = dataset_raw.select(range(250))
print(f"\nDataset reduced to: {len(dataset_raw)} samples")



Dataset: Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100000
})
Number of examples: 100000
Dataset features: {'id': Value('int32'), 'domain': Value('string'), 'domain_description': Value('string'), 'sql_complexity': Value('string'), 'sql_complexity_description': Value('string'), 'sql_task_type': Value('string'), 'sql_task_type_description': Value('string'), 'sql_prompt': Value('string'), 'sql_context': Value('string'), 'sql': Value('string'), 'sql_explanation': Value('string')}

First example:
{'id': 5097, 'domain': 'forestry', 'domain_description': 'Comprehensive data on sustainable forest management, timber production, wildlife habitat, and carbon sequestration in forestry.', 'sql_complexity': 'single join', 'sql_complexity_description': 'only one join (specify inner, outer, cross)', 'sql_task_type': 'ana

In [8]:
# Convert to chat format for text-to-SQL
def convert_to_chat_format(item):
    """Convert text-to-SQL data to chat conversation format"""
    
    # Build system message with database context
    system_content = "You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions."
    
    # Build user message with context and question
    user_content = f"Database Context:\n{item['sql_context']}\n\nQuestion: {item['sql_prompt']}"
    
    # Assistant response is the SQL query
    assistant_content = item['sql']
    
    conversation = [
        {
            "role": "system",
            "content": system_content
        },
        {
            "role": "user",
            "content": user_content
        },
        {
            "role": "assistant",
            "content": assistant_content
        }
    ]
    
    return {"conversation": conversation}

# Convert all data to chat format
chat_data = [convert_to_chat_format(item) for item in dataset_raw]

print(f"Converted {len(chat_data)} conversations")
print("\nFirst conversation sample:")
print(json.dumps(chat_data[0], indent=2))



Converted 250 conversations

First conversation sample:
{
  "conversation": [
    {
      "role": "system",
      "content": "You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions."
    },
    {
      "role": "user",
      "content": "Database Context:\nCREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');\n\nQuestion: What is the total volume of timber sold by each salesperson, sorted by salesperson?"
    },
    {
      "role": "assistant",
      "content": "SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JO

In [9]:
# Create Dataset from chat data
dataset = Dataset.from_list(chat_data)

print(f"Dataset size: {len(dataset)}")
print("\nFirst formatted conversation:")
print(json.dumps(dataset[0]['conversation'], indent=2))


Dataset size: 250

First formatted conversation:
[
  {
    "content": "You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions.",
    "role": "system"
  },
  {
    "content": "Database Context:\nCREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');\n\nQuestion: What is the total volume of timber sold by each salesperson, sorted by salesperson?",
    "role": "user"
  },
  {
    "content": "SELECT salesperson_id, name, SUM(volume) as total_volume FROM timber_sales JOIN salesperson ON timber_sales.salesperson_id = salesperson.salesperson_id

In [10]:
# Format conversations with chat template
def formatting_prompts_func(examples):
    convos = examples["conversation"]
    texts = [
        tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False)
        for convo in convos
    ]
    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)
print("Formatted prompt sample:")
print(dataset[0]['text'][:500] + "...")

Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 250/250 [00:00<00:00, 13272.61 examples/s]

Formatted prompt sample:
<bos><start_of_turn>user
You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions. Database Context:
CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_...





In [None]:
from trl import SFTTrainer, SFTConfig

# Training configuration for text-to-SQL
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 50,
        num_train_epochs = 3,
        learning_rate = 2e-5,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 42,
        output_dir = "../models/text-to-sql",
        save_strategy = "epoch",
        save_total_limit = 2,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        report_to = "none",
    ),
)


Unsloth: Tokenizing ["text"] (num_proc=20): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 250/250 [00:11<00:00, 22.29 examples/s]


In [12]:

# Train only on model responses (not user inputs or system prompts)
from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

print("Training dataset preview:")
print(f"Total examples: {len(trainer.train_dataset)}")
print(f"Sample input_ids length: {len(trainer.train_dataset[0]['input_ids'])}")

Map (num_proc=20): 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 250/250 [00:00<00:00, 300.45 examples/s]

Training dataset preview:
Total examples: 250
Sample input_ids length: 272





In [13]:
# Verify training setup - check what parts are being trained
sample_idx = 10
print("Full prompt:")
print(tokenizer.decode(trainer.train_dataset[sample_idx]["input_ids"]))
print("\n" + "="*80 + "\n")
print("Only training on (labels != -100):")
print(tokenizer.decode([x if x != -100 else tokenizer.pad_token_id for x in trainer.train_dataset[sample_idx]["labels"]]).replace(tokenizer.pad_token, ""))

Full prompt:
<bos><bos><start_of_turn>user
You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions. Database Context:
CREATE TABLE farmers_india (id INT, name VARCHAR(255), district_id INT, age INT, income INT); INSERT INTO farmers_india (id, name, district_id, age, income) VALUES (1, 'Farmer A', 1, 45, 50000); CREATE TABLE districts_india (id INT, name VARCHAR(255), state VARCHAR(255)); INSERT INTO districts_india (id, name, state) VALUES (1, 'District A', 'Maharashtra');

Question: What is the average income of farmers in each district in India?<end_of_turn>
<start_of_turn>model
SELECT d.name, AVG(f.income) FROM farmers_india f JOIN districts_india d ON f.district_id = d.id GROUP BY d.name;<end_of_turn>



Only training on (labels != -100):
SELECT d.name, AVG(f.income) FROM farmers_india f JOIN districts_india d ON f.district_id = d.id GROUP BY d.name;<end_of_turn>



In [14]:
# Start training
print("ðŸš€ Starting training...")
print(f"Total steps: ~{len(dataset) * 3 // (2 * 4)} steps (3 epochs, batch_size=2, grad_accum=4)")
print(f"Training on {len(dataset)} text-to-SQL examples")
print("Expected training time: 2-4 hours depending on GPU\n")

trainer_stats = trainer.train()

print("\nâœ… Training completed!")
print(f"Final loss: {trainer_stats.training_loss:.4f}")


ðŸš€ Starting training...
Total steps: ~93 steps (3 epochs, batch_size=2, grad_accum=4)
Training on 250 text-to-SQL examples
Expected training time: 2-4 hours depending on GPU



==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 250 | Num Epochs = 3 | Total steps = 96
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 32,788,480 of 4,332,867,952 (0.76% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss
50,1.1271



âœ… Training completed!
Final loss: 0.7249


In [15]:
# Enable fast inference mode
FastLanguageModel.for_inference(model)

def generate_sql(db_context, question, max_new_tokens=512, temperature=0.3):
    """Generate SQL query from database context and natural language question"""
    
    conversation = [
        {
            "role": "system",
            "content": "You are an expert SQL assistant. Generate accurate SQL queries based on the given database context and natural language questions."
        },
        {
            "role": "user",
            "content": f"Database Context:\n{db_context}\n\nQuestion: {question}"
        }
    ]
    
    prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()
    return response

# Test with sample database contexts
test_cases = [
    {
        "context": "CREATE TABLE employees (id INT, name VARCHAR(50), department VARCHAR(50), salary DECIMAL(10,2)); INSERT INTO employees VALUES (1, 'John', 'IT', 75000), (2, 'Jane', 'HR', 65000);",
        "question": "What is the average salary by department?"
    },
    {
        "context": "CREATE TABLE orders (order_id INT, customer_id INT, total DECIMAL(10,2), order_date DATE); CREATE TABLE customers (customer_id INT, name VARCHAR(50), city VARCHAR(50));",
        "question": "Find all customers from New York who made orders over $1000"
    },
    {
        "context": "CREATE TABLE products (product_id INT, name VARCHAR(100), category VARCHAR(50), price DECIMAL(10,2), stock INT);",
        "question": "Show me all products with low stock (less than 10 units)"
    },
]

print("ðŸ’¬ Testing Fine-tuned Text-to-SQL Model\n" + "="*80 + "\n")

for test in test_cases:
    print(f"DATABASE CONTEXT:\n{test['context']}\n")
    print(f"QUESTION: {test['question']}")
    sql_query = generate_sql(test['context'], test['question'])
    print(f"GENERATED SQL: {sql_query}")
    print("\n" + "-"*80 + "\n")


ðŸ’¬ Testing Fine-tuned Text-to-SQL Model

DATABASE CONTEXT:
CREATE TABLE employees (id INT, name VARCHAR(50), department VARCHAR(50), salary DECIMAL(10,2)); INSERT INTO employees VALUES (1, 'John', 'IT', 75000), (2, 'Jane', 'HR', 65000);

QUESTION: What is the average salary by department?
GENERATED SQL: SELECT department, AVG(salary) FROM employees GROUP BY department;

--------------------------------------------------------------------------------

DATABASE CONTEXT:
CREATE TABLE orders (order_id INT, customer_id INT, total DECIMAL(10,2), order_date DATE); CREATE TABLE customers (customer_id INT, name VARCHAR(50), city VARCHAR(50));

QUESTION: Find all customers from New York who made orders over $1000
GENERATED SQL: SELECT c.name FROM customers c JOIN orders o ON c.customer_id = o.customer_id WHERE c.city = 'New York' AND o.total > 1000;

--------------------------------------------------------------------------------

DATABASE CONTEXT:
CREATE TABLE products (product_id INT, name V

In [16]:
# Save the fine-tuned model
print("ðŸ’¾ Saving fine-tuned model...")

# Save LoRA adapter
model.save_pretrained("../models/text-to-sql-lora")
tokenizer.save_pretrained("../models/text-to-sql-lora")

print("âœ… LoRA adapter saved to ../models/text-to-sql-lora")


ðŸ’¾ Saving fine-tuned model...
âœ… LoRA adapter saved to ../models/text-to-sql-lora
