In this notebook, we will try to integrate the **Maximum Subgroup Discrepancy** tool from `humancompatible.detect` into our **fairness-constrained training** routine.

**MSD** automatically detects groups (as defined by arbitrary combinations of protected attribute values) that are predicted at a higher or lower rate than their complement.

Our idea is to periodically detect new biases during the model training process, and dynamically add fairness constraints based on detected discrepancies.

In [None]:
# load and prepare data

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from folktables import ACSDataSource, ACSIncome, generate_categories

# load folktables data
data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person")
acs_data = data_source.get_data(states=["CA"], download=True)
definition_df = data_source.get_definitions(download=True)
categories = generate_categories(
    features=ACSIncome.features, definition_df=definition_df
)
df_feat, df_labels, _ = ACSIncome.df_to_pandas(
    acs_data, categories=categories, dummies=True
)

sens_feature_names = [
    "SCHL",
    "SEX",
    "RAC1P",
    "AGEP",
]  # leave OCCP out for now cause it has 524 (!) values
sens_col_names = [
    col for col in df_feat.columns if col.startswith(tuple(sens_feature_names))
]

features = df_feat.drop(columns=sens_col_names).to_numpy(dtype="float")
groups = df_feat[sens_col_names].to_numpy(dtype="float")
labels = df_labels.to_numpy(dtype="float")

# split
indices = np.arange(len(features))
(
    X_train,
    X_test,
    y_train,
    y_test,
    groups_train,
    groups_test,
    indices_train,
    indices_test,
) = train_test_split(features, labels, groups, indices, test_size=0.2, random_state=42)
# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# make into a pytorch dataset, remove the sensitive attribute
features_train = torch.tensor(X_train, dtype=torch.float32)
labels_train = torch.tensor(y_train, dtype=torch.float32)
sens_train = torch.tensor(groups_train)
dataset_train = torch.utils.data.TensorDataset(features_train, labels_train)

---
---

As the constraint, we use `NormLoss` from `fairret`, which penalizes the model based on the ratio between the value of a statistic for each group and the overall value: $\sum_{s\in S}{|1-\frac{f(\theta, X_s, y_s)}{f(\theta, X, y)}|}$.

**Every N epochs (or until the loss stops decreasing?) run MSD and add a constraint for that pair of subgroups**

To run MSD efficiently, we need to get rid of the one-hot encoding, and do some minor transformations on the data. We will use a separate DataFrame for MSD inputs.

Issues: each time we add a constraint, we have to reevaluate sensitive attribute encodings.

In [None]:
from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer

# for MSD, we need to do a bit of additional preprocessing:
# - add negations for binary features as a separate column (already done via categories+get_dummies)
# - discretize continuous features

df_feat_msd, df_labels, _ = ACSIncome.df_to_pandas(
    acs_data, categories=categories, dummies=True
)
df_feat_msd, df_labels = (
    df_feat_msd[sens_col_names].iloc[indices_train],
    df_labels.iloc[indices_train],
)

cont_sens_feature_names = ["AGEP"]
ds = KBinsDiscretizer(n_bins=10, encode="onehot-dense")
cont_bins = ds.fit_transform(df_feat_msd[cont_sens_feature_names])
for fidx, cont_feature in enumerate(cont_sens_feature_names):
    for i in range(ds.n_bins_[fidx]):
        lb, ub = ds.bin_edges_[fidx][i], ds.bin_edges_[fidx][i + 1]
        df_feat_msd["_".join([cont_feature, str(lb), str(ub)])] = cont_bins[
            :, i
        ].astype(bool)

df_feat_msd.drop(columns=cont_sens_feature_names, inplace=True)
df_feat_msd = df_feat_msd.convert_dtypes(convert_boolean=True)

Conjuncts_kwargs = {"solver": "glpk", "time_limit": 600}

# cat_features = list(set(df_feat_msd.columns) - set(cont_feature_names))
# FEATURE_PROCESSING = {
#     "POBP": lambda x: int(x) // 100,  # group by continents + US
#     "OCCP": lambda x: int(x) // 100,
#     "PUMA": lambda x: int(x) // 100,
#     "POWPUMA": lambda x: int(x) // 1000,
# }
# oe = OrdinalEncoder()
# df_feat_msd[cat_features] = oe.fit_transform(df_feat_msd[cat_features])

MSD_kwargs = {
    "n_samples": 200,
}

df_feat_msd



