In [None]:
import os
from sklearn.metrics import roc_curve, auc
from matplotlib import pyplot as plt
import numpy as np

import torch
from torch import tensor
from utils.load_folktables import prepare_folktables_multattr
from humancompatible.train.benchmark.constraints.constraint_fns import *
from fairret.statistic import *
from utils.network import SimpleNet

In [None]:
# import numpy as np
# import pandas as pd
# from scipy.io.arff import loadarff

# raw_data = loadarff('utils/raw_data/dutch_census_2001.arff')
# df_data = pd.DataFrame(raw_data[0])
# df_data

This notebook presents some useful plots based on the performance of the trained models.

### **Preparation**

**Load the Folktables dataset for the selected state and prepare it for usage**

In [None]:
TASK = "income"
STATE = "VA"

In [None]:
sens_cols=[
    "MAR",
    # "SEX",
    #'RAC1P',
    ]

(
    (X_train, X_val, X_test),
    (y_train, y_val, y_test),
    (group_ind_train, group_ind_val, group_ind_test),
    (group_onehot_train, group_onehot_val, group_onehot_test),
    group_order
) = prepare_folktables_multattr(
    TASK,
    state=STATE.upper(),
    random_state=42,
    onehot=False,
    download=False,
    sens_cols=sens_cols,
    # binarize=[None],
    stratify=True,
)

In [None]:
group_onehot_train.sum(axis=0)

In [None]:
group_codes = {
    "MAR": {0: "OTHER", 1: "Mar", 2: "Wid", 3: "Div", 4: "Sep", 5:"Nev"},
    "SEX": {0: "OTHER", 1: "M", 2: "F"},
    "RAC1P": {0: "OTHER", 1: "W", 2: "B", 3: "AI", 4: "AN", 5: "AIAN", 6: "A", 7: "PA", 8: "OT", 9: "TW"}
}
groups_sep = [[int(g) for g in gr.split('_')] for gr in group_order]
group_names = [
    [
        group_codes[sens_cols[i]][c]
        for i, c in enumerate(gc)
    ]
    for gc in groups_sep]
group_names = ['_'.join(g) for g in group_names]
group_names

In [None]:
device = "cuda" if torch.cuda.is_available() and False else "cpu"

In [None]:
X_train_tensor = tensor(X_train, dtype=torch.float, device=device)
y_train_tensor = tensor(y_train, dtype=torch.float, device=device)

X_test_tensor = tensor(X_test, dtype=torch.float, device=device)
y_test_tensor = tensor(y_test, dtype=torch.float, device=device)

In [None]:
len(X_test_tensor) 

**Load saved models**

In [None]:
from itertools import product

constraints = {
    "abs_loss_equality": 0.05,
    # "unconstrained": 0.005,
    # "unconstrained": 0.03,
    # "abs_diff_pr": 0.05,
}

dict_alg_names = {
    "SGD": "SGD",
    "SGD+Reg": "SGD+Reg",
    "SSG": "SSw",
    "SSLALM": "SSLALM",
    "StochasticGhost": "Ghost",
    # "Adam": "Adam",
    # "fairret": "SGD-Fairret",
    # "TorchSSLALM": "SSLALM",
    # "TorchSSG": "SSG"
}

DATASET = TASK + "_" + STATE
loaded_models = []

for constr, cb in constraints.items():
    DIRECTORY_PATH = (
        "./utils/saved_models/" + DATASET + "/" + constr + "/" + ((f"{cb:.0E}" + "/") if cb is not None else '')
    )
    FILE_EXT = ".pt"

    directory_path = DIRECTORY_PATH
    print(f"Looking for models in: {directory_path}")
    try:
        file_list = os.listdir(directory_path)
    except FileNotFoundError:
        print("Not found")
        continue
    model_files = [file for file in file_list if file.endswith(FILE_EXT)]
    for model_file in model_files:
        if model_file.split("_")[0] not in dict_alg_names.keys():
            continue
        model_name = model_file
        model = SimpleNet(X_test.shape[1], 1, torch.float32).to(device)
        print(model_file)
        try:
            model.load_state_dict(
                torch.load(
                    directory_path + model_name, weights_only=True, map_location=device
                )
            )
        except:
            continue
        loaded_models.append((model_file, model))


### **Evaluation**

