This notebook will demonstrate how to use the **constrained training algorithms** implemented in this toolkit with **PyTorch**-like API.

The algorithms implemented in the **humancompatible.train.torch** subpackage share a similar idea. Before the training, you initialize an algorithm like you would a PyTorch one. Then, during the training process, you:

1. Evaluate a constraint and compute its gradient
2. Call the `dual_step` function to update dual parameters and save the constraint gradient for the primal update
3. Call the `step` function to update the primal parameters (generally, model weights)

Let's try the Stochastic Smooth Linearized Augmented Lagrangian (SSLALM) algorithm on a constrained learning task.

Let's train a simple classification model, putting a constraint on the norm of each layer's parameters.

In the canonical form, the algorithm expects equality constraints that are equal to 0; however, we can easily transform arbitrary inequality constraints to that form.

In [None]:
# load and prepare data

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import numpy as np
from folktables import ACSDataSource, ACSIncome, generate_categories

# load folktables data
data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
acs_data = data_source.get_data(states=["VA"], download=True)
definition_df = data_source.get_definitions(download=True)
categories = generate_categories(features=ACSIncome.features, definition_df=definition_df)
df_feat, df_labels, _ = ACSIncome.df_to_pandas(acs_data,categories=categories, dummies=True)

sens_cols = ['SEX_Female', 'SEX_Male']
features = df_feat.drop(columns=sens_cols).to_numpy(dtype="float")
groups = df_feat[sens_cols].to_numpy(dtype="float")
labels = df_labels.to_numpy(dtype="float")
# split
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
    features, labels, groups, test_size=0.2, random_state=42)
# scale
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# make into a pytorch dataset, remove the sensitive attribute
features_train = torch.tensor(X_train, dtype=torch.float32)
labels_train = torch.tensor(y_train,dtype=torch.float32)
sens_train = torch.tensor(groups_train)
dataset_train = torch.utils.data.TensorDataset(features_train, labels_train)

In [None]:
from humancompatible.train.algorithms.torch import SSLALM
import torch
from torch.nn import Sequential

dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True)

hsize1 = 64
hsize2 = 32
model = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

m = len(list(model.parameters()))

optimizer = SSLALM(
    params=model.parameters(),
    m=m,
    lr=0.01,
    dual_lr=0.1
)
# bounds for the constraints: norm of each param group should be <= 1
constraint_bounds = [1.]*m

epochs = 10
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
for epoch in range(epochs):
    loss_log = []
    c_log = []
    slack_log = []
    duals_log = []
    for batch_input, batch_label in dataloader:
        # calculate constraints and constraint grads
        c_log.append([])
        for i, param in enumerate(model.parameters()):
            norm = torch.linalg.norm(param, ord=2)
            # convert constraint to equality
            norm_viol = torch.max(
                norm 
                - constraint_bounds[i],
                torch.zeros(1)
            )
            norm_viol.backward()
            # for the Lagrangian family of algorithms, dual_step requires the index of constraint and the value as arguments
            # to update the corresponding dual multiplier
            optimizer.dual_step(i, c_val=norm_viol)
            optimizer.zero_grad()
            c_log[-1].append(norm.detach().numpy())
        
        # calculate loss and grad
        batch_output = model(batch_input)
        loss = criterion(batch_output, batch_label)
        loss.backward()
        loss_log.append(loss.detach().numpy())
        duals_log.append(optimizer._dual_vars.detach())
        optimizer.step()
        optimizer.zero_grad()
    
    print(
        f"Epoch: {epoch}, "
        f"loss: {np.mean(loss_log)}, "
        f"constraints: {np.mean(c_log, axis=0)}, "
        f"dual: {np.mean(duals_log, axis=0)}"
    )

Epoch: 0, loss: 0.547227680683136, constraints: [0.91162425 0.2931389  0.9883125  0.4690594  0.89254874 0.04176571], dual: [0.1717298  0.         0.16387744 0.         0.20794065 0.        ]
Epoch: 1, loss: 0.44799599051475525, constraints: [0.9997784  0.7573364  0.999926   0.60372984 0.9997402  0.01519962], dual: [0.22261804 0.         0.24267721 0.         0.27600586 0.        ]
Epoch: 2, loss: 0.4300524890422821, constraints: [0.99943334 0.9939212  0.9994772  0.649871   0.9993597  0.02684791], dual: [0.25665787 0.02926148 0.29758924 0.         0.32820562 0.        ]
Epoch: 3, loss: 0.4226565361022949, constraints: [0.99923307 0.9998924  0.99918497 0.6757832  0.9990267  0.02891587], dual: [0.2852176  0.03667939 0.34263647 0.         0.37016106 0.        ]
Epoch: 4, loss: 0.4160747528076172, constraints: [0.9991638  0.9998352  0.99890137 0.6951654  0.9987838  0.03798234], dual: [0.30832767 0.04193355 0.37953085 0.         0.4063269  0.        ]
Epoch: 5, loss: 0.4118559658527374, cons

The model is now trained subject to the constraints we set.

---
---

It is also possible to train a network subject to **stochastic constraints**. One of the main use-cases for that is **fairness**. Let's train a network on the `folktables` dataset without constraints first, so we can identify some biases:

Define a model:

In [None]:
from torch.nn import Sequential
hsize1 = 64
hsize2 = 32
model_uncon = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

And start training:

In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss

