In [None]:
import gzip
import json

# open schemapile
with gzip.open("../../data/schemapile.json.gz", 'r') as f:
    schemapile = json.loads(f.read().decode('utf-8'))

schemapile_table_strings = {}

In [53]:
# helper functions

def get_columns_and_types(table_name, cursor):
    columns_and_types = cursor.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    return [column[1] for column in columns_and_types], [column[2] for column in columns_and_types]

def make_table_string(table_name, columns, sort=False):
    if sort:
        columns = sorted(columns)
    return table_name.lower()+"("+ ", ".join([column.lower() for column in columns])+")"

for database in schemapile:
    schemapile_table_strings[database] = []
    for table_name in schemapile[database]["TABLES"]:
        columns = list(schemapile[database]["TABLES"][table_name]["COLUMNS"].keys())
        schemapile_table_strings[database].append(make_table_string(table_name, columns, True))
        
def occurs_in_schemapile(table_string_sorted, foreign_table_string_sorted):
    for database in schemapile_table_strings:
        tables = schemapile_table_strings[database]
        if table_string_sorted in tables and foreign_table_string_sorted in tables:
            return True
    return False

In [69]:
# collect foreign keys from databases

import os
import sqlite3

valentine_max_values = 42*23255 # (largest table in valentine: miller2_vertical_70_ac5_ev_source.csv)
llm_max_chars = 8000

def collect_foreign_keys(databases_path):
    databases = os.listdir(databases_path)
    foreign_keys_list = []
    for database in databases:
        database_name = database.replace(".sqlite","")
        con = sqlite3.connect(f"{databases_path}/{database}")
        cursor = con.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()
        table_columns = {}
        table_types = {}
        table_too_large_for_valentine = {}
        
        # collect schema information for lookup
        for table in tables:
            table_name = table[0]
            columns, types = get_columns_and_types(table_name, cursor)
            table_columns[table_name] = columns
            table_types[table_name] = types
            rows = cursor.execute(f"SELECT COUNT(*) FROM '{table_name}';").fetchall()[0][0]
            table_too_large_for_valentine[table_name] = rows * len(columns) > valentine_max_values
            if table_too_large_for_valentine[table_name]:
                print(f"{database}.{table_name} too large - columns: {len(columns)}, rows: {rows}")
        
        # collect foreign keys
        for table in tables:
            table_name = table[0]
            cursor.execute(f"PRAGMA foreign_key_list('{table_name}');")
            foreign_keys_by_id = {}
            for fk_id, seq_id, foreign_table, column, referred_column, _, _, _ in cursor.fetchall():
                if fk_id in foreign_keys_by_id:
                    print(f"{database}.{table_name} removed composite primary key: {fk_id}.{seq_id}")
                    del foreign_keys_by_id[fk_id]
                elif foreign_table in table_columns and referred_column in table_columns[foreign_table] and column in table_columns[table_name]:  
                    foreign_key_pair_occurs_in_schemapile = occurs_in_schemapile(make_table_string(table_name, table_columns[table_name], True),
                                                                                 make_table_string(foreign_table, table_columns[foreign_table], True))
                    if foreign_key_pair_occurs_in_schemapile:
                        print(f"{database}.{table_name}/{foreign_table} occurs in schemapile")
                    
                    table_string = make_table_string(table_name, table_columns[table_name])
                    foreign_table_string = make_table_string(foreign_table, table_columns[foreign_table])
                    too_large_for_llm = len(table_string) + len(foreign_table_string) > llm_max_chars
                    if too_large_for_llm:
                        print(f"{database}.{table_name}/{foreign_table} schema too large for llm: {len(table_string)}+{len(foreign_table_string)}>{llm_max_chars}")
                    
                    foreign_keys_by_id[fk_id] = {"database": database,
                                        "table": table_name, 
                                        "table_string": table_string,
                                        "table_columns": table_columns[table_name],
                                        "table_types": table_types[table_name],
                                        "column": column, 
                                        "foreign_table": foreign_table, 
                                        "foreign_table_string": foreign_table_string, 
                                        "foreign_table_columns": table_columns[foreign_table],
                                        "foreign_table_types": table_types[foreign_table],
                                        "referred_column": referred_column,
                                        "occurs_in_schemapile": foreign_key_pair_occurs_in_schemapile,
                                        "too_large_for_valentine": table_too_large_for_valentine[table_name] or table_too_large_for_valentine[foreign_table],
                                        "too_large_for_llm": too_large_for_llm}
            
            # don't include foreign key pairs if there are multiple between two tables   
            foreign_keys_by_table_combination = {}
            for key_id in list(foreign_keys_by_id.keys()):
                foreign_key_pair = foreign_keys_by_id[key_id]
                table_combination = (foreign_key_pair["table"], foreign_key_pair["foreign_table"])
                if table_combination in foreign_keys_by_table_combination:
                    print(f"{foreign_key_pair['database']} multiple pairs removed for {table_combination}")
                    del foreign_keys_by_table_combination[table_combination]
                else:
                    foreign_keys_by_table_combination[table_combination] = foreign_key_pair

            foreign_keys_list.extend(list(foreign_keys_by_table_combination.values()))
    return foreign_keys_list

