In [2]:
# helper functions

def get_columns_and_types(table_name, cursor):
    columns_and_types = cursor.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    return [column[1] for column in columns_and_types], [column[2] for column in columns_and_types]

def make_table_string(table_name, columns, sort=False):
    if sort:
        columns = sorted(columns)
    return table_name.lower()+"("+ ", ".join([column.lower() for column in columns])+")"

In [3]:
# collect foreign keys from databases

import sqlite3

llm_max_chars = 8000

def collect_foreign_keys(databases_path):
    databases = os.listdir(databases_path)
    foreign_keys_list = []
    for database in databases:
        database_name = database.replace(".sqlite","")
        con = sqlite3.connect(f"{databases_path}/{database}")
        cursor = con.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        table_columns = {}
        table_types = {}
        
        # collect schema information for lookup
        for table in tables:
            table_name = table[0]
            columns, types = get_columns_and_types(table_name, cursor)
            table_columns[table_name] = columns
            table_types[table_name] = types
            
        # collect foreign keys
        for table in tables:
            table_name = table[0]
            cursor.execute(f"PRAGMA foreign_key_list('{table_name}');")
            foreign_keys_by_id = {}
            for fk_id, seq_id, foreign_table, column, referred_column, _, _, _ in cursor.fetchall():
                if fk_id in foreign_keys_by_id:
                    print(f"{database}.{table_name} removed composite primary key: {fk_id}.{seq_id}")
                    del foreign_keys_by_id[fk_id]
                elif foreign_table in table_columns and referred_column in table_columns[foreign_table] and column in table_columns[table_name]:  
                    table_string = make_table_string(table_name, table_columns[table_name])
                    foreign_table_string = make_table_string(foreign_table, table_columns[foreign_table])
                    too_large_for_llm = len(table_string) + len(foreign_table_string) > llm_max_chars
                    if too_large_for_llm:
                        print(f"{database}.{table_name}/{foreign_table} schema too large for llm: {len(table_string)}+{len(foreign_table_string)}>{llm_max_chars}")
                    
                    foreign_keys_by_id[fk_id] = {"database": database,
                                        "table": table_name, 
                                        "table_string": table_string,
                                        "table_columns": table_columns[table_name],
                                        "table_types": table_types[table_name],
                                        "column": column, 
                                        "foreign_table": foreign_table, 
                                        "foreign_table_string": foreign_table_string, 
                                        "foreign_table_columns": table_columns[foreign_table],
                                        "foreign_table_types": table_types[foreign_table],
                                        "referred_column": referred_column,
                                        "too_large_for_llm": too_large_for_llm}
            
            # don't include foreign key pairs if there are multiple between two tables   
            foreign_keys_by_table_combination = {}
            for key_id in list(foreign_keys_by_id.keys()):
                foreign_key_pair = foreign_keys_by_id[key_id]
                table_combination = (foreign_key_pair["table"], foreign_key_pair["foreign_table"])
                if table_combination in foreign_keys_by_table_combination:
                    print(f"{foreign_key_pair['database']} multiple pairs removed for {table_combination}")
                    del foreign_keys_by_table_combination[table_combination]
                else:
                    foreign_keys_by_table_combination[table_combination] = foreign_key_pair

            foreign_keys_list.extend(list(foreign_keys_by_table_combination.values()))
    return foreign_keys_list

In [4]:
foreign_keys_list_spider = collect_foreign_keys("spider_dbs/")
len(foreign_keys_list_spider)

