## **Problem 8: SQL**

# Part 2.

- Fine-tuning

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from peft import PeftModel
import json
import os
import re

MODEL_ID = "google/gemma-2b-it"
DATA_DIR = "spider"
OUTPUT_DIR = "qlora_output"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Define LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db = next((db for db in tables_data if db['db_id'] == db_id), None)
    if not db:
        return ""

    schema_info += f"CREATE DATABASE {db_id};\n"
    for t_idx, t_name in enumerate(db['table_names_original']):
        schema_info += f"CREATE TABLE {t_name} (\n"
        columns = [
            f"   {col_name} TEXT"
            for col_idx, (tbl_idx, col_name) in enumerate(db['column_names_original'])
            if tbl_idx == t_idx
        ]
        schema_info += ",\n".join(columns) + "\n);\n"
    return schema_info

# Prompt template
def create_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        "### SQL Query:\n"
    )

# Load Spider training files and schema definitions
train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
all_train_data_raw = train_spider_data + train_others_data 
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

train_samples = []
for sample in all_train_data_raw: 
    schema = get_db_schema_for_prompt(sample['db_id'], tables_data)
    prompt = create_prompt(sample['question'], schema)
    target = sample['query']
    full_text = prompt + target 
    train_samples.append({"text": full_text})

hf_dataset = Dataset.from_list(train_samples)

# Randomly select a subset of data for faster fine-tuning
SUBSET_RATIO = 0.3 
train_subset_size = int(SUBSET_RATIO * len(hf_dataset))
train_dataset_subset = hf_dataset.shuffle(seed=42).select(range(train_subset_size))

print(f"\n[INFO] Using {len(train_dataset_subset)} samples ({SUBSET_RATIO*100:.0f}%) of the dataset for fine-tuning.")
print(f"Sample formatted text for fine-tuning:\n{train_dataset_subset[0]['text'][:500]}...")


def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256) 

tokenized_dataset = train_dataset_subset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training configuration
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=100,
    learning_rate=2e-4, 
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    fp16=False, 
    bf16=True, 
    save_total_limit=2,
    report_to="none", 
    seed=42, 
    dataloader_num_workers=os.cpu_count() // 2,)

# training phase
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()

trainer.save_model("fine_tuned_gemma_spider_lora")
tokenizer.save_pretrained("fine_tuned_gemma_spider_lora")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493

[INFO] Using 2597 samples (30%) of the dataset for fine-tuning.
Sample formatted text for fine-tuning:
You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.

### Database Schema:
CREATE DATABASE store_1;
CREATE TABLE artists (
   id TEXT,
   name TEXT
);
CREATE TABLE sqlite_sequence (
   name TEXT,
   seq TEXT
);
CREATE TABLE albums (
   id TEXT,
   title TEXT,
   artist_id TEXT
);
CREATE TABLE employees (
   id TEXT,
   last_name TEXT,
   first_name TEXT,
   title TEXT,
   reports_to TEXT,
   birth_date TEXT,
   hire_date TE...


Map:   0%|          | 0/2597 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,2.5053
20,1.5802
30,0.9898
40,0.7345
50,0.5347
60,0.4394
70,0.3327
80,0.2527
90,0.2258
100,0.168


('fine_tuned_gemma_spider_lora\\tokenizer_config.json',
 'fine_tuned_gemma_spider_lora\\special_tokens_map.json',
 'fine_tuned_gemma_spider_lora\\chat_template.jinja',
 'fine_tuned_gemma_spider_lora\\tokenizer.model',
 'fine_tuned_gemma_spider_lora\\added_tokens.json',
 'fine_tuned_gemma_spider_lora\\tokenizer.json')

In [3]:
OFFLOAD_DIR = "./offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
DATA_DIR = "spider"
MODEL_DIR = "qlora_output/checkpoint-489"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapter weights
model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    offload_dir=OFFLOAD_DIR  
)
model.eval()
model = model.to(device)  

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract database schema 
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        f"### SQL Query:\n"
    )

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned = generated_text.replace(prompt_text, "").strip()
    cleaned = re.sub(r'```sql\s*', '', cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r'```\s*', '', cleaned, flags=re.IGNORECASE)
    return cleaned.split(';')[0].strip()

def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql).strip().upper()

