This notebook will demonstrate how to use the **constrained training algorithms** implemented in this toolkit.

To train a network, instantiate an algorithm, passing to it the model, the dataset, a list of `FairnessConstraint`s and the algorithm's hyperparameters.

Load and prepare data from `folktables`:

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from folktables import ACSDataSource, ACSIncome

# load folktables data
data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
acs_data = data_source.get_data(states=["OK"], download=True)
features, labels, groups = ACSIncome.df_to_numpy(acs_data)
# split
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
    features, labels, groups, test_size=0.2, random_state=42)
# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# make into a pytorch dataset, remove the sensitive attribute (RAC1P)
features_train = torch.tensor(X_train, dtype=torch.float32)[:,:-1]
labels_train = torch.tensor(y_train,dtype=torch.float32)
dataset_train = torch.utils.data.TensorDataset(features_train, labels_train)

c_batch_size = 128
min_subgroup_size = c_batch_size
# For each subgroup, FairnessConstraint needs a list of indices of samples belonging to that subgroup
group_indices_train = [
    np.nonzero(groups_train == group_id)[0] for group_id in np.unique(groups_train)
    if np.count_nonzero(groups == group_id) > min_subgroup_size
]

# repeat for test set
features_test = torch.tensor(X_test, dtype=torch.float32)[:,:-1]
labels_test = torch.tensor(y_test,dtype=torch.float32)
dataset_test = torch.utils.data.TensorDataset(features_test, labels_test)
group_indices_test = [
    np.nonzero(groups_test == group_id)[0] for group_id in np.unique(groups_test)
    if np.count_nonzero(groups == group_id) > min_subgroup_size
]

print(f'Protected attribute: {ACSIncome.group}')
print(f'Number of subgroups considered: {len(group_indices_train)}')
print(f'Size of subgroups: {[len(g) for g in group_indices_train]}')

Protected attribute: RAC1P
Number of subgroups considered: 7
Size of subgroups: [10685, 794, 1213, 59, 259, 315, 997]


Let's say we want to add an equal loss constraint on the model.

We can do that by using the `FairnessConstraint` class, which will handle sampling (if possible, sampling an equal number of samples from each relevant subgroup in each minibatch), and passing it the function that will calculate the value of the constraint.

In [None]:
from itertools import combinations
from src.constraints.constraint import FairnessConstraint
from src.constraints.constraint_fns import abs_loss_equality, tpr_equality

# the protected attribute is "Race" (RAC1P)
# we put a pairwise constraint on each combination of subgroups
constraint_bound = 0.01
constraints = []
for gr1, gr2 in combinations(group_indices_train, 2):
    c = FairnessConstraint(
        dataset=dataset_train,
        group_indices=[gr1, gr2],
        # subtract bound from absolute loss difference to bring constraint to form $c \leq 0$
        # also implemented are fairret wrappers, e.g. equal TPR
        fn=lambda model, samples: abs_loss_equality(torch.nn.BCEWithLogitsLoss(), model, samples) - constraint_bound,
        batch_size=c_batch_size
    )

    constraints.append(c)

print(f'Number of constraints: {len(constraints)}')

Number of constraints: 21


In [None]:
# helper function to analyze model performance

def model_stats(model, features, labels, groups, constraints, constraint_bound):
    with torch.inference_mode():
        gr_ind = list(combinations(groups, 2))
        vals = []
        acc_dif = []
        for i, c in enumerate(constraints):
            idx1, idx2 = gr_ind[i]
            val = c.eval(model, [(features[idx1], labels[idx1]), (features[idx2], labels[idx2])]) + constraint_bound
            vals.append(val.numpy().item())

            logits1 = model(features[idx1])
            logits2 = model(features[idx2])
            outs1 = torch.nn.functional.sigmoid(logits1).numpy()
            outs2 = torch.nn.functional.sigmoid(logits2).numpy()
            preds1 = (outs1.T > 0.5).astype(float)
            preds2 = (outs2.T > 0.5).astype(float) 
            acc1 = np.mean(preds1 == labels[idx1].numpy())
            acc2 = np.mean(preds2 == labels[idx2].numpy())
            acc_dif.append(abs(acc1-acc2))

        print(f'constraints (should be <= {constraint_bound}):')
        print(np.round(vals, decimals=3))
        print(f'c mean: {np.mean(vals)}')
        print(f'c min: {np.min(vals)}')
        print(f'c max: {np.max(vals)}')
        print('---')

        logits = model(features)
        outs = torch.nn.functional.sigmoid(logits).numpy()
        preds = (outs.T > 0.5).astype(float)
        acc = np.sum(preds == labels.numpy())/len(labels)
        print(f'accuracy: {acc}')
        print('accuracy abs. difference:')
        print(np.round(acc_dif, decimals=3))
        print(f'acc abs dif mean: {np.mean(acc_dif)}')
        print(f'acc abs dif min: {np.min(acc_dif)}')
        print(f'acc abs dif max: {np.max(acc_dif)}')

