# ClavaDDPM: Multi-relational Data Synthesis with Cluster-guided Diffusion Models

Recent tabular data synthesis research has focused on single tables, while real-world applications often involve complex, interconnected tables. Existing methods for multi-relational data synthesis struggle with scalability and long-range dependencies. This paper introduces Cluster Latent Variable guided Denoising Diffusion Probabilistic Models (ClavaDDPM), using clustering labels to model inter-table relationships, particularly foreign key constraints. ClavaDDPM efficiently propagates latent variables across tables, capturing long-range dependencies. Evaluations show ClavaDDPM outperforms existing methods on multi-table data and remains competitive for single-table data.

In the following sections, we will delve deeper into the implementation of this method.

## Imports and Setup

In this section, we import all necessary libraries and modules for setting up the environment. This includes libraries for logging, argument parsing, file path management, and configuration loading. We also import essential packages for data loading, model creation, and training, such as PyTorch and numpy, along with custom modules specific to the ClavaDDPM.

In [None]:
import os
import shutil
import json

from complex_pipeline import clava_clustering, clava_training, clava_synthesizing, clava_eval, load_configs
from pipeline_modules import load_multi_table
from gen_single_report import gen_single_report
from report_utils import get_multi_metadata

## Load Configuration

In this section, we establish the setup for model training by loading the configuration file, which includes the necessary parameters and settings for the training process. The configuration file, stored in `json` format, is read and parsed into a dictionary.

In [None]:
# Load config
### DEBUG: use debug configs to obtain the results faster ###
# config_path = 'configs/debug_movie_lens.json'
### DEBUG: use debug configs to obtain the results faster ###

config_path = 'configs/california_20_debug.json'
configs, save_dir = load_configs(config_path)

# Display config
json_str = json.dumps(configs, indent=4)
print(json_str)

# Data Loading and Preprocessing

In this section, we load and preprocess the dataset based on the configuration settings. We demonstrate the dataset's metadata and parent-child relationships to provide a clearer understanding of its structure. Following this, we perform clustering to preprocess the data, facilitating the training process for the ClavaDDPM model.

In [None]:
# Load multi-table dataset
tables, relation_order, dataset_meta = load_multi_table(configs['general']['data_dir'])

# Tables is a dictionary of the multi-table dataset
print("{} We show the keys of the tables dictionary below {}".format("="*20, "="*20))
print(tables.keys())

# Relation order is the topological order of the multi-table dataset
print("{} We show the relation order below {}".format("="*20, "="*20))
print(relation_order)

# Visualize the parent-child relationship within the multi-table dataset
multi_meta = get_multi_metadata(tables, relation_order)
multi_meta.visualize()

This paper introduces relation-aware clustering to model parent-child constraints and leverages diffusion models for controlled tabular data synthesis. Specifically, Gaussian Process Latent Variable Models (GPLVM) are used to discover low-dimensional manifolds in noisy, high-dimensional spaces. We run the clustering algorithm below to preprocess the data for training the ClavaDDPM model. Additionally, we empirically determine the distribution of table sizes in the dataset, which will be used in the later sampling process.

In [None]:
# Clustering on the multi-table dataset
# updates the tables dictionary with augmented tables and computes group size distributions
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, save_dir, configs)

## Model Training

<img src="assets/clavaDDPM.png" alt="ClavaDDPM Model Pipeline" width="960"/>

This section outlines the training process for the ClavaDDPM model. The diagram above, taken from the original paper, illustrates the main steps: (a) latent learning and table augmentation (steps 1-2), (b) training (steps 3-5), and (c) synthesis (steps 6-8). Specifically, the clustering process corresponds to latent learning and is used for table augmentation during training. The training code below trains the conditional diffusion models and the cluster classifier models. Subsequently, we implement the generation process, starting with sampling the table size and conducting conditional generation to satisfy the parent-child constraints (i.e., relation order).

