In [26]:
database_folder = "spider_data/test_database"
query_file_path = "spider_data/test.json"

schema_output_path = "prepare_data/test_schemas.csv"
final_result_path = "prepare_data/test_input.csv"

### Load full database schemas and tables

In [27]:
import os
import sqlite3
import pandas as pd


def gather_schemas(root_dir: str) -> pd.DataFrame:
    """
    Walk root_dir for .sqlite files. For each, connect and extract
    each table's CREATE statement. Returns a DataFrame with columns:
      - db_id : filename without .sqlite
      - table : table name
      - schema: CREATE TABLE ... statement
    """
    records = []

    for dirpath, _, filenames in os.walk(root_dir):
        for fn in filenames:
            if fn.lower().endswith(".sqlite"):
                path = os.path.join(dirpath, fn)
                db_id = os.path.splitext(fn)[0]
                # connect to sqlite
                conn = sqlite3.connect(path)
                try:
                    cursor = conn.cursor()
                    # get list of tables
                    cursor.execute(
                        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
                    )
                    tables = [row[0] for row in cursor.fetchall()]

                    for table in tables:
                        # get CREATE statement
                        cursor.execute(
                            "SELECT sql FROM sqlite_master WHERE type='table' AND name=?;",
                            (table,),
                        )
                        row = cursor.fetchone()
                        schema = row[0] if row and row[0] else ""
                        records.append(
                            {"db_id": db_id, "table": table, "schema": schema}
                        )
                finally:
                    conn.close()

    return pd.DataFrame(records, columns=["db_id", "table", "schema"])

In [28]:
schema_df = gather_schemas(database_folder)
schema_df

Unnamed: 0,db_id,table,schema
0,browser_web,Web_client_accelerator,"CREATE TABLE ""Web_client_accelerator"" (\n""id"" ..."
1,browser_web,browser,"CREATE TABLE ""browser"" (\n""id"" int,\n""name"" te..."
2,browser_web,accelerator_compatible_browser,"CREATE TABLE ""accelerator_compatible_browser"" ..."
3,musical,musical,"CREATE TABLE ""musical"" (\n""Musical_ID"" int,\n""..."
4,musical,actor,"CREATE TABLE ""actor"" (\n""Actor_ID"" int,\n""Name..."
...,...,...,...
1048,body_builder,people,"CREATE TABLE ""people"" (\n""People_ID"" int,\n""Na..."
1049,school_player,school,"CREATE TABLE ""school"" (\n""School_ID"" int,\n""Sc..."
1050,school_player,school_details,"CREATE TABLE ""school_details"" (\n""School_ID"" i..."
1051,school_player,school_performance,"CREATE TABLE ""school_performance"" (\n""School_I..."


In [29]:
schema_df.to_csv(schema_output_path)

In [30]:
schemas_map = {}
for _, row in schema_df.iterrows():
    schemas_map.setdefault(row.db_id, {})[row.table.lower()] = row.schema
schemas_map