# Exact match comparison
def calculate_exact_match(pred, gold):
    return 1 if normalize_sql(pred) == normalize_sql(gold) else 0

# Jaccard similarity
def calculate_jaccard_similarity(pred, gold):
    s1, s2 = set(normalize_sql(pred).split()), set(normalize_sql(gold).split())
    return len(s1 & s2) / len(s1 | s2) if s1 | s2 else 1.0

# Load evaluation data
dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

# Evaluate the model
exact_matches = 0
jaccard_scores = 0
n_samples = len(dev_data) 

for i, sample in enumerate(dev_data[:n_samples]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query']
    schema = get_db_schema_for_prompt(db_id, tables_data)
    prompt = create_sql_generation_prompt(question, schema)

    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100,
                num_beams=1,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(gen_text, prompt)

        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error on sample {i}: {e}")

# Final evaluation report
print("\nEvaluation Results (Fine-tuned Model):")
print(f"Samples evaluated: {n_samples}")
print(f"Average Exact Match (EM): {exact_matches / n_samples:.4f}")
print(f"Average Jaccard Similarity (JS): {jaccard_scores / n_samples:.4f}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 4 ---
Database ID: concert_singer
Question: What are the names, countries, and ages for every singer in descending order of age?
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age D

- Warmup step = 50

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from peft import PeftModel
import json
import os
import re

MODEL_ID = "google/gemma-2b-it"
DATA_DIR = "spider"
OUTPUT_DIR = "qlora_output"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Define LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db = next((db for db in tables_data if db['db_id'] == db_id), None)
    if not db:
        return ""

    schema_info += f"CREATE DATABASE {db_id};\n"
    for t_idx, t_name in enumerate(db['table_names_original']):
        schema_info += f"CREATE TABLE {t_name} (\n"
        columns = [
            f"   {col_name} TEXT"
            for col_idx, (tbl_idx, col_name) in enumerate(db['column_names_original'])
            if tbl_idx == t_idx
        ]
        schema_info += ",\n".join(columns) + "\n);\n"
    return schema_info

# Prompt template
def create_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        "### SQL Query:\n"
    )

# Load Spider training files and schema definitions
train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
all_train_data_raw = train_spider_data + train_others_data 
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

train_samples = []
for sample in all_train_data_raw: 
    schema = get_db_schema_for_prompt(sample['db_id'], tables_data)
    prompt = create_prompt(sample['question'], schema)
    target = sample['query']
    full_text = prompt + target 
    train_samples.append({"text": full_text})

hf_dataset = Dataset.from_list(train_samples)

# Randomly select a subset of data for faster fine-tuning
SUBSET_RATIO = 0.3 
train_subset_size = int(SUBSET_RATIO * len(hf_dataset))
train_dataset_subset = hf_dataset.shuffle(seed=42).select(range(train_subset_size))

print(f"\n[INFO] Using {len(train_dataset_subset)} samples ({SUBSET_RATIO*100:.0f}%) of the dataset for fine-tuning.")
print(f"Sample formatted text for fine-tuning:\n{train_dataset_subset[0]['text'][:500]}...")


def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256) 

tokenized_dataset = train_dataset_subset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training configuration
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=50,
    learning_rate=2e-4, 
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    fp16=False, 
    bf16=True, 
    save_total_limit=2,
    report_to="none", 
    seed=42, 
    dataloader_num_workers=os.cpu_count() // 2,)

# training phase
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()

trainer.save_model("fine_tuned_gemma_spider_lora")
tokenizer.save_pretrained("fine_tuned_gemma_spider_lora")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493

[INFO] Using 2597 samples (30%) of the dataset for fine-tuning.
Sample formatted text for fine-tuning:
You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.