---
---

For comparison, let us first train a model **without constraints**.

Define a model:

In [None]:
from torch.nn import Sequential
hsize1 = 64
hsize2 = 32
model_uncon = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

And start training:

In [None]:
from torch.optim import Adam

loader = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=True)
loss = torch.nn.BCEWithLogitsLoss()
optimizer = Adam(model_uncon.parameters())
epochs = 100

for epoch in range(epochs):
    losses = []
    for batch_feat, batch_label in loader:
        optimizer.zero_grad()

        logit = model_uncon(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit.squeeze(), batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

Epoch: 0, loss: 0.4604055254187967
Epoch: 1, loss: 0.42776412623269217
Epoch: 2, loss: 0.42131632336947533
Epoch: 3, loss: 0.4187857337362532
Epoch: 4, loss: 0.41589927753167494
Epoch: 5, loss: 0.41376505203412045
Epoch: 6, loss: 0.41254269820638
Epoch: 7, loss: 0.4111959484206246
Epoch: 8, loss: 0.41077977089610485
Epoch: 9, loss: 0.40981156352375236
Epoch: 10, loss: 0.40897336109940496
Epoch: 11, loss: 0.40737362638381974
Epoch: 12, loss: 0.4065214486992253
Epoch: 13, loss: 0.4054052820096591
Epoch: 14, loss: 0.4055704634104456
Epoch: 15, loss: 0.40459772452179876
Epoch: 16, loss: 0.40259366256317924
Epoch: 17, loss: 0.40246547143241124
Epoch: 18, loss: 0.40104373710762176
Epoch: 19, loss: 0.40087503417661147
Epoch: 20, loss: 0.40067285518827184
Epoch: 21, loss: 0.3999340041607086
Epoch: 22, loss: 0.3995278507604131
Epoch: 23, loss: 0.3977312415039965
Epoch: 24, loss: 0.39830209341432365
Epoch: 25, loss: 0.3964472377140607
Epoch: 26, loss: 0.3967079176634018
Epoch: 27, loss: 0.395766

Let's now analyze how well the **unconstrained** model does in terms of constraints:

In [None]:
print('TRAIN')
model_stats(model_uncon, features_train, labels_train, group_indices_train, constraints, constraint_bound)

TRAIN
constraints (should be <= 0.01):
[0.091 0.004 0.041 0.089 0.166 0.056 0.087 0.132 0.002 0.075 0.035 0.044
 0.086 0.163 0.052 0.13  0.207 0.097 0.077 0.033 0.11 ]
c mean: 0.08464908174106053
c min: 0.0016490817070007324
c max: 0.2070177048444748
---
accuracy: 0.8178329728598339
accuracy abs. difference:
[0.054 0.007 0.004 0.052 0.092 0.029 0.047 0.049 0.002 0.039 0.024 0.003
 0.045 0.085 0.022 0.047 0.088 0.025 0.041 0.022 0.063]
acc abs dif mean: 0.04003946992993007
acc abs dif min: 0.0017165420188091085
acc abs dif max: 0.09241556550868668


In [None]:
print('TEST')
model_stats(model_uncon, features_test, labels_test, group_indices_test, constraints, constraint_bound)

TEST
constraints (should be <= 0.01):
[0.018 0.063 0.464 0.092 0.043 0.103 0.081 0.482 0.11  0.025 0.085 0.401
 0.029 0.106 0.166 0.371 0.507 0.567 0.135 0.195 0.06 ]
c mean: 0.195485953773771
c min: 0.01803898811340332
c max: 0.5669184923171997
---
accuracy: 0.7759486607142857
accuracy abs. difference:
[0.012 0.05  0.277 0.065 0.073 0.039 0.063 0.289 0.078 0.061 0.026 0.227
 0.015 0.124 0.089 0.212 0.351 0.316 0.139 0.104 0.035]
acc abs dif mean: 0.12595712940542286
acc abs dif min: 0.012362990026661413
acc abs dif max: 0.3505747126436781


---
---

Now let us train the same model with one of the **constrained** training algorithms:

In [None]:
from src.algorithms import SSLALM, SSG
from torch.nn import Sequential

hsize1 = 64
hsize2 = 32
model = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

optimizer = SSG(
    net=model,
    data=dataset_train,
    loss=torch.nn.BCEWithLogitsLoss(),
    constraints=constraints,
)

history = optimizer.optimize(
    max_runtime=180,
    batch_size=32,
    seed=42,
    ctol=1.0,
    f_stepsize_rule='dimin',
    f_stepsize=0.1,
    c_stepsize_rule='dimin',
    c_stepsize=0.1
    )

 0|1.0|    1|0.70689|[-0.004  0.002 -0.001  0.009  0.005 -0.007 -0.009  0.000  0.003  0.001 -0.005  0.003  0.004 -0.008
 0|0.707|    2|0.69183|[ 0.011  0.000 -0.007  0.005  0.014 -0.003 -0.004 -0.001  0.002  0.004 -0.006 -0.004  0.004  0.016
 0|0.577|    3|0.67469|[ 0.013 -0.000 -0.006  0.008  0.020  0.003  0.003  0.005  0.001  0.007 -0.010  0.004  0.001  0.006
 0|0.5|    4|0.68116|[-0.005 -0.008 -0.003  0.020  0.023 -0.007  0.009  0.007 -0.004  0.005 -0.003  0.003 -0.008  0.013
 0|0.447|    5|0.67820|[ 0.003 -0.001 -0.002 -0.006  0.016 -0.009  0.001  0.017  0.006 -0.005  0.005 -0.000  0.002  0.018
 0|0.408|    6|0.67011|[ 0.021  0.002 -0.008  0.004  0.027  0.013 -0.005  0.010 -0.008  0.021 -0.009  0.002 -0.002  0.016
 0|0.378|    7|0.65511|[ 0.019  0.005 -0.007  0.030  0.039 -0.002 -0.009  0.010  0.002  0.014 -0.004  0.018  0.001  0.012
 0|0.354|    8|0.65059|[ 0.013  0.005 -0.006  0.017  0.043  0.010  0.005  0.011 -0.008  0.003  0.010 -0.006 -0.002  0.026
 0|0.333|    9|0.66005|[ 0.0

In [None]:
print('TRAIN')
model_stats(model, features_train, labels_train, group_indices_train, constraints, constraint_bound)

TRAIN
constraints (should be <= 0.01):
[0.009 0.007 0.005 0.002 0.009 0.01  0.003 0.015 0.011 0.    0.    0.012
 0.008 0.002 0.003 0.004 0.014 0.015 0.011 0.011 0.001]
c mean: 0.007250950449988956
c min: 0.00020450353622436523
c max: 0.014968276023864746
---
accuracy: 0.7695527802972162
accuracy abs. difference:
[0.065 0.032 0.062 0.047 0.136 0.045 0.032 0.126 0.018 0.071 0.02  0.094
 0.014 0.103 0.012 0.108 0.197 0.106 0.089 0.002 0.091]
acc abs dif mean: 0.06999388490045871
acc abs dif min: 0.0016845904508893117
acc abs dif max: 0.19714823782620394


In [None]:
print('TEST')
model_stats(model, features_test, labels_test, group_indices_test, constraints, constraint_bound)

TEST
constraints (should be <= 0.01):
[0.008 0.002 0.027 0.016 0.001 0.013 0.011 0.035 0.024 0.007 0.004 0.025
 0.014 0.004 0.015 0.011 0.028 0.04  0.017 0.029 0.011]
c mean: 0.016335090001424152
c min: 0.0013061761856079102
c max: 0.03964132070541382
---
accuracy: 0.7684151785714286
accuracy abs. difference:
[0.064 0.008 0.094 0.015 0.101 0.052 0.072 0.158 0.079 0.038 0.012 0.086
 0.007 0.11  0.06  0.079 0.195 0.146 0.116 0.067 0.05 ]
acc abs dif mean: 0.07653972119131433
acc abs dif min: 0.006648863698294205
acc abs dif max: 0.1954022988505747
