Load Spider Dataset

In [2]:
import os
import sqlite3
import json

In [3]:
database_dir_path = "/Users/hannahzhang/Desktop/spider_data"
SCHEMA_FILE = "/Users/hannahzhang/Desktop/spider_data/tables.json"

Spider Dataset Analysis

In [4]:
def list_databases(dir_path):
    """
    List all SQLite database files in the Spider dataset.
    """
    db_files = []
    db_paths = []
    for root, _, files in os.walk(dir_path):
        for file in files:
            if file.endswith(".sqlite"):
                db_files.append(file)
                db_paths.append(os.path.join(root, file))
    return db_files, db_paths

def connect_to_db(db_path):
    """
    Connect to a specific SQLite database.
    """
    if not os.path.exists(db_path):
        raise FileNotFoundError(f"Database file not found: {db_path}")
    
    # Connect to the database
    conn = sqlite3.connect(db_path)
    return conn

def list_tables(conn):
    """
    List all tables in the connected SQLite database.
    """
    query = "SELECT name FROM sqlite_master WHERE type='table';"
    cursor = conn.cursor()
    cursor.execute(query)
    return [row[0] for row in cursor.fetchall()]

def preview_table(conn, table_name, limit=5):
    """
    Preview data from a specific table.
    """
    query = f"SELECT * FROM {table_name} LIMIT {limit};"
    cursor = conn.cursor()
    cursor.execute(query)
    columns = [desc[0] for desc in cursor.description]  # Column names
    rows = cursor.fetchall()  # Data rows
    return columns, rows

In [5]:
db_names, db_paths = list_databases(database_dir_path)
# print("Available Databases:")
# for idx, db_name in enumerate(db_names):
#     print(f"{idx + 1}: {db_name}")

# Select a database to open
db_index = int(input("\nEnter the number of the database to open: ")) - 1
db_path = db_paths[db_index]

# Connect to the database
conn = connect_to_db(db_path)
print(f"\nConnected to: {db_path}")

# List tables in the database
tables = list_tables(conn)
print("# Tables: ", len(tables))
print("\nTables in the database:")
for idx, table in enumerate(tables):
    print(f"{idx + 1}: {table}")

# Select a table to preview
table_index = int(input("\nEnter the number of the table to preview: ")) - 1
table_name = tables[table_index]

# Preview the selected table
print(f"\nPreviewing table: {table_name}")
columns, rows = preview_table(conn, table_name)
print("\nColumns:", columns)
print("\nRows:")
for row in rows:
    print(row)

# Close the connection
conn.close()
print("\nConnection closed.")


Connected to: /Users/hannahzhang/Desktop/spider_data/database/browser_web/browser_web.sqlite
# Tables:  3

Tables in the database:
1: Web_client_accelerator
2: browser
3: accelerator_compatible_browser

Previewing table: accelerator_compatible_browser

Columns: ['accelerator_id', 'browser_id', 'compatible_since_year']

Rows:
(1, 1, 1995)
(1, 2, 1996)
(2, 3, 1996)
(2, 4, 2000)
(3, 1, 2005)

Connection closed.


Schema Formation

In [7]:
def format_schema_for_gpt(schema):
    if not schema:
        return "No schema available"
        
    formatted_schema = []
    db_id = schema.get("db_id", "Unknown")
    formatted_schema.append(f"Database: {db_id}")
    
    table_names = schema.get("table_names_original", []) # original table names stored in the database
    column_names = schema.get("column_names_original", []) # original column names stored in the database
    column_types = schema.get("column_types", []) # data type of each column
    primary_keys = schema.get("primary_keys", []) # primary keys in the database, each number is the index of column_names
    foreign_keys = schema.get("foreign_keys", []) # foreign keys in the database, [3, 8] means column indices in the column_names
    
    # Mapping table index to its columns
    table_columns = {i: [] for i in range(len(table_names))}
    
    for col_idx, (table_idx, column_name) in enumerate(column_names):
        if table_idx != -1:  # Ignore the '*' entry
            table_columns[table_idx].append((column_name, column_types[col_idx]))
    
    for table_idx, table_name in enumerate(table_names):
        formatted_schema.append(f"\nTable: {table_name}")
        formatted_schema.append("Columns:")
        for column_name, column_type in table_columns[table_idx]:
            pk_marker = " [Primary Key]" if column_names.index([table_idx, column_name]) in primary_keys else ""
            formatted_schema.append(f"- {column_name} ({column_type}){pk_marker}")
        
        # Formatting foreign keys
        table_fks = [fk for fk in foreign_keys if fk[0] in range(len(column_names)) and fk[1] in range(len(column_names))]
        if table_fks:
            formatted_schema.append("Foreign Keys:")
            for col_idx1, col_idx2 in table_fks:
                # Get the table names and column names
                table1_idx, col1_name = column_names[col_idx1]
                table2_idx, col2_name = column_names[col_idx2]
                
                table1_name = table_names[table1_idx]
                table2_name = table_names[table2_idx]
                
                # Format the foreign key relationship
                formatted_schema.append(f"- {table1_name}.{col1_name} = {table2_name}.{col2_name}")
    
    return "\n".join(formatted_schema)


# Load the schema
schema_hashmap = {}

try:
    with open(SCHEMA_FILE, 'r') as f:
        schemas = json.load(f)
except Exception as e:
    schemas = []

db_ids = []

# Loop through each schema and extract the db_id
for schema in schemas:
    db_id = schema.get('db_id')
    if db_id:
        db_ids.append(db_id)

print(f"Database IDs:")
print(db_ids)
print("=" * 200)

# Store each formatted schema in the hashmap with db_id as the key
for schema in schemas:
    formatted_schema = format_schema_for_gpt(schema)
    db_id = schema.get("db_id", "Unknown")
    schema_hashmap[db_id] = formatted_schema

# Access the formatted schema of 'perpetrator' database by its db_id
print(f"Formatted schema for 'perpetrator' database:")
print(schema_hashmap.get("perpetrator", "Schema not found"))

Database IDs:
['perpetrator', 'college_2', 'flight_company', 'icfp_1', 'body_builder', 'storm_record', 'pilot_record', 'race_track', 'academic', 'department_store', 'music_4', 'insurance_fnol', 'cinema', 'decoration_competition', 'phone_market', 'store_product', 'assets_maintenance', 'student_assessment', 'dog_kennels', 'music_1', 'company_employee', 'farm', 'solvency_ii', 'city_record', 'swimming', 'flight_2', 'election', 'manufactory_1', 'debate', 'network_2', 'local_govt_in_alabama', 'climbing', 'e_learning', 'scientist_1', 'ship_1', 'entertainment_awards', 'allergy_1', 'imdb', 'products_for_hire', 'candidate_poll', 'chinook_1', 'flight_4', 'pets_1', 'dorm_1', 'journal_committee', 'flight_1', 'medicine_enzyme_interaction', 'local_govt_and_lot', 'station_weather', 'shop_membership', 'driving_school', 'concert_singer', 'music_2', 'sports_competition', 'railway', 'inn_1', 'museum_visit', 'browser_web', 'baseball_1', 'architecture', 'csu_1', 'tracking_orders', 'insurance_policies', 'gas