Additionally, various callbacks are configured to monitor and save the model during training. The trainer is implemented using a custom PyTorch function, specifying parameters such as the number of epochs and checkpoints. The training process is then initiated, logging progress and completing the model's training.

In [None]:
# Relation order is the topological order of the multi-table dataset
print("{} We show the relation order again, each line indicates one conditional generative model {}".format("="*20, "="*20))
print(relation_order)

# Launch training or use the pre-trained models
tables, models = clava_training(tables, relation_order, save_dir, configs)

# if the training process takes too long, please run the following command to load pre-trained models and samples
# !unzip -o clavaDDPM_workspace.zip 

# Model Sampling

Upon completion of the training, the model is evaluated using the test dataset. To assess the model's performance, we first use the `clava_synthesizing` function to generate synthetic samples and showcase the results qualitatively. We initiate the generation process by sampling the table size (i.e., the number of rows per table) and performing conditional generation to meet the parent-child constraints (i.e., relation order). Quantitative evaluations will be conducted in the next section.

In [None]:
cleaned_tables, synthesizing_time_spent, matching_time_spent = clava_synthesizing(
    tables, 
    relation_order, 
    save_dir, 
    all_group_lengths_prob_dicts, 
    models,
    configs,
    sample_scale=1 if not 'debug' in configs else configs['debug']['sample_scale']
)

# Model Evaluation

In this step, we quantitatively evaluate the generated tabular data by computing metrics to determine the accuracy of the predictions, specifically assessing how closely the generated data matches the observed samples in the training set.

In particular, the critical multi-table metrics are as follows:

1. Pair-wise column correlation (k-hop): This metric measures the correlations between columns from tables at a distance k (e.g., 0-hop for columns within the same table, 1-hop for a column and a column from its parent or child table).

2. Average 2-way: This metric computes the average of all k-hop column-pair correlations, taking into account both short-range (k = 0) and longer-range (k > 0) dependencies.

Please refer to the `ts-diff` notebook for the other metrics.



In [None]:
# Multi-table Evaluation
report = clava_eval(tables, save_dir, configs, relation_order, cleaned_tables)

In [None]:
# Print out the multi-table metrics
for key, val in report.items():
    if key in ['hop_relation', 'avg_scores', 'all_avg_score']:
        if isinstance(val, dict):
            print("{:20}".format(key))
            for k, v in val.items():
                print("{:20}: {}".format(k, v))
        else:
            print("{:20}: {}".format(key, val))

In [None]:
# Prepare the synthetic data for single-table metric evaluation
shutil.copy(os.path.join(configs['general']['data_dir'], 'dataset_meta.json'), os.path.join(save_dir, 'dataset_meta.json'))

for table_name in tables.keys():
    shutil.copy(os.path.join(configs['general']['workspace_dir'], table_name, '_final', f'{table_name}_synthetic.csv'), os.path.join(save_dir, f'{table_name}.csv'))
    shutil.copy(os.path.join(configs['general']['data_dir'], f'{table_name}_domain.json'), os.path.join(save_dir, f'{table_name}_domain.json'))

test_tables, _, _ = load_multi_table(save_dir)
real_tables, _, _ = load_multi_table(configs['general']['data_dir'])

# Single table metrics
for table_name in tables.keys():
    print(f'Generating report for {table_name}')
    real_data = real_tables[table_name]['df']
    syn_data = cleaned_tables[table_name]
    domain_dict = real_tables[table_name]['domain']

    if configs['general']['workspace_dir'] is not None:
        test_data = test_tables[table_name]['df']
    else:
        test_data = None

    gen_single_report(
        real_data, 
        syn_data,
        domain_dict,
        table_name,
        save_dir,
        alpha_beta_sample_size=200_000,
        test_data=test_data
    )

## References

**Pang, Wei, et al.** "ClavaDDPM: Multi-relational Data Synthesis with Cluster-guided Diffusion Models." *preprint* (2024).

**GitHub Repository:** [ClavaDDPM](https://github.com/weipang142857/ClavaDDPM)