In [None]:
import json
import numpy as np
import pandas as pd
import os
import shutil
import time


from midst_models.single_table_TabDDPM.complex_pipeline import (
    clava_clustering,
    clava_training,
    clava_load_pretrained,
    clava_synthesizing,
    load_configs,
)
from midst_models.single_table_TabDDPM.pipeline_modules import load_multi_table

AUX_DIRS= ['/home/hadrien/Documents/Phd/ensemble-mia/data/tabddpm_black_box/train','/home/hadrien/Documents/Phd/ensemble-mia/data/tabsyn_black_box/train']


def deduplicate_transactions(df):
    return df.drop_duplicates(subset=['trans_id','account_id'])
def load_aux_dataframes(train_dir):
    # Load all train_with_id CSV files from train directory
    train_dfs = []
    for train_dir in AUX_DIRS:
        for sub_folder in os.listdir(train_dir):
            file_path = os.path.join(train_dir, sub_folder, 'train_with_id.csv')
            df = pd.read_csv(file_path)
            train_dfs.append(df)
    
    return deduplicate_transactions(pd.concat(train_dfs, ignore_index=True))

full_table = load_aux_dataframes(AUX_DIRS)

MODEL_NAME= 'tabddpm' #'tabsyn'
TRAIN_DIR= f'/home/hadrien/Documents/Phd/ensemble-mia/data/{MODEL_NAME}_black_box/train'
CHALLENGE_DIR_DEV= f'/home/hadrien/Documents/Phd/ensemble-mia/data/{MODEL_NAME}_black_box/dev'
CHALLENGE_DIR_FINAL= f'/home/hadrien/Documents/Phd/ensemble-mia/data/{MODEL_NAME}_black_box/final'

all_challenge_dirs= [CHALLENGE_DIR_DEV, CHALLENGE_DIR_FINAL]#, TRAIN_DIR]

all_challenge_tables= []

for challenge_dir in all_challenge_dirs:
    for sub_folder in os.listdir(challenge_dir):
        file_path = os.path.join(challenge_dir, sub_folder, 'challenge_with_id.csv')
        df = pd.read_csv(file_path)
        all_challenge_tables.append(df)

all_challenge= pd.concat(all_challenge_tables, axis=0)
# Drop all observations from full_table that appear in all_challenge based on account_id and trans_id
full_table = full_table.drop_duplicates(subset=['account_id', 'trans_id'])
print(full_table.shape)
full_table = full_table[~full_table.set_index(['account_id', 'trans_id']).index.isin(
    all_challenge.set_index(['account_id', 'trans_id']).index
)]
print(full_table.shape)




full_table.to_csv(os.path.join('/home/hadrien/Documents/Phd/ensemble-mia/data/shadow_ref','full_table.csv'), index=False)
all_challenge.to_csv(os.path.join('/home/hadrien/Documents/Phd/ensemble-mia/data/shadow_ref','all_challenge_tables.csv'), index=False)

config_file_path= 'midst_models/single_table_TabDDPM/configs'


