In [1]:
import gzip
import json
import pandas as pd

In [2]:
# open schemapile
with gzip.open("../../data/schemapile.json.gz", 'r') as f:
    schemapile = json.loads(f.read().decode('utf-8'))

In [3]:
# helper functions

def get_columns_and_types(table_name, cursor):
    columns_and_types = cursor.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    return [column[1] for column in columns_and_types], [column[2] for column in columns_and_types]

def make_table_string(table_name, columns, sort=False):
    if sort:
        columns = sorted(columns)
    return table_name.lower()+"("+ ", ".join([column.lower() for column in columns])+")"

In [None]:
# collect foreign keys from schemapile

import os
llm_max_chars = 8000

foreign_keys_list = []
for schema_name in schemapile:
    tables = list(schemapile[schema_name]["TABLES"].keys())
    table_columns = {}
    table_types = {}

    # collect schema information for lookup
    for table_name in tables:
        columns_schemapile = schemapile[schema_name]["TABLES"][table_name]["COLUMNS"]
        columns = list(columns_schemapile.keys())
        types = [columns_schemapile[column]["TYPE"] for column in columns_schemapile]
        table_columns[table_name] = columns
        table_types[table_name] = types

    # collect foreign keys
    for table_name in tables:
        foreign_keys_by_id = {}
        foreign_keys_table = schemapile[schema_name]["TABLES"][table_name]["FOREIGN_KEYS"]
        for i, foreign_key in enumerate(foreign_keys_table):
            if len(foreign_key["COLUMNS"]) > 1:
                print(f"{schema_name}.{table_name} not including composite primary key: {foreign_key}")
                continue
                
            column = foreign_key["COLUMNS"][0]
            foreign_table = foreign_key["FOREIGN_TABLE"]
            referred_column = foreign_key["REFERRED_COLUMNS"][0]
            if foreign_table in table_columns and referred_column in table_columns[foreign_table] and column in table_columns[table_name]:  
                table_string = make_table_string(table_name, table_columns[table_name])
                foreign_table_string = make_table_string(foreign_table, table_columns[foreign_table])
                too_large_for_llm = len(table_string) + len(foreign_table_string) > llm_max_chars
                if too_large_for_llm:
                    print(f"{schema_name}.{table_name}/{foreign_table} schema too large for llm: {len(table_string)}+{len(foreign_table_string)}>{llm_max_chars}")
                    
                foreign_keys_by_id[i] = {"database": schema_name,
                                    "table": table_name, 
                                    "table_string": table_string,
                                    "table_columns": table_columns[table_name],
                                    "table_types": table_types[table_name],
                                    "column": column, 
                                    "foreign_table": foreign_table, 
                                    "foreign_table_string": foreign_table_string, 
                                    "foreign_table_columns": table_columns[foreign_table],
                                    "foreign_table_types": table_types[foreign_table],
                                    "referred_column": referred_column,
                                    "too_large_for_llm": too_large_for_llm}

        # don't include foreign key pairs if there are multiple between two tables   
        foreign_keys_by_table_combination = {}
        marked_for_deletion = []
        for key_id in list(foreign_keys_by_id.keys()):
            foreign_key_pair = foreign_keys_by_id[key_id]
            table_combination = (foreign_key_pair["table"], foreign_key_pair["foreign_table"])
            if table_combination in foreign_keys_by_table_combination:
                print(f"{foreign_key_pair['database']} multiple pairs removed for {table_combination}")
                marked_for_deletion.append(table_combination)
            elif foreign_key_pair["too_large_for_llm"]:
                print(f"{foreign_key_pair['database']} too large for llm {table_combination}")
                marked_for_deletion.append(table_combination)
            else:
                foreign_keys_by_table_combination[table_combination] = foreign_key_pair
                
        for table_combination in list(foreign_keys_by_table_combination.keys()):
            if table_combination in marked_for_deletion:
                del foreign_keys_by_table_combination[table_combination]

        foreign_keys_list.extend(list(foreign_keys_by_table_combination.values()))

In [6]:
len(foreign_keys_list)

666253

In [7]:
# deduplicate

table_combinations = set()
foreign_keys_list_deduped = []
for foreign_key_pair in foreign_keys_list:
    table_combination = (foreign_key_pair["table_string"],foreign_key_pair["foreign_table_string"], foreign_key_pair["column"], foreign_key_pair["referred_column"])
    if table_combination not in table_combinations:
        foreign_keys_list_deduped.append(foreign_key_pair)
        table_combinations.add(table_combination)

In [8]:
len(foreign_keys_list_deduped)

468770

In [9]:
# create training data for foreign key detection from schemapile
import json
alphabetically_sorted = False
instruction_pairs = []

question_template = "You are given the following SQL database tables: \n{tables}\nOutput a json string with the following schema {{table, column, referencedTable, referencedColumn}} that contains the foreign key relationship between the two tables." 
for foreign_key_pair in foreign_keys_list_deduped: 
    table = foreign_key_pair["table"].lower()
    foreign_table = foreign_key_pair["foreign_table"].lower()

    if alphabetically_sorted:
        prompt = question_template.format(tables="\n".join(sorted([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]])))
    else:
        prompt = question_template.format(tables="\n".join([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]]))
    ground_truth = json.dumps({"table": foreign_key_pair["table"].lower(), "column": foreign_key_pair["column"].lower(), "referencedTable": foreign_key_pair["foreign_table"].lower(), "referencedColumn": foreign_key_pair["referred_column"].lower()})
    
    instruction_pairs.append([{"content": prompt, "role": "user"}, {"content": ground_truth, "role": "assistant"}])

In [11]:
train = pd.DataFrame({"messages": instruction_pairs})

In [12]:
import os
output_path = "../../data/foreign_keys_instruction_data_schemapile"
os.makedirs(output_path, exist_ok=True)
train.reset_index().drop("index",axis=1).to_parquet(f"{output_path}/train.parquet")

In [None]:
from datasets import load_dataset
ds = load_dataset("parquet", data_files={'train': f"{output_path}/train.parquet"})

In [15]:
ds["train"]["messages"][0]

[{'content': 'You are given the following SQL database tables: \naddress(id, uuid, flat_buil_number, locality, city, pincode, state_id)\nstate(id, uuid, state_name)\nOutput a json string with the following schema {table, column, referencedTable, referencedColumn} that contains the foreign key relationship between the two tables.',
  'role': 'user'},
 {'content': '{"table": "address", "column": "state_id", "referencedTable": "state", "referencedColumn": "id"}',
  'role': 'assistant'}]