This notebook will demonstrate the `FairnessConstraint` class and how you can use it to **add a constraint**.

**Fairness constraints** typically involve working with numerous **protected groups** - as such, **sampling** data to evaluate them can sometimes be tricky. Ideally, each minibatch should **contain samples from each of the protected groups** ...

`FairnessConstraint` provides the functionality to sample for and evaluate user-defined fairness constraints.

A toy example: we have a dataset of 10 samples with 2 groups. We want to add an "equal loss" constraint on them.

In [56]:
import torch
from src.constraints.constraint import FairnessConstraint
from src.constraints.constraint_fns import one_sided_loss_constr

# the toy dataset
features = torch.arange(0,10)
# encode group membership of each sample
group_membership = torch.repeat_interleave(torch.tensor([0,1]), 5)
labels = group_membership
dataset = torch.utils.data.TensorDataset(features, labels)

# For each subgroup, FairnessConstraint needs a list of indices of samples belonging to that subgroup
group_indices = [(group_membership == gr_idx).nonzero(as_tuple=True)[0] for gr_idx in group_membership.unique()]

c = FairnessConstraint(
    dataset=dataset,
    group_indices=group_indices,
    fn=one_sided_loss_constr,
    batch_size=2,
    seed=42
)

The constructor creates a `DataLoader` for each of the subgroup. Now, when you call `sample_loader()`, it will return a `(features: torch.Tensor, labels: torch.Tensor)` tuple for each subrgoup, with specified batch size:

In [77]:
c.sample_loader()

[[tensor([0]), tensor([0])], [tensor([8]), tensor([1])]]

When one of the dataloaders runs out, it is reset just like during normal PyTorch training process.

What about the constraint itself?
The constraint formulation is defined by the `fn` argument. When you call the `eval(net, sample)` method of a `FairnessConstraint`, it just passes those arguments plus kwargs to the constraint function and returns the result. It is expected that the constraint function takes the **model** and the **minibatch** in format specified above as the arguments.