In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn import datasets as ds
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils.fairness_utils import pareto_df

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
print(2)

In [None]:
FOLDER = "./../results_Adultdrifted/"

DATASET_NAME = "ADULT"

NAME2 = "_ROAD_10_100_12_20"


fROAD_in = "TEST_" + DATASET_NAME + NAME2 + "Shifted_EO.csv" 
fROAD2014 = "TEST_" + DATASET_NAME + NAME2 + "2014Shifted_EO.csv" 
fROAD2015 = "TEST_" + DATASET_NAME + NAME2 +  "2015Shifted_EO.csv" 


# colors
#palette = plt.cm.tab20
palette = sns.color_palette("bright", 10)

METHODS_DICT = {"fROAD":  {"fname_in":fROAD_in,
                          "fname_2014":fROAD2014,
                          "fname_2015":fROAD2015,
                          "name": "ROAD (Ours)",
                          "color": palette[0]}}

for method in METHODS_DICT:
    for testfname in ["fname_in", "fname_2014", "fname_2015"]:
        df_meth = pd.read_csv(FOLDER + METHODS_DICT[method][testfname], index_col=0)
        df_meth["lambda"] = df_meth["lambda"].astype('str')
        df_meth["lambda"] = df_meth["lambda"].map(lambda x: x[:4])
        df_meth["tau"] = df_meth["tau"].astype('str')
        df_meth["tau"] = df_meth["tau"].map(lambda x: x[:4])
        df_meth["run_id"] = df_meth["run_id"].astype('str')
        
        df_meth["Global EO"] = df_meth["Global FPR"] + df_meth["Global FNR"]
    

    
        METHODS_DICT[method]["df"+testfname[5:]] = df_meth



In [None]:
    
# for printing labels on figures
FEATURE_NAMES = {"Global DI": "Global Fairness (DI)",
                "Global Acc": "Accuracy",
                 "Global EO": "Global Fairness (EO)",
                "top1_DI": "Local Fairness (worst 1 DI)",
                "top3_DI": "Local Fairness (worst 3 DI)", 
                "q_DI_0.8": "Local Fairness (0.8 quantile)"}



# Pareto curves

## 1. Acc - Local Fairness

### 1.a Demographic subgroups

In [None]:
%%time

#### WARNING/ Slow because of PAreto


# Pareto plot Global Acc * top1DI à iso Global DI
sns.set_style('whitegrid')

sns.set_palette("bright")


#### IMPORTANT: this defines the feature to "fix" and the corresponding range. E.g.: keeping only models 
# with GLOBAL DI between 0.0 and 0.05
dataf = "df_in"
iso_f = "Global DI"
GDIRANGE = 0.0, 1.

### FEATURES FOR X AND Y AXES
f1, f2 = "Global EO", "Global Acc"


legend_names = []
for key in METHODS_DICT:
    print( key)
    method = METHODS_DICT[key]
    try:
        df = method[dataf]
        color = method["color"]
    except KeyError:
        print('bug')
        continue
    
    ### PARETO FRONT
    dff = pareto_df(df, f1, f2)
    g = sns.lineplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, linewidth=3, label=method["name"])
    g2 = sns.scatterplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers="o")
    
    legend_names.append(method["name"])

g.invert_xaxis()

plt.legend(fontsize=12, loc='lower left', frameon=True)



plt.xlabel(FEATURE_NAMES[f1], fontsize=17)
plt.ylabel(FEATURE_NAMES[f2], fontsize=17)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylim((0.79, 0.845))
plt.title(DATASET_NAME + " (In-distribution)", fontsize=17)
plt.tight_layout()


FNAME = "./results_img/"

plt.savefig(FNAME + "adultDRIFT_indistribution_BIS.pdf")


In [None]:
%%time

#### WARNING/ Slow because of PAreto


# Pareto plot Global Acc * top1DI à iso Global DI
sns.set_style('whitegrid')

sns.set_palette("bright")


#### IMPORTANT: this defines the feature to "fix" and the corresponding range. E.g.: keeping only models 
# with GLOBAL DI between 0.0 and 0.05
dataf = "df_2014"
iso_f = "Global DI"
GDIRANGE = 0.0, 1.

### FEATURES FOR X AND Y AXES
f1, f2 = "Global EO", "Global Acc"


legend_names = []
for key in METHODS_DICT:

    print( key)
    method = METHODS_DICT[key]
    try:
        df = method[dataf]
        color = method["color"]
    except KeyError:
        print('bug')
        continue
    
    ### PARETO FRONT
    dff = pareto_df(df, f1, f2)

    g = sns.lineplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, linewidth=3, label=method["name"])
    g2 = sns.scatterplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers="o")
    
    legend_names.append(method["name"])

g.invert_xaxis()

plt.legend(fontsize=12, loc='lower left', frameon=True)

plt.xlabel(FEATURE_NAMES[f1], fontsize=17)
plt.ylabel(FEATURE_NAMES[f2], fontsize=17)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylim((0.63, 0.76))
plt.xlim((0.13, 0.00))
plt.title(DATASET_NAME + " (2014)", fontsize=17)
plt.tight_layout()


FNAME = "./results_img/"

plt.savefig(FNAME + "adultDRIFT_2014bis2.pdf")


In [None]:
%%time

#### WARNING/ Slow because of PAreto


# Pareto plot Global Acc * top1DI à iso Global DI
sns.set_style('whitegrid')

sns.set_palette("bright")


#### IMPORTANT: this defines the feature to "fix" and the corresponding range. E.g.: keeping only models 
# with GLOBAL DI between 0.0 and 0.05
dataf = "df_2015"
iso_f = "Global DI"
GDIRANGE = 0.0, 1.

### FEATURES FOR X AND Y AXES
f1, f2 = "Global EO", "Global Acc"


legend_names = []
for key in METHODS_DICT:
    print( key)
    method = METHODS_DICT[key]
    try:
        df = method[dataf]
        color = method["color"]
    except KeyError:
        print('bug')
        continue
        
    ### PARETO FRONT
    dff = pareto_df(df, f1, f2)
    g = sns.lineplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, linewidth=3, label=method["name"])
    g2 = sns.scatterplot(data=dff[dff["pareto"]==1], x=f1, y=f2, color=color, legend=False, s=30, markers="o")
    
    legend_names.append(method["name"])

g.invert_xaxis()

plt.legend(fontsize=12, loc='lower left', frameon=True)

plt.xlabel(FEATURE_NAMES[f1], fontsize=17)
plt.ylabel(FEATURE_NAMES[f2], fontsize=17)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.ylim((0.60, 0.77))
plt.title(DATASET_NAME + " (2015)", fontsize=17)
plt.tight_layout()


FNAME = "./results_img/"

plt.savefig(FNAME + "adultDRIFT_2015bis2.pdf")
