In [17]:
import seaborn as sns
import pandas as pd
from pathlib import Path
import sys
from logging import basicConfig, INFO
import ray
import pickle
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt
import matplotlib as mpl

from simulation import *

%load_ext autoreload
%autoreload 2

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Palatino"],  # Or Times, Computer Modern
})

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Evaluation

In [18]:
import os
import pandas as pd

subdir = Path('./sims_v4')
dataframes = [pd.read_csv(csv) for csv in subdir.glob("*.csv")]
df = pd.concat(dataframes, ignore_index=True)

In [19]:
def add_util_perc(df: pd.DataFrame, col: str) -> pd.DataFrame:
    df = df.copy()
    for bias in df['gender_bias'].unique():
        mask = (df['quota'] == QuotaType.NONE.name) & (df['gender_bias'] == bias)
        none_util = df[mask].set_index('id')[col]
        df.loc[df['gender_bias'] == bias, f'{col}_perc_bias'] = df[col] / df['id'].map(none_util)
    return df

# Apply the function to add the new column
df = add_util_perc(df, 'total_util')
df = add_util_perc(df, 'g0_util')
df = add_util_perc(df, 'g1_util')

# Visualisation

In [None]:
filters = ['n_positions','n_persons', 'gender_bias', 'total_cap', 'alpha_caps', 'alpha_prefs']
targets = ['total_util_perc', 'g0_util_perc', 'g1_util_perc']


def visualise(df):
    df_melted = df.melt(id_vars=["tvd", "quota"], value_vars=targets, var_name="target", value_name="value")
    df_melted[~(np.isfinite(df_melted["tvd"]) & np.isfinite(df_melted["value"]))] = 100  # high number for infinity.
    
    custom_titles = {'EQU50': r'Quota: $q = 50\%$', 
                     'GTE40': r'Quota: $q\geq 40\%$', 
                     'GTE30': r'Quota $q\geq 30\%$', 
                     'GTE20': r'Quota: $q\geq 20\%$', 
                     'NONE': r'Quota: $q\geq0\%$', 
                     'PREF': r'Quota: $q_{\mathrm{pref}}$'}

    custom_legend_labels = {
        "total_util_perc": "Total relative efficiency",
        "g0_util_perc": "Relative efficiency males",
        "g1_util_perc": "Relative efficiency females",
    }

    # Plot with FacetGrid
    g = sns.FacetGrid(df_melted, col="quota", hue="target", palette="Set2", col_wrap=2, height=4, aspect=1 ,margin_titles=True)
    # g.map(sns.regplot, "tvd", "value", scatter_kws={'alpha': 0.5}, ci=None, robust=True)
    g.map(sns.scatterplot, "tvd", "value", alpha=0.5)  # Scatter plot for legend
    g.map(sns.regplot, "tvd", "value", scatter=False, ci=None, robust=True)  # Regression line without scatter
    g.fig.subplots_adjust(hspace=0.4, wspace=0.3)

    # Customize y-axis limits
    spread = 1.1
    y_min = df_melted['value'].quantile(0.05) / spread
    y_max = df_melted['value'].quantile(0.95) * spread
    g.set(ylim=(y_min, y_max))

    # Set custom titles for each subplot
    g.set_titles(template="{col_name}", size=16)
    for ax, col_value in zip(g.axes.flat, g.col_names):
        if col_value in custom_titles:
            ax.set_title(custom_titles[col_value], fontsize=16)

    # Set axis labels
    g.set_axis_labels(
        r"$\textrm{TVD}$", r"$\textrm{Relative Efficiency}$", fontsize=18
    )

    # Increase tick label sizes and restore tick labels if missing
    for ax in g.axes.flat:
        ax.tick_params(axis='x', labelsize=14)
        ax.tick_params(axis='y', labelsize=14)
        ax.xaxis.set_major_locator(plt.MaxNLocator(5))  # Adjust number of ticks
        ax.yaxis.set_major_locator(plt.MaxNLocator(5))  # Adjust number of ticks

    handles, labels = g.axes.flat[0].get_legend_handles_labels()
    line_handles = [handle for handle in handles if isinstance(handle, mpl.lines.Line2D)]  # Only lines
    scatter_handles = [handle for handle in handles if isinstance(handle, mpl.lines.Line2D)]  # Only scatter
    new_labels = [custom_legend_labels.get(label, label) for label in labels[:len(line_handles)]]
    g.fig.legend(
        line_handles,
        new_labels,
        title="",
        title_fontsize=14,
        fontsize=14,
        loc='lower center',
        bbox_to_anchor=(0.49, -0.05),
        ncol=3,
        frameon=False,
        handletextpad=0.5,
        columnspacing=0.8
    )
    sns.despine(offset=10, trim=True)
    plt.savefig('./scatter_plot.pdf', format='pdf', bbox_inches='tight')
    plt.show()


def filter_and_visualize(**kwargs):
    filtered_df = df.copy()

    for key, value in kwargs.items():
        filtered_df = filtered_df[filtered_df[key] == value]
    
    # Show visualization
    if not filtered_df.empty:
        visualise(filtered_df)
    else:
        print("No data matches the filter criteria.")

dropdowns = {col: widgets.Dropdown(options=sorted(df[col].unique()), description=col) for col in filters}

# Link widgets to the output
ui = widgets.HBox(list(dropdowns.values()))
out = widgets.interactive_output(filter_and_visualize, dropdowns)

# Display widgets and output
display(ui, out)

HBox(children=(Dropdown(description='n_positions', options=(np.int64(5), np.int64(10)), value=np.int64(5)), Dr…

Output()