### Dataset Exploration

In [10]:
import pandas as pd
from IPython.display import display, Markdown


def display_report(dfs: dict[str, pd.DataFrame], report_title: str = None):
    """
    Pretty-print multiple DataFrames in a notebook with section headings.

    :param dfs: A dict mapping section titles to their DataFrame.
    :param report_title: Optional top-level title for the report.
    """
    if report_title:
        display(Markdown(f"## {report_title}\n"))
    for section, df in dfs.items():
        display(Markdown(f"### {section}\n"))
        display(
            df.style.set_table_styles(
                [
                    {
                        "selector": "th",
                        "props": [
                            ("text-align", "center"),
                        ],
                    },
                    {"selector": "td", "props": [("padding", "6px")]},
                ]
            )
        )
        display(Markdown("---"))

In [11]:
import pandas as pd


def count_dbs_and_tables(table_file: str, output_file: str):
    df = (
        pd.read_json(table_file)
        .loc[:, ["db_id", "table_names"]]
        .explode("table_names")
        .rename(columns={"table_names": "table"})
        .reset_index(drop=True)
    )
    df.to_csv(output_file, index=False)
    num_dbs = len(df["db_id"].unique())
    num_tables = len(df["table"])
    average_tables_per_db = num_tables / num_dbs
    return num_dbs, num_tables, average_tables_per_db


report = {}
data = []
datasets = ["dev", "test"]
for dataset in datasets:
    db_root_dir = (
        "spider_data/database" if dataset == "dev" else "spider_data/test_database"
    )
    input_file_path = f"prepare_data/{dataset}_input.csv"
    num_dbs, num_tables, average_tables_per_db = count_dbs_and_tables(
        table_file=(
            "spider_data/tables.json"
            if dataset == "dev"
            else "spider_data/test_tables.json"
        ),
        output_file=f"prepare_data/{dataset}_dbs_and_tables.csv",
    )

    solution_df = (
        pd.read_json(f"spider_data/{dataset}.json", orient="records")
        .assign(
            question_number=lambda df_: df_.index,
            solution_query=lambda df_: df_["query"],
        )
        .loc[:, ["question_number", "solution_query"]]
    )
    question_df = (
        pd.read_csv(input_file_path)
        .join(
            solution_df, on="question_number", how="inner", lsuffix="_q", rsuffix="_s"
        )
        .reset_index(drop=True)
        .loc[:, ["question", "hardness", "solution_query"]]
    )

    total_questions = len(question_df)
    easy_questions = len(question_df.query("hardness == 'easy'"))
    medium_questions = len(question_df.query("hardness == 'medium'"))
    hard_questions = len(question_df.query("hardness == 'hard'"))
    extra_questions = len(question_df.query("hardness == 'extra'"))

    data.append(
        {
            "dataset": dataset,
            "num_dbs": num_dbs,
            "num_tables": num_tables,
            "average_tables_per_db": average_tables_per_db,
            "num_questions": total_questions,
            "num_easy_questions": easy_questions,
            "num_medium_questions": medium_questions,
            "num_hard_questions": hard_questions,
            "num_extra_questions": extra_questions,
        }
    )
    if dataset == "test":
        report = {
            "Easy questions": question_df.query("hardness == 'easy'").head(3),
            "Medium questions": question_df.query("hardness == 'medium'").head(3),
            "Hard questions": question_df.query("hardness == 'hard'").head(3),
            "Extra hard questions": question_df.query("hardness == 'extra'").head(3),
        }


report["Overall"] = pd.DataFrame(data)
display_report(report, report_title="Dataset Exploration")

## Dataset Exploration


### Easy questions


Unnamed: 0,question,hardness,solution_query
0,How many clubs are there?,easy,SELECT count(*) FROM club
1,Count the number of clubs.,easy,SELECT count(*) FROM club
2,List the name of clubs in ascending alphabetical order.,easy,SELECT Name FROM club ORDER BY Name ASC


---