In [None]:
spider_foreign_keys = collect_foreign_keys("spider_dbs/")

In [None]:
bird_foreign_keys = collect_foreign_keys("bird_dbs/")

In [None]:
ctu_foreign_keys = collect_foreign_keys("ctu_dbs/")

In [155]:
# filter foreign keys that should be excluded for valentine 

def print_filter_stats(foreign_key_pairs):
    print(f"occurs_in_schemapile: {sum([pair['occurs_in_schemapile'] for pair in foreign_key_pairs])}/{len(foreign_key_pairs)}")
    print(f"too_large_for_valentine: {sum([pair['too_large_for_valentine'] for pair in foreign_key_pairs])}/{len(foreign_key_pairs)}")
    print(f"too_large_for_llm: {sum([pair['too_large_for_llm'] for pair in foreign_key_pairs])}/{len(foreign_key_pairs)}")
    

def filter_foreign_key_pairs(foreign_key_pairs, valentine=False):
    print_filter_stats(foreign_key_pairs)
    filtered_list = list(filter(lambda foreign_key_pair: not (foreign_key_pair["occurs_in_schemapile"]
                                                     or (valentine and foreign_key_pair["too_large_for_valentine"]) 
                                                     or foreign_key_pair["too_large_for_llm"]), 
                       foreign_key_pairs))
    print(f"kept: {len(filtered_list)}/{len(foreign_key_pairs)}")
    return filtered_list  

In [None]:
spider_foreign_keys_filtered = filter_foreign_key_pairs(spider_foreign_keys)

In [None]:
bird_foreign_keys_filtered = filter_foreign_key_pairs(bird_foreign_keys)

In [None]:
ctu_foreign_keys_filtered = filter_foreign_key_pairs(ctu_foreign_keys)

In [164]:
import json
with open("foreign_keys_filtered_spider.json", "w+") as f:
    json.dump(spider_foreign_keys_filtered, f)

In [165]:
import json
with open("foreign_keys_filtered_bird.json", "w+") as f:
    json.dump(bird_foreign_keys_filtered, f)

In [None]:
import json
with open("foreign_keys_filtered_ctu.json", "w+") as f:
    json.dump(ctu_foreign_keys_filtered, f)

In [None]:
# create evaluation data for valentine

def map_data_type(data_type):
    return data_type.lower()

# Optional:
# - make column names and table names lowercase
# - determine types and make type mapping

import csv
import pandas as pd

