In [None]:
import os
import json
import pandas as pd
import sqlite3
import numpy as np
import csv
import shutil

In [None]:
DB_METADATAS_PATH="Spider2/spider2-lite/resource/databases/sqlite/"
#You need to download the sqlite dbs manually and unzip them here
DATABASES_PATH="Spider2/spider2-lite/resource/databases/spider2-localdb/" 
DB_CSVS_BASE_PATH="Spider2/spider2-lite/resource/databases/csv_dbs"
EVALUATION_SET_PATH="Spider2/spider2-lite/spider2-lite.jsonl"
SPIDER2_LOCAL_DB_LINK = "https://drive.usercontent.google.com/download?id=1coEVsCZq-Xvj9p2TnhBFoFTsY-UoYGmG&export=download&authuser=0&confirm=t&uuid=e4894821-9b03-4a4a-b574-9e931c7f6497&at=AEz70l4CupjM1wWNkGFVtYAST2Xs%3A1743423729461"
GOLD_RESULT_PATH = "Spider2/spider2-lite/evaluation_suite/gold/exec_result"

In [None]:
def get_local_tasks(tasks_jsonl=None):
    tasks_jsonl = tasks_jsonl or EVALUATION_SET_PATH
    local_tasks = []
    with open(EVALUATION_SET_PATH, "r", encoding="utf-8") as file:
        for line in file:
            record = json.loads(line)
            if record.get("instance_id", "").startswith("local"):
                local_tasks.append(record)
    return local_tasks
    

In [None]:
def get_relevant_databases():
    with open(os.path.join(DATABASES_PATH, "local-map.jsonl")) as f:
        task_db_map = json.load(f)
    return set(task_db_map.values())
    # local_tasks = get_local_tasks()
    # relevant_dbs = set(task.get("db") for task in local_tasks)
    # return relevant_dbs

def get_task_expected_output(task_id, rows_limit=10, as_dict=False):
    expected_output_files = []
    result_mapping = {}
    for file_name in os.listdir(GOLD_RESULT_PATH):
        if file_name.startswith(task_id) and file_name.endswith(".csv"):
            file_path = os.path.join(GOLD_RESULT_PATH, file_name)
            table_name = os.path.splitext(file_name)[0]
            df = pd.read_csv(file_path)
            expected_output_files.append(file_name)
            if rows_limit >0:
                result_mapping[table_name] = df.head(10)
            if as_dict:
                result_mapping[table_name] = result_mapping[table_name].to_dict(orient='records')
    return result_mapping
    # with open(EVALUATION_SET_PATH, "r", encoding="utf-8") as file:
    #     for line in file:
    #         record = json.loads(line)
    #         if record.get("instance_id", "").startswith("local"):
    #             db_set.add(record.get("db"))
    # return db_set

In [None]:
get_task_expected_output("local002")

In [None]:
def dump_tables_to_csv(folders_path, dbs_path, output_folder, force=False, db_set=None):
    """
    Reads JSON files describing database tables, fetches data, and exports to CSV.
    
    :param folders_path: Path to folder containing JSON files.
    :param db_path: Path to SQLite database file.
    :param output_folder: Path to folder where CSV files will be saved.
    """
    with open(os.path.join(DATABASES_PATH, "local-map.jsonl")) as f:
        task_db_map = json.load(f)
    for db_name in set(task_db_map.values()):
        print(db_name)
        if db_set and db_name not in db_set:
            continue
        db_folder_path = os.path.join(folders_path, db_name)
        db_output_path = os.path.join(output_folder, db_name)
        if force and os.path.exists(db_output_path):
            shutil.rmtree(db_output_path)
        os.makedirs(db_output_path, exist_ok=True)
    
        db_path = os.path.join(dbs_path, f"{db_name}.sqlite")    
        os.makedirs(db_folder_path, exist_ok=True)
        
        # Connect to the SQLite database
        print(db_path)
        conn = sqlite3.connect(db_path)
        try:
            for filename in os.listdir(db_folder_path):
                if filename.endswith(".json"):
                    file_path = os.path.join(db_folder_path, filename)
                    with open(file_path, "r", encoding="utf-8") as f:
                        table_metadata = json.load(f)
                    
                    table_name = table_metadata["table_name"]
                    output_csv = os.path.join(db_output_path, f"{table_name}.csv")
                    
                    # Fetch data from the table
                    query = f"SELECT * FROM {table_name};"
                    df = pd.read_sql_query(query, conn)
                    
                    # Save to CSV
                    df.to_csv(output_csv, index=False)
                    print(f"Exported {table_name} to {output_csv}")
        finally:
            conn.close()
            print("Database connection closed.")