### Medium questions


Unnamed: 0,question,hardness,solution_query
4,What are the managers and captains of clubs?,medium,"SELECT Manager , Captain FROM club"
5,Return the managers and captains of all clubs.,medium,"SELECT Manager , Captain FROM club"
10,What is the name of the player with the highest earnings?,medium,SELECT Name FROM player ORDER BY Earnings DESC LIMIT 1


---

### Hard questions


Unnamed: 0,question,hardness,solution_query
14,What is the country of the player with the highest earnings among players that have more than 2 win counts?,hard,SELECT Country FROM player WHERE Wins_count > 2 ORDER BY Earnings DESC LIMIT 1
15,"Of players who have more than 2 wins, what is the country of the player who makes the most?",hard,SELECT Country FROM player WHERE Wins_count > 2 ORDER BY Earnings DESC LIMIT 1
22,Show names of clubs in descending order of average earnings of players belonging.,hard,SELECT T1.Name FROM club AS T1 JOIN player AS T2 ON T1.Club_ID = T2.Club_ID GROUP BY T1.Club_ID ORDER BY avg(T2.Earnings) DESC


---

### Extra hard questions


Unnamed: 0,question,hardness,solution_query
40,"List the id, first name and last name of the customers who both have placed more than 2 orders and have bought at least 3 items.",extra,"SELECT T1.customer_id , T1.customer_first_name , T1.customer_last_name FROM Customers AS T1 JOIN Orders AS T2 ON T1.customer_id = T2.customer_id GROUP BY T1.customer_id HAVING count(*) > 2 INTERSECT SELECT T1.customer_id , T1.customer_first_name , T1.customer_last_name FROM Customers AS T1 JOIN Orders AS T2 ON T1.customer_id = T2.customer_id JOIN Order_items AS T3 ON T2.order_id = T3.order_id GROUP BY T1.customer_id HAVING count(*) >= 3"
41,"What are the ids, first and last names of the customers who have ordered more than twice and have bought at least 3 items?",extra,"SELECT T1.customer_id , T1.customer_first_name , T1.customer_last_name FROM Customers AS T1 JOIN Orders AS T2 ON T1.customer_id = T2.customer_id GROUP BY T1.customer_id HAVING count(*) > 2 INTERSECT SELECT T1.customer_id , T1.customer_first_name , T1.customer_last_name FROM Customers AS T1 JOIN Orders AS T2 ON T1.customer_id = T2.customer_id JOIN Order_items AS T3 ON T2.order_id = T3.order_id GROUP BY T1.customer_id HAVING count(*) >= 3"
46,"Which customers did not make any orders? List the first name, middle initial and last name.",extra,"SELECT customer_first_name , customer_middle_initial , customer_last_name FROM Customers EXCEPT SELECT T1.customer_first_name , T1.customer_middle_initial , T1.customer_last_name FROM Customers AS T1 JOIN Orders AS T2 ON T1.customer_id = T2.customer_id"


---

### Overall


Unnamed: 0,dataset,num_dbs,num_tables,average_tables_per_db,num_questions,num_easy_questions,num_medium_questions,num_hard_questions,num_extra_questions
0,dev,166,876,5.277108,1034,248,446,174,166
1,test,206,1056,5.126214,2147,470,857,463,357


---

### Evaluation

**Prompt Overview**

In [12]:
from langchain_core.prompts import PromptTemplate

common_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of {dialect} to generate an SQL statement that answers the [User Question]. Pay attention to the [Database Schema], only use tables and columns that are in the [Database Schema]. Avoid using any other tables or columns that are not in the [Database Schema].

[User Question]
{user_question}

[Database Schema]
{schema}

[Reference Information]
{example_rows}

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.
"""
)
xiyan_en_prompt_template = PromptTemplate.from_template(
    """You are an expert in {dialect}. Your job is to read and understand the following 【Schema】 description, along with any 【Evidence】, and then use your knowledge of {dialect} to generate an SQL statement that answers the 【Question】.

