This notebook will demonstrate how to use the **constrained training algorithms** implemented in this toolkit, designed to augment your normal **PyTorch** training routine.

The algorithms implemented in the **humancompatible.train.optim** subpackage share a similar idea.

1. Define your PyTorch optimizer like you would normally. Let's call it`opt`.
2. Define a dual "optimizer" from`humancompatible.train.optim`; it will keep track of and update the dual variables. Let's call it`dual`.
3. In the training loop:
    - Compute the constraints and the objective function (loss);
    - Calculate the Lagrangian and update the dual variables using `dual.forward_update` method;
    - Run a `backward` pass through the Lagrangian (instead of the loss);
    - Run `opt.step`.

Let's train a simple classification model, putting a constraint on the norm of each layer's parameters.

In the canonical form, the algorithm expects equality constraints that are equal to 0; however, we can easily transform arbitrary inequality constraints to that form.

In [None]:
# load and prepare data

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from folktables import ACSDataSource, ACSIncome, generate_categories

torch.set_default_dtype(torch.double)

# load folktables data
data_source = ACSDataSource(survey_year="2018", horizon="1-Year", survey="person")
acs_data = data_source.get_data(states=["CA"], download=True)
definition_df = data_source.get_definitions(download=True)
categories = generate_categories(
    features=ACSIncome.features, definition_df=definition_df
)
df_feat, df_labels, _ = ACSIncome.df_to_pandas(
    acs_data, categories=categories, dummies=True
)

sens_cols = ["SEX_Female", "SEX_Male"]
features = df_feat.drop(columns=sens_cols).to_numpy(dtype=float)
groups = df_feat[sens_cols].to_numpy(dtype=float)
labels = df_labels.to_numpy(dtype=float)

# split into train and test
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
    features, labels, groups, test_size=0.3, random_state=42
)

# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# further split test into val and test
X_val, X_test, y_val, y_test, groups_val, groups_test = train_test_split(
    X_test, y_test, groups_test, test_size=0.5, random_state=42
)

# make into a pytorch dataset, remove the sensitive attribute
features_train = torch.tensor(X_train)
labels_train = torch.tensor(y_train)
sens_train = torch.tensor(groups_train)
dataset_train = torch.utils.data.TensorDataset(features_train, labels_train)
dataset_train_s = torch.utils.data.TensorDataset(features_train, sens_train, labels_train)

# make val and test into pytorch datasets
features_val = torch.tensor(X_val)
labels_val = torch.tensor(y_val)
sens_val = torch.tensor(groups_val)
dataset_val = torch.utils.data.TensorDataset(features_val, labels_val)
dataset_val_s = torch.utils.data.TensorDataset(features_val, sens_val, labels_val)

features_test = torch.tensor(X_test)
labels_test = torch.tensor(y_test)
sens_test = torch.tensor(groups_test)
dataset_test = torch.utils.data.TensorDataset(features_test, labels_test)
dataset_test_s = torch.utils.data.TensorDataset(features_test, sens_test, labels_test)


In [None]:
import torch
from humancompatible.train.dual_optim import ALM, PBM
from humancompatible.train.dual_optim import MoreauEnvelope
from torch.nn import Sequential
torch.manual_seed(0)

dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=False)

hsize1 = 12
hsize2 = 12
model = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1),
)

optimizer = MoreauEnvelope(
    torch.optim.Adam(model.parameters(), lr=0.001),
)

m = len(list(model.parameters()))
dual = ALM(m=6, lr=0.1, momentum=0.5, dampening=0.5)

constraint_bounds = [1.0] * m
epochs = 10
criterion = torch.nn.BCEWithLogitsLoss()

# do a dummy backward pass
model(dataset_train[0][0]).backward()
model.zero_grad()