### Database Schema:
CREATE DATABASE store_1;
CREATE TABLE artists (
   id TEXT,
   name TEXT
);
CREATE TABLE sqlite_sequence (
   name TEXT,
   seq TEXT
);
CREATE TABLE albums (
   id TEXT,
   title TEXT,
   artist_id TEXT
);
CREATE TABLE employees (
   id TEXT,
   last_name TEXT,
   first_name TEXT,
   title TEXT,
   reports_to TEXT,
   birth_date TEXT,
   hire_date TE...


Map:   0%|          | 0/2597 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,2.3838
20,1.2389
30,0.8108
40,0.5788
50,0.3846
60,0.3179
70,0.2388
80,0.193
90,0.1866
100,0.1379


('fine_tuned_gemma_spider_lora\\tokenizer_config.json',
 'fine_tuned_gemma_spider_lora\\special_tokens_map.json',
 'fine_tuned_gemma_spider_lora\\chat_template.jinja',
 'fine_tuned_gemma_spider_lora\\tokenizer.model',
 'fine_tuned_gemma_spider_lora\\added_tokens.json',
 'fine_tuned_gemma_spider_lora\\tokenizer.json')

In [2]:
OFFLOAD_DIR = "./offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
DATA_DIR = "spider"
MODEL_DIR = "qlora_output/checkpoint-489"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapter weights
model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    offload_dir=OFFLOAD_DIR  
)
model.eval()
model = model.to(device)  

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract database schema 
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        f"### SQL Query:\n"
    )

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned = generated_text.replace(prompt_text, "").strip()
    cleaned = re.sub(r'```sql\s*', '', cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r'```\s*', '', cleaned, flags=re.IGNORECASE)
    return cleaned.split(';')[0].strip()

def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql).strip().upper()

# Exact match comparison
def calculate_exact_match(pred, gold):
    return 1 if normalize_sql(pred) == normalize_sql(gold) else 0

# Jaccard similarity
def calculate_jaccard_similarity(pred, gold):
    s1, s2 = set(normalize_sql(pred).split()), set(normalize_sql(gold).split())
    return len(s1 & s2) / len(s1 | s2) if s1 | s2 else 1.0

# Load evaluation data
dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

# Evaluate the model
exact_matches = 0
jaccard_scores = 0
n_samples = len(dev_data) 

for i, sample in enumerate(dev_data[:n_samples]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query']
    schema = get_db_schema_for_prompt(db_id, tables_data)
    prompt = create_sql_generation_prompt(question, schema)

    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100,
                num_beams=1,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(gen_text, prompt)

        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error on sample {i}: {e}")

# Final evaluation report
print("\nEvaluation Results (Fine-tuned Model):")
print(f"Samples evaluated: {n_samples}")
print(f"Average Exact Match (EM): {exact_matches / n_samples:.4f}")
print(f"Average Jaccard Similarity (JS): {jaccard_scores / n_samples:.4f}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.



--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 4 ---
Database ID: concert_singer
Question: What are the names, countries, and ages for every singer in descending order of age?
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age D

- Warmup step = 75

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from peft import PeftModel
import json
import os
import re

MODEL_ID = "google/gemma-2b-it"
DATA_DIR = "spider"
OUTPUT_DIR = "qlora_output"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Define LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db = next((db for db in tables_data if db['db_id'] == db_id), None)
    if not db:
        return ""

    schema_info += f"CREATE DATABASE {db_id};\n"
    for t_idx, t_name in enumerate(db['table_names_original']):
        schema_info += f"CREATE TABLE {t_name} (\n"
        columns = [
            f"   {col_name} TEXT"
            for col_idx, (tbl_idx, col_name) in enumerate(db['column_names_original'])
            if tbl_idx == t_idx
        ]
        schema_info += ",\n".join(columns) + "\n);\n"
    return schema_info

# Prompt template
def create_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        "### SQL Query:\n"
    )

# Load Spider training files and schema definitions
train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
all_train_data_raw = train_spider_data + train_others_data 
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

train_samples = []
for sample in all_train_data_raw: 
    schema = get_db_schema_for_prompt(sample['db_id'], tables_data)
    prompt = create_prompt(sample['question'], schema)
    target = sample['query']
    full_text = prompt + target 
    train_samples.append({"text": full_text})

hf_dataset = Dataset.from_list(train_samples)

# Randomly select a subset of data for faster fine-tuning
SUBSET_RATIO = 0.3 
train_subset_size = int(SUBSET_RATIO * len(hf_dataset))
train_dataset_subset = hf_dataset.shuffle(seed=42).select(range(train_subset_size))

print(f"\n[INFO] Using {len(train_dataset_subset)} samples ({SUBSET_RATIO*100:.0f}%) of the dataset for fine-tuning.")
print(f"Sample formatted text for fine-tuning:\n{train_dataset_subset[0]['text'][:500]}...")


def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256) 

