In [1]:
from data_generator.main import get_real_data, generate_from_real_data
from data_generator.utils import plot_distribution_comparison
from methods.adf.main1 import adf_fairness_testing
from methods.utils import reformat_discrimination_results, convert_to_non_float_rows, compare_discriminatory_groups
import seaborn as sns
from matplotlib import pyplot as plt
import pandas as pd
from tqdm import tqdm


def get_groups(results_df_origin, data_obj, schema):
    non_float_df = convert_to_non_float_rows(results_df_origin, schema)
    predefined_groups_origin = reformat_discrimination_results(non_float_df, data_obj.dataframe)
    nb_elements = sum([el.group_size for el in predefined_groups_origin])
    return predefined_groups_origin, nb_elements

def run_experiment(dataset_name, original_params, synth_params):
    # Get original data
    data_obj, schema = get_real_data(dataset_name, use_cache=True)

    # Run fairness testing on original data
    results_df_origin, metrics_origin = adf_fairness_testing(
        data_obj, **original_params
    )
    predefined_groups_origin, nb_elements = get_groups(results_df_origin, data_obj, schema)

    # Generate and test synthetic data
    data_obj_synth, schema = generate_from_real_data(dataset_name, nb_groups=100, use_cache=True)
    results_df_synth, metrics_synth = adf_fairness_testing(
        data_obj_synth, **synth_params
    )
    predefined_groups_synth, nb_elements_synth = get_groups(results_df_synth, data_obj, schema)

    # Compare results
    comparison_results = compare_discriminatory_groups(predefined_groups_origin, predefined_groups_synth)

    return {
        'metrics_origin': metrics_origin,
        'metrics_synth': metrics_synth,
        'comparison_results': comparison_results,
        'nb_elements': nb_elements,
        'nb_elements_synth': nb_elements_synth
    }


def run_multiple_experiments(dataset_name, original_params, synth_params, num_runs=10):
    results_list = []

    for run in tqdm(range(num_runs), desc=f"Running experiments for {dataset_name}"):
        result = run_experiment(dataset_name, original_params, synth_params)
        metrics = {
            'dataset': dataset_name,
            'run': run,
            'coverage_ratio': result['comparison_results']['coverage_ratio'],
            'matched_groups': result['comparison_results']['total_groups_matched'],
            'total_groups': result['comparison_results']['total_original_groups'],
            'matched_size': result['comparison_results']['total_matched_size'],
            'total_size': result['comparison_results']['total_original_size']
        }
        results_list.append(metrics)

    return pd.DataFrame(results_list)


def plot_experiment_results(results_df):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Coverage ratio boxplot
    sns.boxplot(data=results_df, x='dataset', y='coverage_ratio', ax=axes[0, 0])
    axes[0, 0].set_title('Coverage Ratio Distribution')

    # Matched groups vs total groups
    for dataset in results_df['dataset'].unique():
        dataset_data = results_df[results_df['dataset'] == dataset]
        axes[0, 1].scatter(dataset_data['total_groups'], dataset_data['matched_groups'],
                           label=dataset, alpha=0.6)
    axes[0, 1].plot([0, results_df['total_groups'].max()], [0, results_df['total_groups'].max()],
                    'k--', alpha=0.3)
    axes[0, 1].set_title('Matched vs Total Groups')
    axes[0, 1].legend()

    # Size comparison
    results_df.groupby('dataset')[['matched_size', 'total_size']].mean().plot(
        kind='bar', ax=axes[1, 0])
    axes[1, 0].set_title('Average Matched vs Total Size')

    # Run variation
    for dataset in results_df['dataset'].unique():
        dataset_data = results_df[results_df['dataset'] == dataset]
        axes[1, 1].plot(dataset_data['run'], dataset_data['coverage_ratio'],
                        marker='o', label=dataset)
    axes[1, 1].set_title('Coverage Ratio by Run')
    axes[1, 1].legend()

    plt.tight_layout()
    return fig

from path import HERE
import sqlite3
from datetime import datetime

def init_db():
    conn = sqlite3.connect(HERE.joinpath('experiments/baseline_exp/experiments.db'))
    c = conn.cursor()
    c.execute('''
        CREATE TABLE IF NOT EXISTS experiments (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            dataset TEXT,
            method_name TEXT,
            run INTEGER,
            coverage_ratio REAL,
            matched_groups INTEGER,
            total_groups INTEGER,
            matched_size INTEGER,
            total_size INTEGER,
            timestamp DATETIME
        )
    ''')
    conn.commit()
    return conn