In [None]:
def get_all_data_types(folders_path):
    datatypes_set = set()
    db_set = get_relevant_databases()
    for db_name in os.listdir(folders_path):
        if db_set and db_name not in db_set:
            continue
        db_folder_path = os.path.join(folders_path, db_name)            
        try:
            for filename in os.listdir(db_folder_path):
                if filename.endswith(".json"):
                    file_path = os.path.join(db_folder_path, filename)
                    with open(file_path, "r", encoding="utf-8") as f:
                        table_metadata = json.load(f)
                    
                    datatypes_set.update(x.lower() for x in table_metadata.get("column_types", []))
        except Exception as e:
            print(e)
            return
    return datatypes_set

In [None]:
def get_table_metadata(database_name, table_name):
    table_metadata_file = os.path.join(DB_METADATAS_PATH, database_name, f"{table_name}.json")
    try:
        with open(table_metadata_file, "r", encoding="utf-8") as f:
           table_metadata = json.load(f)
           return table_metadata
    except Exception as e:
        print(e)
        return None

In [None]:
def get_database_table_names_list(database_name):
    db_folder_path = os.path.join(DB_METADATAS_PATH, database_name)
    table_names = []
    try:
        for filename in os.listdir(db_folder_path):
            if filename.endswith(".json"):
                file_path = os.path.join(db_folder_path, filename)
                with open(file_path, "r", encoding="utf-8") as f:
                    table_metadata = json.load(f)
                    table_name = table_metadata["table_name"]
                    table_names.append(table_name)
        return table_names
    except Exception as e:
        print(e)
        return None

In [None]:
def get_table_dtype_map(database_name, table_name, sql=False):
    table_metadata = get_table_metadata(database_name, table_name)
    dtype_map = dict(
        zip(
            table_metadata.get("column_names", []),
            table_metadata.get("column_types", []),
        )
    )
    if sql:
        return dtype_map
    else:
        try:
            with open("spider_dtype_mappings.json", "r") as f:
                mappings = json.load(f)
                dtype_map = {k: mappings.get(v.lower(), "object") for k, v in dtype_map.items()}
                return dtype_map
        except Exception as e:
            print(e)
            return None

def get_database_dtype_map(database_name, sql=False):
    table_names = get_database_table_names_list(database_name=database_name)
    if not table_names: return {}
    dtype_map = {
        tname: get_table_dtype_map(database_name, tname, sql=sql) for tname in table_names
    }
    return dtype_map
    

In [None]:
def read_ground_truth_sql(task_id):
    sql_path = f"Spider2/spider2-lite/evaluation_suite/gold/sql/{task_id}.sql"
    if os.path.exists(sql_path):
        print(sql_path)
        with open(sql_path, "r") as f:
            return f.read()
    else:
        return None

In [None]:
get_database_dtype_map("E_commerce")

In [None]:
def load_csv_database(database_name, rows_limit=10, as_dict=False):
    """
    Load a CSV-dumped database into a dictionary where each key is a table name and the value is a pandas DataFrame.

    :param database_path: Path to the directory containing the CSV files representing the database.
    :return: A dictionary with table names as keys and pandas DataFrames as values.
    """
    path1 = os.path.join(DB_CSVS_BASE_PATH, database_name)
    path2 = path1.replace("-", "_")
    path3 = path1.replace("_", "-")
    path = [x for x in [path1, path2, path3] if os.path.exists(x)]
    if path:
        database_path = path[0]
    else:
        print("Failed to get database")
        return None
    tables = {}
    for file_name in os.listdir(database_path):
        if file_name.endswith(".csv"):
            table_name = os.path.splitext(file_name)[0]
            file_path = os.path.join(database_path, file_name)
            dtypes = get_table_dtype_map(database_name, table_name)
            tables[table_name] = pd.read_csv(file_path,)# dtype=dtypes
            if rows_limit >= 0:
                tables[table_name] = tables[table_name].iloc[:rows_limit]
            if as_dict:
                tables[table_name] = tables[table_name].to_dict(orient='records')
    return tables

