Imports

In [49]:
import os
import sqlite3
import json
import numpy as np
import nltk
import re
from typing import List, Dict, Tuple
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from sklearn.feature_extraction.text import TfidfVectorizer
from fuzzywuzzy import fuzz
from openai import OpenAI
from dotenv import load_dotenv

Load Spider Dataset

In [50]:
database_dir_path = "/Users/hannahzhang/Desktop/spider_data"
SCHEMA_FILE = "/Users/hannahzhang/Desktop/spider_data/tables.json"

Spider Dataset Analysis

In [51]:
def list_databases(dir_path):
    """
    List all SQLite database files in the Spider dataset.
    """
    db_files = []
    db_paths = []
    for root, _, files in os.walk(dir_path):
        for file in files:
            if file.endswith(".sqlite"):
                db_files.append(file)
                db_paths.append(os.path.join(root, file))
    return db_files, db_paths

def connect_to_db(db_path):
    """
    Connect to a specific SQLite database.
    """
    if not os.path.exists(db_path):
        raise FileNotFoundError(f"Database file not found: {db_path}")
    
    # Connect to the database
    conn = sqlite3.connect(db_path)
    return conn

def list_tables(conn):
    """
    List all tables in the connected SQLite database.
    """
    query = "SELECT name FROM sqlite_master WHERE type='table';"
    cursor = conn.cursor()
    cursor.execute(query)
    return [row[0] for row in cursor.fetchall()]

def preview_table(conn, table_name, limit=5):
    """
    Preview data from a specific table.
    """
    query = f"SELECT * FROM {table_name} LIMIT {limit};"
    cursor = conn.cursor()
    cursor.execute(query)
    columns = [desc[0] for desc in cursor.description]  # Column names
    rows = cursor.fetchall()  # Data rows
    return columns, rows

In [52]:
db_names, db_paths = list_databases(database_dir_path)
# print("Available Databases:")
# for idx, db_name in enumerate(db_names):
#     print(f"{idx + 1}: {db_name}")

# Select a database to open
db_index = int(input("\nEnter the number of the database to open: ")) - 1
db_path = db_paths[db_index]

# Connect to the database
conn = connect_to_db(db_path)
print(f"\nConnected to: {db_path}")

# List tables in the database
tables = list_tables(conn)
print("# Tables: ", len(tables))
print("\nTables in the database:")
for idx, table in enumerate(tables):
    print(f"{idx + 1}: {table}")

# Select a table to preview
table_index = int(input("\nEnter the number of the table to preview: ")) - 1
table_name = tables[table_index]

# Preview the selected table
print(f"\nPreviewing table: {table_name}")
columns, rows = preview_table(conn, table_name)
print("\nColumns:", columns)
print("\nRows:")
for row in rows:
    print(row)

# Close the connection
conn.close()
print("\nConnection closed.")

ValueError: invalid literal for int() with base 10: ''

Value Retrieval

In [53]:
class PSLsh:
    """Locality-Sensitive Hashing implementation for fast approximate nearest neighbor search."""
    def __init__(self, vectors, n_planes=10, n_tables=5, seed: int = 42):
        self.n_planes = n_planes
        self.n_tables = n_tables
        self.hash_tables = [{} for _ in range(n_tables)]
        self.random_planes = []
        
        np.random.seed(seed)
        
        # Generate random planes for each hash table
        for _ in range(n_tables):
            planes = np.random.randn(vectors.shape[1], n_planes)
            self.random_planes.append(planes)
            
        self.num_vectors = vectors.shape[0]
        self.vectors = vectors
        self.build_hash_tables()

    def build_hash_tables(self):
        """Build hash tables from input vectors."""
        for idx in range(self.num_vectors):
            vector = self.vectors[idx].toarray()[0]
            hashes = self.hash_vector(vector)
            for i, h in enumerate(hashes):
                if h not in self.hash_tables[i]:
                    self.hash_tables[i][h] = []
                self.hash_tables[i][h].append(idx)

    def hash_vector(self, vector):
        """Generate hash codes for a vector."""
        hashes = []
        for planes in self.random_planes:
            projections = np.dot(vector, planes)
            hash_code = ''.join(['1' if x > 0 else '0' for x in projections])
            hashes.append(hash_code)
        return hashes

    def query(self, vector):
        """Find candidate nearest neighbors for a query vector."""
        vector = vector.toarray()[0]  # Convert sparse matrix to 1D array
        hashes = self.hash_vector(vector)
        candidates = set()
        for i, h in enumerate(hashes):
            candidates.update(self.hash_tables[i].get(h, []))
        return candidates