tokenized_dataset = train_dataset_subset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training configuration
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=75,
    learning_rate=2e-4, 
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    fp16=False, 
    bf16=True, 
    save_total_limit=2,
    report_to="none", 
    seed=42, 
    dataloader_num_workers=os.cpu_count() // 2,)

# training phase
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()

trainer.save_model("fine_tuned_gemma_spider_lora")
tokenizer.save_pretrained("fine_tuned_gemma_spider_lora")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 9,805,824 || all params: 2,515,978,240 || trainable%: 0.3897420034920493

[INFO] Using 2597 samples (30%) of the dataset for fine-tuning.
Sample formatted text for fine-tuning:
You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.

### Database Schema:
CREATE DATABASE store_1;
CREATE TABLE artists (
   id TEXT,
   name TEXT
);
CREATE TABLE sqlite_sequence (
   name TEXT,
   seq TEXT
);
CREATE TABLE albums (
   id TEXT,
   title TEXT,
   artist_id TEXT
);
CREATE TABLE employees (
   id TEXT,
   last_name TEXT,
   first_name TEXT,
   title TEXT,
   reports_to TEXT,
   birth_date TEXT,
   hire_date TE...


Map:   0%|          | 0/2597 [00:00<?, ? examples/s]

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,1.2034
20,0.5547
30,0.4356
40,0.3069
50,0.2283
60,0.1958
70,0.1611
80,0.1366
90,0.1418
100,0.1151


('fine_tuned_gemma_spider_lora\\tokenizer_config.json',
 'fine_tuned_gemma_spider_lora\\special_tokens_map.json',
 'fine_tuned_gemma_spider_lora\\chat_template.jinja',
 'fine_tuned_gemma_spider_lora\\tokenizer.model',
 'fine_tuned_gemma_spider_lora\\added_tokens.json',
 'fine_tuned_gemma_spider_lora\\tokenizer.json')

In [3]:
OFFLOAD_DIR = "./offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
DATA_DIR = "spider"
MODEL_DIR = "qlora_output/checkpoint-489"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapter weights
model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    offload_dir=OFFLOAD_DIR  
)
model.eval()
model = model.to(device)  

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract database schema 
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        f"### SQL Query:\n"
    )

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned = generated_text.replace(prompt_text, "").strip()
    cleaned = re.sub(r'```sql\s*', '', cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r'```\s*', '', cleaned, flags=re.IGNORECASE)
    return cleaned.split(';')[0].strip()

def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql).strip().upper()

# Exact match comparison
def calculate_exact_match(pred, gold):
    return 1 if normalize_sql(pred) == normalize_sql(gold) else 0

# Jaccard similarity
def calculate_jaccard_similarity(pred, gold):
    s1, s2 = set(normalize_sql(pred).split()), set(normalize_sql(gold).split())
    return len(s1 & s2) / len(s1 | s2) if s1 | s2 else 1.0

# Load evaluation data
dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

# Evaluate the model
exact_matches = 0
jaccard_scores = 0
n_samples = len(dev_data) 

for i, sample in enumerate(dev_data[:n_samples]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query']
    schema = get_db_schema_for_prompt(db_id, tables_data)
    prompt = create_sql_generation_prompt(question, schema)

    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100,
                num_beams=1,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(gen_text, prompt)

        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error on sample {i}: {e}")

# Final evaluation report
print("\nEvaluation Results (Fine-tuned Model):")
print(f"Samples evaluated: {n_samples}")
print(f"Average Exact Match (EM): {exact_matches / n_samples:.4f}")
print(f"Average Jaccard Similarity (JS): {jaccard_scores / n_samples:.4f}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC LIMIT 1
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 0.0000, Jaccard Score (JS): 0.8333

--- Sample 4 ---
Database ID: concert_singer
Question: What are the names, countries, and ages for every singer in descending order of age?
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER 

- Temperature = 0.5

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from peft import PeftModel
import json
import os
import re

MODEL_ID = "google/gemma-2b-it"
DATA_DIR = "spider"
OUTPUT_DIR = "qlora_output"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Define LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db = next((db for db in tables_data if db['db_id'] == db_id), None)
    if not db:
        return ""

    schema_info += f"CREATE DATABASE {db_id};\n"
    for t_idx, t_name in enumerate(db['table_names_original']):
        schema_info += f"CREATE TABLE {t_name} (\n"
        columns = [
            f"   {col_name} TEXT"
            for col_idx, (tbl_idx, col_name) in enumerate(db['column_names_original'])
            if tbl_idx == t_idx
        ]
        schema_info += ",\n".join(columns) + "\n);\n"
    return schema_info

# Prompt template
def create_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        "### SQL Query:\n"
    )

# Load Spider training files and schema definitions
train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
all_train_data_raw = train_spider_data + train_others_data 
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

train_samples = []
for sample in all_train_data_raw: 
    schema = get_db_schema_for_prompt(sample['db_id'], tables_data)
    prompt = create_prompt(sample['question'], schema)
    target = sample['query']
    full_text = prompt + target 
    train_samples.append({"text": full_text})

hf_dataset = Dataset.from_list(train_samples)

# Randomly select a subset of data for faster fine-tuning
SUBSET_RATIO = 0.3 
train_subset_size = int(SUBSET_RATIO * len(hf_dataset))
train_dataset_subset = hf_dataset.shuffle(seed=42).select(range(train_subset_size))

print(f"\n[INFO] Using {len(train_dataset_subset)} samples ({SUBSET_RATIO*100:.0f}%) of the dataset for fine-tuning.")
print(f"Sample formatted text for fine-tuning:\n{train_dataset_subset[0]['text'][:500]}...")


def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256) 