def create_valentine_data(databases_path, foreign_key_pairs, valentine_path):

    for foreign_key_pair in foreign_key_pairs:
        if foreign_key_pair["too_large_for_valentine"]:
            continue

        database = foreign_key_pair['database'].replace(".","_")
        table = foreign_key_pair["table"].lower()
        column = foreign_key_pair["column"].lower()
        foreign_table = foreign_key_pair["foreign_table"].lower()
        referred_column = foreign_key_pair["referred_column"].lower()
        dataset_path = f"{valentine_path}/{database}_{table}_{foreign_table}"

        # Create your connection.
        mapping = {"matches": []}
        mapping["matches"].append(
                {
                    "source_table": table+"_source",
                    "source_column": column,
                    "target_table": foreign_table+"_target",
                    "target_column": referred_column
                }
        )

        con = sqlite3.connect(f"{databases_path}/{foreign_key_pair['database']}")

        source_schema = {foreign_key_pair["table_columns"][i].lower(): {"type": foreign_key_pair["table_types"][i].lower()} for i in range(len(foreign_key_pair["table_columns"]))}
        target_schema = {foreign_key_pair["foreign_table_columns"][i].lower(): {"type": foreign_key_pair["foreign_table_types"][i].lower()} for i in range(len(foreign_key_pair["foreign_table_columns"]))}

        try:
            os.makedirs(dataset_path)
        except Exception as e:
            print(e)
            continue

        with open(f"{dataset_path}/{database}_{table}_{foreign_table}_mapping.json", "w") as f:
            json.dump(mapping, f, indent=4)

        with open(f"{dataset_path}/{mapping['matches'][0]['source_table']}.json", "w") as f:
            json.dump(source_schema, f, indent=4)

        with open(f"{dataset_path}/{mapping['matches'][0]['target_table']}.json", "w") as f:
            json.dump(target_schema, f, indent=4)

        source_df = pd.read_sql_query(f"SELECT * FROM '{foreign_key_pair['table']}'", con)
        source_df = source_df.rename(columns=str.lower)
        source_df.to_csv(f"{dataset_path}/{mapping['matches'][0]['source_table']}.csv",index=False,header=True, quoting=csv.QUOTE_MINIMAL)

        target_df = pd.read_sql_query(f"SELECT * FROM '{foreign_key_pair['foreign_table']}'", con)
        target_df = target_df.rename(columns=str.lower)
        target_df.to_csv(f"{dataset_path}/{mapping['matches'][0]['target_table']}.csv",index=False,header=True, quoting=csv.QUOTE_MINIMAL)

In [None]:
create_valentine_data("spider_dbs/", spider_foreign_keys_filtered, "valentine/datasets/spider/")

In [None]:
create_valentine_data("bird_dbs/", bird_foreign_keys_filtered, "valentine/datasets/bird/")

In [None]:
create_valentine_data("ctu_dbs/", ctu_foreign_keys_filtered, "valentine/datasets/ctu/")

In [None]:
# create prompts and reference responses for LLM matcher

import json
spider_foreign_keys_filtered = json.load(open("foreign_keys_filtered_spider.json"))
bird_foreign_keys_filtered = json.load(open("foreign_keys_filtered_bird.json"))
ctu_foreign_keys_filtered = json.load(open("foreign_keys_filtered_ctu.json"))

In [None]:
def generate_prompts_and_responses(foreign_keys_filtered, alphabetically_sorted=False):
    prompts_ground_truth = {}

    question_template = "You are given the following SQL database tables: \n{tables}\nOutput a json string with the following schema {{table, column, referencedTable, referencedColumn}} that contains the foreign key relationship between the two tables."
    for foreign_key_pair in foreign_keys_filtered:
        database = foreign_key_pair['database'].replace(".","_")
        table = foreign_key_pair["table"].lower()
        foreign_table = foreign_key_pair["foreign_table"].lower()
        dataset_name = f"{database}_{table}_{foreign_table}"

        if alphabetically_sorted:
            prompt = question_template.format(tables="\n".join(sorted([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]])))
        else:
            prompt = question_template.format(tables="\n".join([foreign_key_pair["table_string"], foreign_key_pair["foreign_table_string"]]))
        ground_truth = json.dumps({"table": foreign_key_pair["table"].lower(), "column": foreign_key_pair["column"].lower(), "referencedTable": foreign_key_pair["foreign_table"].lower(), "referencedColumn": foreign_key_pair["referred_column"].lower()})
        prompts_ground_truth[dataset_name] = {"prompt": prompt, "foreign_key": ground_truth}

    return prompts_ground_truth

In [None]:
prompts_ground_truth_spider = generate_prompts_and_responses(spider_foreign_keys_filtered)

In [None]:
prompts_ground_truth_bird = generate_prompts_and_responses(bird_foreign_keys_filtered)

In [None]:
prompts_ground_truth_ctu = generate_prompts_and_responses(ctu_foreign_keys_filtered)

In [None]:
with open("prompts_ground_truth_spider.json", "w+") as f:
    json.dump(prompts_ground_truth_spider, f)

In [None]:
with open("prompts_ground_truth_bird.json", "w+") as f:
    json.dump(prompts_ground_truth_bird, f)

In [None]:
with open("prompts_ground_truth_ctu.json", "w+") as f:
    json.dump(prompts_ground_truth_ctu, f)