In [None]:
import sys 
sys.path.append("..")
from src.dataset import Dataset
import pandas as pd
import os

In [None]:
adult_dataset_generator = Dataset("adult")

In [None]:
protected_attributes = ["sex"]
split_dfs, additional_sizes = adult_dataset_generator.split_population(adult_dataset_generator.original_dataframe, protected_attributes)

print(adult_dataset_generator.original_dataframe)
print(adult_dataset_generator.original_dataframe_encoded)

In [None]:
print(adult_dataset_generator.original_mappings["race"])

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Define colors for income categories
colors = ['#1f77b4', '#ff7f0e']

# Calculate percentages of income categories for each split
income_categories = ['<=50K', '>50K']  # Define the income categories
income_percentages = {}
legend_labels = []
for split_key, split_df in split_dfs.items():
    income_counts = split_df['income'].value_counts(normalize=True) * 100
    income_percentages[split_key] = [income_counts.get(cat, 0) for cat in income_categories]

# Plot population sizes of all split DataFrames along with income category percentages
plt.figure(figsize=(10, 6))
for i, (split_key, split_df) in enumerate(split_dfs.items()):
    population_size = len(split_df)
    income_percentage = income_percentages[split_key]
    bar_positions = np.arange(len(income_percentage))
    # Normalize the heights of the stacked bars
    normalized_heights = np.array(income_percentage) / 100 * population_size
    # Plot stacked bar chart for each split
    plt.bar(split_key, normalized_heights[0], color=colors[0])
    plt.bar(split_key, normalized_heights[1], bottom=normalized_heights[0], color=colors[1])

plt.xlabel(', '.join(protected_attributes))
plt.ylabel('Population Size')
plt.title('Population Sizes of Split DataFrames')
plt.xticks(rotation=45)

# Create a separate legend for income categories
plt.legend(income_categories, loc='upper right')
plt.tight_layout()
plt.show()


In [None]:
# # Get the DataFrame with the maximum length
# max_length_df_key = max(split_dfs, key=lambda x: len(split_dfs[x]))
# print(max_length_df_key)
# # Retrieve the DataFrame using the key
# max_length_df = split_dfs[max_length_df_key]

# max_length_df_class_counts = max_length_df['income'].value_counts()

# max_length_df_majority_class = max_length_df_class_counts.idxmax()
# max_length_df_majority_class_count = max_length_df_class_counts[max_length_df_majority_class]

# target_count_all_protected = 2 * max_length_df_majority_class_count

# print(max_length_df_class_counts)

In [None]:
# # Print split information
# for split_key, split_df in split_dfs.items():
#     class_counts = split_df['income'].value_counts()
#     if len(class_counts) >= 2:  # Check if there are at least two classes
#         print(f"For split '{split_key}':")
#         for class_label, class_count in class_counts.items():
#             minority_class_count = class_count
#             imbalance = max_length_df_majority_class_count - minority_class_count
#             instances_needed = imbalance 
#             total = minority_class_count + instances_needed
#             print(f"  Class '{class_label}' has {class_count} instances, and {instances_needed} instances are needed for total to be {total}.")


