## **Problem 8: SQL**

# Part 4.

In [1]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub.utils import disable_progress_bars
import re
import os
from sentence_transformers import SentenceTransformer
import faiss 

os.environ["TORCH_COMPILE_DISABLE"] = "1" 
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "900" 
disable_progress_bars()

DATA_DIR = "spider" 
MODEL_ID = "google/gemma-2b-it" 
EMBEDDER_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" 

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                 
    bnb_4bit_use_double_quant=True,    
    bnb_4bit_quant_type="nf4",         
    bnb_4bit_compute_dtype=torch.bfloat16, 
)

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config, 
        device_map="auto"               
    )
    model.eval() 
except Exception as e:
    print(f"Failed to load LLM: {e}")
    exit()

try:
    embedder = SentenceTransformer(EMBEDDER_MODEL_ID)
    embedder.eval() 
except Exception as e:
    print(f"Failed to load Embedder: {e}")
    exit()

def load_json_file(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def load_sql_file_as_list(filepath):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"File not found: {filepath}")
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
        content = re.sub(r'--.*$', '', content, flags=re.MULTILINE)
        content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
        queries = [q.strip() for q in content.split(';') if q.strip()]
        return queries

def get_schema_chunks_for_db(db_id, tables_data, include_pk_fk=True):
    db_found = None
    for db in tables_data:
        if db['db_id'] == db_id:
            db_found = db
            break
    
    if not db_found:
        return [] 

    chunks = []
    for table_idx, table_name_original in enumerate(db_found['table_names_original']):
        table_schema_text = f"CREATE TABLE {table_name_original} (\n"
        table_cols = []
        for col_idx, (col_table_idx, col_name_original) in enumerate(db_found['column_names_original']):
            if col_table_idx == table_idx:
                col_type = "TEXT" 
                table_cols.append(f"   {col_name_original} {col_type}")
        
        table_schema_text += ",\n".join(table_cols)
        table_schema_text += "\n);"
        
        pk_fk_info = ""
        if include_pk_fk:
            for pk_col_idx in db_found.get('primary_keys', []):
                pk_table_idx, pk_col_name = db_found['column_names_original'][pk_col_idx]
                if db_found['table_names_original'][pk_table_idx] == table_name_original:
                    pk_fk_info += f"\n-- {table_name_original}.{pk_col_name} is PRIMARY KEY"
            
            for fk_info in db_found.get('foreign_keys', []):
                fk_col_idx = fk_info[0]
                ref_col_idx = fk_info[1]
                fk_table_idx, fk_col_name = db_found['column_names_original'][fk_col_idx]
                if db_found['table_names_original'][fk_table_idx] == table_name_original:
                    ref_table_idx, ref_col_name = db_found['column_names_original'][ref_col_idx]
                    ref_table_name = db_found['table_names_original'][ref_table_idx]
                    pk_fk_info += f"\n-- {table_name_original}.{fk_col_name} REFERENCES {ref_table_name}.{ref_col_name}"
        
        chunk_text = table_schema_text + pk_fk_info
        chunks.append({
            "text": chunk_text,
            "db_id": db_id,
            "table_name": table_name_original
        })
    return chunks

def create_sql_generation_prompt(question, retrieved_schema_chunks_text):
    prompt = (
        "You are a skilled SQL assistant. Given the following SQL database schema information and a natural language question, generate the correct SQL query.\n"
        "Do not make up table or column names. Only use the ones that exist in the provided schema information.\n\n"
        "### Relevant Database Schema Information:\n" 
        f"{retrieved_schema_chunks_text}\n\n"
        "### Question:\n"
        f"{question}\n\n"
        "### SQL Query:\n"
    )
    return prompt

def extract_sql_from_model_output(generated_text, prompt_text):
    cleaned_text = generated_text.replace(prompt_text, "").strip()
    cleaned_text = re.sub(r'```sql\s*', '', cleaned_text, flags=re.IGNORECASE)
    cleaned_text = re.sub(r'```\s*', '', cleaned_text, flags=re.IGNORECASE)
    if ';' in cleaned_text:
        cleaned_text = cleaned_text.split(';')[0].strip()
    return cleaned_text

def normalize_sql(sql_query):
    sql_query = re.sub(r'\s+', ' ', sql_query).strip()
    sql_query = sql_query.upper()
    return sql_query

def calculate_exact_match(predicted_sql, gold_sql):
    return 1 if normalize_sql(predicted_sql) == normalize_sql(gold_sql) else 0

def calculate_jaccard_similarity(predicted_sql, gold_sql):
    set1 = set(normalize_sql(predicted_sql).split())
    set2 = set(normalize_sql(gold_sql).split())
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    if union == 0:
        return 1.0 
    return intersection / union

try:
    dev_data = load_json_file(os.path.join(DATA_DIR, "dev.json"))
    print(f"Loaded {len(dev_data)} dev samples.")

    tables_data = load_json_file(os.path.join(DATA_DIR, "tables.json"))
    print(f"Loaded {len(tables_data)} table schemas.")

except FileNotFoundError as e:
    print(f"[ERROR] {e}")
    exit()

all_schema_chunks = []
for db_entry in tables_data:
    all_schema_chunks.extend(get_schema_chunks_for_db(db_entry['db_id'], tables_data, include_pk_fk=True)) # Use full schema for indexing

