## **Problem 8: SQL**

# Part 1.

In [1]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub.utils import disable_progress_bars
import re
import os

# Initial environment settings 
os.environ["TORCH_COMPILE_DISABLE"] = "1" 
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "900" 
disable_progress_bars()

DATA_DIR = "spider" 
MODEL_ID = "google/gemma-2b-it"
print(f"Loading tokenizer and model: {MODEL_ID}...")

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                 
    bnb_4bit_use_double_quant=True,    
    bnb_4bit_quant_type="nf4",         
    bnb_4bit_compute_dtype=torch.bfloat16, 
)

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config, 
        device_map="auto"               
    )
    model.eval() 
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    exit()

# Load JSON, SQL files and split SQL into individual queries
def load_json_file(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_sql_file_as_list(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
        content = re.sub(r'--.*$', '', content, flags=re.MULTILINE)
        content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
        queries = [q.strip() for q in content.split(';') if q.strip()]
        return queries

# Extract database schema for prompt creation
def get_db_schema_for_prompt(db_id, tables_data):
    schema_info = ""
    db_found = None
    for db in tables_data: 
        if db['db_id'] == db_id: 
            db_found = db
            break
    
    if db_found:
        schema_info += f"CREATE DATABASE {db_id};\n" 
        for table_idx, table_name_original in enumerate(db_found['table_names_original']): 
            schema_info += f"CREATE TABLE {table_name_original} (\n"
            table_cols = []
            for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']): 
                if col_table_idx == table_idx:
                    col_type = "TEXT" 
                    table_cols.append(f"   {col_name_original} {col_type}")
            
            schema_info += ",\n".join(table_cols)
            schema_info += "\n);\n"
        
        if 'primary_keys' in db_found and db_found['primary_keys']:
            schema_info += "-- Primary Keys:\n"
            for pk_col_idx in db_found['primary_keys']: 
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx] 
                pk_table_name = db_found['table_names_original'][pk_table_idx] 
                schema_info += f"-- {pk_table_name}.{pk_col_name} is PRIMARY KEY\n"
        
        if 'foreign_keys' in db_found and db_found['foreign_keys']:
            schema_info += "-- Foreign Keys:\n"
            for fk_info in db_found['foreign_keys']: 
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx] 
                fk_table_name = db_found['table_names_original'][fk_table_idx] 
                
                ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx] 
                ref_table_name = db_found['table_names_original'][ref_table_idx] 
                schema_info += f"-- {fk_table_name}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}\n"
    return schema_info

# Prompt template
def create_sql_generation_prompt(question, schema):
    prompt = (
        "You are a skilled SQL assistant. Given a SQL database schema and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the schema.\n\n"
        "### Database Schema:\n"
        f"{schema}\n\n"
        "### Question:\n"
        f"{question}\n\n"
        "### SQL Query:\n"
    )
    return prompt

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned_text = generated_text.replace(prompt_text, "").strip()
    cleaned_text = re.sub(r'```sql\s*', '', cleaned_text, flags=re.IGNORECASE)
    cleaned_text = re.sub(r'```\s*', '', cleaned_text, flags=re.IGNORECASE)
    if ';' in cleaned_text:
        cleaned_text = cleaned_text.split(';')[0].strip()
    return cleaned_text

# Load dataset files
try:
    train_spider_data = load_json_file(os.path.join(DATA_DIR, "train_spider.json"))
    train_others_data = load_json_file(os.path.join(DATA_DIR, "train_others.json"))
    all_train_data = train_spider_data + train_others_data
    print(f"Loaded {len(all_train_data)} training samples.")

    train_gold_queries_raw = load_sql_file_as_list(os.path.join(DATA_DIR, "train_gold.sql"))
    print(f"Loaded {len(train_gold_queries_raw)} gold SQL queries.")

    dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
    print(f"Loaded {len(dev_data)} dev samples.")

    tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))
    print(f"Loaded {len(tables_data)} table schemas.")

except FileNotFoundError as e:
    print(e)
    print(f"Please ensure all dataset files are in the specified DATA_DIR: '{DATA_DIR}' folder.")
    exit()

def normalize_sql(sql_query):
    sql_query = re.sub(r'\s+', ' ', sql_query).strip()
    sql_query = sql_query.upper()
    return sql_query

# Exact match comparison
def calculate_exact_match(predicted_sql, gold_sql):
    return 1 if normalize_sql(predicted_sql) == normalize_sql(gold_sql) else 0

# Jaccard similarity
def calculate_jaccard_similarity(predicted_sql, gold_sql):
    set1 = set(normalize_sql(predicted_sql).split())
    set2 = set(normalize_sql(gold_sql).split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    if union == 0:
        return 1.0 
    return intersection / union


# Evaluate the model
exact_matches = 0
jaccard_scores = 0.0
total_processed_samples = 0

num_samples_to_test = len(dev_data) 

for i, sample in enumerate(dev_data[:num_samples_to_test]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query'] 
    
    schema = get_db_schema_for_prompt(db_id, tables_data) 
    prompt = create_sql_generation_prompt(question, schema)
    
    generated_sql = "" 
    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) 
        with torch.no_grad(): 
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100, 
                num_beams=1,        
                do_sample=False,    
                pad_token_id=tokenizer.eos_token_id 
            )
            
        generated_text_full = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(generated_text_full, prompt)


        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js
        total_processed_samples += 1

        if i < 100: 
            print(f"\n--- Sample {i+1} ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"Error processing sample {i} (Question: {question}, DB: {db_id}): {e}")
        
# Final evaluation report
print(f"\nFinal Evaluation Results for Base Model:")
if total_processed_samples > 0:
    avg_em = exact_matches / total_processed_samples
    avg_js = jaccard_scores / total_processed_samples
    print(f"Average Exact Match (EM): {avg_em:.4f}")
    print(f"Average Jaccard Similarity (JS): {avg_js:.4f}")
else:
    print("No samples evaluated.")

print("\nBase model evaluation complete.")

Loading tokenizer and model: google/gemma-2b-it...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded successfully.
Loaded 8659 training samples.
Loaded 2047 gold SQL queries.
Loaded 1034 dev samples.
Loaded 166 table schemas.

--- Sample 1 ---
Database ID: concert_singer
Question: How many singers do we have?
Generated SQL: SELECT COUNT(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 ---
Database ID: concert_singer
Question: What is the total number of singers?
Generated SQL: SELECT COUNT(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 3 ---
Database ID: concert_singer
Question: Show name, country, age for all singers ordered by age from the oldest to the youngest.
Generated SQL: SELECT s.Name, s.Country, s.Age
FROM singer s
ORDER BY s.Age DESC
Gold SQL: SELECT name ,  country ,  age FROM singer ORDER BY age DESC
Exact Match (EM): 0.0000, Jaccard Score (JS): 0.4286

--- Sample 4 ---
Database ID: concert_singer
Question: What are the