In [None]:
def get_synthetic_splits(split_dfs, generative_method="cart", generative_seed=0, return_plot=False, sampling_method="class_protected"):
    max_length_df_key = max(split_dfs, key=lambda x: len(split_dfs[x]))
    # Retrieve the DataFrame using the key
    max_length_df = split_dfs[max_length_df_key]

    max_length_df_class_counts = max_length_df['income'].value_counts()

    max_length_df_majority_class = max_length_df_class_counts.idxmax()
    max_length_df_majority_class_count = max_length_df_class_counts[max_length_df_majority_class]

    total_count = max_length_df_class_counts.sum()  # Summing up the counts
    percentages = (max_length_df_class_counts / total_count) * 100
    max_length_df_class_counts_percentage = percentages
    
    if sampling_method=="class_protected":
        augmented_dfs = []
        if return_plot:
            augmented_dfs_plot = []

        for split_key, split_df in split_dfs.items():
            class_counts = split_df['income'].value_counts()
            augmented_dfs.append(split_df)
            if return_plot:
                split_df_plot = split_df.copy()
                split_df_plot["method"] = "real"
                augmented_dfs_plot.append(split_df_plot.copy())
                
            for class_label, class_count in class_counts.items():
                minority_class_count = class_count
                imbalance = max_length_df_majority_class_count - minority_class_count
                size = imbalance

                if size > 0:
                    class_split_df = split_df[split_df['income'] == class_label].copy()
                    class_split_df.drop('income', axis=1, inplace=True)
                    class_split_df.drop('sex', axis=1, inplace=True)
                    if generative_method=="tvae":
                        split_synthesizer = adult_dataset_generator.train_synthesizer("tvae", class_split_df, encode=False, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, name="tvae", decode=False, random_state=generative_seed)
                    else:
                        split_synthesizer = adult_dataset_generator.train_synthesizer(generative_method, class_split_df, encode=True, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, random_state=generative_seed)
                        
                    split_synthetic_data['income'] = class_label
                    split_synthetic_data['sex'] = split_key
                    augmented_dfs.append(split_synthetic_data.copy())
                    if return_plot:
                        split_synthetic_data_plot = split_synthetic_data.copy()
                        split_synthetic_data_plot['method'] = "synthetic"

                        augmented_dfs_plot.append(split_synthetic_data_plot.copy())
        if return_plot:
            return augmented_dfs, augmented_dfs_plot
        return augmented_dfs

    if sampling_method=="protected":
        max_length_df_key = max(split_dfs, key=lambda x: len(split_dfs[x]))
        max_length_df_length = len(max_length_df)

        augmented_dfs = []
        if return_plot:
            augmented_dfs_plot = []

        for split_key, split_df in split_dfs.items():
            class_counts = split_df['income'].value_counts()
            augmented_dfs.append(split_df)
            if return_plot:
                split_df_plot = split_df.copy()
                split_df_plot["method"] = "real"
                augmented_dfs_plot.append(split_df_plot.copy())
                
            size_current_df = len(split_df)
            imbalance = max_length_df_length - size_current_df
            size = imbalance

            if size > 0:
                class_split_df = split_df.copy()
                class_split_df.drop('sex', axis=1, inplace=True)
                if generative_method=="tvae":
                    split_synthesizer = adult_dataset_generator.train_synthesizer("tvae", class_split_df, encode=False, random_state=generative_seed) 
                    split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, name="tvae", decode=False, random_state=generative_seed)
                else:
                    split_synthesizer = adult_dataset_generator.train_synthesizer(generative_method, class_split_df, encode=True, random_state=generative_seed) 
                    split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, random_state=generative_seed)
                    
                split_synthetic_data['sex'] = split_key
                augmented_dfs.append(split_synthetic_data.copy())
                if return_plot:
                    split_synthetic_data_plot = split_synthetic_data.copy()
                    split_synthetic_data_plot['method'] = "synthetic"

                    augmented_dfs_plot.append(split_synthetic_data_plot.copy())
        if return_plot:
            return augmented_dfs, augmented_dfs_plot
        return augmented_dfs

    if sampling_method=="class":
        augmented_dfs = []
        if return_plot:
            augmented_dfs_plot = []

        for split_key, split_df in split_dfs.items():
            class_counts = split_df['income'].value_counts()
            augmented_dfs.append(split_df)
            if return_plot:
                split_df_plot = split_df.copy()
                split_df_plot["method"] = "real"
                augmented_dfs_plot.append(split_df_plot.copy())

            df_majority_class = class_counts.idxmax()
            df_majority_class_count = class_counts[df_majority_class]

            for class_label, class_count in class_counts.items():
                minority_class_count = class_count
                imbalance = df_majority_class_count - minority_class_count
                size = imbalance

                if size > 0:
                    class_split_df = split_df[split_df['income'] == class_label].copy()
                    class_split_df.drop('income', axis=1, inplace=True)
                    class_split_df.drop('sex', axis=1, inplace=True)
                    if generative_method=="tvae":
                        split_synthesizer = adult_dataset_generator.train_synthesizer("tvae", class_split_df, encode=False, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, name="tvae", decode=False, random_state=generative_seed)
                    else:
                        split_synthesizer = adult_dataset_generator.train_synthesizer(generative_method, class_split_df, encode=True, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, random_state=generative_seed)
                        
                    split_synthetic_data['income'] = class_label
                    split_synthetic_data['sex'] = split_key
                    augmented_dfs.append(split_synthetic_data.copy())
                    if return_plot:
                        split_synthetic_data_plot = split_synthetic_data.copy()
                        split_synthetic_data_plot['method'] = "synthetic"

                        augmented_dfs_plot.append(split_synthetic_data_plot.copy())
        if return_plot:
            return augmented_dfs, augmented_dfs_plot
        return augmented_dfs

    if sampling_method=="same_class":
        augmented_dfs = []
        if return_plot:
            augmented_dfs_plot = []

        for split_key, split_df in split_dfs.items():
            class_counts = split_df['income'].value_counts()
            augmented_dfs.append(split_df)
            if return_plot:
                split_df_plot = split_df.copy()
                split_df_plot["method"] = "real"
                augmented_dfs_plot.append(split_df_plot.copy())
            total_split_count = class_counts.sum()  # Summing up the counts
            split_percentages = (class_counts / total_split_count) * 100
            for class_label, class_percentage in split_percentages.items():
                minority_class_percentage = class_percentage/100
                max_length_df_class_counts_percentage_class = max_length_df_class_counts_percentage[class_label]/100

                class_1_instances = int(class_counts[class_label])
                size = 0
                if not max_length_df_class_counts_percentage_class == minority_class_percentage:
                    additional_class_1_instances = (max_length_df_class_counts_percentage_class * (sum(class_counts)) - class_1_instances) / (1 - max_length_df_class_counts_percentage_class)
                    size = int(additional_class_1_instances)

                if size > 0:
                    class_split_df = split_df[split_df['income'] == class_label].copy()
                    class_split_df.drop('income', axis=1, inplace=True)
                    class_split_df.drop('sex', axis=1, inplace=True)
                    if generative_method=="tvae":
                        split_synthesizer = adult_dataset_generator.train_synthesizer("tvae", class_split_df, encode=False, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, name="tvae", decode=False, random_state=generative_seed)
                    else:
                        split_synthesizer = adult_dataset_generator.train_synthesizer(generative_method, class_split_df, encode=True, random_state=generative_seed) 
                        split_synthetic_data = adult_dataset_generator.generate_data(split_synthesizer, num=size, random_state=generative_seed)
                        
                    split_synthetic_data['income'] = class_label
                    split_synthetic_data['sex'] = split_key
                    augmented_dfs.append(split_synthetic_data.copy())
                    if return_plot:
                        split_synthetic_data_plot = split_synthetic_data.copy()
                        split_synthetic_data_plot['method'] = "synthetic"

                        augmented_dfs_plot.append(split_synthetic_data_plot.copy())
        if return_plot:
            return augmented_dfs, augmented_dfs_plot
        return augmented_dfs

