In [54]:
import sys 
import pandas as pd
import os
import numpy as np
import copy
from itertools import product


%load_ext autoreload

# Enable autoreload
%autoreload 2
import matplotlib.pyplot as plt
import seaborn as sns


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
from evaluation_framework.utils_eval import generate_train_test_splits, generate_synthetic_datasets, synthetic_results, real_results
from run_experiment import generate_folders

from tabular_datasets.dataset import Dataset

dataset_name = "adult"
generative_methods = ['tabfairgdt', 'tabular_argn', 'tab_fair_gan', 'cuts', 'fsmote', 'prefair']
synthetic_df_colors = ['blue', 'orange',  'red',  'green', 'brown', 'yellow']

no_lamda_methods = ['cuts', 'fsmote', 'prefair']
results_folder = "results"
split = 1
lamda = 1.0
synthetic_dfs = []


for generative_method in generative_methods:

    synthetic_path, results_path_avg, results_path_std = generate_folders(dataset_name, generative_method, results_folder)

    if generative_method in no_lamda_methods:
        experiment_name = f"{generative_method}_0.json"
    else:
        experiment_name = f"{generative_method}_lamda_{100*lamda:.0f}_0.json"


    file_path = f'{synthetic_path}/split_{split}/{experiment_name}'
    print(file_path)

    synthetic_data = pd.read_json(file_path, orient='records', lines=True)
    synthetic_dfs.append(synthetic_data)

synthetic_df_names = generative_methods

dataset_loader = Dataset(dataset_name)

real_split_file_path = f'tabular_datasets/{dataset_name}/train_test_splits/train/train_{split}.json'
real_split_data = pd.read_json(real_split_file_path, orient='records', lines=True)


dtype_map = dataset_loader.dtype_map

results/adult/tabfairgdt/synthetic_train_splits/split_1/tabfairgdt_lamda_100_0.json
results/adult/tabular_argn/synthetic_train_splits/split_1/tabular_argn_lamda_100_0.json
results/adult/tab_fair_gan/synthetic_train_splits/split_1/tab_fair_gan_lamda_100_0.json
results/adult/cuts/synthetic_train_splits/split_1/cuts_0.json
results/adult/fsmote/synthetic_train_splits/split_1/fsmote_0.json
results/adult/prefair/synthetic_train_splits/split_1/prefair_0.json
tabular_datasets/adult/adult.json
Dataset adult has ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country'] categorical and ['age', 'capital-gain', 'capital-loss', 'hours-per-week'] numerical columns.


