This notebook will demonstrate how to use the **constrained training algorithms** implemented in this toolkit.

To train a network, instantiate an algorithm, passing to it the model, the dataset, a list of `FairnessConstraint`s and the algorithm's hyperparameters.

Load and prepare data from `folktables`:

In [None]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from folktables import ACSDataSource, ACSIncome

device = 'cpu'
torch.set_default_device(device)

# load folktables data
data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
acs_data = data_source.get_data(states=["OK"], download=True)
features, labels, groups = ACSIncome.df_to_numpy(acs_data)
# split
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
    features, labels, groups, test_size=0.2, random_state=42)
# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# make into a pytorch dataset, remove the sensitive attribute (RAC1P)
features_train = torch.tensor(X_train, dtype=torch.float32)[:,:-1].to(device)
labels_train = torch.tensor(y_train,dtype=torch.float32).to(device)
dataset_train = torch.utils.data.TensorDataset(features_train, labels_train)

c_batch_size = 128
min_subgroup_size = c_batch_size
# For each subgroup, FairnessConstraint needs a list of indices of samples belonging to that subgroup
group_indices_train = [
    np.nonzero(groups_train == group_id)[0] for group_id in np.unique(groups_train)
    if np.count_nonzero(groups == group_id) > min_subgroup_size
]

# repeat for test set
features_test = torch.tensor(X_test, dtype=torch.float32)[:,:-1].to(device)
labels_test = torch.tensor(y_test,dtype=torch.float32).to(device)
dataset_test = torch.utils.data.TensorDataset(features_test, labels_test)
group_indices_test = [
    np.nonzero(groups_test == group_id)[0] for group_id in np.unique(groups_test)
    if np.count_nonzero(groups == group_id) > min_subgroup_size
]

print(f'Protected attribute: {ACSIncome.group}')
print(f'Number of subgroups considered: {len(group_indices_train)}')
print(f'Size of subgroups: {[len(g) for g in group_indices_train]}')

Protected attribute: RAC1P
Number of subgroups considered: 6
Size of subgroups: [10685, 794, 1213, 259, 315, 997]


Let's say we want to add an equal loss constraint on the model.

We can do that by using the `FairnessConstraint` class, which will handle sampling (if possible, sampling an equal number of samples from each relevant subgroup in each minibatch), and passing it the function that will calculate the value of the constraint.

In [None]:
from itertools import combinations
from src.constraints.constraint import FairnessConstraint
from src.constraints.constraint_fns import abs_loss_equality, tpr_equality

# the protected attribute is "Race" (RAC1P)
# we put a pairwise constraint on each combination of subgroups
constraint_bound = 0.01
constraints = []
for gr1, gr2 in combinations(group_indices_train, 2):
    c = FairnessConstraint(
        dataset=dataset_train,
        group_indices=[gr1, gr2],
        # subtract bound from absolute loss difference to bring constraint to form $c \leq 0$
        # also implemented are fairret wrappers, e.g. equal TPR
        fn=lambda model, samples: abs_loss_equality(torch.nn.BCEWithLogitsLoss(), model, samples) - constraint_bound,
        batch_size=c_batch_size,
        device=device
    )

    constraints.append(c)

print(f'Number of constraints: {len(constraints)}')

Number of constraints: 15


In [None]:
# helper function to analyze model performance

def model_stats(model, features, labels, groups, constraints, constraint_bound):
    with torch.inference_mode():
        gr_ind = list(combinations(groups, 2))
        vals = []
        acc_dif = []
        for i, c in enumerate(constraints):
            idx1, idx2 = gr_ind[i]
            val = c.eval(model, [(features[idx1], labels[idx1]), (features[idx2], labels[idx2])]) + constraint_bound
            vals.append(val.cpu().numpy().item())

            logits1 = model(features[idx1])
            logits2 = model(features[idx2])
            outs1 = torch.nn.functional.sigmoid(logits1).cpu().numpy()
            outs2 = torch.nn.functional.sigmoid(logits2).cpu().numpy()
            preds1 = (outs1.T > 0.5).astype(float)
            preds2 = (outs2.T > 0.5).astype(float) 
            acc1 = np.mean(preds1 == labels[idx1].cpu().numpy())
            acc2 = np.mean(preds2 == labels[idx2].cpu().numpy())
            acc_dif.append(abs(acc1-acc2))

        print(f'constraints (should be <= {constraint_bound}):')
        print(np.round(vals, decimals=3))
        print(f'c mean: {np.mean(vals)}')
        print(f'c min: {np.min(vals)}')
        print(f'c max: {np.max(vals)}')
        print('---')

        logits = model(features)
        outs = torch.nn.functional.sigmoid(logits).cpu().numpy()
        preds = (outs.T > 0.5).astype(float)
        acc = np.sum(preds == labels.cpu().numpy())/len(labels)
        print(f'accuracy: {acc}')
        print('accuracy abs. difference:')
        print(np.round(acc_dif, decimals=3))
        print(f'acc abs dif mean: {np.mean(acc_dif)}')
        print(f'acc abs dif min: {np.min(acc_dif)}')
        print(f'acc abs dif max: {np.max(acc_dif)}')

---
---

For comparison, let us first train a model **without constraints**.

Define a model:

In [None]:
from torch.nn import Sequential
hsize1 = 64
hsize2 = 32
model_uncon = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
).to(device)

And start training:

In [None]:
from torch.optim import Adam

loader = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=(device != 'cuda'))
loss = torch.nn.BCEWithLogitsLoss()
optimizer = Adam(model_uncon.parameters())
epochs = 100