In [None]:
for epoch in range(epochs):
    loss_log = []
    c_log = []
    slack_log = []
    duals_log = []
    for batch_input, batch_label in dataloader:
        # calculate constraints and constraint grads
        c_log.append([])
        constraints = []
        for i, param in enumerate(model.parameters()):
            norm = torch.linalg.norm(param)
            # convert constraint to equality
            norm_viol = torch.max(norm - constraint_bounds[i], torch.zeros(1, requires_grad=True))
            constraints.append(norm_viol)
            c_log[-1].append(norm.detach().numpy())

        constraints = torch.cat(constraints)
        batch_output = model(batch_input)
        bce_loss = criterion(batch_output, batch_label)
        
        lag_loss = dual.forward_update(bce_loss, constraints)
        lag_loss.backward()
        
        optimizer.step()
        optimizer.zero_grad()
        
        loss_log.append(bce_loss.detach().numpy())
        duals_log.append(dual.duals)

    print(
        f"Epoch: {epoch}, "
        f"loss: {loss_log[-1]}, "
        f"constraints: {c_log[-1]}, "
        f"dual: {duals_log[-1]}"
    )

Epoch: 0, loss: 0.4742393715781231, constraints: [array(0.99822405), array(0.99977582), array(0.98370329), array(0.99975662), array(1.00032864), array(0.27740726)], dual: tensor([0.8365, 0.1397, 4.4786, 0.0683, 0.3129, 0.0000])
Epoch: 1, loss: 0.46727603619805774, constraints: [array(0.99974732), array(0.99991013), array(0.97601688), array(0.99970829), array(0.99954425), array(0.4327701)], dual: tensor([0.9207, 0.1785, 4.4865, 0.0948, 0.3925, 0.0000])
Epoch: 2, loss: 0.4497933717844898, constraints: [array(1.0006221), array(0.99918517), array(0.98203093), array(0.99990305), array(0.99923101), array(0.61347269)], dual: tensor([0.9860, 0.2028, 4.4930, 0.1103, 0.4512, 0.0000])
Epoch: 3, loss: 0.4470780470827978, constraints: [array(0.99765493), array(0.99988715), array(0.98583056), array(0.99938058), array(0.99993007), array(0.78371915)], dual: tensor([1.0353, 0.2229, 4.4991, 0.1206, 0.5010, 0.0000])
Epoch: 4, loss: 0.43683666477725475, constraints: [array(0.99835512), array(0.99917183), 

The model is now trained subject to the constraints we set.

---
---

It is also possible to train subject to **stochastic constraints**. One of the main use-cases for that is **fairness**. Let's now train a fairness-constrained model on the `folktables` dataset.

To track the "fairness" of our model, we will track **Positive Rate** across groups. To calculate it, we will use the `fairret` package and its `NormLoss`.
The latter calculates the ratio between the value of a statistic for each group and the overall value: $\sum_{s\in S}{|1-\frac{f(\theta, X_s, y_s)}{f(\theta, X, y)}|}$, where $S$ is the set of groups. 

To make sure each batch contains representatives of each protected group, we can use the BalancedBatchSampler from the `fairness.utils` subpackage - a custom PyTorch `Sampler` which yields an equal number of samples from each subgroup in each batch.

In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss
from humancompatible.train.fairness.utils import BalancedBatchSampler

fair_sampler = BalancedBatchSampler(group_onehot=sens_train, batch_size=128)
loader = torch.utils.data.DataLoader(dataset_train_s, sampler=fair_sampler)

criterion = torch.nn.BCEWithLogitsLoss()
statistic = PositiveRate()
fair_criterion = NormLoss(statistic=statistic)

hsize1 = 32
hsize2 = 16

First, let us train an unconstrained model:

In [None]:
from torch.optim import Adam
from torch.nn import Sequential

model_uncon = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1),
)

optimizer = Adam(model_uncon.parameters())

train_losses = []
train_f = []
val_losses =[]
val_f = []

for epoch in range(epochs):
    for batch_feat, _, batch_label in loader:
        optimizer.zero_grad()

        logit = model_uncon(batch_feat)
        loss = criterion(logit, batch_label)
        loss.backward()

        optimizer.step()
    
    # logging
    with torch.no_grad():
        logits_train = model_uncon(features_train)
        loss_train = criterion(logits_train, labels_train)
        fair_train = fair_criterion(model_uncon(features_train), sens_train)
        logits_val = model_uncon(features_val)
        loss_val = criterion(logits_val, labels_val)
        fair_val = fair_criterion(model_uncon(features_val), sens_val)
        train_losses.append(loss_train)
        train_f.append(fair_train)
        val_losses.append(loss_val)
        val_f.append(fair_val)
    print(f"Epoch: {epoch}, loss: {loss_train} / {loss_val}, constraint: {fair_train} / {fair_val}")