**Calculate test set statistics for the models - AUC, constraint satisfaction, loss, etc.. and aggregate per algorithm:**

In [None]:
from utils.stats import make_pairwise_constraint_stats_table, aggregate_model_stats_table, make_groupwise_stats_table

In [None]:
X = X_test_tensor
y = y_test_tensor
g = group_ind_test

In [None]:
full_statistics = make_groupwise_stats_table(
    X,
    y,
    loaded_models
    ).drop('Model',axis=1).groupby('Algorithm').agg('mean')

Groupwise statistics:

In [None]:
# full_data_stats = make_groupwise_stats_table(
#     X_train_tensor,
#     y_train_tensor,
#     loaded_models
#     ).drop('Model',axis=1).groupby('Algorithm').agg('mean')

# X = X_train_tensor
# y = y_train_tensor
# g = group_ind_train

groupwise_stats = []
groupwise_stats_inv = []

for group_ind in g:
    groupwise_stats.append(
        make_groupwise_stats_table(
            X[group_ind],
            y[group_ind],
            loaded_models
        ).drop('Model',axis=1).groupby('Algorithm').agg('mean')
    )
    except_group_ind = np.delete(np.arange(len(X)), group_ind, 0)
    groupwise_stats_inv.append(
        make_groupwise_stats_table(
            X[except_group_ind],
            y[except_group_ind],
            loaded_models
        ).drop('Model',axis=1).groupby('Algorithm').agg('mean')
    )

In [None]:
import pandas as pd
stats = pd.concat(groupwise_stats, keys=group_names, names=['group'])
inv_stats = pd.concat(groupwise_stats_inv, keys=group_names, names=['group'])
sep = {}
sep_total = 0
ind = {}
ind_total = 0
sf = {}
sf_total = 0

gn = [g for g in group_names]

for group_idx, group in enumerate(gn):
    sep[group] = 0
    ind[group] = 0
    sf[group] = 0
    for alt_group in [g for g in gn if g != group]:
        sep[group] += abs(stats.loc[group]['tpr'] - stats.loc[alt_group]['tpr']) + abs(stats.loc[group]['fpr'] - stats.loc[alt_group]['fpr'])
        ind[group] += abs(stats.loc[group]['pr'] - stats.loc[alt_group]['pr'])
        sf[group] += abs(stats.loc[group]['ppv'] - stats.loc[alt_group]['ppv']) + abs(stats.loc[group]['fomr'] - stats.loc[alt_group]['fomr'])
    sep_total += sep[group]
    ind_total += ind[group]
    sf_total += sf[group]

sep = pd.concat(sep, keys=gn, names=['group'])
ind = pd.concat(ind, keys=gn, names=['group'])
sf = pd.concat(sf, keys=gn, names=['group'])

stats['Sp'] = sep
stats['Ind'] = ind
stats['Sf'] = sf

# stats['Sp']= abs(stats['tpr'] - inv_stats['tpr']) + abs(stats['fpr'] - inv_stats['fpr'])
# stats['Ind'] = abs(stats['pr'] - inv_stats['pr'])
# stats['Sf'] = abs(stats['ppv'] - inv_stats['ppv']) + abs(stats['fomr'] - inv_stats['fomr'])

stats['Ina'] = 1 - stats['acc']
stats

**Plots:**

In [None]:
from utils.plotting import spider_line
cr = stats.loc['Mar']
# cr = sts
# cr = stats.groupby('Algorithm').mean()
# cr_alg = cr[cr['Algorithm'] == 'SGD_0.05']
# cr_alg.index = cr_alg.group

f = spider_line(cr, yticks=np.arange(0, 0.5, 0.1))
f.set_figwidth(10)
f.set_figheight(10)

**Distribution of predictions by group:**

In [None]:
# X = X_train_tensor
# y = y_train_tensor

X = X_test_tensor
y = y_test_tensor

In [None]:
predictions_by_alg = {alg: {} for alg in set([model_name.split("_")[0] for model_name, _ in loaded_models])}

for i, group in enumerate(group_ind_test):
    for model_name, model in loaded_models:
        alg = model_name.split("_")[0]

        preds = torch.nn.functional.sigmoid(model(X[group])).detach().numpy().squeeze()
        try:
            predictions_by_alg[alg][i].append(preds)
        except:
            predictions_by_alg[alg][i] = [preds]

