In [None]:
import os
from sklearn.metrics import roc_curve, auc
from matplotlib import pyplot as plt
import numpy as np

import torch
from torch import tensor
from utils.load_folktables import prepare_folktables
from src.constraints.constraint_fns import *
from fairret.statistic import *
from utils.network import SimpleNet

This notebook presents some useful plots based on the performance of the trained models.

### **Preparation**

**Load the Folktables dataset for the selected state and prepare it for usage**

In [None]:
TASK = "income"
# TASK = 'employment'
STATE = "OK"

In [None]:
(
    X_train,
    y_train,
    [w_idx_train, nw_idx_train],
    X_test,
    y_test,
    [w_idx_test, nw_idx_test],
) = prepare_folktables(
    TASK,
    state=STATE,
    random_state=42,
    make_unbalanced=False,
    onehot=False,
    download=True,
)

sensitive_value_0 = "white"
sensitive_value_1 = "non-white"

In [None]:
device = "cuda" if torch.cuda.is_available() and False else "cpu"

In [None]:
X_train_tensor = tensor(X_train, dtype=torch.float, device=device)
y_train_tensor = tensor(y_train, dtype=torch.float, device=device)

X_test_tensor = tensor(X_test, dtype=torch.float, device=device)
y_test_tensor = tensor(y_test, dtype=torch.float, device=device)

X_train_w = X_train_tensor[w_idx_train]
y_train_w = y_train_tensor[w_idx_train]
X_train_nw = X_train_tensor[nw_idx_train]
y_train_nw = y_train_tensor[nw_idx_train]

X_test_w = X_test_tensor[w_idx_test]
y_test_w = y_test_tensor[w_idx_test]
X_test_nw = X_test_tensor[nw_idx_test]
y_test_nw = y_test_tensor[nw_idx_test]

In [None]:
print("w, nw, total")
print("train")
print(len(y_train_w), len(y_train_nw), len(y_train))
print(
    sum(y_train_w == 1) / len(y_train_w),
    sum(y_train_nw == 1) / len(y_train_nw),
    sum(y_train_tensor == 1) / len(y_train_tensor),
)
print("test")
print(len(y_test_w), len(y_test_nw), len(y_test))
print(
    sum(y_test_w == 1) / len(y_test_w),
    sum(y_test_nw == 1) / len(y_test_nw),
    sum(y_test_tensor == 1) / len(y_test_tensor),
)

**Load saved models**

In [None]:
# directory to load models from

LOSS_BOUND = 0.005
DATASET = TASK + "_" + STATE
constraint = "eq_loss"
DIRECTORY_PATH = (
    "./utils/saved_models/"
    + DATASET
    + "/"
    + constraint
    + "/"
    + f"{LOSS_BOUND:.0e}"
    + "/"
)
FILE_EXT = ".pt"

In [None]:
loaded_models = []
directory_path = DIRECTORY_PATH
file_list = os.listdir(directory_path)
model_files = [file for file in file_list if file.endswith(FILE_EXT)]
for model_file in model_files:
    model_name = model_file
    model = SimpleNet(X_test.shape[1], 1, torch.float32).to(device)
    print(model_file, end="\r")
    try:
        model.load_state_dict(
            torch.load(
                directory_path + model_name, weights_only=False, map_location=device
            )
        )
    except:
        continue
    model_file = str.join("", model_file.split("_trial")[:-1])
    loaded_models.append((model_file, model))

### **Evaluation**

**Calculate test set statistics for the models - AUC, constraint satisfaction, loss, etc.. and aggregate per algorithm:**

In [None]:
from utils.stats import make_model_stats_table
from utils.stats import aggregate_model_stats_table

**Train set**:

In [None]:
res_df_train = make_model_stats_table(
    X_train_w, y_train_w, X_train_nw, y_train_nw, loaded_models
)

train_df = aggregate_model_stats_table(res_df_train, "mean")
train_df_std = aggregate_model_stats_table(res_df_train, ["mean", "std"])
train_df_std

**Test set**:

In [None]:
res_df_test = make_model_stats_table(
    X_test_w, y_test_w, X_test_nw, y_test_nw, loaded_models
)

test_df = aggregate_model_stats_table(res_df_test, "mean")
test_df_std = aggregate_model_stats_table(res_df_test, ["mean", "std"])
test_df_std

**Plots:**

In [None]:
for model_name in test_df.index:
    alg_name = (
        "sslalm_aug"
        if model_name.startswith("sslalm_mu0")
        else model_name.split("_")[0]
    )
    os.makedirs(os.path.dirname(f"./plots/{alg_name}/{DATASET}/"), exist_ok=True)

In [None]:
from utils.plotting import spider_line


f = spider_line(train_df)
f = spider_line(test_df)

**Distribution of predictions by group:**

In [None]:
predictions_0 = {}
predictions_1 = {}

for model_name, model in loaded_models:
    preds_0 = torch.nn.functional.sigmoid(model(X_test_w)).detach().numpy()
    preds_1 = torch.nn.functional.sigmoid(model(X_test_nw)).detach().numpy()
    try:
        predictions_0[model_name].append(preds_0)
        predictions_1[model_name].append(preds_1)
    except:
        predictions_0[model_name] = [preds_0]
        predictions_1[model_name] = [preds_1]

for name in np.unique([name for name, _ in loaded_models]):
    predictions_0[name] = np.concatenate(predictions_0[name])
    predictions_1[name] = np.concatenate(predictions_1[name])

In [None]:
import seaborn as sns

