In [12]:
dataset = "dev"
database_folder = (
    "spider_data/database" if dataset == "dev" else "spider_data/test_database"
)
query_file_path = f"spider_data/{dataset}.json"

schema_output_path = f"prepare_data/{dataset}_schemas.csv"
final_result_path = f"prepare_data/{dataset}_input.csv"

### Load full database schemas and tables

In [2]:
import os
import sys

sys.path.append(os.getcwd() + "/M-Schema")

In [3]:
import os
import sqlite3
import pandas as pd
from sqlalchemy import create_engine
from schema_engine import SchemaEngine


def gather_schemas(root_dir: str) -> pd.DataFrame:
    """
    Walk root_dir for .sqlite files. For each, connect and extract
    each table's CREATE statement. Returns a DataFrame with columns:
      - db_id : filename without .sqlite
      - table : table name
      - schema: CREATE TABLE ... statement
    """
    records = []

    for dirpath, _, filenames in os.walk(root_dir):
        for fn in filenames:
            if fn.lower().endswith(".sqlite"):
                path = os.path.join(dirpath, fn)
                db_id = os.path.splitext(fn)[0]
                # connect to sqlite
                conn = sqlite3.connect(path)
                try:
                    cursor = conn.cursor()
                    # get list of tables
                    cursor.execute(
                        "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';"
                    )
                    tables = [row[0] for row in cursor.fetchall()]

                    for table in tables:
                        # get CREATE statement
                        cursor.execute(
                            "SELECT sql FROM sqlite_master WHERE type='table' AND name=?;",
                            (table,),
                        )
                        row = cursor.fetchone()
                        schema = row[0] if row and row[0] else None
                        cursor.execute(f"SELECT * FROM {table} LIMIT 3;")
                        ex_rows = cursor.fetchall()
                        ex_rows_str = f"Table: {table}\n" + "\n".join(
                            [str(row) for row in ex_rows]
                        )
                        schema_engine = SchemaEngine(
                            engine=create_engine(f"sqlite:///{path}"), db_name=db_id
                        )
                        mschema = schema_engine.mschema
                        mschema_str = mschema.to_mschema()
                        records.append(
                            {
                                "db_id": db_id,
                                "table": table,
                                "schema": schema,
                                "example_rows": ex_rows_str,
                                "mschema": mschema_str,
                            }
                        )
                finally:
                    conn.close()

    return pd.DataFrame(records)

In [4]:
schema_df = gather_schemas(database_folder)
schema_df.to_csv(schema_output_path)
schema_df

NameError: name 'gather_schemas' is not defined

In [13]:
import pandas as pd

schema_df = pd.read_csv(schema_output_path)

In [14]:
schemas_map = {}
example_rows_map = {}
mschemas_map = {}
for _, row in schema_df.iterrows():
    schemas_map.setdefault(row.db_id, {})[row.table.lower()] = row.schema
    example_rows_map.setdefault(row.db_id, {})[row.table.lower()] = row.example_rows
    mschemas_map.setdefault(row.db_id, {})[row.table.lower()] = row.mschema

### Parse queries

In [15]:
import re
from typing import List, Set

import sqlglot
from sqlglot import exp


def extract_tables(sql: str) -> List[str]:
    """
    Parse the given SQL and return a sorted list of all
    real table names referenced (joins, subqueries, CTE bodies, etc.),
    but *exclude* any CTE aliases.
    """
    # 1) Pull out the CTE block (if any) and collect its aliases
    cte_names: Set[str] = set()
    match = re.search(
        r"WITH\s+(.*?)\)\s*SELECT",
        sql,
        flags=re.IGNORECASE | re.DOTALL,
    )
    if match:
        cte_block = match.group(1)
        # find all "<name> AS (" inside that block
        found = re.findall(
            r"([A-Za-z_][A-Za-z0-9_]*)\s+AS\b", cte_block, flags=re.IGNORECASE
        )
        cte_names = set(found)

    # 2) Parse into an AST
    try:
        tree = sqlglot.parse_one(sql)
    except sqlglot.errors.ParseError as e:
        raise ValueError(f"Failed to parse SQL: {e}")

    # 3) Walk every Table node and collect its .this (the table identifier),
    #    unless it’s one of the CTE names we just saw.
    tables: Set[str] = set()
    for tbl in tree.find_all(exp.Table):
        name = tbl.name  # e.g. 'sales', 'customers', 'archived_sales'
        if name not in cte_names:
            tables.add(name)

    return sorted(tables)


def lookup_schemas(row):
    db = row.db_id
    tables = row.tables  # list of table names
    db_map = schemas_map.get(db, {})
    # collect schema for each table, skip if not found
    return "\n".join(
        [db_map.get(tbl.lower(), f"<no schema for {tbl}>") for tbl in tables]
    )


def lookup_example_rows(row):
    db = row.db_id
    tables = row.tables
    db_map = example_rows_map.get(db, {})
    return "\n".join(
        [db_map.get(tbl.lower(), f"<no example rows for {tbl}>") for tbl in tables]
    )


def lookup_mschemas(row):
    db = row.db_id
    tables = row.tables
    db_map = mschemas_map.get(db, {})
    return "\n".join(
        [db_map.get(tbl.lower(), f"<no mschema for {tbl}>") for tbl in tables]
    )