Unnamed: 0,"SCHL_1 or more years of college credit, no degree",SCHL_12th grade - no diploma,SCHL_Associate's degree,SCHL_Bachelor's degree,SCHL_Doctorate degree,SCHL_GED or alternative credential,SCHL_Grade 1,SCHL_Grade 10,SCHL_Grade 11,SCHL_Grade 2,...,AGEP_17.0_23.0,AGEP_23.0_28.0,AGEP_28.0_32.0,AGEP_32.0_37.0,AGEP_37.0_42.0,AGEP_42.0_47.0,AGEP_47.0_52.0,AGEP_52.0_57.0,AGEP_57.0_63.0,AGEP_63.0_94.0
82350,True,False,False,False,False,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,False
23973,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,False,False,False
92330,False,False,False,True,False,False,False,False,False,False,...,False,False,False,True,False,False,False,False,False,False
115743,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
82575,False,False,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
119879,False,False,True,False,False,False,False,False,False,False,...,False,True,False,False,False,False,False,False,False,False
103694,False,False,False,False,False,False,False,False,False,False,...,False,False,False,True,False,False,False,False,False,False
131932,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
146867,False,False,False,False,False,False,False,False,False,False,...,False,False,False,False,True,False,False,False,False,False


In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss

dataset = torch.utils.data.TensorDataset(
    features_train, torch.tensor(df_feat_msd.to_numpy(dtype=bool)), labels_train
)

# we need bigger batch size to make sure that each batch contains members of each potential subgroup
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=96)

criterion = torch.nn.BCEWithLogitsLoss()

# MSD detects discrepancies in positive predictions
statistic = PositiveRate()
fair_criterion = NormLoss(statistic=statistic)
fair_crit_bound = 0.25

In [None]:
from torch.nn import Sequential

hsize1 = 32
hsize2 = 16
model_con = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1),
)

from humancompatible.train.optim import SSLALM_Adam

optimizer = SSLALM_Adam(
    params=model_con.parameters(),
    m=0,
    lr=0.05,
    dual_lr=0.05,
    dual_bound=50,
    rho=1.0,
    mu=2.0,
)

# add slack variables
slack_vars = torch.zeros(0, requires_grad=True)
# optimizer.add_param_group(param_group={"params": slack_vars, "name": "slack"})

epochs = 1000
msd_interval = 50

In [None]:
from humancompatible.detect import detect_and_score
from humancompatible.detect.helpers import report_subgroup_bias
from humancompatible.detect.methods.msd import get_conjuncts_MSD
import random
from tqdm import tqdm

constraints = []

for epoch in range(epochs):
    # clear epoch stats
    loss_log = []
    c_log = []
    duals_log = []

    for batch_input, batch_sens, batch_label in tqdm(dataloader, leave=False):
        # calculate constraints and constraint grads
        out = model_con(batch_input)
        c_log.append([])
        for j, conj_indices in enumerate(constraints):
            # one-hot evaluation of rule `j`
            sens_attrs = batch_sens[:, conj_indices]
            rule_eval = torch.prod(sens_attrs, dim=1, keepdim=True, dtype=bool)
            batch_rule_groups = torch.cat([rule_eval, ~rule_eval], dim=1)
            # evaluate fairret loss
            fair_loss = fair_criterion(out, batch_rule_groups)
            fair_constraint = torch.max(fair_loss - fair_crit_bound, torch.zeros(1))
            fair_constraint.backward(retain_graph=True)
            # dual update
            optimizer.dual_step(j, c_val=fair_constraint)
            optimizer.zero_grad()
            c_log[-1].append([fair_constraint.detach().item()])
        duals_log.append(optimizer._dual_vars.detach())

        # primal eval and update
        loss = criterion(out, batch_label)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_log.append(loss.detach().numpy())
        # slack variables must be non-negative. this is the "projection" step from the SSL-ALM paper
        with torch.no_grad():
            for s in slack_vars:
                if s < 0:
                    s.zero_()

    # dimin dual stepsize
    # optimizer.dual_lr *= 0.98

    # print epoch stats
    if epoch % msd_interval // 2 == 0:
        print(
            f"Epoch: {epoch}, "
            f"loss: {np.mean(loss_log)}, "
            f"constraints: {np.mean(c_log, axis=0)}, "
            f"dual: {np.mean(duals_log, axis=0)}"
        )

    if (
        epoch > 0
        and epoch % ((1 + len(constraints)) * msd_interval) == 0
        and len(constraints) <= 5
    ):
        ### find new MSD ###
        with torch.no_grad():
            indices = random.sample(range(len(df_feat_msd)), MSD_kwargs["n_samples"])
            MSD_sample_model_input = dataset[indices][0]
            MSD_sample_x = df_feat_msd.iloc[indices].to_numpy(dtype=bool)
            # get binary predictions on sample
            MSD_sample_y = (
                torch.nn.functional.sigmoid(model_con(MSD_sample_model_input)).squeeze()
                > 0.5
            ).numpy()
        if not np.any(MSD_sample_y):
            continue
        conj_indices = get_conjuncts_MSD(MSD_sample_x, MSD_sample_y, **Conjuncts_kwargs)
        print(f"Adding new rule: {df_feat_msd.columns[conj_indices].to_list()}")
        constraints.append(conj_indices)
        optimizer.add_constraint()

  0%|          | 0/1631 [00:00<?, ?it/s]

                                                    

