## Schema-Linking

### Citation
@article{pourreza2024dts, title={DTS-SQL: Decomposed Text-to-SQL with Small Large Language Models}, author={Pourreza, Mohammadreza and Rafiei, Davood}, journal={arXiv preprint arXiv:2402.01117}, year={2024} }

https://github.com/MohammadrezaPourreza/DTS-SQL

In [3]:
import sqlite3
import numpy as np
import pandas as pd
import os
import re
from tqdm import tqdm

In [4]:
def get_all_table_names(db_uri: str) -> list[str]:
    conn = sqlite3.connect(db_uri)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    table_names = cursor.fetchall()
    conn.close()
    return [table_name[0] for table_name in table_names]

In [5]:
def get_table_schema_with_samples(
    db_uri: str, table_name: str, sample_limit: int = 0, columns_description: dict[str, str] = {}
) -> str:
    conn = sqlite3.connect(db_uri)
    cursor = conn.cursor()

    # Fetch table schema
    cursor.execute(f"PRAGMA table_info(`{table_name}`);")
    columns = cursor.fetchall()
    cursor.execute(f"PRAGMA foreign_key_list(`{table_name}`);")
    foreign_keys = cursor.fetchall()
    cursor.execute(f"PRAGMA index_list(`{table_name}`);")
    primary_key_indices = cursor.fetchall()
    primary_key_columns = []

    for index_info in primary_key_indices:
        index_name = index_info[1]
        cursor.execute(f"PRAGMA index_info(`{index_name}`);")
        index_columns = cursor.fetchall()
        primary_key_columns.extend(column[2] for column in index_columns)

    # Construct CREATE TABLE statement
    schema_str = f"CREATE TABLE `{table_name}` (\n"
    for column in columns:
        column_name = column[1]
        data_type = column[2]
        schema_str += f"  {column_name} {data_type}"
        if column_name in primary_key_columns:
            schema_str += " PRIMARY KEY"
        for foreign_key in foreign_keys:
            if column_name == foreign_key[3]:
                schema_str += f" REFERENCES {foreign_key[2]}({foreign_key[4]})"
        if column_name in columns_description:
            schema_str += f" -- '{columns_description[column_name]}'"

        schema_str += ",\n"
    schema_str = schema_str.rstrip(",\n")
    schema_str += "\n);\n"

    
    cursor.execute(f"SELECT * FROM `{table_name}` LIMIT {sample_limit};")
    sample_rows = cursor.fetchall()

    if len(sample_rows) > 0:
        schema_str += f"Sample rows from `{table_name}`:\n"
        for row in sample_rows:
            formatted_row = ", ".join(str(item) for item in row)
            schema_str += f"{formatted_row}\n"

    conn.close()
    return schema_str

In [6]:
def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)

In [7]:
def load_descriptions(db_path: str, table_name: str) -> list[str]:
    if not os.path.exists(f"{db_path}/database_description/{table_name}.csv"):
        return {}
    try:
        df = pd.read_csv(f"{db_path}/database_description/{table_name}.csv")
    except Exception:
        return {}
    if "column_description" not in df.columns or "value_description" not in df.columns:
        return {}
    columns_description = {}
    for index, row in df.iterrows():
        if np.nan != row["column_description"] and pd.notna(row["column_description"]):
            columns_description[row["original_column_name"]] = remove_spaces(row["column_description"])
            if np.nan != row["value_description"] and pd.notna(row["value_description"]):
                columns_description[row["original_column_name"]] += f" has values: ({remove_spaces(row['value_description'])})"
    return columns_description

In [None]:
def generate_schema_for_instance(row, base_databases_dir):
    db_id = row['db_id']
    
    # Set up database paths
    db_uri = f"{base_databases_dir}/{db_id}/{db_id}.sqlite"
    
    # Get table names and build schema
    table_names = get_all_table_names(db_uri)
    database_schema = ""
    
    for table_name in table_names:
        columns_description = load_descriptions(db_id, table_name)
        schema = get_table_schema_with_samples(db_uri, table_name, 0, columns_description)
        database_schema += schema + "\n"
    
    return database_schema.strip()

In [23]:
BASE_DATASET_DIR = "../dataset/dev.json"
BASE_DABATASES_DIR =  "../dataset/dev_databases/"
OUTPUT_FILENAME = "dataset.json"

In [28]:
df = pd.read_json(BASE_DATASET_DIR)
row = df.iloc[0]

In [29]:
# Loop over each row and generate schema for each instance
for idx, row in df.iterrows():
# Generate the schema for the current row
    database_schema = generate_schema_for_instance(row, BASE_DABATASES_DIR)
        
    # Add the generated schema to a new field 'database_schema' in the row
    df.at[idx, 'database_schema'] = database_schema

In [30]:
df.iloc[0]

question_id                                                        0
db_id                                             california_schools
question           What is the highest eligible free rate for K-1...
evidence           Eligible free rate for K-12 = `Free Meal Count...
SQL                SELECT `Free Meal Count (K-12)` / `Enrollment ...
difficulty                                                    simple
database_schema    CREATE TABLE `frpm` (\n  CDSCode TEXT PRIMARY ...
Name: 0, dtype: object

In [31]:
import json
# Specify the file name
file_name = OUTPUT_FILENAME

# Convert DataFrame to dictionary
data = df.to_dict(orient="records")

# Write the dictionary to a JSON file
with open(file_name, 'w') as json_file:
    json.dump(data, json_file, indent=4)