tokenized_dataset = train_dataset_subset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training configuration
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=100,
    learning_rate=2e-4, 
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    fp16=False, 
    bf16=True, 
    save_total_limit=2,
    report_to="none", 
    seed=42, 
    dataloader_num_workers=os.cpu_count() // 2,)

# training phase
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()

trainer.save_model("fine_tuned_gemma_spider_lora")
tokenizer.save_pretrained("fine_tuned_gemma_spider_lora")

In [4]:
OFFLOAD_DIR = "./offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
DATA_DIR = "spider"
MODEL_DIR = "qlora_output/checkpoint-489"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapter weights
model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    offload_dir=OFFLOAD_DIR  
)
model.eval()
model = model.to(device)  

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract database schema 
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        f"### SQL Query:\n"
    )

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned = generated_text.replace(prompt_text, "").strip()
    cleaned = re.sub(r'```sql\s*', '', cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r'```\s*', '', cleaned, flags=re.IGNORECASE)
    return cleaned.split(';')[0].strip()

def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql).strip().upper()

# Exact match comparison
def calculate_exact_match(pred, gold):
    return 1 if normalize_sql(pred) == normalize_sql(gold) else 0

# Jaccard similarity
def calculate_jaccard_similarity(pred, gold):
    s1, s2 = set(normalize_sql(pred).split()), set(normalize_sql(gold).split())
    return len(s1 & s2) / len(s1 | s2) if s1 | s2 else 1.0

# Load evaluation data
dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

# Evaluate the model
exact_matches = 0
jaccard_scores = 0
n_samples = len(dev_data) 

for i, sample in enumerate(dev_data[:n_samples]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query']
    schema = get_db_schema_for_prompt(db_id, tables_data)
    prompt = create_sql_generation_prompt(question, schema)

    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100,
                num_beams=1,
                do_sample=True,
                temperature=0.5, 
                pad_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(gen_text, prompt)

        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error on sample {i}: {e}")

# Final evaluation report
print("\nEvaluation Results (Fine-tuned Model):")
print(f"Samples evaluated: {n_samples}")
print(f"Average Exact Match (EM): {exact_matches / n_samples:.4f}")
print(f"Average Jaccard Similarity (JS): {jaccard_scores / n_samples:.4f}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.



--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 4 ---
Database ID: concert_singer
Question: What are the names, countries, and ages for every singer in descending order of age?
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY Age D

- Temperature = 1.0

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from peft import PeftModel
import json
import os
import re

MODEL_ID = "google/gemma-2b-it"
DATA_DIR = "spider"
OUTPUT_DIR = "qlora_output"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto"
)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
# Define LoRA configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], 
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db = next((db for db in tables_data if db['db_id'] == db_id), None)
    if not db:
        return ""

    schema_info += f"CREATE DATABASE {db_id};\n"
    for t_idx, t_name in enumerate(db['table_names_original']):
        schema_info += f"CREATE TABLE {t_name} (\n"
        columns = [
            f"   {col_name} TEXT"
            for col_idx, (tbl_idx, col_name) in enumerate(db['column_names_original'])
            if tbl_idx == t_idx
        ]
        schema_info += ",\n".join(columns) + "\n);\n"
    return schema_info

# Prompt template
def create_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        "### SQL Query:\n"
    )