for epoch in range(epochs):
    losses = []
    for batch_feat, batch_label in loader:
        optimizer.zero_grad()

        logit = model_uncon(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit.squeeze(), batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

Epoch: 0, loss: 0.4642235327711595
Epoch: 1, loss: 0.4293003898513104
Epoch: 2, loss: 0.42335555399768054
Epoch: 3, loss: 0.4196424624283931
Epoch: 4, loss: 0.4175680346920022
Epoch: 5, loss: 0.41486163972876966
Epoch: 6, loss: 0.4139517352783254
Epoch: 7, loss: 0.41123612145228045
Epoch: 8, loss: 0.41036991043282406
Epoch: 9, loss: 0.40926071751995813
Epoch: 10, loss: 0.40827423497103155
Epoch: 11, loss: 0.4069685044565371
Epoch: 12, loss: 0.40704553817132755
Epoch: 13, loss: 0.40542702303667155
Epoch: 14, loss: 0.4030081082933715
Epoch: 15, loss: 0.402854029188997
Epoch: 16, loss: 0.40264014064866516
Epoch: 17, loss: 0.4018292943947017
Epoch: 18, loss: 0.4005603741388768
Epoch: 19, loss: 0.39991264747056576
Epoch: 20, loss: 0.39948308022160617
Epoch: 21, loss: 0.3977042373070227
Epoch: 22, loss: 0.3978979098277965
Epoch: 23, loss: 0.3971360599222992
Epoch: 24, loss: 0.3965825682306396
Epoch: 25, loss: 0.39528360883040087
Epoch: 26, loss: 0.3952076831566436
Epoch: 27, loss: 0.39478948

Let's now analyze how well the **unconstrained** model does in terms of constraints:

In [None]:
print('TRAIN')
model_stats(model_uncon, features_train, labels_train, group_indices_train, constraints, constraint_bound)

TRAIN
constraints (should be <= 0.01):
[0.087 0.018 0.09  0.172 0.056 0.069 0.003 0.085 0.031 0.072 0.154 0.039
 0.082 0.033 0.116]
c mean: 0.07374673883120218
c min: 0.0027457773685455322
c max: 0.1718178391456604
---
accuracy: 0.8260657224586618
accuracy abs. difference:
[0.059 0.004 0.074 0.099 0.014 0.055 0.015 0.041 0.045 0.07  0.096 0.01
 0.026 0.06  0.086]
acc abs dif mean: 0.05016161364761892
acc abs dif min: 0.003959985818891454
acc abs dif max: 0.09949120187772498


In [None]:
print('TEST')
model_stats(model_uncon, features_test, labels_test, group_indices_test, constraints, constraint_bound)

TEST
constraints (should be <= 0.01):
[0.051 0.069 0.032 0.023 0.079 0.119 0.082 0.028 0.029 0.037 0.091 0.148
 0.054 0.111 0.057]
c mean: 0.06724767088890075
c min: 0.02253669500350952
c max: 0.14789214730262756
---
accuracy: 0.7915736607142857
accuracy abs. difference:
[0.036 0.023 0.009 0.05  0.038 0.059 0.045 0.015 0.002 0.014 0.074 0.061
 0.059 0.047 0.012]
acc abs dif mean: 0.03640103546705118
acc abs dif min: 0.0021533979352713617
acc abs dif max: 0.0738071478730088


---
---

Now let us train the same model with one of the **constrained** training algorithms:

In [None]:
from src.algorithms import SSLALM, SSG
from torch.nn import Sequential

hsize1 = 64
hsize2 = 32
model = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

optimizer = SSG(
    net=model,
    data=dataset_train,
    loss=torch.nn.BCEWithLogitsLoss(),
    constraints=constraints
)

history = optimizer.optimize(
    max_runtime=180,
    batch_size=32,
    seed=42,
    device=device,
    ctol=1.0,
    f_stepsize_rule='dimin',
    f_stepsize=0.1,
    c_stepsize_rule='dimin',
    c_stepsize=0.1,
    verbose=False
    )

In [None]:
print('TRAIN')
model_stats(model, features_train, labels_train, group_indices_train, constraints, constraint_bound)

TRAIN
constraints (should be <= 0.01):
[0.017 0.017 0.001 0.013 0.015 0.    0.015 0.003 0.002 0.015 0.003 0.002
 0.012 0.014 0.002]
c mean: 0.008758107821146647
c min: 0.0002809762954711914
c max: 0.0168074369430542
---
accuracy: 0.7457615293378915
accuracy abs. difference:
[0.085 0.067 0.077 0.144 0.064 0.018 0.008 0.059 0.021 0.01  0.077 0.003
 0.067 0.013 0.079]
acc abs dif mean: 0.052704355578546074
acc abs dif min: 0.0027039072700376643
acc abs dif max: 0.14368310418848562


In [None]:
print('TEST')
model_stats(model, features_test, labels_test, group_indices_test, constraints, constraint_bound)

TEST
constraints (should be <= 0.01):
[0.019 0.002 0.026 0.004 0.022 0.017 0.045 0.015 0.003 0.028 0.002 0.02
 0.03  0.048 0.018]
c mean: 0.01985848347345988
c min: 0.0017260313034057617
c max: 0.04803180694580078
---
accuracy: 0.7452566964285714
accuracy abs. difference:
[0.086 0.02  0.088 0.061 0.087 0.066 0.175 0.026 0.001 0.108 0.041 0.067
 0.149 0.175 0.026]
acc abs dif mean: 0.0784395204252176
acc abs dif min: 0.0007811345451474994
acc abs dif max: 0.17542678822737567