network_1.sqlite multiple pairs removed for ('Friend', 'Highschooler')
network_1.sqlite multiple pairs removed for ('Likes', 'Highschooler')
wedding.sqlite multiple pairs removed for ('wedding', 'people')
solvency_ii.sqlite multiple pairs removed for ('Assets_in_Events', 'Events')
twitter_1.sqlite multiple pairs removed for ('follows', 'user_profiles')
cre_Drama_Workshop_Groups.sqlite.Invoice_Items removed composite primary key: 0.1
network_2.sqlite multiple pairs removed for ('PersonFriend', 'Person')
dog_kennels.sqlite multiple pairs removed for ('Dogs', 'Owners')
flight_4.sqlite multiple pairs removed for ('routes', 'airports')
insurance_fnol.sqlite.First_Notification_of_Loss removed composite primary key: 0.1
soccer_1.sqlite multiple pairs removed for ('Player_Attributes', 'Player')
soccer_1.sqlite multiple pairs removed for ('Team_Attributes', 'Team')
student_transcripts_tracking.sqlite multiple pairs removed for ('Students', 'Addresses')
cre_Doc_Control_Systems.sqlite.Draft_Copie

555

In [5]:
foreign_keys_list_bird = collect_foreign_keys("bird_dbs/")
len(foreign_keys_list_bird)

professional_basketball.sqlite.coaches removed composite primary key: 0.1
professional_basketball.sqlite.draft removed composite primary key: 0.1
professional_basketball.sqlite.awards_coaches removed composite primary key: 0.1
professional_basketball.sqlite.players_teams removed composite primary key: 0.1
professional_basketball.sqlite.series_post removed composite primary key: 0.1
professional_basketball.sqlite.series_post removed composite primary key: 1.1
codebase_community.sqlite multiple pairs removed for ('postLinks', 'posts')
codebase_community.sqlite multiple pairs removed for ('posts', 'users')
language_corpus.sqlite multiple pairs removed for ('biwords', 'words')
image_and_language.sqlite.IMG_OBJ_ATT removed composite primary key: 0.1
image_and_language.sqlite.IMG_REL removed composite primary key: 0.1
image_and_language.sqlite.IMG_REL removed composite primary key: 1.1
movie_platform.sqlite multiple pairs removed for ('lists_users', 'lists')
hockey.sqlite.Coaches removed com

386

In [22]:
foreign_keys_list_ctu = collect_foreign_keys("ctu_dbs/")
len(foreign_keys_list_ctu)

Biodegradability.sqlite multiple pairs removed for ('bond', 'atom')
AdventureWorks2014.sqlite multiple pairs removed for ('BillOfMaterials', 'Product')
AdventureWorks2014.sqlite multiple pairs removed for ('CurrencyRate', 'Currency')
AdventureWorks2014.sqlite multiple pairs removed for ('Product', 'UnitMeasure')
AdventureWorks2014.sqlite.SalesOrderDetail removed composite primary key: 0.1
AdventureWorks2014.sqlite multiple pairs removed for ('SalesOrderHeader', 'Address')
stats_CEB.sqlite multiple pairs removed for ('postLinks', 'posts')
stats_CEB.sqlite multiple pairs removed for ('posts', 'users')
CORA.sqlite multiple pairs removed for ('cites', 'paper')
WebKP.sqlite multiple pairs removed for ('cites', 'webpage')
Credit.sqlite multiple pairs removed for ('charge', 'member')
Grants.sqlite.institution_awards removed composite primary key: 0.1
mutagenesis.sqlite multiple pairs removed for ('bond', 'atom')
SAT.sqlite multiple pairs removed for ('succ', 'time')
UW_std.sqlite multiple pai

genes.sqlite multiple pairs removed for ('Interactions', 'Classification')
Elti.sqlite multiple pairs removed for ('brother', 'person')
Elti.sqlite multiple pairs removed for ('daughter', 'person')
Elti.sqlite multiple pairs removed for ('elti', 'person')
Elti.sqlite multiple pairs removed for ('father', 'person')
Elti.sqlite multiple pairs removed for ('husband', 'person')
Elti.sqlite multiple pairs removed for ('mother', 'person')
Elti.sqlite multiple pairs removed for ('sister', 'person')
Elti.sqlite multiple pairs removed for ('son', 'person')
Elti.sqlite multiple pairs removed for ('target', 'person')
Elti.sqlite multiple pairs removed for ('wife', 'person')
nations.sqlite multiple pairs removed for ('relation', 'country')
Dunur.sqlite multiple pairs removed for ('aunt', 'person')
Dunur.sqlite multiple pairs removed for ('brother', 'person')
Dunur.sqlite multiple pairs removed for ('daughter', 'person')
Dunur.sqlite multiple pairs removed for ('dunur', 'person')
Dunur.sqlite multi