In [54]:
with open(SCHEMA_FILE, 'r') as f:
        schemas = json.load(f)

In [55]:
class SpiderValueRetrieval:
    def __init__(self, spider_tables_path: str = SCHEMA_FILE, lsh_seed: int = 42):
        load_dotenv()
        self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))

        # Load all Spider schemas
        print("DEBUG: Loading Spider schemas from:", spider_tables_path)
        with open(spider_tables_path, 'r') as f:
            self.schemas = json.load(f)
        
        # Create a mapping of db_id to schema
        self.db_schemas = {schema['db_id']: schema for schema in self.schemas}
        
        self.lemmatizer = WordNetLemmatizer()
        self.stop_words = set(stopwords.words('english'))
        
        # Initialize database-specific components
        self.column_indices = {}
        self.vectorizers = {}
        self.lsh_indices = {}
        
        # Build indices for each database
        self._build_indices()

    def process_schema(self, question: str, db_id: str) -> str:
        """Process schema with database context."""
        if db_id not in self.db_schemas:
            raise ValueError(f"Unknown database ID: {db_id}")
            
        # Get schema for this database
        schema = self.db_schemas[db_id]
        
        # Get schema relationships
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        # Process question with database context
        results = self.process_question(question, db_id)

    def _get_schema_relationships(self, schema: Dict) -> Tuple[List[str], List[str]]:
        """Extract primary and foreign keys for a specific database."""
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        primary_keys = schema.get('primary_keys', [])
        foreign_keys = schema.get('foreign_keys', [])
        
        # Format primary keys
        formatted_pks = []
        for pk in primary_keys:
            table_idx, col_name = column_names[pk]
            if table_idx != -1:
                table_name = table_names[table_idx]
                formatted_pks.append(f"{table_name.lower()}.{col_name.lower()}")
        
        # Format foreign keys
        formatted_fks = []
        for fk in foreign_keys:
            fk_col = column_names[fk[0]]
            pk_col = column_names[fk[1]]
            fk_table = table_names[fk_col[0]]
            pk_table = table_names[pk_col[0]]
            formatted_fks.append(
                f"{fk_table.lower()}.{fk_col[1].lower()} = {pk_table.lower()}.{pk_col[1].lower()}"
            )
            
        return formatted_pks, formatted_fks

    def _parse_numeric_value(self, word: str) -> str:
        """Parse numeric values from words."""
        if 'billion' in word.lower():
            return '1000000000'
        elif 'million' in word.lower():
            return '1000000'
        return word
    
    def _build_indices(self):
        """Build all necessary indices for all databases."""
        for db_id, schema in self.db_schemas.items():
            # Build column index
            self.column_indices[db_id] = self._build_column_index(schema)
            
            # Build vectorizer and LSH
            terms = self._get_schema_terms(schema)
            vectorizer = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), min_df=1, max_df=0.95)
            term_vectors = vectorizer.fit_transform(terms)
            
            self.vectorizers[db_id] = {
                'vectorizer': vectorizer,
                'terms': terms
            }
            
            self.lsh_indices[db_id] = {
                'lsh': PSLsh(term_vectors, n_planes=10, n_tables=5),
                'vectors': term_vectors
            }

    def _build_column_index(self, schema: Dict) -> Dict:
        """Build column index for a specific database schema."""
        column_index = {}
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        column_types = schema.get('column_types', [])
        cleaned_column_names = schema.get('column_names', [])
        
        for (table_idx, col_name), (table_idx, clean_col_name), col_type in zip(column_names[1:], cleaned_column_names[1:], column_types[1:]):  # Skip first row (*) 
            if table_idx != -1:  # Skip table_idx == -1 which represents '*'
                table_name = table_names[table_idx].lower()
                qualified_name = f"{table_name}.{col_name.lower()}"
                
                column_index[qualified_name] = {
                    'table': table_name,
                    'column': col_name.lower(),
                    'type': col_type,
                    'words': clean_col_name.lower(),
                    'synonyms': self._get_column_synonyms(col_name)
                }
        
        return column_index

    def _split_column_name(self, column_name: str) -> List[str]:
        """Split column name into individual words."""
        words = re.sub('([A-Z][a-z]+)', r' \1', re.sub('([A-Z]+)', r' \1', column_name)).split()
        words.extend(column_name.split('_'))
        return [word.lower() for word in words if word]

    def _get_column_synonyms(self, column_name: str) -> List[str]:
        """Get synonyms for words in column name."""
        words = self._split_column_name(column_name)
        return list(set(words))  # For Spider, we'll just use the words themselves as synonyms

    def _get_schema_terms(self, schema: Dict) -> List[str]:
        """Get all terms from a specific database schema."""
        terms = []
        table_names = schema.get('table_names_original', [])
        column_names = schema.get('column_names_original', [])
        
        for idx, table in enumerate(table_names):
            table = table.lower()
            terms.append(table)
            
            # Add column terms
            table_columns = [(t_idx, col) for t_idx, col in column_names if t_idx == idx]
            for _, column in table_columns:
                terms.append(f"{table}.{column.lower()}")
                
        return terms

    def preprocess_text(self, text: str) -> List[str]:
        """Tokenize and lemmatize input text, removing stop words."""
        if not text:
            return []
            
        try:
            tokens = nltk.word_tokenize(str(text).lower())
            filtered_tokens = [word for word in tokens if word not in self.stop_words and word.isalnum()]
            lemmatized_tokens = [self.lemmatizer.lemmatize(token) for token in filtered_tokens]
            return lemmatized_tokens
        except Exception as e:
            print(f"Error in preprocessing text '{text}': {str(e)}")
            return []

    def _find_similar_words(self, word: str, db_id: str) -> List[Tuple[str, float]]:
        """Find similar words in the database-specific schema."""
        if not word:
            return []

        word = word.lower()
        matches = []
        
        # Direct matching with column names
        column_index = self.column_indices[db_id]
        for qualified_name, metadata in column_index.items():
            score = 0.0
            
            # Check exact matches in column words
            if word in metadata['words']:
                matches.append((qualified_name, 1.0))
                continue
            
            # Fuzzy match with column words
            for col_word in metadata['words']:
                ratio = fuzz.ratio(word, col_word) / 100.0
                if ratio > score:
                    score = ratio
            
            if score > 0.6:
                matches.append((qualified_name, score))

        # LSH-based matching as backup
        if len(matches) < 5:
            try:
                vectorizer = self.vectorizers[db_id]['vectorizer']
                terms = self.vectorizers[db_id]['terms']
                word_vector = vectorizer.transform([word])
                
                lsh = self.lsh_indices[db_id]['lsh']
                vectors = self.lsh_indices[db_id]['vectors']
                
                candidate_indices = lsh.query(word_vector)
                
                for idx in candidate_indices:
                    term = terms[idx]
                    if not any(term == match[0] for match in matches):
                        candidate_vector = vectors[idx].toarray()[0]
                        word_vector_array = word_vector.toarray()[0]
                        dist = np.linalg.norm(word_vector_array - candidate_vector)
                        sim = 1 / (1 + dist)
                        if sim > 0.5:
                            matches.append((term, sim * 0.8))
            except Exception as e:
                print(f"LSH matching failed: {e}")

        matches.sort(key=lambda x: x[1], reverse=True)
        return matches[:5]

    def _extract_keywords(self, question: str, db_id: str) -> Dict:
        """Extract keywords with database-specific context."""
        schema = self.db_schemas[db_id]
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        system_prompt = f"""Given this database schema:
        Tables: {schema['table_names_original']}
        Columns: {schema['column_names_original']}
        
        Primary Keys: {primary_keys}
        Foreign Keys: {foreign_keys}

        Extract relevant keywords, keyphrases, and numerical values from the question in JSON format."""

        response = self.client.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Question: {question}"}
            ],
            functions=[
                {
                    "name": "extract_components",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "keywords": {"type": "array", "items": {"type": "string"}},
                            "keyphrases": {"type": "array", "items": {"type": "string"}},
                            "numerical_values": {"type": "array", "items": {"type": "string"}}
                        },
                        "required": ["keywords"]
                    }
                }
            ],
            function_call={"name": "extract_components"}
        )

        # Access the function_call attribute directly
        function_call = response.choices[0].message.function_call
        arguments = function_call.arguments
        extracted_info = json.loads(arguments)

        # Debugging statement
        # print("Extracted Info:", extracted_info)

        return extracted_info


    def process_question(self, question: str, db_id: str) -> Dict:
        """Process question with database context."""
        # Extract keywords using database-specific schema
        extracted_info = self._extract_keywords(question, db_id)
        
        # Process words
        words = []
        for key in ['keywords', 'keyphrases', 'numerical_values']:
            words.extend(extracted_info.get(key, []))
        
        # Debugging statement
        # print("Words Extracted:", words)

        processed_words = []
        for word in words:
            processed_words.extend(self.preprocess_text(word))
        
        processed_words = list(set(processed_words))
        
        # Debugging statement
        # print("Processed Words:", processed_words)

        # Find similar columns using database-specific indices
        similar_matches = {}
        for word in processed_words:
            similar_matches[word] = self._find_similar_words(word, db_id)
            # Debugging statement
            # print(f"Similar matches for '{word}': {similar_matches[word]}")
        
        return {
            "question": question,
            "extracted_info": extracted_info,
            "processed_words": processed_words,
            "similar_matches": similar_matches,
            "schema_relationships": {
                "primary_keys": self._get_schema_relationships(self.db_schemas[db_id])[0],
                "foreign_keys": self._get_schema_relationships(self.db_schemas[db_id])[1]
            }
        }


    def process_schema(self, question: str, db_id: str) -> str:
        """Process schema with database context."""
        if db_id not in self.db_schemas:
            raise ValueError(f"Unknown database ID: {db_id}")
            
        # Get schema for this database
        schema = self.db_schemas[db_id]
        
        # Get schema relationships
        primary_keys, foreign_keys = self._get_schema_relationships(schema)
        
        # Process question with database context
        results = self.process_question(question, db_id)
        
        table_columns = []
        relevant_primary_keys = []
        relevant_foreign_keys = []
        
        # Use database-specific indices and relationships
        for word, matches in results['similar_matches'].items():
            if matches:
                top_match = matches[0]
                if top_match[1] > 0.7:
                    if word in results['extracted_info'].get('numerical_values', []):
                        # Handle numerical values
                        value = self._parse_numeric_value(word)
                        table_columns.append(f"{top_match[0]} > {value}")
                    else:
                        table_columns.append(top_match[0])
        
        # Debugging statement
        print("Table Columns:", table_columns)

        # Get relevant tables for this database
        tables_needed = set()
        for link in table_columns:
            if '.' in link:
                tables_needed.add(link.split('.')[0].lower())
        
        # Add relevant primary keys
        for pk in primary_keys:
            table = pk.split('.')[0].lower()
            if table in tables_needed:
                relevant_primary_keys.append(pk)
                
        # Add relevant foreign keys
        for fk in foreign_keys:
            tables_in_fk = set(part.split('.')[0].lower() for part in fk.split(' = '))
            if tables_in_fk.intersection(tables_needed):
                relevant_foreign_keys.append(fk)
        
        schema_dict = {
            "table_columns": table_columns,
            "primary_keys": relevant_primary_keys,
            "foreign_keys": relevant_foreign_keys,
            "schema_links": table_columns  # Added for DIN SQL compatibility
        }
        
        # Debugging statement
        print("Schema Dict:", schema_dict)
        
        return str(schema_dict)