Epoch: 0, loss: 0.4456871747970581, constraints: [], dual: []


                                                    

Epoch: 1, loss: 0.4195505678653717, constraints: [], dual: []


                                                    

Epoch: 50, loss: 0.3493681848049164, constraints: [], dual: []
Adding new rule: ["SCHL_Bachelor's degree"]


                                                    

Epoch: 51, loss: 0.46873563528060913, constraints: [[0.02995319]], dual: [2.4426827]


                                                    

Epoch: 100, loss: 0.5981041193008423, constraints: [[0.00090436]], dual: [11.135156]
Adding new rule: ['AGEP_17.0_23.0']


                                                    

Epoch: 101, loss: 0.6369485259056091, constraints: [[0.        ]
 [0.01366439]], dual: [11.134973   1.1143476]


                                                    

Epoch: 150, loss: 0.6178159713745117, constraints: [[0.00038175]
 [0.00100983]], dual: [12.206237   6.0645304]
Adding new rule: ['SEX_Male', 'RAC1P_White alone']


                                                    

Epoch: 151, loss: 0.6163669228553772, constraints: [[0.00018188]
 [0.00054146]
 [0.00025451]], dual: [12.220855    6.108724    0.02075543]


                                                    

Epoch: 200, loss: 0.6117810606956482, constraints: [[0.00028794]
 [0.00034323]
 [0.00021217]], dual: [13.2265625   8.623331    0.70239264]


                                                    

Epoch: 201, loss: 0.6182326078414917, constraints: [[0.00031957]
 [0.00027968]
 [0.00012087]], dual: [13.252546    8.646427    0.71226686]


                                                    

Epoch: 250, loss: 0.6148971915245056, constraints: [[2.05031536e-04]
 [4.80884244e-04]
 [2.55993361e-05]], dual: [14.162274  10.552781   1.1865553]


                                                    

Epoch: 251, loss: 0.6191721558570862, constraints: [[1.96570381e-04]
 [9.00425332e-04]
 [1.74309051e-05]], dual: [14.178393  10.626217   1.1879902]


                                                    

Epoch: 300, loss: 0.6103061437606812, constraints: [[1.14646702e-04]
 [7.34636858e-04]
 [3.26486178e-05]], dual: [14.881484  12.41226    1.4319855]


                                                    

Epoch: 301, loss: 0.6113454699516296, constraints: [[2.43223386e-04]
 [1.40388782e-04]
 [8.17946391e-05]], dual: [14.901716  12.423829   1.4386395]


                                                    

KeyboardInterrupt: 

In [None]:
batch_sens[:, 0].shape

torch.Size([96])

In [None]:
from humancompatible.detect.binarizer import Bin


def rule_to_indicator(rule, features):
    """
    Given a rule and a tensor of one-hot variables, creates a new indicator column denoting the rule.
    `features`must have shape (`batch_size`,`n_features`)
    """
    for feature_idx, bin in rule:
        features[feature_idx]

In [None]:
with torch.no_grad():
    f = model_con.forward(dataset[:][0])

In [None]:
torch.mean(((torch.nn.functional.sigmoid(f) > 0.5) == dataset[:][2]).to(torch.float16))

tensor(0.6255, dtype=torch.float16)

In [None]:
from fairret.statistic import PositiveRate

preds = torch.nn.functional.sigmoid(model_con(features_train))
pr = PositiveRate()
pr(preds, sens_train)

tensor([0.4769, 0.4559, 0.4392, 0.4757, 0.4919, 0.5193, 0.4537, 0.4199, 0.4385,
        0.4023, 0.3916, 0.4016, 0.4039, 0.4031, 0.4086, 0.4025, 0.4282, 0.4300,
        0.4723, 0.5128, 0.4259, 0.4359, 0.5164, 0.4432, 0.4597, 0.4572, 0.4796,
        0.5135, 0.4557, 0.4454, 0.4731, 0.4720, 0.4660, 0.4348, 0.4614, 0.4749],
       dtype=torch.float64, grad_fn=<IndexPutBackward0>)

In [None]:
fair_criterion(model_con(features_train), sens_train)

tensor(2.4357, dtype=torch.float64, grad_fn=<SumBackward0>)

In [None]:
from matplotlib import pyplot as plt

window_len = 5
c_mavg = np.array(
    [np.mean(ep_c_log[i : i + window_len]) for i in range(0, len(ep_c_log), window_len)]
)
plt.plot(np.array(c_mavg).flatten(), lw=0.05)
plt.hlines(y=fair_crit_bound, xmin=0, xmax=len(c_mavg), colors="black", ls="--")