dataset = torch.utils.data.TensorDataset(features_train, sens_train, labels_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

statistic = PositiveRate()
fair_criterion = NormLoss(statistic=statistic)
fair_crit_bound = 0.3

In [None]:
from torch.optim import Adam

loader = torch.utils.data.DataLoader(dataset_train, batch_size=256, shuffle=True)
loss = torch.nn.BCEWithLogitsLoss()
optimizer = Adam(model_uncon.parameters())
epochs = 100

for epoch in range(epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()

        logit = model_uncon(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()

        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}, fair: {np.mean(fair)}")

Epoch: 0, loss: 0.47336341450954306, fair: 0.3440811772483088
Epoch: 1, loss: 0.3969301203201557, fair: 0.3440811772483088
Epoch: 2, loss: 0.3833263358165478, fair: 0.3440811772483088
Epoch: 3, loss: 0.3736273459319411, fair: 0.3440811772483088
Epoch: 4, loss: 0.3659545639465595, fair: 0.3440811772483088
Epoch: 5, loss: 0.3573377646248916, fair: 0.3440811772483088
Epoch: 6, loss: 0.3477083052026814, fair: 0.3440811772483088
Epoch: 7, loss: 0.3387393544460165, fair: 0.3440811772483088
Epoch: 8, loss: 0.32858473226941864, fair: 0.3440811772483088
Epoch: 9, loss: 0.3172479239003412, fair: 0.3440811772483088
Epoch: 10, loss: 0.30479535460472107, fair: 0.3440811772483088
Epoch: 11, loss: 0.2953071137954449, fair: 0.3440811772483088
Epoch: 12, loss: 0.2825759767458357, fair: 0.3440811772483088
Epoch: 13, loss: 0.27013415735343405, fair: 0.3440811772483088
Epoch: 14, loss: 0.26126931391913316, fair: 0.3440811772483088
Epoch: 15, loss: 0.2519400581203658, fair: 0.3440811772483088
Epoch: 16, lo

In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss

preds = torch.nn.functional.sigmoid(model_uncon(features_train))
pr = PositiveRate()
pr(preds, sens_train)

tensor([0.3554, 0.5108], dtype=torch.float64, grad_fn=<IndexPutBackward0>)

---
---

Now let us train the same model with one of the **constrained** training algorithms:

In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss

dataset = torch.utils.data.TensorDataset(features_train, sens_train, labels_train)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True)

statistic = PositiveRate()
fair_criterion = NormLoss(statistic=statistic)
fair_crit_bound = 0.2

In [None]:
from torch.nn import Sequential
hsize1 = 64
hsize2 = 32
model_con = Sequential(
    torch.nn.Linear(features_train.shape[1], hsize1),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize1, hsize2),
    torch.nn.ReLU(),
    torch.nn.Linear(hsize2, 1)
)

optimizer = SSLALM(
    params=model_con.parameters(),
    m=1,
    lr=0.05,
    dual_lr=0.05
)

epochs = 100

In [None]:
for epoch in range(epochs):
    loss_log = []
    c_log = []
    duals_log = []
    for batch_input, batch_sens, batch_label in dataloader:
        # calculate constraints and constraint grads
        out = model_con(batch_input)
        fair_loss = fair_criterion(out, batch_sens)
        fair_constraint = torch.max(fair_loss -  fair_crit_bound, torch.zeros(1))
        fair_constraint.backward()
        optimizer.dual_step(0, c_val=fair_constraint)
        optimizer.zero_grad()

        c_log.append([fair_loss.detach().numpy()])
        duals_log.append(optimizer._dual_vars.detach())
        # calculate loss and grad
        batch_output = model_con(batch_input)
        loss = criterion(batch_output, batch_label)
        loss.backward()
        loss_log.append(loss.detach().numpy())
        optimizer.step()
        optimizer.zero_grad()
    
    print(
        f"Epoch: {epoch}, "
        f"loss: {np.mean(loss_log)}, "
        f"constraints: {np.mean(c_log, axis=0)}, "
        f"dual: {np.mean(duals_log, axis=0)}"
    )

Epoch: 0, loss: 0.6634813547134399, constraints: [0.00981526], dual: [0.]
Epoch: 1, loss: 0.5301157832145691, constraints: [0.07230375], dual: [0.00784478]
Epoch: 2, loss: 0.42276254296302795, constraints: [0.16573656], dual: [0.13601923]
Epoch: 3, loss: 0.3991071283817291, constraints: [0.1685044], dual: [0.2906894]
Epoch: 4, loss: 0.3933959901332855, constraints: [0.14941347], dual: [0.4277189]
Epoch: 5, loss: 0.39016368985176086, constraints: [0.13931526], dual: [0.51401454]
Epoch: 6, loss: 0.3857656419277191, constraints: [0.13841809], dual: [0.59241724]
Epoch: 7, loss: 0.3842432200908661, constraints: [0.12440836], dual: [0.6724649]
Epoch: 8, loss: 0.38192689418792725, constraints: [0.12657558], dual: [0.740716]
Epoch: 9, loss: 0.37957674264907837, constraints: [0.1206379], dual: [0.81986153]
Epoch: 10, loss: 0.3767065703868866, constraints: [0.11726112], dual: [0.8704299]
Epoch: 11, loss: 0.37544986605644226, constraints: [0.11596639], dual: [0.9260662]
Epoch: 12, loss: 0.3748477

In [None]:
from fairret.statistic import PositiveRate
from fairret.loss import NormLoss

preds = torch.nn.functional.sigmoid(model_con(features_train))
pr = PositiveRate()
pr(preds, sens_train)

tensor([0.4479, 0.4525], dtype=torch.float64, grad_fn=<IndexPutBackward0>)

In [None]:
fair_criterion(model_con(features_train), sens_train)

tensor(0.0101, dtype=torch.float64, grad_fn=<SumBackward0>)