In [None]:
from collections import defaultdict

def make_header(db_name, work_dir=""):
    header = f"""
import pandas as pd

{db_name} = dict()
for table, table_data in load_csv_database('{db_name}', rows_limit=-1).items():
    {db_name}[table] = pd.DataFrame(table_data)
OUTPUT_DIR = f"{work_dir}/output.csv"
"""
    return header

def make_llm_etl_dataset_df(local_task_list):
    """
    Create a DataFrame containing ETL dataset information.

    :param local_task_list: List of local tasks.
    :return: A pandas DataFrame with the ETL dataset.
    """
    with open(os.path.join(DATABASES_PATH, "local-map.jsonl")) as f:
        task_db_map = json.load(f)
    with open("Spider2/spider2-lite/evaluation_suite/gold/spider2lite_eval.jsonl", "r", encoding="utf-8") as eval_file:
        groundtruth_data = [json.loads(line) for line in eval_file]
    groundtruth_data = {task.get("instance_id"): task for task in groundtruth_data if task.get("instance_id", "").startswith("local")}
    
    notebook_dict = defaultdict(int)
    # fieldnames = ["nb_name", "work_dir", "nb_header", "intent_number", "intent", "code", "inputs", "outputs", "d_types", "db_name"]
    dataset = []
    for id, task in enumerate(local_task_list):
        spider_task_id = task["instance_id"]
        gt_data = groundtruth_data.get(spider_task_id, {})
        db_name = task_db_map[spider_task_id]
        work_dir = f"dataset_{db_name}/notebook_{notebook_dict[db_name]}"
        nb_name = f"{work_dir}/annotated.ipynb"
        notebook_dict[db_name] += 1
        task_data = {
            "spider_task_id": spider_task_id,
            "nb_name": nb_name,
            "work_dir": work_dir,
            "nb_header": "",
            "intent_number": id,
            "intent": task['question'],
            "code": "-1",
            "inputs": load_csv_database(task['db'], 10, as_dict=True),
            "outputs": get_task_expected_output(task['instance_id'], rows_limit=10, as_dict=True),
            "d_types": get_database_dtype_map(db_name),
            "db_name": task['db'],
            "ground_truth_sql": read_ground_truth_sql(task),
            "condition_cols": gt_data.get("condition_cols", None),
            "ignore_order": gt_data.get("ignore_order", True),
            "toks": gt_data.get("toks", None),
            "external_knowledge": task["external_knowledge"],
        }
        dataset.append(task_data)
    return pd.DataFrame(dataset)
    

In [None]:
if __name__ == "__main__":
    db_set = get_relevant_databases()
    local_task_list = get_local_tasks()
    dump_tables_to_csv(folders_path=DB_METADATAS_PATH,
                   dbs_path=DATABASES_PATH,
                   output_folder=DB_CSVS_BASE_PATH,
                   force=False,
                   db_set=db_set)
    llm_etl_df = make_llm_etl_dataset_df(local_task_list)
    this_df = llm_etl_df[llm_etl_df["d_types"] != {}]
    # llm_etl_df.to_pickle("datasets/spider2.single_intent.pickle")
    # llm_etl_df.to_json("datasets/spider2.single_intent.jsonl", orient='records', lines=True)

In [None]:
llm_etl_df.to_csv("datasets/spider2.single_intent.csv", index=False)
llm_etl_df.to_json("datasets/spider2.single_intent.jsonl", orient='records', lines=True)
llm_etl_df.to_pickle("datasets/spider2.single_intent.pickle")