In [None]:
import os
from sklearn.metrics import roc_curve, auc
from matplotlib import pyplot as plt
import numpy as np

import torch
from torch import tensor
from utils.load_folktables import prepare_folktables_multattr
from humancompatible.train.fairness.constraints.constraint_fns import *
from fairret.statistic import *
from utils.network import SimpleNet

This notebook presents some useful plots based on the performance of the trained models.

### **Preparation**

**Load the Folktables dataset for the selected state and prepare it for usage**

In [None]:
TASK = "income"
STATE = "OK"

In [None]:
sens_cols=[
    "MAR",
    # "SEX",
    # 'RAC1P',
    ]

(
    X_train,
    y_train,
    group_ind_train,
    group_onehot_train,
    sep_group_ind_train,
    X_test,
    y_test,
    group_ind_test,
    group_onehot_test,
    sep_group_ind_test,
    group_order
) = prepare_folktables_multattr(
    TASK,
    state=STATE.upper(),
    random_state=42,
    onehot=False,
    download=True,
    sens_cols=sens_cols,
    binarize=[None],
    stratify=False,
)

In [None]:
group_onehot_train.sum(axis=0)

In [None]:
group_codes = {
    "MAR": {0: "OTHER", 1: "Mar", 2: "Wid", 3: "Div", 4: "Sep", 5:"Nev"},
    "SEX": {0: "OTHER", 1: "M", 2: "F"},
    "RAC1P": {0: "OTHER", 1: "W", 2: "B", 3: "AI", 4: "AN", 5: "AIAN", 6: "A", 7: "PA", 8: "OT", 9: "TW"}
}
groups_sep = [[int(g) for g in gr.split('_')] for gr in group_order]
group_names = [
    [
        group_codes[sens_cols[i]][c]
        for i, c in enumerate(gc)
    ]
    for gc in groups_sep]
group_names = ['_'.join(g) for g in group_names]
group_names

In [None]:
device = "cuda" if torch.cuda.is_available() and False else "cpu"

In [None]:
X_train_tensor = tensor(X_train, dtype=torch.float, device=device)
y_train_tensor = tensor(y_train, dtype=torch.float, device=device)

X_test_tensor = tensor(X_test, dtype=torch.float, device=device)
y_test_tensor = tensor(y_test, dtype=torch.float, device=device)

In [None]:
len(X_test_tensor) 

**Load saved models**

In [None]:
from itertools import product

constraints = {
    # "loss_equality": 0.005,
    # "unconstrained": 0.005,
    # "unconstrained": 0.03,
    "abs_diff_pr": 0.05,
}

dict_alg_names = {
    "StochasticGhost": "Ghost",
    "SSLALM": "SSLALM",
    "SSG": "SSw",
    # "SGD": "SGD",
    "Adam": "Adam",
    "fairret": "SGD-Fairret",
    "TorchSSLALM": "SSLALM",
    "TorchSSG": "SSG"
}

DATASET = TASK + "_" + STATE
loaded_models = []

for constr, cb in constraints.items():
    DIRECTORY_PATH = (
        "./utils/saved_models/" + DATASET + "/" + constr + "/" + ((f"{cb:.0E}" + "/") if cb is not None else '')
    )
    FILE_EXT = ".pt"

    directory_path = DIRECTORY_PATH
    print(f"Looking for models in: {directory_path}")
    try:
        file_list = os.listdir(directory_path)
    except FileNotFoundError:
        print("Not found")
        continue
    model_files = [file for file in file_list if file.endswith(FILE_EXT)]
    for model_file in model_files:
        if model_file.split("_")[0] not in dict_alg_names.keys():
            continue
        model_name = model_file
        model = SimpleNet(X_test.shape[1], 1, torch.float32).to(device)
        print(model_file)
        try:
            model.load_state_dict(
                torch.load(
                    directory_path + model_name, weights_only=True, map_location=device
                )
            )
        except:
            continue
        loaded_models.append((model_file, model))


### **Evaluation**

**Calculate test set statistics for the models - AUC, constraint satisfaction, loss, etc.. and aggregate per algorithm:**

In [None]:
from utils.stats import make_pairwise_constraint_stats_table, aggregate_model_stats_table, make_groupwise_stats_table

**Train set**:

In [None]:
loaded_models.sort(key=lambda x: x[0])