In [56]:
import time

def classification_prompt_maker(question, schema_links, db_id, spider_schemas):
    """Create classification prompt with Spider database context."""
    # Get specific database schema
    schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
    if not schema:
        raise ValueError(f"Unknown database ID: {db_id}")
    
    # Format schema info for the prompt
    schema_info = f"Tables: {schema['table_names_original']}\nColumns: {schema['column_names_original']}"
    
    instruction = """Given the database schema:
{schema_info}

Primary Keys: {primary_keys}
Foreign Keys: {foreign_keys}

Classify the question as:
- EASY: no JOIN and no nested queries needed
- NON-NESTED: needs JOIN but no nested queries
- NESTED: needs nested queries

Question: "{question}"
Schema Links: {schema_links}

Let's think step by step:"""

    # Get formatted keys
    def format_keys(schema):
        pks = []
        fks = []
        for pk in schema.get('primary_keys', []):
            table_idx, col_name = schema['column_names_original'][pk]
            if table_idx != -1:
                table_name = schema['table_names_original'][table_idx]
                pks.append(f"{table_name.lower()}.{col_name.lower()}")
        
        for fk in schema.get('foreign_keys', []):
            fk_col = schema['column_names_original'][fk[0]]
            pk_col = schema['column_names_original'][fk[1]]
            fk_table = schema['table_names_original'][fk_col[0]]
            pk_table = schema['table_names_original'][pk_col[0]]
            fks.append(f"{fk_table.lower()}.{fk_col[1].lower()} = {pk_table.lower()}.{pk_col[1].lower()}")
        
        return pks, fks

    primary_keys, foreign_keys = format_keys(schema)
    
    return instruction.format(
        schema_info=schema_info,
        primary_keys=primary_keys,
        foreign_keys=foreign_keys,
        question=question,
        schema_links=schema_links
    )

