In [34]:
import os
import re
import json
import sqlite3
import torch
import faiss
import numpy as np
from transformers import T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

In [14]:
DATABASE_DIR = 'data/evaluation'
TEST_FILE = os.path.join(DATABASE_DIR, 'test_spider.json')
MODEL_PATH = 'model/best_model.pt'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base")

if os.path.exists(MODEL_PATH):
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
    print(f"Loaded fine-tuned model from {MODEL_PATH}")
else:
    print(f"Fine-tuned model not found at {MODEL_PATH}. Using base model.")

model.to(device)
model.eval()

Loaded fine-tuned model from model/best_model.pt


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [6]:
embedding_model = SentenceTransformer('BAAI/bge-small-en-v1.5', device=device)
embedding_size = 384

In [7]:
def extract_schema(db_path, db_id):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [table[0] for table in cursor.fetchall()]
    
    table_schema_text = []
    
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table});")
        columns_info = cursor.fetchall()
        
        primary_keys = [col[1] for col in columns_info if col[5] > 0]
        
        cursor.execute(f"PRAGMA foreign_key_list({table});")
        foreign_keys = cursor.fetchall()
        
        columns = []
        for col in columns_info:
            col_name = col[1]
            col_type = col[2].lower()
            
            pk_marker = " [PK]" if col_name in primary_keys else ""

            fk_info = ""
            for fk in foreign_keys:
                if col_name == fk[3]:
                    fk_table = fk[2]
                    fk_col = fk[4]
                    fk_info = f" [FK -> {fk_table}.{fk_col}]"
                    break
            
            columns.append(f"{col_name} ({col_type}){pk_marker}{fk_info}")
        
        if columns:
            cols_text = ", ".join(columns)
            table_schema_text.append(f"Table: {table} ({cols_text})")
    
    conn.close()
    
    schema_text = "\n".join(table_schema_text)
    return schema_text

In [8]:
index = faiss.IndexFlatL2(embedding_size)
db_ids = []
schema_texts = {}

In [25]:
for db_folder in os.listdir(DATABASE_DIR):
    db_path = os.path.join(DATABASE_DIR, db_folder, f"{db_folder}.sqlite")
    
    if os.path.exists(db_path):
        print(f"Processing database: {db_folder}")

        schema_text = extract_schema(db_path, db_folder)
        schema_texts[db_folder] = schema_text

        embedding = embedding_model.encode([schema_text])[0]
        embedding = np.array([embedding]).astype('float32')

        index.add(embedding)
        db_ids.append(db_folder)

print(f"Indexed {len(db_ids)} databases: {', '.join(db_ids)}")

Processing database: restaurants
Processing database: yelp
Processing database: geo
Processing database: academic
Processing database: imdb
Processing database: scholar
Indexed 6 databases: restaurants, yelp, geo, academic, imdb, scholar


In [26]:
def execute_sql(sql_query, db_path):
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(sql_query)
        results = cursor.fetchall()
        conn.close()
        return True, results
    except Exception as e:
        return False, str(e)

In [27]:
def find_most_similar_schema(question):
    if len(db_ids) == 0:
        return None

    question_embedding = embedding_model.encode([question])[0]
    question_embedding = np.array([question_embedding]).astype('float32')

    D, I = index.search(question_embedding, 1)
    
    if I[0][0] < len(db_ids):
        return db_ids[I[0][0]]
    return None

In [28]:
def generate_sql(question, db_id=None):
    if db_id is None:
        db_id = find_most_similar_schema(question)
        if db_id is None:
            return "No database found", None, ""
    
    schema_text = schema_texts.get(db_id)
    if schema_text is None:
        return f"Database '{db_id}' not found", db_id, ""
    
    input_text = f"translate to SQL: {question} \n{schema_text}"
    
    input_ids = tokenizer(
        input_text,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )

    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return sql_query, db_id, input_text

In [29]:
if not os.path.exists(TEST_FILE):
    print(f"Error: Test file {TEST_FILE} not found.")