In [None]:
full_data_stats = make_groupwise_stats_table(
    X_train_tensor,
    y_train_tensor,
    loaded_models
    ).drop('Model',axis=1).groupby('Algorithm').agg('mean')

groupwise_stats = []

for group_ind in group_ind_train:
    groupwise_stats.append(
        make_groupwise_stats_table(
            X_train_tensor[group_ind],
            y_train_tensor[group_ind],
            loaded_models
        ).drop('Model',axis=1).groupby('Algorithm').agg('mean')
    )

In [None]:
groupwise_dev = []

for group_stats in groupwise_stats:
    diff = group_stats - full_data_stats
    diff = diff.add_suffix('_dev')
    diff['Sp'] = abs(diff['tpr_dev']) + abs(diff['fpr_dev'])
    diff['Ind'] = abs(diff['ppv_dev']) + abs(diff['fomr_dev'])
    diff['Sf'] = abs(diff['pr_dev'])
    diff['Ina'] = 1 - group_stats['acc']
    groupwise_dev.append(diff)


In [None]:
import pandas as pd
stats = pd.concat(groupwise_stats, keys=group_names, names=['group'])
stats

In [None]:
import pandas as pd
con = pd.concat(groupwise_dev, keys=group_names, names=['group'])
con

In [None]:
from itertools import combinations
import pandas as pd

bin_dfs = []

for group_idx_1, group_idx_2 in list(combinations(group_ind_train, 2)):
    X_train_1, y_train_1 = X_train_tensor[group_idx_1], y_train_tensor[group_idx_1]
    X_train_2, y_train_2 = X_train_tensor[group_idx_2], y_train_tensor[group_idx_2]
    table = make_pairwise_constraint_stats_table(
        X_train_1, y_train_1, X_train_2, y_train_2, loaded_models
    )
    table.index = table.Algorithm.apply(lambda x: dict_alg_names[x.split("_")[0]])
    table.drop("Algorithm", axis=1, inplace=True)
    bin_dfs.append(table)
    
df_train = pd.concat(bin_dfs, axis=0, keys=range(len(bin_dfs)), names=["constraint"])

In [None]:
train_df = aggregate_model_stats_table(
    df_train, "mean", agg_cols=["constraint", "Algorithm"]
)
train_df_std = aggregate_model_stats_table(
    df_train, ["mean", "std"], agg_cols=["constraint", "Algorithm"]
)
train_df_std.drop("Algname", axis=1, inplace=True)

In [None]:
train_df

**Plots:**

In [None]:
from utils.plotting import spider_line
cr = con.reset_index()
cr_alg = cr[cr['Algorithm'] == 'TorchSSLALM_0.05']
cr_alg.index = cr_alg.group

f = spider_line(cr_alg, yticks=[0,0.1,0.2,0.35])

**Distribution of predictions by group:**

In [None]:
predictions_by_alg = {alg: {} for alg in set([model_name.split("_")[0] for model_name, _ in loaded_models])}


for i, group in enumerate(group_ind_test):
    for model_name, model in loaded_models:
        alg = model_name.split("_")[0]

        preds = torch.nn.functional.sigmoid(model(X_test_tensor[group])).detach().numpy().squeeze()
        try:
            predictions_by_alg[alg][i].append(preds)
        except:
            predictions_by_alg[alg][i] = [preds]

for alg in predictions_by_alg.keys():
    for i in predictions_by_alg[alg].keys():
        predictions_by_alg[alg][i] = np.concatenate(predictions_by_alg[alg][i])

In [None]:
pred_dfs = {}

for alg, pred_dict in predictions_by_alg.items():
    preds = []
    groups = []
    for group, group_preds in pred_dict.items():
        preds.extend(group_preds)
        groups.extend([group]*len(group_preds))
    
    pred_dfs[alg] = (
        pd.DataFrame({'pred': preds, 'group': groups})
    )

In [None]:
import seaborn as sns

fig, axs = plt.subplots(nrows=1, ncols=3)

for i, (alg, predictions) in enumerate(pred_dfs.items()):
    ax = axs[i]
    predictions.group = predictions.group.apply(lambda x: group_names[x])
    sns.kdeplot(
        predictions,
        x='pred',
        hue='group',
        palette=sns.color_palette("husl", 5),
        fill=True,
        alpha=0.1,
        bw_adjust=0.4,
        ax=ax,
        clip=[0,1],
        common_norm=False)
    ax.vlines(0.5,0.,10, ls='--',color='black')
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(0, 10)
    ax.set_xlabel("Predictions", fontsize=20)
    ax.set_ylabel("Density", fontsize=20)
    ax.set_title(alg)