Epoch: 0, loss: 0.3968636369462622 / 0.40470909243747405, constraint: 0.18165649822563323 / 0.18781794489482784
Epoch: 1, loss: 0.3906023357402648 / 0.40295611185247887, constraint: 0.16304437408194883 / 0.16775616049726738
Epoch: 2, loss: 0.3827542259947103 / 0.3991090992544509, constraint: 0.1830651758212012 / 0.18514395786567173
Epoch: 3, loss: 0.3790003210896969 / 0.3978286528551326, constraint: 0.17819499988186338 / 0.179512346701729
Epoch: 4, loss: 0.379721713846987 / 0.40115765726758496, constraint: 0.17223594705362688 / 0.17338911341025487
Epoch: 5, loss: 0.37080572509470167 / 0.39682120662030707, constraint: 0.18351866706274733 / 0.18596974220973295
Epoch: 6, loss: 0.3671832809863401 / 0.40050300410026235, constraint: 0.1963664795819673 / 0.19770557394897526
Epoch: 7, loss: 0.36277761952639254 / 0.40109646608534527, constraint: 0.19989953869687838 / 0.19996672677926564
Epoch: 8, loss: 0.36184892644534755 / 0.4030079680737153, constraint: 0.1893552547808227 / 0.192681738889383


In [None]:
pr = PositiveRate()
print('Positive rate by group:')
preds = torch.nn.functional.sigmoid(model_uncon(features_train))
print(f'Train: {pr(preds, sens_train).detach().numpy()}')
preds_test = torch.nn.functional.sigmoid(model_uncon(features_test))
print(f'Test: {pr(preds_test, sens_test).detach().numpy()}')

Positive rate by group:
Train: [0.38232849 0.46760742]
Test: [0.38632841 0.46746498]


In [None]:
print("Loss:")
print(f'Train: {criterion(model_uncon(features_train), labels_train).detach().numpy()}')
print(f'Test: {criterion(model_uncon(features_test), labels_test).detach().numpy()}')

Loss:
Train: 0.335393091872077
Test: 0.4385469714010663


---
---

We see that unconstrained optimization leads to a large **discrepancy in the positive rate**. Say we want the fairret term to be no greater than 0.1. We can enforce this constraint with the **constrained** training algorithms.

First, we will try the **Augmented Lagrangian**. Natively, it only works with equality constraints, so we will introduce **slack variables**.

A note: stochastic constrained optimization algorithm benefit greatly from smoothing. We provide the `MoreauEnvelope` class that adds an L2 smoothing term to the model's loss function gradient (without spending any resources during the backward call).

In [None]:
from torch.nn import Sequential
from torch.optim import Adam
from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM

model_con = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1),
)

# Define data and optimizers
optimizer = MoreauEnvelope(Adam(model_con.parameters(), lr=0.001))
dual = ALM(m=1, lr=0.01)

slack_vars = torch.zeros(1, requires_grad=True)
optimizer.add_param_group({"params":slack_vars})

# fairness constraint bound
fair_crit_bound = 0.1

In [None]:
train_losses = []
train_f = []
val_losses =[]
val_f = []