for alg in predictions_by_alg.keys():
    for i in predictions_by_alg[alg].keys():
        predictions_by_alg[alg][i] = np.concatenate(predictions_by_alg[alg][i])

In [None]:
pred_dfs = {}

for alg, pred_dict in predictions_by_alg.items():
    preds = []
    groups = []
    for group, group_preds in pred_dict.items():
        preds.extend(group_preds)
        groups.extend([group]*len(group_preds))
    
    pred_dfs[alg] = (
        pd.DataFrame({'pred': preds, 'group': groups})
    )

In [None]:
import seaborn as sns

fig, axs = plt.subplots(nrows=1, ncols=5)

for i, (alg, predictions) in enumerate(pred_dfs.items()):
    ax = axs[i]
    predictions.group = predictions.group.apply(lambda x: group_names[x])
    sns.kdeplot(
        predictions,
        x='pred',
        hue='group',
        palette=sns.color_palette("husl", 5),
        fill=True,
        alpha=0.1,
        bw_adjust=0.4,
        ax=ax,
        clip=[0,1],
        common_norm=False)
    ax.vlines(0.5,0.,10, ls='--',color='black')
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(0, 10)
    ax.set_xlabel("Predictions", fontsize=20)
    ax.set_ylabel("Density", fontsize=20)
    ax.set_title(alg)

fig.set_figwidth(30)
fig.tight_layout()

### **Model plots**

**We choose one model per algorithm to make some useful plots**

Choose the model with the highest mean AUC:

In [None]:
from sklearn.metrics import auc, roc_curve, accuracy_score

# X = X_train_tensor
# y = y_train_tensor

X = X_test_tensor
y = y_test_tensor

alg_list = [str.join("", model_name.split("_trial")[:-1]) for model_name, _ in loaded_models]
# model_df = pd.DataFrame(columns=['alg', 'auc', 'model'])
model_df = []
for model_name, model in loaded_models:
    alg = str.join("", model_name.split("_trial")[:-1])
    with torch.inference_mode():
        preds = model(X)
    fpr, tpr, thresholds = roc_curve(
        y.cpu().numpy(), preds.cpu().numpy()
    )
    model_auc = auc(fpr, tpr)
    model_df.append({'alg': alg, 'auc': model_auc, 'model': model})

model_df = pd.DataFrame(model_df)

In [None]:
models_idx = model_df.groupby('alg')['auc'].idxmax()
best_models = model_df.iloc[models_idx][['alg', 'model']]
best_models.index = best_models['alg']
best_models.drop('alg', inplace=True, axis=1)
best_models = best_models['model'].to_dict()

#### Subgroup ROC

**TPR-FPR plot**

In [None]:
# Function to generate predictions and plot ROC curve
def plot_roc_curve_pr(ax, predictions, targets, sensitive_value):
    # Compute ROC curve and area under the curve
    fpr, tpr, thresholds = roc_curve(targets, predictions)
    roc_auc = auc(fpr, tpr)
    # Plot ROC curve
    ax.plot(fpr, tpr, label=f"group={sensitive_value}, AUC = {roc_auc:.2f}")
    tpr_minus_fpr = tpr - fpr
    # Find the threshold that maximizes TPR - FPR difference
    optimal_threshold_index = np.argmax(tpr_minus_fpr)
    optimal_threshold = thresholds[optimal_threshold_index]
    ax.scatter(
        fpr[optimal_threshold_index],
        tpr[optimal_threshold_index],
        # c="blue" if sensitive_value == sensitive_value_0 else "red",
        label=f"Optimal Threshold {sensitive_value} {optimal_threshold:.2f}",
    )


for alg, model in best_models.items():
    f = plt.figure()
    f.set_figwidth(10)
    f.set_figheight(10)
    ax = f.subplots()
    ax.set_title(alg)
    with torch.inference_mode():
        for i,group in enumerate(group_ind_test):
            predictions = model(X_test_tensor[group])
            # Plot ROC for sensitive attribute A=0
            plot_roc_curve_pr(
                ax, predictions, y_test[group], sensitive_value=i
            )
            ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random Classifier")
            ax.set_xlabel("False Positive Rate", fontsize=24)
            ax.set_ylabel("True Positive Rate", fontsize=24)
            ax.legend()