fig.set_figwidth(30)
fig.tight_layout()

### **Model plots**

**We choose one model per algorithm to make some useful plots**

For now, choose the model with the highest mean AUC:

In [None]:
df_train.index.get_level_values('Algorithm').unique()

In [None]:
select_by = "AUC_M"

In [None]:
best_models = {}
algs = df_train.index.get_level_values('Algorithm').unique()
for alg in algs:
    alg_df = df_train.xs(alg, level=1).reset_index()
    best_model_name = alg_df[['Model', select_by]].groupby('Model').mean()[select_by].idxmax()
    model = [(name, model) for name, model in loaded_models if name == best_model_name][0]
    best_models[alg] = model

#### Subgroup ROC

**TPR-FPR plot**

In [None]:
# Function to generate predictions and plot ROC curve
def plot_roc_curve_pr(ax, predictions, targets, sensitive_value):
    # Compute ROC curve and area under the curve
    fpr, tpr, thresholds = roc_curve(targets, predictions)
    roc_auc = auc(fpr, tpr)
    # Plot ROC curve
    ax.plot(fpr, tpr, label=f"group={sensitive_value}, AUC = {roc_auc:.2f}")
    tpr_minus_fpr = tpr - fpr
    # Find the threshold that maximizes TPR - FPR difference
    optimal_threshold_index = np.argmax(tpr_minus_fpr)
    optimal_threshold = thresholds[optimal_threshold_index]
    ax.scatter(
        fpr[optimal_threshold_index],
        tpr[optimal_threshold_index],
        # c="blue" if sensitive_value == sensitive_value_0 else "red",
        label=f"Optimal Threshold {sensitive_value} {optimal_threshold:.2f}",
    )


for alg, (model_name, model) in best_models.items():
    f = plt.figure()
    f.set_figwidth(10)
    f.set_figheight(10)
    ax = f.subplots()
    ax.set_title(alg)
    with torch.inference_mode():
        for i,group in enumerate(group_ind_test):
            predictions = model(X_test_tensor[group])
            # Plot ROC for sensitive attribute A=0
            plot_roc_curve_pr(
                ax, predictions, y_test[group], sensitive_value=i
            )
            ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random Classifier")
            ax.set_xlabel("False Positive Rate", fontsize=24)
            ax.set_ylabel("True Positive Rate", fontsize=24)
            ax.legend()

**TNR-FNR plot**

In [None]:
# Function to generate predictions and plot ROC curve
def plot_roc_curve_nr(ax, predictions, targets, sensitive_value):
    # Convert PyTorch tensors to numpy arrays
    # predictions = predictions.detach().numpy()
    # targets = targets.numpy()

    # Compute ROC curve and area under the curve
    fpr, tpr, thresholds = roc_curve(targets, predictions)
    fnr = 1 - tpr
    tnr = 1 - fpr
    roc_auc = auc(tnr, fnr)
    # Plot ROC curve
    ax.plot(tnr, fnr, label=f"Sensitive={sensitive_value}, AUC = {roc_auc:.2f}")

    tnr_minus_fnr = tnr - fnr

    # Find the threshold that maximizes tnr - fnr difference
    optimal_threshold_index = np.argmax(tnr_minus_fnr)
    optimal_threshold = thresholds[optimal_threshold_index]
    ax.scatter(
        fnr[optimal_threshold_index],
        tpr[optimal_threshold_index],
        # c="blue" if sensitive_value == sensitive_value_0 else "red",
        label=f"Optimal Threshold {sensitive_value} {optimal_threshold:.2f}",
    )


for alg, (model_name, model) in best_models.items():
    f = plt.figure()
    f.set_figwidth(10)
    f.set_figheight(10)
    ax = f.subplots()
    ax.set_title(alg)
    with torch.inference_mode():
        for i,group in enumerate(group_ind_test):
            predictions = model(X_test_tensor[group])
            # Plot ROC for sensitive attribute A=0
            plot_roc_curve_nr(
                ax, predictions, y_test[group], sensitive_value=i
            )
            ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random Classifier")
            ax.set_xlabel("False Negative Rate", fontsize=24)
            ax.set_ylabel("True Negative Rate", fontsize=24)
            ax.legend()