for epoch in range(epochs):

    for batch_input, batch_sens, batch_label in loader:
        # do forward pass
        batch_out = model_con(batch_input)
        loss = criterion(batch_out, batch_label)

        # evaluate fairness criterion (constraint)
        fair_loss = fair_criterion(batch_out.squeeze(0), batch_sens.squeeze(0))
        fair_constraint = fair_loss - fair_crit_bound + slack_vars[0]
        
        # calculate lagrangian, update dual variables
        lagrangian = dual.forward_update(loss, fair_constraint.unsqueeze(0))

        # gradient, optimizer step
        lagrangian.backward()
        optimizer.step()
        optimizer.zero_grad()

        # set slacks to be non-negative
        with torch.no_grad():
            for s in slack_vars:
                if s < 0:
                    s.zero_()
    
    # logging
    with torch.no_grad():
        logits_train = model_con(features_train)
        loss_train = torch.nn.functional.binary_cross_entropy_with_logits(logits_train, labels_train)
        fair_train = fair_criterion(logits_train, sens_train)
        lagr_train = dual.forward(loss_train, fair_train.unsqueeze(0))
        logits_val = model_con(features_val)
        loss_val = torch.nn.functional.binary_cross_entropy_with_logits(logits_val, labels_val)
        fair_val = fair_criterion(logits_val, sens_val)
        train_losses.append(loss_train)
        train_f.append(fair_train)
        val_losses.append(loss_val)
        val_f.append(fair_val)
    
    print(f"Epoch: {epoch}, loss: {loss_train} / {loss_val}, fairret loss: {fair_train} / {fair_val}, L: {lagr_train}")

Epoch: 0, loss: 0.40394780530396024 / 0.410732146277, fairret loss: 0.05203240917074947 / 0.061091899135758077, L: 0.41707331541453646
Epoch: 1, loss: 0.4040077916559369 / 0.41456315954936923, fairret loss: 0.044065317615095134 / 0.050434035002550726, L: 0.4239857269973955
Epoch: 2, loss: 0.4042374101783031 / 0.41510828319779564, fairret loss: 0.03651663806609562 / 0.043627017051205086, L: 0.42555957964784896
Epoch: 3, loss: 0.4017089130338191 / 0.41375565931964703, fairret loss: 0.0392035864604946 / 0.04589470945646934, L: 0.4288281278418492
Epoch: 4, loss: 0.40321648229099527 / 0.41545547866675675, fairret loss: 0.02132574583257052 / 0.030285629561043748, L: 0.4189467696237422
Epoch: 5, loss: 0.4018237916265805 / 0.4151791621663035, fairret loss: 0.03185733342310659 / 0.040023449363411956, L: 0.4268584228473389
Epoch: 6, loss: 0.4062603489590199 / 0.4195608239498349, fairret loss: 0.01627218377530759 / 0.024136046147144485, L: 0.4201776068471261
Epoch: 7, loss: 0.3984736191478028 / 0

In [None]:
pr = PositiveRate()
print('Positive rate by group:')
preds = torch.nn.functional.sigmoid(model_con(features_train))
print(f'Train: {pr(preds, sens_train).detach().numpy()}')
preds = torch.nn.functional.sigmoid(model_con(features_test))
print(f'Test: {pr(preds, sens_test).detach().numpy()}')

Positive rate by group:
Train: [0.43184613 0.44203915]
Test: [0.43123685 0.44325591]


In [None]:
print('Fairness criterion value:')
print(f'Train: {fair_criterion(model_con(features_train), sens_train).detach().numpy()}')
print(f'Test: {fair_criterion(model_con(features_test), sens_test).detach().numpy()}')

Fairness criterion value:
Train: 0.023312999588212868
Test: 0.027467776537341893


In [None]:
print("Loss:")
print(f'Train: {criterion(model_con(features_train), labels_train).detach().numpy()}')
print(f'Test: {criterion(model_con(features_test), labels_test).detach().numpy()}')

Loss:
Train: 0.3926763912970974
Test: 0.4190584715397698


Curoiusly, in this case, the Augmented Lagrangian converges to a feasible local minimum not on the constraint boundary (our constraint is strictly less than the bound we set for it).

***
Now, let us try the **Penalty-Barrier** method. In contrast to the ALM, it can handle inequality constraints natively, so we have no need for slacks.

In [None]:
from torch.nn import Sequential
from torch.optim import Adam
from humancompatible.train.dual_optim import ALM, MoreauEnvelope, PBM

model_con = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1),
)

# Define data and optimizers
optimizer = MoreauEnvelope(Adam(model_con.parameters(), lr=0.001))