In [None]:
split_df_keys, split_df_vals = zip(*split_dfs.items())

augmented_dfs, augmented_dfs_plot = get_synthetic_splits(split_dfs, generative_method="cart", generative_seed=0, return_plot=True, sampling_method="same_class")


In [None]:
final_augmented_dataset = pd.concat(augmented_dfs)
final_augmented_dataset_plot = pd.concat(augmented_dfs_plot)
final_augmented_dataset_plot['income_method'] = final_augmented_dataset_plot['income'] + ' (' + final_augmented_dataset_plot['method'] + ')'



In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming you have two different DataFrames: df1 and df2
# Replace df1 and df2 with your actual DataFrame names

# Set up the plot with two subplots in two columns
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6), sharey=True)

# Define the palette using Seaborn's color palette generator
palette = sns.color_palette("husl", 2)  # Using 'husl' palette with 2 colors

gender_order = ['Female', 'Male']  # Adjust as per your actual category order

# Plot for df1
sns.histplot(data=adult_dataset_generator.original_dataframe, x='sex', hue='income', palette=palette, hue_order=['>50K', '<=50K'], multiple="stack", ax=axes[0], discrete = True)
axes[0].set_title('Adult')
axes[0].set_xlabel('Gender')
axes[0].set_ylabel('Count')

# Plot for df2
sns.histplot(data=final_augmented_dataset_plot, x='sex', hue='income', palette=palette, hue_order=['>50K', '<=50K'], multiple="stack", ax=axes[1], discrete = True)
axes[1].set_title('Adult augmented')
axes[1].set_xlabel('Gender')
axes[1].set_ylabel('Count')

# Adjust layout
plt.tight_layout()

for ax in axes:
    for bar in ax.patches:
        # Find the total height of bars for the current x-coordinate
        total_height = sum(p.get_height() for p in ax.patches if p.get_x() == bar.get_x())
        # Calculate the percentage
        percentage = (bar.get_height() / total_height) * 100
        # Annotate the bar with the percentage
        ax.annotate(f'{percentage:.1f}%', 
                    xy=(bar.get_x() + bar.get_width() / 2, bar.get_y() + bar.get_height() / 2),
                    xytext=(0, 0),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=8)


plt.show()


In [None]:
final_augmented_dataset.shape