In [6]:
from common.common_utils import read_json
from glob import glob
import os

# table.json

In [3]:
train_table = read_json("data/tables.json")
print(len(train_table))

test_table = read_json("data/test_data/tables.json")
print(len(test_table))

166
206


In [9]:
# 看看路径数量是否匹配
train_paths = glob("data/database/*")
train_paths = [p for p in train_paths if os.path.isdir(p)]
print(len(train_paths))

# test
test_paths = glob("data/test_database/*")
test_paths = [p for p in test_paths if os.path.isdir(p)]
print(len(test_paths))

# 为什么要复制一遍？？

166
206


In [28]:
train_dbs = [os.path.basename(p) for p in train_paths]
test_dbs = [os.path.basename(p) for p in test_paths]

# in test not in train
not_in_train = [db for db in test_dbs if db not in train_dbs]
len(not_in_train)

40

In [31]:
print(sorted(not_in_train))

['aan_1', 'address_1', 'advertising_agencies', 'art_1', 'bakery_1', 'bbc_channels', 'bike_racing', 'boat_1', 'book_1', 'book_press', 'book_review', 'car_racing', 'car_road_race', 'club_leader', 'conference', 'country_language', 'cre_Doc_Workflow', 'cre_Doc_and_collections', 'cre_Students_Information_Systems', 'customers_and_orders', 'district_spokesman', 'e_commerce', 'government_shift', 'headphone_store', 'institution_sports', 'movie_2', 'online_exams', 'pilot_1', 'planet_1', 'real_estate_rentals', 'region_building', 'restaurant_bills', 'sing_contest', 'soccer_3', 'tv_shows', 'university_rank', 'vehicle_driver', 'vehicle_rent', 'video_game', 'warehouse_1']


# data

In [10]:
train_data = read_json("data/train_spider.json") + read_json("data/train_others.json")
len(train_data)

8659

In [12]:
train_data[0].keys()

dict_keys(['db_id', 'query', 'query_toks', 'query_toks_no_value', 'question', 'question_toks', 'sql'])

In [42]:
d= train_data[0]

In [45]:
print(d["question"])
print(d["query"])
print(d["sql"])

How many heads of the departments are older than 56 ?
SELECT count(*) FROM head WHERE age  >  56
{'from': {'table_units': [['table_unit', 1]], 'conds': []}, 'select': [False, [[3, [0, [0, 0, False], None]]]], 'where': [[False, 3, [0, [0, 10, False], None], 56.0, None]], 'groupBy': [], 'having': [], 'orderBy': [], 'limit': None, 'intersect': None, 'union': None, 'except': None}


# table

In [32]:
train_table = read_json("data/tables.json")
len(train_table)

166

In [37]:
print(train_table[0].keys())

dict_keys(['column_names', 'column_names_original', 'column_types', 'db_id', 'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'])


# 统计

In [18]:
from collections import defaultdict
from tqdm import tqdm

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("LLMs/meta-llama/Meta-Llama-3-8B/")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [15]:
# count db_id
db_id_count = defaultdict(int)
for d in train_data:
    db_id_count[d["db_id"]] += 1
# sort by name
db_id_count = sorted(db_id_count.items(), key=lambda x: x[0])
# db_id_count

In [None]:
# 按照DB ID 统计问题和SQL的长度
import matplotlib.pyplot as plt

types = [d[0] for d in db_id_count]

fig, axes = plt.subplots(nrows=(len(types) + 4) // 5, ncols=5, figsize=(15, 5 * ((len(types) + 4) // 5)))
axes = axes.flatten()

for idx, tp in enumerate(tqdm(types)):
    tp_data = [d for d in train_data if d["db_id"] == tp] # db_id or difficulty

    lengths_question = []
    lengths_SQL = []

    for d in tp_data:
        lengths_question.append(len(tokenizer.tokenize(d['question'])))
        lengths_SQL.append(len(tokenizer.tokenize(d['query'])))

    ax = axes[idx]
    ax.hist(lengths_question, bins=20, color='skyblue', edgecolor='black', alpha=0.5, label='question')
    ax.hist(lengths_SQL, bins=20, color='orange', edgecolor='black', alpha=0.5, label='SQL') 
    ax.set_title(f'{tp}')
    ax.set_xlabel('Length of tokens')
    ax.set_ylabel('Frequency')
    ax.grid(True)
    ax.legend()

for ax in axes[idx+1:]:
    ax.axis('off')

plt.tight_layout()
# plt.show()
os.makedirs("img", exist_ok=True)
plt.savefig("img/train-db-q-sql-token-length.png")

In [25]:
!ls -lh img/train-db-q-sql-token-length.png

-rw-rw-r-- 1 xionggm xionggm 1.5M  6月 18 22:23 img/train-db-q-sql-token-length.png


# DB schema 打印

In [40]:
import sqlite3
from glob import glob
import os
from tqdm import tqdm

def get_tables(cursor):
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    # exclude sqlite_sequence
    tables = [table[0] for table in tables if table[0] != "sqlite_sequence"]
    return tables

def get_primary_key(cursor, table_name):
    cursor.execute(f"PRAGMA table_info(`{table_name}`)")
    columns = cursor.fetchall()
    primary_keys = [col[1] for col in columns if col[5] == 1]
    return primary_keys

def get_foreign_keys(cursor, table_name):
    cursor.execute(f"PRAGMA foreign_key_list(`{table_name}`)")
    foreign_keys = cursor.fetchall()
    return foreign_keys

def generate_markdown_table(cursor):
    tables = get_tables(cursor)
    markdown = "| Table | Primary Key | Foreign Key |\n"
    markdown += "| --- | --- | --- |\n"
    
    for table in sorted(tables):
        primary_keys = get_primary_key(cursor, table)
        foreign_keys = get_foreign_keys(cursor, table)

        primary_key_str = ", ".join(primary_keys)

        foreign_key_str = ""
        if foreign_keys:
            # fk[2]: table name, fk[3]: column name, fk[4]: parent column name
            foreign_key_str = ", ".join([f"{fk[3]} references {fk[2]}({fk[4]})" for fk in foreign_keys])

        markdown += f"| {table} | {primary_key_str} | {foreign_key_str} |\n"

    return markdown


In [41]:
db_path = 'data/database/academic/academic.sqlite'

conn = sqlite3.connect(db_path)
cursor = conn.cursor()
markdown_table = generate_markdown_table(cursor)
print(markdown_table)

| Table | Primary Key | Foreign Key |
| --- | --- | --- |
| author | aid |  |
| cite |  | citing references publication(pid), cited references publication(pid) |
| conference | cid |  |
| domain | did |  |
| domain_author | did | did references domain(did), aid references author(aid) |
| domain_conference | did | did references domain(did), cid references conference(cid) |
| domain_journal | did | did references domain(did), jid references journal(jid) |
| domain_keyword | did | did references domain(did), kid references keyword(kid) |
| domain_publication | did | did references domain(did), pid references publication(pid) |
| journal | jid |  |
| keyword | kid |  |
| organization | oid |  |
| publication | pid | cid references conference(cid), jid references journal(jid) |
| publication_keyword | kid | kid references keyword(kid), pid references publication(pid) |
| writes | aid | aid references author(aid), pid references publication(pid) |