{'browser_web': {'web_client_accelerator': 'CREATE TABLE "Web_client_accelerator" (\n"id" int,\n"name" text,\n"Operating_system" text,\n"Client" text,\n"Connection" text,\nprimary key("id")\n)',
  'browser': 'CREATE TABLE "browser" (\n"id" int,\n"name" text,\n"market_share" real,\nprimary key("id")\n)',
  'accelerator_compatible_browser': 'CREATE TABLE "accelerator_compatible_browser" (\n"accelerator_id" int,\n"browser_id" int,\n"compatible_since_year" int,\nprimary key("accelerator_id", "browser_id"),\nforeign key ("accelerator_id") references `Web_client_accelerator`("id"),\nforeign key ("browser_id") references `browser`("id")\n)'},
 'musical': {'musical': 'CREATE TABLE "musical" (\n"Musical_ID" int,\n"Name" text,\n"Year" int,\n"Award" text,\n"Category" text,\n"Nominee" text,\n"Result" text,\nPRIMARY KEY ("Musical_ID")\n)',
  'actor': 'CREATE TABLE "actor" (\n"Actor_ID" int,\n"Name" text,\n"Musical_ID" int,\n"Character" text,\n"Duration" text,\n"age" int,\nPRIMARY KEY ("Actor_ID"),\

### Parse queries

In [31]:
import re
from typing import List, Set

import sqlglot
from sqlglot import exp


def extract_tables(sql: str) -> List[str]:
    """
    Parse the given SQL and return a sorted list of all
    real table names referenced (joins, subqueries, CTE bodies, etc.),
    but *exclude* any CTE aliases.
    """
    # 1) Pull out the CTE block (if any) and collect its aliases
    cte_names: Set[str] = set()
    match = re.search(
        r"WITH\s+(.*?)\)\s*SELECT",
        sql,
        flags=re.IGNORECASE | re.DOTALL,
    )
    if match:
        cte_block = match.group(1)
        # find all "<name> AS (" inside that block
        found = re.findall(
            r"([A-Za-z_][A-Za-z0-9_]*)\s+AS\b", cte_block, flags=re.IGNORECASE
        )
        cte_names = set(found)

    # 2) Parse into an AST
    try:
        tree = sqlglot.parse_one(sql)
    except sqlglot.errors.ParseError as e:
        raise ValueError(f"Failed to parse SQL: {e}")

    # 3) Walk every Table node and collect its .this (the table identifier),
    #    unless it’s one of the CTE names we just saw.
    tables: Set[str] = set()
    for tbl in tree.find_all(exp.Table):
        name = tbl.name  # e.g. 'sales', 'customers', 'archived_sales'
        if name not in cte_names:
            tables.add(name)

    return sorted(tables)


def lookup_schemas(row):
    db = row.db_id
    tables = row.tables  # list of table names
    db_map = schemas_map.get(db, {})
    # collect schema for each table, skip if not found
    return [db_map.get(tbl.lower(), f"<no schema for {tbl}>") for tbl in tables]

In [32]:
import pandas as pd

df = pd.read_json(query_file_path)
df = (
    df.assign(
        question_number=df.index,
        tables=lambda df_: df_["query"].apply(extract_tables),
    )
    .assign(
        schemas=lambda df_: df_.apply(lookup_schemas, axis=1),
    )
    .loc[:, ["question_number", "question", "db_id", "tables", "schemas"]]
)
df

Unnamed: 0,question_number,question,db_id,tables,schemas
0,0,How many clubs are there?,soccer_3,[club],"[CREATE TABLE ""club"" (\n""Club_ID"" int,\n""Name""..."
1,1,Count the number of clubs.,soccer_3,[club],"[CREATE TABLE ""club"" (\n""Club_ID"" int,\n""Name""..."
2,2,List the name of clubs in ascending alphabetic...,soccer_3,[club],"[CREATE TABLE ""club"" (\n""Club_ID"" int,\n""Name""..."
3,3,"What are the names of clubs, ordered alphabeti...",soccer_3,[club],"[CREATE TABLE ""club"" (\n""Club_ID"" int,\n""Name""..."
4,4,What are the managers and captains of clubs?,soccer_3,[club],"[CREATE TABLE ""club"" (\n""Club_ID"" int,\n""Name""..."
...,...,...,...,...,...
2142,2142,Return the ids and details of staff who have a...,advertising_agencies,[staff],[CREATE TABLE `Staff` (\n`staff_id` INTEGER PR...
2143,2143,"What are the id, sic code and agency id of the...",advertising_agencies,"[clients, invoices, meetings]",[CREATE TABLE `Clients` (\n`client_id` INTEGER...
2144,2144,"Return the ids, sic codes, and agency ids of c...",advertising_agencies,"[clients, invoices, meetings]",[CREATE TABLE `Clients` (\n`client_id` INTEGER...
2145,2145,"List the start time, end time of each meeting,...",advertising_agencies,"[clients, meetings, staff, staff_in_meetings]",[CREATE TABLE `Clients` (\n`client_id` INTEGER...


In [33]:
df.to_csv(final_result_path, index=False)