In [1]:
import seaborn as sns
import pandas as pd
from pathlib import Path
import sys
from logging import basicConfig, INFO
import ray
import pickle
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib as mpl

from simulation import *

%load_ext autoreload
%autoreload 2

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],  # Or Times, Computer Modern
})

# Evaluation

In [8]:
import os
import pandas as pd

subdir = Path('./sims_v4')
dataframes = [pd.read_csv(csv) for csv in subdir.glob("*.csv")]
df = pd.concat(dataframes, ignore_index=True)

In [9]:
def add_util_perc(df: pd.DataFrame, col: str) -> pd.DataFrame:
    df = df.copy()
    for bias in df['gender_bias'].unique():
        mask = (df['quota'] == QuotaType.NONE.name) & (df['gender_bias'] == bias)
        none_util = df[mask].set_index('id')[col]
        df.loc[df['gender_bias'] == bias, f'{col}_perc_bias'] = df[col] / df['id'].map(none_util)
    return df

# Apply the function to add the new column
df = add_util_perc(df, 'total_util')
df = add_util_perc(df, 'g0_util')
df = add_util_perc(df, 'g1_util')

# Visualisation

In [11]:
def filter(x):
    if x >= 0.8:
        return 'low'
    elif x <= 0.2:
        return 'high'
    else:
        return 'medium'

df['similiarity'] = df['tvd'].apply(filter)
df = df.loc[df['similiarity'] != 'medium']

In [None]:
df = df.loc[df['quota'] != 'GTE30']

filters = [
    'n_positions',
    'n_persons',
    'total_cap',
]
targets = ['total_util_perc', 'g0_util_perc', 'g1_util_perc', 'total_util_perc_bias', 'g0_util_perc_bias', 'g1_util_perc_bias']
#targets = ['total_util', 'g0_util', 'g1_util']

def visualise(df, target):
    ax = sns.violinplot(data=df, x='quota', y=target, hue='similiarity', palette="Set3", split=True, inner="quart", fill=True)
    sns.despine(offset=10, trim=True)
    plt.legend(frameon=False, fontsize=12, title=r'$\mathrm{TVD}$', title_fontsize=14)
    #sns.set_context("talk", font_scale=10)
    custom_tick_mapping = {'EQU50': r'$= 50\%$', 'GTE40': r'$\geq 40\%$', 'GTE20': r'$\geq 20\%$', 'NONE': r'$\geq0\%$', 'PREF': r'$q_{\mathrm{pref}}$'}
    original_labels = [tick.get_text() for tick in ax.get_xticklabels()]
    custom_labels = [custom_tick_mapping.get(label, label) for label in original_labels]
    # Set the custom x-tick labels
    ax.set_xticklabels(custom_labels, fontsize=18)
    plt.ylabel(r"$\textrm{Relative efficiency}$", fontsize=18)
    plt.xlabel(r"$\textrm{Quota}$", fontsize=18)
    plt.yticks(fontsize=18)
    plt.savefig('./violin_plot.pdf', format='pdf', bbox_inches='tight')
    plt.show()

def filter_and_visualize(**kwargs):
    target = kwargs.pop('target')
    tvd = kwargs.pop('tvd')
    tvd_dir = kwargs.pop('tvd_dir')
    filtered_df = df.copy()

    if tvd_dir == 'UP':
        filtered_df = filtered_df[filtered_df['tvd'] >= float(tvd)]
    else:
        filtered_df = filtered_df[filtered_df['tvd'] <= float(tvd)]

    for key, value in kwargs.items():
        filtered_df = filtered_df[filtered_df[key] == value]
    
    # Show visualization
    if not filtered_df.empty:
        visualise(filtered_df, target)
    else:
        print("No data matches the filter criteria.")

dropdowns = {col: widgets.Dropdown(options=sorted(df[col].unique()), description=col) for col in filters}
dropdowns['target'] = widgets.Dropdown(options=targets, description='target')
dropdowns['gender_bias'] = widgets.Dropdown(options=['NONE', 'LOW', 'MEDIUM', 'HIGH', 'VERY_HIGH', 'EXTREME'], description='gender bias')
dropdowns['tvd'] = widgets.Dropdown(options=['0.0', '0.1', '0.2', '0.3', '0.4', '0.5', '0.6', '0.7', '0.8', '0.9', '1.0'], description='tvd')
dropdowns['tvd_dir'] = widgets.Dropdown(options=['UP', 'DOWN'], description='tvd_dir')

# Link widgets to the output
ui = widgets.HBox(list(dropdowns.values()))
out = widgets.interactive_output(filter_and_visualize, dropdowns)

# Display widgets and output
display(ui, out)

HBox(children=(Dropdown(description='n_positions', options=(np.int64(5), np.int64(10)), value=np.int64(5)), Dr…

Output()