In [16]:
WHERE_OPS = (
    "not",
    "between",
    "=",
    ">",
    "<",
    ">=",
    "<=",
    "!=",
    "in",
    "like",
    "is",
    "exists",
)
AGG_OPS = ("none", "max", "min", "count", "sum", "avg")


def has_agg(unit):
    return unit[0] != AGG_OPS.index("none")


def count_agg(units):
    return len([unit for unit in units if has_agg(unit)])


def count_component1(sql):
    count = 0
    if len(sql["where"]) > 0:
        count += 1
    if len(sql["groupBy"]) > 0:
        count += 1
    if len(sql["orderBy"]) > 0:
        count += 1
    if sql["limit"] is not None:
        count += 1
    if len(sql["from"]["table_units"]) > 0:  # JOIN
        count += len(sql["from"]["table_units"]) - 1

    ao = sql["from"]["conds"][1::2] + sql["where"][1::2] + sql["having"][1::2]
    count += len([token for token in ao if token == "or"])
    cond_units = sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]
    count += len(
        [
            cond_unit
            for cond_unit in cond_units
            if cond_unit[1] == WHERE_OPS.index("like")
        ]
    )

    return count


def get_nestedSQL(sql):
    nested = []
    for cond_unit in sql["from"]["conds"][::2] + sql["where"][::2] + sql["having"][::2]:
        if type(cond_unit[3]) is dict:
            nested.append(cond_unit[3])
        if type(cond_unit[4]) is dict:
            nested.append(cond_unit[4])
    if sql["intersect"] is not None:
        nested.append(sql["intersect"])
    if sql["except"] is not None:
        nested.append(sql["except"])
    if sql["union"] is not None:
        nested.append(sql["union"])
    return nested


def count_component2(sql):
    nested = get_nestedSQL(sql)
    return len(nested)


def count_others(sql):
    count = 0
    # number of aggregation
    agg_count = count_agg(sql["select"][1])
    agg_count += count_agg(sql["where"][::2])
    agg_count += count_agg(sql["groupBy"])
    if len(sql["orderBy"]) > 0:
        agg_count += count_agg(
            [unit[1] for unit in sql["orderBy"][1] if unit[1]]
            + [unit[2] for unit in sql["orderBy"][1] if unit[2]]
        )
    agg_count += count_agg(sql["having"])
    if agg_count > 1:
        count += 1

    # number of select columns
    if len(sql["select"][1]) > 1:
        count += 1

    # number of where conditions
    if len(sql["where"]) > 1:
        count += 1

    # number of group by clauses
    if len(sql["groupBy"]) > 1:
        count += 1

    return count


def eval_hardness(sql: dict):
    count_comp1_ = count_component1(sql)
    count_comp2_ = count_component2(sql)
    count_others_ = count_others(sql)

    if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
        return "easy"
    elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or (
        count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0
    ):
        return "medium"
    elif (
        (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0)
        or (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0)
        or (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1)
    ):
        return "hard"
    else:
        return "extra"

In [17]:
import pandas as pd

df = pd.read_json(query_file_path)
df = (
    df.assign(
        question_number=df.index,
        tables=lambda df_: df_["query"].apply(extract_tables),
        hardness=lambda df_: df_["sql"].apply(eval_hardness),
    )
    .assign(
        schemas=lambda df_: df_.apply(lookup_schemas, axis=1),
        example_rows=lambda df_: df_.apply(lookup_example_rows, axis=1),
        mschemas=lambda df_: df_.apply(lookup_mschemas, axis=1),
    )
    .loc[
        :,
        [
            "question_number",
            "question",
            "hardness",
            "db_id",
            "tables",
            "schemas",
            "example_rows",
            "mschemas",
        ],
    ]
)
df

Unnamed: 0,question_number,question,hardness,db_id,tables,schemas,example_rows,mschemas
0,0,How many singers do we have?,easy,concert_singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
1,1,What is the total number of singers?,easy,concert_singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
2,2,"Show name, country, age for all singers ordere...",medium,concert_singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
3,3,"What are the names, countries, and ages for ev...",medium,concert_singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
4,4,"What is the average, minimum, and maximum age ...",medium,concert_singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
...,...,...,...,...,...,...,...,...
1029,1029,What are the citizenships that are shared by s...,hard,singer,[singer],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Liliane Bettencourt', 1944...",【DB_ID】 singer\n【Schema】\n# Table: main.singer...
1030,1030,How many available features are there in total?,easy,real_estate_properties,[Other_Available_Features],CREATE TABLE `Other_Available_Features` (\n`fe...,"Table: Other_Available_Features\n(2, 'Amenity'...",【DB_ID】 real_estate_properties\n【Schema】\n# Ta...
1031,1031,What is the feature type name of feature AirCon?,medium,real_estate_properties,"[Other_Available_Features, Ref_Feature_Types]",CREATE TABLE `Other_Available_Features` (\n`fe...,"Table: Other_Available_Features\n(2, 'Amenity'...",【DB_ID】 real_estate_properties\n【Schema】\n# Ta...
1032,1032,Show the property type descriptions of propert...,medium,real_estate_properties,"[Properties, Ref_Property_Types]",CREATE TABLE `Properties` (\n`property_id` INT...,"Table: Properties\n(1, 'House', '1991-06-21 23...",【DB_ID】 real_estate_properties\n【Schema】\n# Ta...


In [18]:
df.to_csv(final_result_path, index=False)