In [56]:
def plot_feature_distribution_protected(real_df, synthetic_dfs, synthetic_df_names, synthetic_df_colors, dtype_map, protected_attribute="sex", column='relationship'):
    num_rows = 1
    ncols = 2

    i = 0
    method_names = {"real": r'Real Data', "tabfairgdt": r"$\mathbf{T}$$_{\mathrm{\mathbf{AB}}}$$\mathbf{F}$$_{\mathrm{\mathbf{AIR}}}$$\mathbf{GDT}$", "tabular_argn": r'TabularARGN', "tab_fair_gan": r'TabFairGAN', "cuts":r'CuTS', "fsmote":r'FSMOTE', "prefair": r'PreFair'}

    unique_vals = real_df[protected_attribute].unique()
    len_unique_vals = len(unique_vals)

    fig, axes = plt.subplots(nrows=num_rows, ncols=ncols, figsize=(12*ncols, 6 * num_rows))
    bar_width = 0.12

    axes = axes.flatten()

    dtype = dtype_map[column]

    # Create handles and labels for the legend
    handles = []
    labels = []
    
    # Add real data to legend
    real_patch = mpatches.Patch(color='black', alpha=0.7, hatch='//', label='Real')
    handles.append(real_patch)
    labels.append('Real')
    
    # Add synthetic data methods to legend
    for k in range(len(synthetic_dfs)):
        synthetic_patch = mpatches.Patch(color=synthetic_df_colors[k], alpha=1, label=method_names[synthetic_df_names[k]])
        handles.append(synthetic_patch)
        labels.append(method_names[synthetic_df_names[k]])

    # Plot distributions split by protected attribute
    for j in range(len_unique_vals):
        ax_split = axes[j]
        unique_val = unique_vals[j]

        real_subgroup = real_df[real_df[protected_attribute] == unique_val][column]
        synthetic_dfs_subgroups = [synthetic_df[synthetic_df[protected_attribute] == unique_val][column] for synthetic_df in synthetic_dfs]

        # Get all unique categories from real and all synthetic datasets
        all_categories = set(real_subgroup.unique())
        for synthetic_subgroup in synthetic_dfs_subgroups:
            all_categories.update(synthetic_subgroup.unique())
        
        # Sort categories for consistent ordering
        all_categories = sorted(all_categories)
        
        positions_real = range(len(all_categories))
        
        # Add vertical lines first
        ax_split.axvline(x=-0.8*bar_width, color='gray', 
                        linestyle='--', alpha=0.6, linewidth=1)
        
        # List to store positions of vertical lines
        vline_positions = [-0.8*bar_width]  # First line position
        
        for pos in positions_real:
            line_pos = pos + len(synthetic_dfs) * bar_width + 0.8*bar_width
            vline_positions.append(line_pos)
            ax_split.axvline(x=line_pos, color='gray', 
                           linestyle='--', alpha=0.6, linewidth=1)
        
        # Add highlighted areas BEFORE drawing the bars
        if j == 0:  # First plot (highlight between last two vertical lines)
            ax_split.set_ylabel('Num. Samples', fontsize=25)

            if len(vline_positions) >= 2:
                last_idx = len(vline_positions) - 1
                second_last_idx = last_idx - 1
                # Set zorder to a negative value to ensure it's behind all other elements
                ax_split.axvspan(vline_positions[second_last_idx], vline_positions[last_idx], 
                                alpha=0.1, color='red', label='Highlighted Area', zorder=-1)
        
        elif j == 1:  # Second plot (highlight between first two vertical lines)
            if len(vline_positions) >= 2:
                # Set zorder to a negative value to ensure it's behind all other elements
                ax_split.axvspan(vline_positions[0], vline_positions[1], 
                                alpha=0.1, color='red', label='Highlighted Area', zorder=-1)
        
        # Now draw the bars on top of the highlighted area
        # Count occurrences for real data (use 0 for categories not in real data)
        counts_real = [len(real_subgroup[real_subgroup == cat]) if cat in real_subgroup.unique() else 0 for cat in all_categories]
        
        # Set zorder for bars to ensure they're drawn on top of the highlighting
        ax_split.bar(positions_real, counts_real, bar_width, color='black', alpha=0.7, hatch='//', zorder=3)
        
        for k, synthetic_subgroup in enumerate(synthetic_dfs_subgroups):
            # Count occurrences for synthetic data (use 0 for categories not in this synthetic data)
            counts_synthetic = [len(synthetic_subgroup[synthetic_subgroup == cat]) if cat in synthetic_subgroup.unique() else 0 for cat in all_categories]
            
            positions_synthetic = [p + (k + 1) * bar_width for p in positions_real]
            
            # Set zorder for bars to ensure they're drawn on top of the highlighting
            ax_split.bar(positions_synthetic, counts_synthetic, bar_width, color=synthetic_df_colors[k], alpha=1, zorder=3)

        ax_split.set_xticks([p + (len(synthetic_dfs) * bar_width / 2) for p in positions_real])
        ax_split.set_xticklabels(all_categories, rotation=30, ha='center')
        
        ax_split.set_yscale('log')
        ax_split.set_title(f'Distribution of {column} : $\\bf{{sex={unique_val}}}$', fontsize=30)
        ax_split.tick_params(axis='both', which='major', labelsize=20)
        ax_split.tick_params(axis='x', which='major', labelsize=25)

        if j == 0:
            xticks = ax_split.get_xticklabels()
            if xticks:
                xticks[-1].set_fontweight('bold')
        else:
            xticks = ax_split.get_xticklabels()
            if xticks:
                xticks[0].set_fontweight('bold')
    
    # Remove any unused subplots
    for del_i in range(i * ncols + len_unique_vals, len(axes)):
        fig.delaxes(axes[del_i])
    
    # Add the legend above the plots
    legend_y_position = 1.1
    fig.legend(handles=handles, 
               labels=labels, 
               loc='upper center',
               fontsize=25,
               bbox_to_anchor=(0.5, legend_y_position),  # Position centered, above the title
               ncol=len(handles),  # Make the legend horizontal
               frameon=True,  # Add a frame around the legend
               borderaxespad=0.)

    plt.tight_layout()
    # Adjust the top margin to make room for the legend
    # plt.subplots_adjust(top=0.85)
    # plt.show()
    plt.savefig(os.path.join("plots", "relationship_distribution_new.png"), dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)  # Close the figure to free memory


# Don't forget to import matplotlib.patches
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

# Usage example:
# plot_feature_distribution_protected(real_split_data, synthetic_dfs, synthetic_df_names, synthetic_df_colors, dtype_map, protected_attribute="sex")

plot_feature_distribution_protected(real_split_data, synthetic_dfs, synthetic_df_names, synthetic_df_colors, dtype_map, protected_attribute="sex")