Dataset archive downloaded from https://drive.google.com/uc?export=download&id=1TqleXec_OykOYFREKKtschzY29dUcVAQ

In [None]:
!unzip ../data/spider.zip

# Dataset

In [None]:
from pathlib import Path
import pandas as pd
import plotly.express as px

In [None]:
dataset_path = Path("../data/spider")

In [None]:
train_spider = pd.read_json(dataset_path / "train_spider.json")
train_others = pd.read_json(dataset_path / "train_others.json")
dev = pd.read_json(dataset_path / "dev.json")
train_spider.head()

In [None]:
print("train_spider: ", len(train_spider))
print("train_others: ", len(train_others))
print("dev: ", len(dev))

In [None]:
for df, name in zip([train_spider, train_others, dev], ["train_spider", "train_others", "dev"]):
    fig = px.histogram([len(x) for x in df.question_toks])
    fig.update_layout(
        showlegend=False,
        xaxis_title="Number of tokens",
        title={
            "text": f"Distribution of the number of tokens in {name} questions",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
    )
    fig.show()

In [None]:
for df, name in zip([train_spider, train_others, dev], ["train_spider", "train_others", "dev"]):
    fig = px.histogram([len(x) for x in df.query_toks])
    fig.update_layout(
        showlegend=False,
        xaxis_title="Number of tokens",
        title={
            "text": f"Distribution of the number of tokens in {name} queries",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
    )
    fig.show()

In [None]:
for df, name in zip([train_spider, train_others, dev], ["train_spider", "train_others", "dev"]):
    fig = px.histogram(df, x="db_id").update_xaxes(categoryorder="category ascending")
    fig.update_layout(
        showlegend=False,
        title={
            "text": f"Databases used in {name}",
            "y": 0.95,
            "x": 0.5,
            "xanchor": "center",
            "yanchor": "top",
        },
    )
    fig.show()

# Databases 

In [None]:
import sqlite3

In [None]:
databases_path = list((dataset_path / "database").glob("*/*.sqlite"))
print("Databases:", len(databases_path))

In [None]:
n_tables = {}
for db_path in databases_path:
    conn = sqlite3.connect(str(db_path))
    cursor = conn.cursor()
    cursor.execute("SELECT count(*) FROM sqlite_master WHERE type='table';")
    n_tables[db_path.name] = cursor.fetchone()[0]

fig = px.histogram(n_tables.values())
fig.update_layout(
    showlegend=False,
    xaxis_title="Number of tables",
    title={
        "text": f"Distribution of the number of tables in the databases",
        "y": 0.95,
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",
    },
)

In [None]:
print(f"Average number of tables: {sum(n_tables.values()) / len(n_tables.values())}")
print(
    f"Max number of tables: {max(n_tables, key=n_tables.get)} with {n_tables[max(n_tables, key=n_tables.get)]} tables"
)

In [None]:
n_columns = []
for db_path in databases_path:
    conn = sqlite3.connect(str(db_path))
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    for table in tables:
        cursor.execute(f"SELECT COUNT(*) FROM pragma_table_info('{table[0]}')")
        n_columns.append(cursor.fetchone()[0])

fig = px.histogram(n_columns)
fig.update_layout(
    showlegend=False,
    xaxis_title="Number of columns",
    title={
        "text": f"Distribution of the number of columns in the databases' tables",
        "y": 0.95,
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",
    },
)