else:
    with open(TEST_FILE, 'r') as f:
        temp_test_data = json.load(f)

print(f"Loaded {len(temp_test_data)} test examples.")

Loaded 1659 test examples.


In [30]:
db_dirs = [d for d in os.listdir(DATABASE_DIR) if os.path.isdir(os.path.join(DATABASE_DIR, d))]
db_dirs

['restaurants', 'yelp', 'geo', 'academic', 'imdb', 'scholar']

In [31]:
test_data = []

for each in temp_test_data:
    if each['db_id'] in db_dirs:
        test_data.append(each)

print(f"Loaded {len(test_data)} test examples.")

Loaded 1659 test examples.


In [None]:
def normalize_sql(sql):
    sql = re.sub(r"\s+", " ", sql).strip().lower()
    # Ensure operators are separated by a space
    sql = re.sub(r"(=|<=|>=|<>|<|>)", lambda m: f" {m.group(0)} ", sql)
    sql = re.sub(r"\s+", " ", sql).strip()
    return sql

In [32]:
exact_match = 0
execution_match = 0
execution_true = 0
simple_count = 0
simple_exact_match = 0
simple_execution_match = 0
complex_count = 0
complex_exact_match = 0
complex_execution_match = 0
rag_correct_db = 0

total = len(test_data)

In [None]:
for example in tqdm(test_data, desc="Evaluating"):
    question = example['question']
    true_db_id = example['db_id']
    true_sql = example['query']
    
    if true_db_id not in db_ids:
        print(f"Warning: Database '{true_db_id}' not in index. Skipping example.")
        continue

    pred_sql, pred_db_id, _ = generate_sql(question)

    if pred_db_id == true_db_id:
        rag_correct_db += 1

    if re.search(r'\bjoin\b', true_sql, re.IGNORECASE):
        complex_query = True
        complex_count += 1
    else:
        complex_query = False
        simple_count += 1

    if normalize_sql(pred_sql) == normalize_sql(true_sql):
        exact_match += 1
        if complex_query:
            complex_exact_match += 1
        else:
            simple_exact_match += 1

    db_path = os.path.join(DATABASE_DIR, true_db_id, f"{true_db_id}.sqlite")
    
    _, true_results = execute_sql(true_sql, db_path)
    pred_ok, pred_results = execute_sql(pred_sql, db_path)

    if pred_ok:
        execution_true += 1

        if pred_results == true_results:
            execution_match += 1
            if complex_query:
                complex_execution_match += 1
            else:
                simple_execution_match += 1

In [None]:
def safe_div(n, d):
    return n / d if d else 0

In [None]:
exact_match_acc = safe_div(exact_match, total)
execution_acc = safe_div(execution_match, total)
execution_true_acc = safe_div(execution_true, total)
simple_exact_acc = safe_div(simple_exact_match, simple_count)
simple_exec_acc = safe_div(simple_execution_match, simple_count)
complex_exact_acc = safe_div(complex_exact_match, complex_count)
complex_exec_acc = safe_div(complex_execution_match, complex_count)

In [None]:
print(f"Total Examples: {total}")
print(f"RAG DB Identification Accuracy: {rag_acc:.4f} ({rag_correct_db}/{total})\n")
print(f"Exact Match Accuracy:      {exact_match_acc:.4f}")
print(f"Execution Match Accuracy:  {execution_acc:.4f}")
print(f"Execution True Accuracy:   {execution_true_acc:.4f}\n")
print(f"Simple Count:              {simple_count}")
print(f"  • Simple Exact Match:    {simple_exact_acc:.4f}")
print(f"  • Simple Exec  Match:    {simple_exec_acc:.4f}")
print(f"Complex Count:             {complex_count}")
print(f"  • Complex Exact Match:   {complex_exact_acc:.4f}")
print(f"  • Complex Exec  Match:   {complex_exec_acc:.4f}")