dual = PBM(m=1, mu=0.1,
    penalty_update='dimin',
    lr=0.95,
    penalty_range=(0.001, 100),
    init_penalties=100.,
    dual_range=(0.01, 100),
    init_duals=0.01
)

# fairness constraint bound
fair_crit_bound = 0.1

In [None]:
train_losses = []
train_f = []
val_losses =[]
val_f = []

for epoch in range(epochs):

    for batch_input, batch_sens, batch_label in loader:
 
        batch_out = model_con(batch_input)
        loss = criterion(batch_out, batch_label)

        fair_loss = fair_criterion(batch_out.squeeze(0), batch_sens.squeeze(0))
        fair_constraint = fair_loss - fair_crit_bound

        lagrangian = dual.forward_update(loss, fair_constraint.unsqueeze(0))
        
        lagrangian.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    dual.update_penalties()

    with torch.no_grad():
        logits_train = model_con(features_train)
        loss_train = torch.nn.functional.binary_cross_entropy_with_logits(logits_train, labels_train)
        fair_train = fair_criterion(logits_train, sens_train)
        lagr_train = dual.forward(loss_train, fair_train.unsqueeze(0))
        logits_val = model_con(features_val)
        loss_val = torch.nn.functional.binary_cross_entropy_with_logits(logits_val, labels_val)
        fair_val = fair_criterion(logits_val, sens_val)
        train_losses.append(loss_train)
        train_f.append(fair_train)
        val_losses.append(loss_val)
        val_f.append(fair_val)
    
    print(f"Epoch: {epoch}, loss: {loss_train} / {loss_val}, fairret loss: {fair_train} / {fair_val}, L: {lagr_train}")

Epoch: 0, loss: 0.3921223925491205 / 0.4014447260266624, fairret loss: 0.1637256617563726 / 0.1680401242116627, L: 0.3953740697073134
Epoch: 1, loss: 0.3845667147927025 / 0.39732545991580837, fairret loss: 0.1558060491783967 / 0.1584644607588589, L: 0.39214009649027004
Epoch: 2, loss: 0.38050414751603734 / 0.396364260518223, fairret loss: 0.1278090867802667 / 0.13490724066253623, L: 0.39389339022260517
Epoch: 3, loss: 0.3792019705589975 / 0.3988856914466905, fairret loss: 0.10417460506117415 / 0.10891968051621037, L: 0.40010469344958915
Epoch: 4, loss: 0.38003440376419956 / 0.4015328275397177, fairret loss: 0.06694725931199819 / 0.07669726760012474, L: 0.40106554710017955
Epoch: 5, loss: 0.3805996948469289 / 0.4034603573921707, fairret loss: 0.06844513545425046 / 0.0776009863696131, L: 0.40969238675949393
Epoch: 6, loss: 0.38326046394857244 / 0.4065245130493526, fairret loss: 0.041615551456070765 / 0.050314030491271367, L: 0.40638157283123216
Epoch: 7, loss: 0.38385194130402844 / 0.406

In [None]:
pr = PositiveRate()
print('Positive rate by group:')
preds = torch.nn.functional.sigmoid(model_con(features_train))
print(f'Train: {pr(preds, sens_train).detach().numpy()}')
preds = torch.nn.functional.sigmoid(model_con(features_test))
print(f'Test: {pr(preds, sens_test).detach().numpy()}')

Positive rate by group:
Train: [0.43673083 0.45237684]
Test: [0.43601392 0.45366248]


In [None]:
print('Fairness criterion value:')
print(f'Train: {fair_criterion(model_con(features_train), sens_train).detach().numpy()}')
print(f'Test: {fair_criterion(model_con(features_test), sens_test).detach().numpy()}')

Fairness criterion value:
Train: 0.0351606154662224
Test: 0.03963182594842718


In [None]:
print("Loss:")
print(f'Train: {criterion(model_con(features_train), labels_train).detach().numpy()}')
print(f'Test: {criterion(model_con(features_test), labels_test).detach().numpy()}')

Loss:
Train: 0.39351061534436116
Test: 0.422021726193791