【Question】
{user_question}

【Schema】
{schema}

【Evidence】
{example_rows}

【Question】
{user_question}
```sql
"""
)

In [16]:
input_df = pd.read_csv("prepare_data/dev_input.csv")
input_df.head(3)

Unnamed: 0,question_number,question,hardness,db_id,tables,schemas,example_rows,mschemas
0,0,How many singers do we have?,easy,concert_singer,['singer'],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
1,1,What is the total number of singers?,easy,concert_singer,['singer'],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...
2,2,"Show name, country, age for all singers ordere...",medium,concert_singer,['singer'],"CREATE TABLE ""singer"" (\n""Singer_ID"" int,\n""Na...","Table: singer\n(1, 'Joe Sharp', 'Netherlands',...",【DB_ID】 concert_singer\n【Schema】\n# Table: mai...


Common prompt

In [20]:
from IPython.display import Markdown


Markdown(
    common_prompt_template.invoke(
        {
            "dialect": "sqlite",
            "user_question": input_df.iloc[0]["question"],
            "schema": input_df.iloc[0]["schemas"],
            "example_rows": input_df.iloc[0]["example_rows"],
        }
    ).text
)

You are an expert in sqlite. Your job is to read and understand the following [Database Schema] description, along with any [Reference Information], and then use your knowledge of sqlite to generate an SQL statement that answers the [User Question]. Pay attention to the [Database Schema], only use tables and columns that are in the [Database Schema]. Avoid using any other tables or columns that are not in the [Database Schema].

[User Question]
How many singers do we have?

[Database Schema]
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Country" text,
"Song_Name" text,
"Song_release_year" text,
"Age" int,
"Is_male" bool,
PRIMARY KEY ("Singer_ID")
)

[Reference Information]
Table: singer
(1, 'Joe Sharp', 'Netherlands', 'You', '1992', 52, 'F')
(2, 'Timbaland', 'United States', 'Dangerous', '2008', 32, 'T')
(3, 'Justin Brown', 'France', 'Hey Oh', '2013', 29, 'T')

ONLY OUTPUT THE SQL STATEMENT, NO OTHER TEXT.


In [21]:
Markdown(
    xiyan_en_prompt_template.invoke(
        {
            "dialect": "sqlite",
            "user_question": input_df.iloc[0]["question"],
            "schema": input_df.iloc[0]["schemas"],
            "example_rows": input_df.iloc[0]["example_rows"],
        }
    ).text
)

You are an expert in sqlite. Your job is to read and understand the following 【Schema】 description, along with any 【Evidence】, and then use your knowledge of sqlite to generate an SQL statement that answers the 【Question】.

【Question】
How many singers do we have?

【Schema】
CREATE TABLE "singer" (
"Singer_ID" int,
"Name" text,
"Country" text,
"Song_Name" text,
"Song_release_year" text,
"Age" int,
"Is_male" bool,
PRIMARY KEY ("Singer_ID")
)

【Evidence】
Table: singer
(1, 'Joe Sharp', 'Netherlands', 'You', '1992', 52, 'F')
(2, 'Timbaland', 'United States', 'Dangerous', '2008', 32, 'T')
(3, 'Justin Brown', 'France', 'Hey Oh', '2013', 29, 'T')

【Question】
How many singers do we have?
```sql


**Evaluation result**

In [19]:
from evaluation import evaluate_sql
import pandas as pd


def run_evaluate(model: str, dataset: str):
    print(f"Evaluating {model} on {dataset}")
    result = evaluate_sql(
        input_file=f"prepare_data/{dataset}_input.csv",
        solution_file=f"spider_data/{dataset}_gold.sql",
        answer_file=f"inference_data/{model}_{dataset}_inf.txt",
        db_root=(
            "spider_data/database" if dataset == "dev" else "spider_data/test_database"
        ),
    )
    result.details_df.to_csv(f"evaluate_data/{model}_{dataset}_eval.csv", index=False)
    return {
        "model": model,
        "dataset": dataset,
        **result.summary_data,
    }