for i in range(3,11):
    workspace= f'{TRAIN_DIR}/tabddpm_20{i}'
    # Create the new folder if it doesn't exist
    os.makedirs(workspace, exist_ok=True)

    df_all_challenge= pd.read_csv(os.path.join('/home/hadrien/Documents/Phd/ensemble-mia/data/shadow_ref','all_challenge_tables.csv'))
    full_table= pd.read_csv(os.path.join('/home/hadrien/Documents/Phd/ensemble-mia/data/shadow_ref','full_table.csv'))

    challenge_sample= df_all_challenge.sample(n=2   000, replace=False)
    aux_data= full_table.sample(n=18000, replace=False)

    df_train= pd.concat([challenge_sample, aux_data], axis=0)

    df_train.to_csv(os.path.join(workspace,'train_with_id.csv'), index=False)
    df_train.drop(columns=['account_id', 'trans_id']).to_csv(os.path.join(workspace,'train.csv'), index=False)


    # Copy the original config file to the new folder
    shutil.copy(os.path.join(config_file_path, "trans.json"), workspace)
    shutil.copy(os.path.join(config_file_path, "dataset_meta.json"), workspace)
    shutil.copy(os.path.join(config_file_path, "trans_domain.json"), workspace)

    # Modify the config file
    with open(os.path.join(workspace, "trans.json"), "r") as file:
        trans_config = json.load(file)

    trans_config["general"]["data_dir"] = str(workspace)
    trans_config["general"]["workspace_dir"] = str(workspace)
    trans_config["general"]["test_data_dir"] = ""


    # Save the changed
    with open(os.path.join(workspace, "trans.json"), "w") as file:
        json.dump(trans_config, file, indent=4)

    # Load config
    config_path = f"{workspace}/trans.json"
    configs, save_dir = load_configs(config_path)

    # Display config
    json_str = json.dumps(configs, indent=4)
    print(json_str)

    # Load  dataset
    # In this step, we load the dataset according to the 'dataset_meta.json' file located in the data_dir.
    tables, relation_order, dataset_meta = load_multi_table(configs["general"]["data_dir"])
    print("")

    # Tables is a dictionary of the multi-table dataset
    print(
        "{} We show the keys of the tables dictionary below {}".format("=" * 20, "=" * 20)
    )
    print(list(tables.keys()))

    # Display important clustering parameters
    params_clustering = configs["clustering"]
    print("{} We show the clustering parameters below {}".format("=" * 20, "=" * 20))
    for key, val in params_clustering.items():
        print(f"{key}: {val}")
    print("")

    # Clustering on the multi-table dataset
    tables, all_group_lengths_prob_dicts = clava_clustering(
        tables, relation_order, save_dir, configs
    )

    # Display important sampling parameters
    params_sampling = configs["diffusion"]
    print(
        "{} We show the important sampling parameters below {}".format("=" * 20, "=" * 20)
    )
    for key, val in params_sampling.items():
        print(f"{key}: {val}")
    print("")

    # Launch training from scratch

    t= time.time()
    models = clava_training(tables, relation_order, save_dir, configs)
    print(f"Training time: {time.time() - t} seconds")

    # Display important sampling parameters
    params_sampling = configs["sampling"]
    print(
        "{} We show the important sampling parameters below {}".format("=" * 20, "=" * 20)
    )
    for key, val in params_sampling.items():
        print(f"{key}: {val}")
    print("")

    cleaned_tables, synthesizing_time_spent, matching_time_spent = clava_synthesizing(
        tables,
        relation_order,
        save_dir,
        all_group_lengths_prob_dicts,
        models,
        configs,
        sample_scale=1 if "debug" not in configs else configs["debug"]["sample_scale"],
    )

    # Cast int values that saved as string to int for further evaluation
    for key in cleaned_tables.keys():
        for col in cleaned_tables[key].columns:
            if cleaned_tables[key][col].dtype == "object":
                try:
                    cleaned_tables[key][col] = cleaned_tables[key][col].astype(int)
                except ValueError:
                    print(f"Column {col} cannot be converted to int.")
    # ajouter une commande pour copier les donnes synth dans le bon dossier
    #shutil.copy(os.path.join(save_dir, "trans","final","trans_synthetic.csv"), os.path.join(workspace, "trans_synthetic.csv"))

(722471, 10)
(715884, 10)
{
    "general": {
        "data_dir": "/home/hadrien/Documents/Phd/ensemble-mia/data/tabddpm_black_box/train/tabddpm_203",
        "exp_name": "train_1",
        "workspace_dir": "/home/hadrien/Documents/Phd/ensemble-mia/data/tabddpm_black_box/train/tabddpm_203",
        "sample_prefix": "",
        "test_data_dir": ""
    },
    "clustering": {
        "parent_scale": 1.0,
        "num_clusters": 50,
        "clustering_method": "both"
    },
    "diffusion": {
        "d_layers": [
            512,
            1024,
            1024,
            1024,
            1024,
            512
        ],
        "dropout": 0.0,
        "num_timesteps": 2000,
        "model_type": "mlp",
        "iterations": 200000,
        "batch_size": 4096,
        "lr": 0.0006,
        "gaussian_loss_type": "mse",
        "weight_decay": 1e-05,
        "scheduler": "cosine"
    },
    "classifier": {
        "d_layers": [
            128,
            256,
            512,
      