def process_question_classification(question, schema_links, db_id, spider_schemas):
    """Process question classification with Spider database context."""
    def extract_classification(text):
        print(f"Trying to extract classification from: {text}")
        text = text.upper()
        
        for class_type in ["EASY", "NON-NESTED", "NESTED"]:
            if class_type in text:
                return class_type
                
        patterns = ["LABEL:", "CLASSIFICATION:", "CAN BE CLASSIFIED AS"]
        for pattern in patterns:
            if pattern in text:
                parts = text.split(pattern)
                if len(parts) > 1:
                    result = parts[1].strip().strip('"').strip("'")
                    classification = result.split()[0].strip()
                    if classification in ["EASY", "NON-NESTED", "NESTED"]:
                        return classification
        
        return "NESTED"  # Default fallback

    classification = None
    attempts = 0
    while classification is None and attempts < 3:
        try:
            client = OpenAI()
            response = client.chat.completions.create(
                model="gpt-4",
                messages=[{
                    "role": "user", 
                    "content": classification_prompt_maker(
                        question=question,
                        schema_links=schema_links,
                        db_id=db_id,
                        spider_schemas=spider_schemas
                    )
                }],
                temperature=0.0,
                max_tokens=300
            )
            raw_response = response.choices[0].message.content
            # print("Raw response:", raw_response)
            classification = extract_classification(raw_response)
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            time.sleep(3)
            attempts += 1
    
    final_class = classification if classification else "NESTED"
    return f'"{final_class}"'