for model_name in np.unique([name for name, _ in loaded_models]):
    # predictions_0 = torch.nn.functional.sigmoid(model(X_test_w)).detach().numpy()
    # predictions_1 = torch.nn.functional.sigmoid(model(X_test_nw)).detach().numpy()

    sns.kdeplot(
        predictions_0[model_name].squeeze(),
        label=sensitive_value_0,
        color="blue",
        fill=True,
        bw_adjust=0.4,
    )  # ,clip=[0,1],common_norm=True)
    sns.kdeplot(
        predictions_1[model_name].squeeze(),
        label=sensitive_value_1,
        color="red",
        fill=True,
        bw_adjust=0.4,
    )  # ,clip=[0,1],common_norm=True)
    plt.xlim(-0.1, 1.1)
    plt.ylim(0, 22)
    plt.xlabel("Predictions", fontsize=20)
    plt.ylabel("Density", fontsize=20)
    # plt.title(model_name, fontsize=10)
    # plt.title(alg)
    # print(alg)
    alg_name = (
        "sslalm_aug"
        if model_name.startswith("sslalm_mu0")
        else model_name.split("_")[0]
    )
    plt.savefig(f"./plots/{alg_name}/{DATASET}/dist")
    plt.legend()
    plt.show()

### **Model plots**

**We choose one model per algorithm to make some useful plots**

For now, choose the model with the highest mean AUC:

In [None]:
select_by = "auc"

In [None]:
best_models = {}
algs = res_df_test.Algorithm.unique()
for alg in algs:
    alg_df = res_df_test[res_df_test.Algorithm == alg]
    if select_by == "auc":
        model = loaded_models[alg_df.AUC_M.idxmax()]
    elif select_by == "wd":
        model = loaded_models[alg_df.Wd.idxmin()]
    best_models[alg] = model

#### Subgroup ROC

**TPR-FPR plot**

In [None]:
# Function to generate predictions and plot ROC curve
def plot_roc_curve_pr(ax, predictions, targets, sensitive_value):
    # Compute ROC curve and area under the curve
    fpr, tpr, thresholds = roc_curve(targets, predictions)
    roc_auc = auc(fpr, tpr)
    # Plot ROC curve
    ax.plot(fpr, tpr, label=f"Sensitive={sensitive_value}, AUC = {roc_auc:.2f}")
    tpr_minus_fpr = tpr - fpr
    # Find the threshold that maximizes TPR - FPR difference
    optimal_threshold_index = np.argmax(tpr_minus_fpr)
    optimal_threshold = thresholds[optimal_threshold_index]
    ax.scatter(
        fpr[optimal_threshold_index],
        tpr[optimal_threshold_index],
        c="blue" if sensitive_value == sensitive_value_0 else "red",
        label=f"Optimal Threshold {sensitive_value} {optimal_threshold:.2f}",
    )


for alg, (model_name, model) in best_models.items():
    f = plt.figure()
    ax = f.subplots()
    ax.set_title(alg)
    with torch.inference_mode():
        predictions_0 = model(X_test_w)
        predictions_1 = model(X_test_nw)
        # Plot ROC for sensitive attribute A=0
        plot_roc_curve_pr(
            ax, predictions_0, y_test_w, sensitive_value=sensitive_value_0
        )
        # Plot ROC for sensitive attribute A=1
        plot_roc_curve_pr(
            ax, predictions_1, y_test_nw, sensitive_value=sensitive_value_1
        )
        ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random Classifier")
        ax.set_xlabel("False Positive Rate", fontsize=24)
        ax.set_ylabel("True Positive Rate", fontsize=24)
        ax.legend()

**TNR-FNR plot**

In [None]:
# Function to generate predictions and plot ROC curve
def plot_roc_curve_nr(ax, predictions, targets, sensitive_value):
    # Convert PyTorch tensors to numpy arrays
    # predictions = predictions.detach().numpy()
    # targets = targets.numpy()

    # Compute ROC curve and area under the curve
    fpr, tpr, thresholds = roc_curve(targets, predictions)
    fnr = 1 - tpr
    tnr = 1 - fpr
    roc_auc = auc(tnr, fnr)
    # Plot ROC curve
    ax.plot(tnr, fnr, label=f"Sensitive={sensitive_value}, AUC = {roc_auc:.2f}")

    tnr_minus_fnr = tnr - fnr

    # Find the threshold that maximizes TPR - FPR difference
    optimal_threshold_index = np.argmax(tnr_minus_fnr)
    optimal_threshold = thresholds[optimal_threshold_index]
    ax.scatter(
        tnr[optimal_threshold_index],
        fnr[optimal_threshold_index],
        c="blue" if sensitive_value == sensitive_value_0 else "red",
        label=f"Optimal Threshold {sensitive_value} {optimal_threshold:.2f}",
    )


for alg, (model_name, model) in best_models.items():
    f = plt.figure()
    ax = f.subplots()
    ax.set_title(alg)
    with torch.inference_mode():
        predictions_0 = model(X_test_w)
        predictions_1 = model(X_test_nw)
        # Plot ROC for sensitive attribute A=0
        plot_roc_curve_nr(
            ax, predictions_0, y_test_w, sensitive_value=sensitive_value_0
        )
        # Plot ROC for sensitive attribute A=1
        plot_roc_curve_nr(
            ax, predictions_1, y_test_nw, sensitive_value=sensitive_value_1
        )
        ax.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Random Classifier")
        ax.set_xlabel("False Negative Rate", fontsize=24)
        ax.set_ylabel("True Negative Rate", fontsize=24)
        ax.legend()