In [15]:
import argparse
import json
import os
import pickle
from pathlib import Path
import sqlite3
from tqdm import tqdm
import random

from utils.datasets.spider import load_tables

In [3]:

# merge two training split of Spider
spider_dir = "dataset/spider"
split1 = "train_spider.json"
split2 = "train_others.json"
total_train = []
for item in json.load(open(os.path.join(spider_dir, split1))):
    total_train.append(item)
for item in json.load(open(os.path.join(spider_dir, split2))):
    total_train.append(item)
with open(os.path.join(spider_dir, 'train_spider_and_others.json'), 'w') as f:
    json.dump(total_train, f)

# schema-linking between questions and databases for Spider
spider_dev = "dev.json"
spider_train = 'train_spider_and_others.json'
spider_table = 'tables.json'
spider_db = 'database'
#schema_linking_producer(spider_dev, spider_train, spider_table, spider_db, spider_dir)

In [12]:
# load data
dataset_dir = spider_dir
test = spider_dev
train = spider_train
table = spider_table
db = spider_db

test_data = json.load(open(os.path.join(dataset_dir, test)))
train_data = json.load(open(os.path.join(dataset_dir, train)))

# load schemas
schemas, _ = load_tables([os.path.join(dataset_dir, table)])

In [21]:
#Backup in-memory copies of all the DBs and create the live connections
for db_id, schema in tqdm(schemas.items(), desc="DB connections"):
    sqlite_path = Path(dataset_dir) / db / db_id / f"{db_id}.sqlite"
    source: sqlite3.Connection
    with sqlite3.connect(str(sqlite_path)) as source:
        dest = sqlite3.connect(':memory:')
        dest.row_factory = sqlite3.Row
        source.backup(dest)
    schema.connection = dest

DB connections: 100%|██████████| 166/166 [00:00<00:00, 233.84it/s]


In [41]:
db_id = "academic"  # or any other db_id from your schemas
connection = schemas[db_id].connection
cursor = connection.cursor()

# Query sqlite_master table to get all CREATE statements
cursor.execute("""
    SELECT sql 
    FROM sqlite_master 
    WHERE type='table' AND sql IS NOT NULL
""")

# Print each CREATE statement
for (sql,) in cursor.fetchall():
    print(sql)
    print("\n" + "-"*50 + "\n")

CREATE TABLE "author" (
"aid" int,
"homepage" text,
"name" text,
"oid" int,
primary key("aid")
)

--------------------------------------------------

CREATE TABLE "conference" (
"cid" int,
"homepage" text,
"name" text,
primary key ("cid")
)

--------------------------------------------------

CREATE TABLE "domain" (
"did" int,
"name" text,
primary key ("did")
)

--------------------------------------------------

CREATE TABLE "domain_author" (
"aid" int, 
"did" int,
primary key ("did", "aid"),
foreign key("aid") references `author`("aid"),
foreign key("did") references `domain`("did")
)

--------------------------------------------------

CREATE TABLE "domain_conference" (
"cid" int,
"did" int,
primary key ("did", "cid"),
foreign key("cid") references `conference`("cid"),
foreign key("did") references `domain`("did")
)

--------------------------------------------------

CREATE TABLE "journal" (
"homepage" text,
"jid" int,
"name" text,
primary key("jid")
)

----------------------------

In [37]:
# Example 1: Basic query for a specific database
db_id = "academic"  # or any other db_id from your schemas
connection = schemas[db_id].connection
cursor = connection.cursor()

# Execute a simple query
cursor.execute("SELECT * FROM author LIMIT 2")
rows = cursor.fetchall()
for row in rows:
    # Since we set row_factory = sqlite3.Row, we can access by column name
    print(f"Singer ID: {row['homepage']}, Name: {row['aid']}")

In [38]:
rows

[]

In [40]:
schemas["race_track"]

Schema(db_id='race_track', tables=(Table(id=0, name=['race'], unsplit_name='race', orig_name='race', columns=[Column(id=1, table=..., name=['race', 'id'], unsplit_name='race id', orig_name='Race_ID', type='number', foreign_key_for=None), Column(id=2, table=..., name=['name'], unsplit_name='name', orig_name='Name', type='text', foreign_key_for=None), Column(id=3, table=..., name=['class'], unsplit_name='class', orig_name='Class', type='text', foreign_key_for=None), Column(id=4, table=..., name=['date'], unsplit_name='date', orig_name='Date', type='text', foreign_key_for=None), Column(id=5, table=..., name=['track', 'id'], unsplit_name='track id', orig_name='Track_ID', type='text', foreign_key_for=Column(id=6, table=Table(id=1, name=['track'], unsplit_name='track', orig_name='track', columns=[..., Column(id=7, table=..., name=['name'], unsplit_name='name', orig_name='Name', type='text', foreign_key_for=None), Column(id=8, table=..., name=['location'], unsplit_name='location', orig_name='