## Schema-Linking

### Citation
@article{pourreza2024dts, title={DTS-SQL: Decomposed Text-to-SQL with Small Large Language Models}, author={Pourreza, Mohammadreza and Rafiei, Davood}, journal={arXiv preprint arXiv:2402.01117}, year={2024} }

In [1]:
import sqlite3
import numpy as np
import pandas as pd
import os
import re
from tqdm import tqdm

In [3]:
def get_all_table_names(db_uri: str) -> list[str]:
    conn = sqlite3.connect(db_uri)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    table_names = cursor.fetchall()
    conn.close()
    return [table_name[0] for table_name in table_names]

In [4]:
def get_table_schema_with_samples(
    db_uri: str, table_name: str, sample_limit: int = 0, columns_description: dict[str, str] = {}
) -> str:
    conn = sqlite3.connect(db_uri)
    cursor = conn.cursor()

    # Fetch table schema
    cursor.execute(f"PRAGMA table_info(`{table_name}`);")
    columns = cursor.fetchall()
    cursor.execute(f"PRAGMA foreign_key_list(`{table_name}`);")
    foreign_keys = cursor.fetchall()
    cursor.execute(f"PRAGMA index_list(`{table_name}`);")
    primary_key_indices = cursor.fetchall()
    primary_key_columns = []

    for index_info in primary_key_indices:
        index_name = index_info[1]
        cursor.execute(f"PRAGMA index_info(`{index_name}`);")
        index_columns = cursor.fetchall()
        primary_key_columns.extend(column[2] for column in index_columns)

    # Construct CREATE TABLE statement
    schema_str = f"CREATE TABLE `{table_name}` (\n"
    for column in columns:
        column_name = column[1]
        data_type = column[2]
        schema_str += f"  {column_name} {data_type}"
        if column_name in primary_key_columns:
            schema_str += " PRIMARY KEY"
        for foreign_key in foreign_keys:
            if column_name == foreign_key[3]:
                schema_str += f" REFERENCES {foreign_key[2]}({foreign_key[4]})"
        if column_name in columns_description:
            schema_str += f" -- '{columns_description[column_name]}'"

        schema_str += ",\n"
    schema_str = schema_str.rstrip(",\n")
    schema_str += "\n);\n"

    
    cursor.execute(f"SELECT * FROM `{table_name}` LIMIT {sample_limit};")
    sample_rows = cursor.fetchall()

    if len(sample_rows) > 0:
        schema_str += f"Sample rows from `{table_name}`:\n"
        for row in sample_rows:
            formatted_row = ", ".join(str(item) for item in row)
            schema_str += f"{formatted_row}\n"

    conn.close()
    return schema_str

In [5]:
def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)

In [6]:
def load_descriptions(db_path: str, table_name: str) -> list[str]:
    if not os.path.exists(f"{db_path}/database_description/{table_name}.csv"):
        return {}
    try:
        df = pd.read_csv(f"{db_path}/database_description/{table_name}.csv")
    except Exception:
        return {}
    if "column_description" not in df.columns or "value_description" not in df.columns:
        return {}
    columns_description = {}
    for index, row in df.iterrows():
        if np.nan != row["column_description"] and pd.notna(row["column_description"]):
            columns_description[row["original_column_name"]] = remove_spaces(row["column_description"])
            if np.nan != row["value_description"] and pd.notna(row["value_description"]):
                columns_description[row["original_column_name"]] += f" has values: ({remove_spaces(row['value_description'])})"
    return columns_description

In [11]:
BASE_DATASET_DIR = "../dataset/dev.json"
BASE_DABATASES_DIR =  "../dataset/dev_databases/"
OUTPUT_DIR = "predict_dev.json"

In [12]:
df = pd.read_json(BASE_DATASET_DIR)
row = df.iloc[0]

In [17]:
if __name__ == "__main__":
        db_id = row['db_id']
        query = row['SQL']
        question = row['question']
        if row['evidence'] != "" and row['evidence'] is not None:
            question += " Hint: " + row['evidence']
        db_uri = f"{BASE_DABATASES_DIR}/{db_id}/{db_id}.sqlite"
        db_path = f"{BASE_DABATASES_DIR}/{db_id}"
        table_names = get_all_table_names(db_uri)
        database_schema = ""
        for table_name in table_names:
            columns_description = load_descriptions(db_id, table_name)
            schema = get_table_schema_with_samples(db_uri, table_name, 0, columns_description)
            database_schema += schema + "\n"
        user_message = f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.
{database_schema}
####
Question: {question}
"""
        messages = [
            {"role": "user", "content": user_message.strip()}
        ]

In [18]:
messages

[{'role': 'user',
  'content': 'Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.\nCREATE TABLE `frpm` (\n  CDSCode TEXT PRIMARY KEY REFERENCES schools(CDSCode),\n  Academic Year TEXT,\n  County Code TEXT,\n  District Code INTEGER,\n  School Code TEXT,\n  County Name TEXT,\n  District Name TEXT,\n  School Name TEXT,\n  District Type TEXT,\n  School Type TEXT,\n  Educational Option Type TEXT,\n  NSLP Provision Status TEXT,\n  Charter School (Y/N) INTEGER,\n  Charter School Number TEXT,\n  Charter Funding Type TEXT,\n  IRC INTEGER,\n  Low Grade TEXT,\n  High Grade TEXT,\n  Enrollment (K-12) REAL,\n  Free Meal Count (K-12) REAL,\n  Percent (%) Eligible Free (K-12) REAL,\n  FRPM Count (K-12) REAL,\n  Percent (%) Eligible FRPM (K-12) REAL,\n  Enrollment (Ages 5-17) REAL,\n  Free Meal Count (Ages 5-17) REAL,\n  Percent (%) Eligible Free (Ages 5-17) REAL,\n  FRPM Count (Ages 5-17) REAL,\n  Percent (%) Eligible FRPM (Ages 5-17) R