In [21]:
eval_list = [
    {"model": "llama32_3b_base", "dataset": "dev"},
    {"model": "llama32_3b_w_ex", "dataset": "dev"},
    {"model": "llama32_3b_w_re", "dataset": "dev"},
    {"model": "llama32_3b_all", "dataset": "dev"},
    {"model": "xiyansql_7b_16_base", "dataset": "dev"},
    {"model": "xiyansql_7b_16_w_ex", "dataset": "dev"},
    {"model": "xiyansql_7b_16_w_re", "dataset": "dev"},
    {"model": "xiyansql_7b_16_all", "dataset": "dev"},
    {"model": "llama32_3b_all", "dataset": "test"},
    {"model": "xiyansql_7b_16_all", "dataset": "test"},
    {"model": "xiyansql_7b_8_all", "dataset": "test"},
]

data = []
for config in eval_list:
    res = run_evaluate(**config)
    data.append(res)

Evaluating llama32_3b_base on dev
Evaluating llama32_3b_w_ex on dev
Evaluating llama32_3b_w_re on dev
Evaluating llama32_3b_all on dev
Evaluating xiyansql_7b_16_base on dev
Evaluating xiyansql_7b_16_w_ex on dev
Evaluating xiyansql_7b_16_w_re on dev
Evaluating xiyansql_7b_16_all on dev
Evaluating llama32_3b_all on test
Evaluating xiyansql_7b_16_all on test
Evaluating xiyansql_7b_8_all on test


In [22]:
df = pd.DataFrame(data)
df

Unnamed: 0,model,dataset,easy,medium,hard,extra,total
0,llama32_3b_base,dev,213/248 (85.89%),299/446 (67.04%),88/174 (50.57%),51/166 (30.72%),651/1034 (62.96%)
1,llama32_3b_w_ex,dev,208/248 (83.87%),307/446 (68.83%),82/174 (47.13%),58/166 (34.94%),655/1034 (63.35%)
2,llama32_3b_w_re,dev,212/248 (85.48%),309/446 (69.28%),89/174 (51.15%),58/166 (34.94%),668/1034 (64.6%)
3,llama32_3b_all,dev,214/248 (86.29%),320/446 (71.75%),87/174 (50.0%),57/166 (34.34%),678/1034 (65.57%)
4,xiyansql_7b_16_base,dev,240/248 (96.77%),397/446 (89.01%),156/174 (89.66%),127/166 (76.51%),920/1034 (88.97%)
5,xiyansql_7b_16_w_ex,dev,243/248 (97.98%),409/446 (91.7%),157/174 (90.23%),126/166 (75.9%),935/1034 (90.43%)
6,xiyansql_7b_16_w_re,dev,241/248 (97.18%),397/446 (89.01%),158/174 (90.8%),126/166 (75.9%),922/1034 (89.17%)
7,xiyansql_7b_16_all,dev,243/248 (97.98%),411/446 (92.15%),158/174 (90.8%),127/166 (76.51%),939/1034 (90.81%)
8,llama32_3b_all,test,398/470 (84.68%),583/857 (68.03%),269/463 (58.1%),150/357 (42.02%),1400/2147 (65.21%)
9,xiyansql_7b_16_all,test,441/470 (93.83%),709/857 (82.73%),364/463 (78.62%),259/357 (72.55%),1773/2147 (82.58%)


In [23]:
# import sqlglot
# import sqlglot.expressions
# from sqlglot.expressions import Expression, Column


# tree = sqlglot.parse_one(
#     "SELECT a, a as a1, A.a,  max(a), max(a) as max_a, count(*), count(*) as count_a FROM A"
# )
# for expr in tree.expressions[1:3]:
#     if isinstance(expr, Column):
#         print(expr.to_dot())
#     if isinstance(expr, sqlglot.expressions.Alias):
#         print(expr.unalias().to_dot())