chunk_texts = [chunk['text'] for chunk in all_schema_chunks]
chunk_db_ids = [chunk['db_id'] for chunk in all_schema_chunks]
chunk_table_names = [chunk['table_name'] for chunk in all_schema_chunks]
chunk_embeddings = embedder.encode(chunk_texts, show_progress_bar=True, convert_to_tensor=True)

embedding_dimension = chunk_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dimension) 
index.add(chunk_embeddings.cpu().numpy()) 

def retrieve_relevant_schema_info(question, db_id, top_k_tables=3):
    db_specific_indices = [i for i, id_val in enumerate(chunk_db_ids) if id_val == db_id]
    
    if not db_specific_indices:
        return get_schema_chunks_for_db(db_id, tables_data, include_pk_fk=True)[0]['text'] if get_schema_chunks_for_db(db_id, tables_data, include_pk_fk=True) else ""

    db_specific_embeddings = chunk_embeddings[db_specific_indices].cpu().numpy()
    db_index = faiss.IndexFlatL2(embedding_dimension)
    db_index.add(db_specific_embeddings)
    question_embedding = embedder.encode(question, convert_to_tensor=True).cpu().numpy().reshape(1, -1)
    distances, retrieved_local_indices = db_index.search(question_embedding, top_k_tables)
    retrieved_global_indices = [db_specific_indices[idx] for idx in retrieved_local_indices[0]]
    retrieved_chunks_text = [all_schema_chunks[idx]['text'] for idx in retrieved_global_indices]
    
    return "\n\n".join(retrieved_chunks_text)

exact_matches = 0
jaccard_scores = 0.0
total_processed_samples = 0

num_samples_to_test = len(dev_data) 

for i, sample in enumerate(dev_data[:num_samples_to_test]):
    question = sample['question']
    db_id = sample['db_id']
    gold_sql = sample['query'] 
    
    retrieved_schema_text = retrieve_relevant_schema_info(question, db_id, top_k_tables=3) 
    prompt = create_sql_generation_prompt(question, retrieved_schema_text)
    
    generated_sql = "" 
    try:
        input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) 
        with torch.no_grad(): 
            outputs = model.generate(
                **input_ids,
                max_new_tokens=100, 
                num_beams=1,        
                do_sample=False,    
                pad_token_id=tokenizer.eos_token_id 
            )
            
        generated_text_full = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_sql = extract_sql_from_model_output(generated_text_full, prompt)


        em = calculate_exact_match(generated_sql, gold_sql)
        js = calculate_jaccard_similarity(generated_sql, gold_sql)

        exact_matches += em
        jaccard_scores += js
        total_processed_samples += 1

        if i < 100: 
            print(f"\n--- Sample {i+1} (RAG) ---")
            print(f"Database ID: {db_id}")
            print(f"Question: {question}")
            print(f"Retrieved Schema:\n{retrieved_schema_text[:300]}...") 
            print(f"Generated SQL: {generated_sql}")
            print(f"Gold SQL: {gold_sql}")
            print(f"Exact Match (EM): {em:.4f}, Jaccard Score (JS): {js:.4f}")

    except Exception as e:
        print(f"[ERROR] Processing sample {i} (Question: {question}, DB: {db_id}): {e}")
        
print("\n Final Evaluation Results for Model with RAG:")
if total_processed_samples > 0:
    avg_em = exact_matches / total_processed_samples
    avg_js = jaccard_scores / total_processed_samples
    print(f"Total samples evaluated: {total_processed_samples}")
    print(f"Average Exact Match (EM): {avg_em:.4f}")
    print(f"Average Jaccard Similarity (JS): {avg_js:.4f}")
else:
    print("No samples were successfully evaluated.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded 1034 dev samples.
Loaded 166 table schemas.


Batches:   0%|          | 0/28 [00:00<?, ?it/s]


--- Sample 1 (RAG) ---
Database ID: concert_singer
Question: How many singers do we have?
Retrieved Schema:
CREATE TABLE singer (
   Singer_ID TEXT,
   Name TEXT,
   Country TEXT,
   Song_Name TEXT,
   Song_release_year TEXT,
   Age TEXT,
   Is_male TEXT
);
-- singer.Singer_ID is PRIMARY KEY

CREATE TABLE singer_in_concert (
   concert_ID TEXT,
   Singer_ID TEXT
);
-- singer_in_concert.concert_ID is PRIMA...
Generated SQL: SELECT COUNT(*) FROM singer
Gold SQL: SELECT count(*) FROM singer
Exact Match (EM): 1.0000, Jaccard Score (JS): 1.0000

--- Sample 2 (RAG) ---
Database ID: concert_singer
Question: What is the total number of singers?
Retrieved Schema:
CREATE TABLE singer (
   Singer_ID TEXT,
   Name TEXT,
   Country TEXT,
   Song_Name TEXT,
   Song_release_year TEXT,
   Age TEXT,
   Is_male TEXT
);
-- singer.Singer_ID is PRIMARY KEY

CREATE TABLE singer_in_concert (
   concert_ID TEXT,
   Singer_ID TEXT
);
-- singer_in_concert.concert_ID is PRIMA...
Generated SQL: SELECT COUNT(*) F