def process_question_sql(question, predicted_class, schema_links, db_id, spider_schemas, max_retries=3):
    def extract_sql(text):
        if "SQL:" in text:
            return text.split("SQL:")[-1].strip()
        return text.strip()

    def make_spider_prompt(template_type):
        schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
        if not schema:
            raise ValueError(f"Unknown database ID: {db_id}")

        # Spider-specific examples
        examples = {
            "easy": '''Q: "How many clubs are there?"
Schema_links: [club.id]
SQL: SELECT COUNT(*) FROM club''',
            
            "medium": '''Q: "Show the names of all teams and their leagues."
Schema_links: [team.name, league.name]
A: Let's think step by step. We need to join teams with leagues.
SQL: SELECT team.name, league.name 
FROM team 
JOIN league ON team.league_id = league.id''',
            
            "hard": '''Q: "Find players who scored more goals than average."
Schema_links: [player.name, player.goals]
A: Let's think step by step:
1. Calculate average goals
2. Find players above average
SQL: SELECT name FROM player 
WHERE goals > (SELECT AVG(goals) FROM player)'''
        }

        template = examples[template_type]
        prompt = f"""Database Schema for {db_id}:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
{template}

Generate SQL for:
Question: {question}
Schema_links: {schema_links}
"""
        return prompt

    for attempt in range(max_retries):
        try:
            if '"EASY"' in predicted_class:
                prompt = make_spider_prompt("easy")
            elif '"NON-NESTED"' in predicted_class:
                prompt = make_spider_prompt("medium")
            else:
                prompt = make_spider_prompt("hard")

            client = OpenAI()
            response = client.chat.completions.create(
                model="gpt-4",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.0,
                max_tokens=500
            )
            sql = extract_sql(response.choices[0].message.content)
            break
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(3)
            else:
                sql = "SELECT"
    
    return sql