def save_experiment(conn, dataset, method_name, run, result):
    cursor = conn.cursor()
    cursor.execute('''
        INSERT INTO experiments (
            dataset, method_name, run, coverage_ratio, matched_groups,
            total_groups, matched_size, total_size, timestamp
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
    ''', (
        dataset,
        method_name,
        run+1,
        result['comparison_results']['coverage_ratio'],
        result['comparison_results']['total_groups_matched'],
        result['comparison_results']['total_original_groups'],
        result['comparison_results']['total_matched_size'],
        result['comparison_results']['total_original_size'],
        str(datetime.now())
    ))
    conn.commit()


Intel(R) Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)


In [2]:
# Run experiments
datasets = ['bank']
method_name = 'adf'
conn = init_db()
all_results = []

original_params = {'max_global': 5000, 'max_local': 2000, 'max_iter': 10, 'cluster_num': 50, 'max_runtime_seconds': 300}
synth_params = {'max_global': 5000, 'max_local': 2000, 'max_iter': 10, 'cluster_num': 50, 'max_runtime_seconds': 300}

for dataset in datasets:
    for run in tqdm(range(1), desc=f"Running experiments for {dataset}"):
        result = run_experiment(dataset, original_params, synth_params)
        save_experiment(conn, dataset, method_name, run, result)

# results_df = pd.concat(all_results)