877

In [23]:
foreign_keys_list = foreign_keys_list_ctu

In [24]:
# deduplicate

table_combinations = set()
foreign_keys_list_deduped = []
for foreign_key_pair in foreign_keys_list:
    table_combination = (foreign_key_pair["table_string"],foreign_key_pair["foreign_table_string"], foreign_key_pair["column"], foreign_key_pair["referred_column"])
    if table_combination not in table_combinations:
        foreign_keys_list_deduped.append(foreign_key_pair)
        table_combinations.add(table_combination)

In [25]:
# remove overlaps with eval dataset 

def clean_overlapping_foreign_key_pairs(foreign_key_pairs, foreign_key_pairs_others):
    print(f"before: {len(foreign_key_pairs)}")
    foreign_key_pairs_cleaned = []
    for foreign_key_pair in foreign_key_pairs:
        foreign_key_table_string = make_table_string(foreign_key_pair["table"], foreign_key_pair["table_columns"], True)
        foreign_key_foreign_table_string = make_table_string(foreign_key_pair["foreign_table"], foreign_key_pair["foreign_table_columns"], True)
        
        occurs_in_other = False
        for foreign_key_pair_other in foreign_key_pairs_others:
            foreign_key_table_string_other = make_table_string(foreign_key_pair_other["table"], foreign_key_pair_other["table_columns"], True)
            foreign_key_foreign_table_string_other = make_table_string(foreign_key_pair_other["foreign_table"], foreign_key_pair_other["foreign_table_columns"], True)
            
            if foreign_key_table_string == foreign_key_table_string_other and foreign_key_foreign_table_string == foreign_key_foreign_table_string_other:
                occurs_in_other = True
        
        if not occurs_in_other:
            foreign_key_pairs_cleaned.append(foreign_key_pair)
            
    print(f"after: {len(foreign_key_pairs_cleaned)}, removed {len(foreign_key_pairs)-len(foreign_key_pairs_cleaned)}")
    return foreign_key_pairs_cleaned

In [26]:
# clean ctu from foreign keys that occur in spider
foreign_keys_list_deduped = clean_overlapping_foreign_key_pairs(foreign_keys_list_deduped, foreign_keys_list_spider)

before: 725
after: 706, removed 19


In [27]:
# create training data for foreign key detection from schemapile
import json
alphabetically_sorted = False
instruction_pairs = []

question_template = "You are given the following SQL database tables: \n{tables}\nOutput a json string with the following schema {{table, column, referencedTable, referencedColumn}} that contains the foreign key relationship between the two tables." 
for foreign_key_pair in foreign_keys_list_deduped: 
    table = foreign_key_pair["table"].lower()
    foreign_table = foreign_key_pair["foreign_table"].lower()

    if alphabetically_sorted:
        prompt = question_template.format(tables="\n".join(sorted([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]])))
    else:
        prompt = question_template.format(tables="\n".join([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]]))
    ground_truth = json.dumps({"table": foreign_key_pair["table"].lower(), "column": foreign_key_pair["column"].lower(), "referencedTable": foreign_key_pair["foreign_table"].lower(), "referencedColumn": foreign_key_pair["referred_column"].lower()})
    
    instruction_pairs.append([{"content": prompt, "role": "user"}, {"content": ground_truth, "role": "assistant"}])

In [29]:
import os
import pandas as pd
output_path = "foreign_keys_instruction_data_ctu"
os.makedirs(output_path, exist_ok=True)
train = pd.DataFrame({"messages": instruction_pairs})
train.reset_index().drop("index",axis=1).to_parquet(f"{output_path}/train.parquet")