In [57]:
import time
from openai import OpenAI

def process_question_sql(question, predicted_class, schema_links, db_id, spider_schemas, max_retries=3):
    def extract_sql(text):
        print(f"\nTrying to extract SQL from: {text}")  # Debug print
        if not text:
            return "SELECT"
            
        markers = ["SQL:", "Query:", "QUERY:", "SQL Query:", "Final SQL:"]
        for marker in markers:
            if marker in text:
                parts = text.split(marker)
                if len(parts) > 1:
                    sql = parts[-1].strip()
                    print(f"Found SQL after {marker}: {sql}")  # Debug print
                    return sql
        print("No SQL marker found, returning full text")  # Debug print
        return text.strip()

    schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
    if not schema:
        raise ValueError(f"Unknown database ID: {db_id}")

    if '"EASY"' in predicted_class:
        print("EASY")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(easy_prompt_maker(
                    question=question,
                    schema_links=schema_links,
                    schema=schema
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"
                    
    elif '"NON-NESTED"' in predicted_class:
        print("NON-NESTED")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(medium_prompt_maker(
                    question=question,
                    schema_links=schema_links,
                    schema=schema
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"
                    
    else:
        print("NESTED")
        for attempt in range(max_retries):
            try:
                SQL = GPT4_generation(hard_prompt_maker(
                    question=question,
                    schema_links=schema_links,
                    schema=schema
                ))
                if SQL:
                    SQL = extract_sql(SQL)
                    break
            except Exception as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(3)
                else:
                    SQL = "SELECT"

    return SQL if SQL else "SELECT"

def GPT4_generation(prompt, max_retries=3):
    client = OpenAI()
    for attempt in range(max_retries):
        try:
            response = client.chat.completions.create(
                model="gpt-4", 
                messages=[{
                    "role": "system",
                    "content": "You are a database assistant that generates SQL queries based on questions about any database schema."},
                    {"role": "user", 
                    "content": prompt}],
                n = 1,
                stream = False,
                temperature=0.0,
                max_tokens=600,
                top_p = 1.0,
                frequency_penalty=0.0,
                presence_penalty=0.0
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(3)
            else:
                print("Max retries reached")
                return None
    return None

def easy_prompt_maker(question, schema_links, schema):
    """Create prompt for easy Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "How many clubs are there?"
Schema_links: [club.id]
SQL: SELECT COUNT(*) FROM club

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
SQL:"""
    return prompt

def medium_prompt_maker(question, schema_links, schema):
    """Create prompt for medium (non-nested) Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "Show the names of all teams and their leagues."
Schema_links: [team.name, league.name]
A: Let's think step by step. We need to join teams with leagues.
SQL: SELECT team.name, league.name 
FROM team 
JOIN league ON team.league_id = league.id

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
A: Let's think step by step."""
    return prompt

def hard_prompt_maker(question, schema_links, schema):
    """Create prompt for hard (nested) Spider questions."""
    prompt = f"""Database Schema:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}

Example:
Q: "Find players who scored more goals than average."
Schema_links: [player.name, player.goals]
A: Let's think step by step:
1. Calculate average goals
2. Find players above average
SQL: SELECT name FROM player 
WHERE goals > (SELECT AVG(goals) FROM player)

Now generate SQL for:
Question: {question}
Schema_links: {schema_links}
A: Let's think step by step."""
    return prompt

In [58]:
import time
from openai import OpenAI

def debugger(question: str, sql: str, db_id: str, spider_schemas: list):
    """Create debug prompt with Spider database context."""
    # Get specific database schema
    schema = next((s for s in spider_schemas if s['db_id'] == db_id), None)
    if not schema:
        raise ValueError(f"Unknown database ID: {db_id}")
    
    # Format relationships for better context
    def format_keys(schema):
        pks = []
        fks = []
        for pk in schema.get('primary_keys', []):
            table_idx, col_name = schema['column_names_original'][pk]
            if table_idx != -1:
                table_name = schema['table_names_original'][table_idx]
                pks.append(f"{table_name.lower()}.{col_name.lower()}")
        
        for fk in schema.get('foreign_keys', []):
            fk_col = schema['column_names_original'][fk[0]]
            pk_col = schema['column_names_original'][fk[1]]
            fk_table = schema['table_names_original'][fk_col[0]]
            pk_table = schema['table_names_original'][pk_col[0]]
            fks.append(f"{fk_table.lower()}.{fk_col[1].lower()} = {pk_table.lower()}.{pk_col[1].lower()}")
        
        return pks, fks

    primary_keys, foreign_keys = format_keys(schema)

    instruction = f"""#### Database Schema for {db_id}:
Tables: {schema['table_names_original']}
Columns: {schema['column_names_original']}
Primary Keys: {primary_keys}
Foreign Keys: {foreign_keys}

#### For the given question, check and fix the SQL query based on these rules:
1) Use the correct table and column names from the schema
2) Use proper JOIN conditions based on the foreign key relationships
3) Use DESC and DISTINCT when needed based on the question
4) Ensure GROUP BY statements include necessary columns
5) Verify SELECT statement columns match the question requirements
6) Remove redundant columns from GROUP BY
7) Use GROUP BY on one column when possible

Question: {question}
Original SQL: {sql}

Return the fixed SQL query:"""

    return instruction

def GPT4_debug(prompt):
    """Debug SQL using GPT-4."""
    client = OpenAI()
    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{
                "role": "system",
                "content": "You are a SQL expert focusing on fixing and optimizing SQL queries for SQLite."
            },
            {
                "role": "user",
                "content": prompt
            }],
            temperature=0.0,
            max_tokens=350,
            top_p=1.0,
            frequency_penalty=0.0,
            presence_penalty=0.0,
            stop=["#", ";", "\n\n"]
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"Error in GPT4_debug: {str(e)}")
        return None

def refine_query(question: str, sql: str, db_id: str, spider_schemas: list) -> str:
    """Refine and debug the SQL query."""
    max_attempts = 3
    attempt = 0
    debugged_SQL = None
    
    while debugged_SQL is None and attempt < max_attempts:
        try:
            debug_prompt = debugger(question, sql, db_id, spider_schemas)
            debugged_SQL = GPT4_debug(debug_prompt)
            
            if debugged_SQL:
                # Clean up the response
                debugged_SQL = debugged_SQL.replace("\n", " ").strip()
                
                # Extract just the SQL if there's additional text
                if "sql" in debugged_SQL.lower():
                    parts = debugged_SQL.lower().split("sql")
                    debugged_SQL = parts[-1].strip()
                
                print(f"Refined SQL (Attempt {attempt + 1}):", debugged_SQL)
                return debugged_SQL
            
        except Exception as e:
            print(f"Error in refinement attempt {attempt + 1}: {str(e)}")
            time.sleep(3)
        
        attempt += 1
    
    # If all attempts fail, return original SQL
    return sql

In [59]:
class SpiderTester:

    def log(self, message: str):
        """Log message to file and console."""
        print(message)
        with open(self.output_file, 'a') as f:
            f.write(message + '\n')

    def __init__(self, spider_path: str = '/Users/hannahzhang/Desktop/spider_data', mode='dev'):
        """Initialize Spider tester.
        Args:
            spider_path: Path to Spider dataset
            mode: Which dataset to use ('dev', 'train_spider', or 'train_others')
        """
        self.spider_path = spider_path
        
        # Load schemas
        with open(os.path.join(spider_path, 'tables.json'), 'r') as f:
            self.spider_schemas = json.load(f)
            
        # note you must select correct data file based on mode (else it will say database not found)
        data_files = {
            'dev': 'dev.json',
            'train_spider': 'train_spider.json',
            'train_others': 'train_others.json'
        }
        
        if mode not in data_files:
            raise ValueError(f"Mode must be one of {list(data_files.keys())}")
            
        data_file = data_files[mode]
        print(f"Loading {data_file}...")
        
        # Load questions
        with open(os.path.join(spider_path, data_file), 'r') as f:
            self.test_data = json.load(f)
            
        print(f"Loaded {len(self.test_data)} questions")
            
        # Initialize value retrieval
        self.value_retrieval = SpiderValueRetrieval(
            spider_tables_path=os.path.join(spider_path, 'tables.json')
        )
        
        # Create output file
        self.output_file = os.path.join(spider_path, f'results_{mode}_{time.strftime("%Y%m%d_%H%M%S")}.txt')

    def run_test(self, num_questions: int = None):
        if num_questions is None:
            num_questions = len(self.test_data)
        else:
            num_questions = min(num_questions, len(self.test_data))

        processed = 0
        successful = 0
        
        self.log(f"\nTesting {num_questions} questions")
        
        for idx, test_case in enumerate(self.test_data[:num_questions]):
            question = test_case['question']
            db_id = test_case['db_id']
            
            self.log(f"\nQuestion {idx + 1}/{num_questions}")
            self.log(f"Question: {question}")
            self.log(f"Database: {db_id}")
                
            # try:
                # Process question
            schema_links = self.value_retrieval.process_schema(question, db_id)
            self.log(f"Schema Links: {schema_links}")
            
            # Get classification
            classification = process_question_classification(
                question=question,
                schema_links=schema_links,
                db_id=db_id,
                spider_schemas=self.spider_schemas
            )
            self.log(f"Classification: {classification}")
            
            # Generate SQL
            sql = process_question_sql(
                question=question,
                predicted_class=classification,
                schema_links=schema_links,
                db_id=db_id,
                spider_schemas=self.spider_schemas
            )

            print("Generated sql: ", sql)
            
            self.log(f"Generated SQL: {sql}")
            self.log(f"Ground Truth: {test_case['query']}")
            
            processed += 1
                    
            # except Exception as e:
            #     self.log(f"Error processing question: {str(e)}")
            #     continue

        self.log("\n=== Testing Summary ===")
        self.log(f"Total Questions Processed: {processed}")
        self.log(f"Successful: {successful}")
        if processed > 0:
            self.log(f"Success Rate: {(successful/processed)*100:.2f}%")

def main():
    # Choose which dataset to use HERE (dev, train_spider, or train_others)
    print("Available modes:")
    print("1. dev (1034 questions)")
    print("2. train_spider (7000 questions)")
    print("3. train_others (1659 questions)")
    
    mode = input("Choose mode (dev/train_spider/train_others) [default: dev]: ").strip() or 'dev'
    num_questions = input("How many questions to process? (press Enter for all): ").strip()
    
    tester = SpiderTester(mode=mode)
    tester.run_test(num_questions=int(num_questions) if num_questions else None)

if __name__ == "__main__":
    main()


Available modes:
1. dev (1034 questions)
2. train_spider (7000 questions)
3. train_others (1659 questions)
Loading dev.json...
Loaded 1034 questions
DEBUG: Loading Spider schemas from: /Users/hannahzhang/Desktop/spider_data/tables.json

Testing 1 questions

Question 1/1
Question: How many singers do we have?
Database: concert_singer


RateLimitError: Error code: 429 - {'error': {'message': 'You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.', 'type': 'insufficient_quota', 'param': None, 'code': 'insufficient_quota'}}