Running experiments for bank:   0%|          | 0/1 [00:00<?, ?it/s]2025-03-01 15:32:37 - SingleTableSynthesizer - INFO - {'EVENT': 'Instance', 'TIMESTAMP': datetime.datetime(2025, 3, 1, 15, 32, 37, 476209), 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', 'SYNTHESIZER ID': 'GaussianCopulaSynthesizer_1.18.0_590770198fc942a3969a51b395d84dff'}
2025-03-01 15:32:37 - SingleTableSynthesizer - INFO - {'EVENT': 'Fit', 'TIMESTAMP': datetime.datetime(2025, 3, 1, 15, 32, 37, 477283), 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', 'SYNTHESIZER ID': 'GaussianCopulaSynthesizer_1.18.0_590770198fc942a3969a51b395d84dff', 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': 45211, 'TOTAL NUMBER OF COLUMNS': 17}
2025-03-01 15:32:37 - sdv.data_processing.data_processor - INFO - Fitting table  metadata
2025-03-01 15:32:37 - sdv.data_processing.data_processor - INFO - Fitting formatters for table 
2025-03-01 15:32:37 - sdv.data_processing.data_processor - INFO - Fitting constraints for table 


Fitting GaussianCopulaSynthesizer...


2025-03-01 15:32:41 - SingleTableSynthesizer - INFO - {'EVENT': 'Fit processed data', 'TIMESTAMP': datetime.datetime(2025, 3, 1, 15, 32, 41, 362375), 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', 'SYNTHESIZER ID': 'GaussianCopulaSynthesizer_1.18.0_590770198fc942a3969a51b395d84dff', 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': 45211, 'TOTAL NUMBER OF COLUMNS': 17}
2025-03-01 15:32:41 - copulas.multivariate.gaussian - INFO - Fitting GaussianMultivariate(distribution="{'Attr1_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr2_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr3_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr4_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr5_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr6_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr7_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr8_X': <class

Pre-computing attribute combinations...



Attribute combinations: 0it [00:00, ?it/s][A
Attribute combinations: 49it [00:00, 479.32it/s][A
Attribute combinations: 98it [00:00, 484.13it/s][A
Attribute combinations: 148it [00:00, 487.22it/s][A
Attribute combinations: 198it [00:00, 489.60it/s][A
Attribute combinations: 248it [00:00, 491.20it/s][A
Attribute combinations: 298it [00:00, 490.99it/s][A
Attribute combinations: 348it [00:00, 491.13it/s][A
Attribute combinations: 416it [00:00, 488.35it/s][A


Processing valid cases...



Processing case pairs:   0%|          | 0/789 [00:00<?, ?it/s][A
Processing case pairs:   4%|▍         | 31/789 [00:00<00:02, 302.71it/s][A
Processing case pairs:   9%|▉         | 70/789 [00:00<00:02, 347.64it/s][A
Processing case pairs:  14%|█▎        | 108/789 [00:00<00:01, 362.16it/s][A
Processing case pairs:  18%|█▊        | 145/789 [00:00<00:01, 361.07it/s][A
Processing case pairs:  24%|██▍       | 191/789 [00:00<00:01, 391.43it/s][A
Processing case pairs:  29%|██▉       | 231/789 [00:00<00:01, 390.86it/s][A
Processing case pairs:  34%|███▍      | 271/789 [00:00<00:01, 366.88it/s][A
Processing case pairs:  41%|████      | 322/789 [00:00<00:01, 404.90it/s][A
Processing case pairs:  48%|████▊     | 379/789 [00:00<00:00, 449.50it/s][A
Processing case pairs:  54%|█████▍    | 425/789 [00:01<00:00, 446.79it/s][A
Processing case pairs:  60%|█████▉    | 470/789 [00:01<00:00, 436.90it/s][A
Processing case pairs:  66%|██████▌   | 518/789 [00:01<00:00, 445.93it/s][A
Processing 

Fitting GaussianCopulaSynthesizer...


2025-03-01 15:38:22 - SingleTableSynthesizer - INFO - {'EVENT': 'Fit processed data', 'TIMESTAMP': datetime.datetime(2025, 3, 1, 15, 38, 22, 537226), 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', 'SYNTHESIZER ID': 'GaussianCopulaSynthesizer_1.18.0_d956c0796dde45f18c98e8d272d49195', 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': 45211, 'TOTAL NUMBER OF COLUMNS': 17}
2025-03-01 15:38:22 - copulas.multivariate.gaussian - INFO - Fitting GaussianMultivariate(distribution="{'Attr1_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr2_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr3_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr4_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr5_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr6_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr7_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr8_X': <class

Training GaussianCopulaSynthesizer on bank dataset...


2025-03-01 15:38:29 - SingleTableSynthesizer - INFO - {'EVENT': 'Fit processed data', 'TIMESTAMP': datetime.datetime(2025, 3, 1, 15, 38, 29, 190957), 'SYNTHESIZER CLASS NAME': 'GaussianCopulaSynthesizer', 'SYNTHESIZER ID': 'GaussianCopulaSynthesizer_1.18.0_d8ae9c9eb1314e4a81194262b0f49214', 'TOTAL NUMBER OF TABLES': 1, 'TOTAL NUMBER OF ROWS': 45211, 'TOTAL NUMBER OF COLUMNS': 17}
2025-03-01 15:38:29 - copulas.multivariate.gaussian - INFO - Fitting GaussianMultivariate(distribution="{'Attr1_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr2_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr3_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr4_T': <class 'copulas.univariate.beta.BetaUnivariate'>, 'Attr5_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr6_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr7_X': <class 'copulas.univariate.truncated_gaussian.TruncatedGaussian'>, 'Attr8_X': <class

Pre-computing attribute combinations...



Attribute combinations: 0it [00:00, ?it/s][A
Attribute combinations: 50it [00:00, 496.66it/s][A
Attribute combinations: 100it [00:00, 488.39it/s][A
Attribute combinations: 150it [00:00, 491.92it/s][A
Attribute combinations: 201it [00:00, 496.99it/s][A
Attribute combinations: 251it [00:00, 495.40it/s][A
Attribute combinations: 301it [00:00, 496.86it/s][A
Attribute combinations: 352it [00:00, 497.86it/s][A
Attribute combinations: 404it [00:00, 503.18it/s][A
Attribute combinations: 455it [00:00, 500.02it/s][A
Attribute combinations: 506it [00:01, 498.03it/s][A
Attribute combinations: 556it [00:01, 485.32it/s][A
Attribute combinations: 605it [00:01, 479.60it/s][A
Attribute combinations: 670it [00:01, 487.46it/s][A


Processing valid cases...



Processing case pairs:   0%|          | 0/530 [00:00<?, ?it/s][A
Processing case pairs:   8%|▊         | 44/530 [00:00<00:01, 430.25it/s][A
Processing case pairs:  17%|█▋        | 88/530 [00:00<00:01, 416.23it/s][A
Processing case pairs:  26%|██▌       | 137/530 [00:00<00:00, 445.05it/s][A
Processing case pairs:  36%|███▌      | 191/530 [00:00<00:00, 479.92it/s][A
Processing case pairs:  46%|████▌     | 244/530 [00:00<00:00, 497.11it/s][A
Processing case pairs:  55%|█████▌    | 294/530 [00:00<00:00, 491.83it/s][A
Processing case pairs:  65%|██████▍   | 344/530 [00:00<00:00, 494.36it/s][A
Processing case pairs:  76%|███████▌  | 404/530 [00:00<00:00, 526.80it/s][A
Processing case pairs:  88%|████████▊ | 466/530 [00:00<00:00, 555.50it/s][A
Processing case pairs: 100%|██████████| 530/530 [00:01<00:00, 500.85it/s][A
Running experiments for bank: 100%|██████████| 1/1 [12:38<00:00, 758.95s/it]
