In this notebook, we'll show the use of the StackedLinearFractionalStatistic, which is a helpful class to combine multiple linear-fractional statistics. For example, the well-known fairness definition of *equalised odds* enforces equality in both the true positive rate and the false positive rate across sensitive groups. Hence, we can keep track of these statistics in a single vector-valued statistic:

In [25]:
from fairret.statistic import TruePositiveRate, FalsePositiveRate, StackedLinearFractionalStatistic

equalised_odds_stats = StackedLinearFractionalStatistic(TruePositiveRate(), FalsePositiveRate())

Let's quickly try it out...

In [26]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

from fairret.loss import NormLoss

norm_loss = NormLoss(equalised_odds_stats)

h_layer_dim = 16
lr = 1e-3
batch_size = 1024

def build_model():
    _model = torch.nn.Sequential(
        torch.nn.Linear(feat.shape[1], h_layer_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(h_layer_dim, 1)
    )
    _optimizer = torch.optim.Adam(_model.parameters(), lr=lr)
    return _model, _optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Without fairret...

In [27]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()
                
        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss.backward()
                
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
    
pred = torch.sigmoid(model(feat))
eo_per_group = equalised_odds_stats(pred, sens, label)
absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])

print(f"The TPR and FPR for group 0 are {eo_per_group[:, 0]}")
print(f"The TPR and FPR for group 1 are {eo_per_group[:, 1]}")
print(f"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}")

Epoch: 0, loss: 0.7091795206069946
Epoch: 1, loss: 0.7061765193939209
Epoch: 2, loss: 0.7033581733703613
Epoch: 3, loss: 0.7007156610488892
Epoch: 4, loss: 0.6982340812683105
Epoch: 5, loss: 0.6959078907966614
Epoch: 6, loss: 0.6937355995178223
Epoch: 7, loss: 0.6917158365249634
Epoch: 8, loss: 0.6898466944694519
Epoch: 9, loss: 0.6881252527236938
Epoch: 10, loss: 0.6865478754043579
Epoch: 11, loss: 0.6851094961166382
Epoch: 12, loss: 0.6838041543960571
Epoch: 13, loss: 0.6826250553131104
Epoch: 14, loss: 0.6815641522407532
Epoch: 15, loss: 0.6806124448776245
Epoch: 16, loss: 0.6797604560852051
Epoch: 17, loss: 0.6789975762367249
Epoch: 18, loss: 0.6783132553100586
Epoch: 19, loss: 0.6776963472366333
Epoch: 20, loss: 0.6771360039710999
Epoch: 21, loss: 0.6766215562820435
Epoch: 22, loss: 0.6761429309844971
Epoch: 23, loss: 0.6756909489631653
Epoch: 24, loss: 0.6752569675445557
Epoch: 25, loss: 0.6748337745666504
Epoch: 26, loss: 0.674415111541748
Epoch: 27, loss: 0.673996090888977
Epoc

With fairret...

In [28]:
import numpy as np

nb_epochs = 100
fairness_strength = 1
model, optimizer = build_model()
for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_label in dataloader:
        optimizer.zero_grad()
                
        logit = model(batch_feat)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
        loss += fairness_strength * norm_loss(logit, batch_sens, batch_label)
        loss.backward()
                
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")
    
pred = torch.sigmoid(model(feat))
eo_per_group = equalised_odds_stats(pred, sens, label)
absolute_diff = torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])

print(f"The TPR and FPR for group 0 are {eo_per_group[:, 0]}")
print(f"The TPR and FPR for group 1 are {eo_per_group[:, 1]}")
print(f"The absolute differences are {torch.abs(eo_per_group[:, 0] - eo_per_group[:, 1])}")

Epoch: 0, loss: 0.8069422245025635
Epoch: 1, loss: 0.7932361960411072
Epoch: 2, loss: 0.7793688178062439
Epoch: 3, loss: 0.7653393149375916
Epoch: 4, loss: 0.7511466145515442
Epoch: 5, loss: 0.7367900013923645
Epoch: 6, loss: 0.7222684025764465
Epoch: 7, loss: 0.7075802683830261
Epoch: 8, loss: 0.693701446056366
Epoch: 9, loss: 0.7050990462303162
Epoch: 10, loss: 0.7107114195823669
Epoch: 11, loss: 0.7118827104568481
Epoch: 12, loss: 0.7095987200737
Epoch: 13, loss: 0.704603374004364
Epoch: 14, loss: 0.6974690556526184
Epoch: 15, loss: 0.6966591477394104
Epoch: 16, loss: 0.70110023021698
Epoch: 17, loss: 0.7034574151039124
Epoch: 18, loss: 0.7040092945098877
Epoch: 19, loss: 0.7029833793640137
Epoch: 20, loss: 0.7005670070648193
Epoch: 21, loss: 0.6969154477119446
Epoch: 22, loss: 0.6944364905357361
Epoch: 23, loss: 0.6974213719367981
Epoch: 24, loss: 0.6976830959320068
Epoch: 25, loss: 0.695574164390564
Epoch: 26, loss: 0.6945269703865051
Epoch: 27, loss: 0.6960493922233582
Epoch: 28,