# Load Spider training files and schema definitions
train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
all_train_data_raw = train_spider_data + train_others_data 
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

train_samples = []
for sample in all_train_data_raw: 
    schema = get_db_schema_for_prompt(sample['db_id'], tables_data)
    prompt = create_prompt(sample['question'], schema)
    target = sample['query']
    full_text = prompt + target 
    train_samples.append({"text": full_text})

hf_dataset = Dataset.from_list(train_samples)

# Randomly select a subset of data for faster fine-tuning
SUBSET_RATIO = 0.3 
train_subset_size = int(SUBSET_RATIO * len(hf_dataset))
train_dataset_subset = hf_dataset.shuffle(seed=42).select(range(train_subset_size))

print(f"\n[INFO] Using {len(train_dataset_subset)} samples ({SUBSET_RATIO*100:.0f}%) of the dataset for fine-tuning.")
print(f"Sample formatted text for fine-tuning:\n{train_dataset_subset[0]['text'][:500]}...")


def tokenize(example):
    return tokenizer(example["text"], truncation=True, padding="max_length", max_length=256) 

tokenized_dataset = train_dataset_subset.map(tokenize, batched=True, remove_columns=["text"])

# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Define training configuration
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=100,
    learning_rate=2e-4, 
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    fp16=False, 
    bf16=True, 
    save_total_limit=2,
    report_to="none", 
    seed=42, 
    dataloader_num_workers=os.cpu_count() // 2,)

# training phase
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)
trainer.train()

trainer.save_model("fine_tuned_gemma_spider_lora")
tokenizer.save_pretrained("fine_tuned_gemma_spider_lora")

In [6]:
OFFLOAD_DIR = "./offload"
os.makedirs(OFFLOAD_DIR, exist_ok=True)
DATA_DIR = "spider"
MODEL_DIR = "qlora_output/checkpoint-489"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load LoRA adapter weights
model = PeftModel.from_pretrained(
    base_model,
    MODEL_DIR,
    offload_dir=OFFLOAD_DIR  
)
model.eval()
model = model.to(device)  

def load_json_file(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

# Extract database schema 
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    return (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        f"### Database Schema:\n{schema}\n\n"
        f"### Question:\n{question}\n\n"
        f"### SQL Query:\n"
    )

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned = generated_text.replace(prompt_text, "").strip()
    cleaned = re.sub(r'```sql\s*', '', cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r'```\s*', '', cleaned, flags=re.IGNORECASE)
    return cleaned.split(';')[0].strip()

def normalize_sql(sql):
    return re.sub(r'\s+', ' ', sql).strip().upper()

# Exact match comparison
def calculate_exact_match(pred, gold):
    return 1 if normalize_sql(pred) == normalize_sql(gold) else 0

# Jaccard similarity
def calculate_jaccard_similarity(pred, gold):
    s1, s2 = set(normalize_sql(pred).split()), set(normalize_sql(gold).split())
    return len(s1 & s2) / len(s1 | s2) if s1 | s2 else 1.0

# Load evaluation data
dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))

# Evaluate the model
exact_matches = 0
jaccard_scores = 0
n_samples = len(dev_data) 

for i, sample in enumerate(dev_data[:n_samples]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query']
    schema = get_db_schema_for_prompt(db_id, tables_data)
    prompt = create_sql_generation_prompt(question, schema)

    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100,
                num_beams=1,
                do_sample=True,
                temperature=1.0, 
                pad_token_id=tokenizer.eos_token_id
            )
        gen_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(gen_text, prompt)

        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error on sample {i}: {e}")

# Final evaluation report
print("\nEvaluation Results (Fine-tuned Model):")
print(f"Samples evaluated: {n_samples}")
print(f"Average Exact Match (EM): {exact_matches / n_samples:.4f}")
print(f"Average Jaccard Similarity (JS): {jaccard_scores / n_samples:.4f}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT count(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 4 ---
Database ID: concert_singer
Question: What are the names, countries, and ages for every singer in descending order of age?
Generated SQL: SELECT name ,  country ,  age FROM singer ORDER BY age D