In this notebook, we show a full example of how the `fairret` library might be used to train a PyTorch model with a fairness cost.

# Loading some data
To start, let's load some data where fair binary classification is desirable. We'll use the `folktables` [library](https://github.com/socialfoundations/folktables) and their example data of the 2018 [American Community Survey](https://www.census.gov/programs-surveys/acs) (ACS).

In [13]:
from folktables import ACSDataSource

data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
ca_data = data_source.get_data(states=["AL"], download=True)

In [14]:
from folktables import ACSIncome, generate_categories

definition_df = data_source.get_definitions(download=True)
categories = generate_categories(features=ACSIncome.features, definition_df=definition_df)

ca_features, ca_labels, _ = ACSIncome.df_to_pandas(ca_data, categories=categories, dummies=True)
ca_features.head()

Unnamed: 0,AGEP,WKHP,"COW_Employee of a private for-profit company or business, or of an individual, for wages, salary, or commissions","COW_Employee of a private not-for-profit, tax-exempt, or charitable organization",COW_Federal government employee,"COW_Local government employee (city, county, etc.)","COW_Self-employed in own incorporated business, professional practice or farm","COW_Self-employed in own not incorporated business, professional practice, or farm",COW_State government employee,COW_Working without pay in family business or farm,...,SEX_Male,RAC1P_Alaska Native alone,RAC1P_American Indian alone,"RAC1P_American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races",RAC1P_Asian alone,RAC1P_Black or African American alone,RAC1P_Native Hawaiian and Other Pacific Islander alone,RAC1P_Some Other Race alone,RAC1P_Two or More Races,RAC1P_White alone
0,18,21.0,True,False,False,False,False,False,False,False,...,False,False,False,False,False,True,False,False,False,False
1,53,40.0,False,False,True,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,True
2,41,40.0,True,False,False,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,True
3,18,2.0,False,False,False,False,False,True,False,False,...,False,False,False,False,False,False,False,False,False,True
4,21,50.0,False,False,True,False,False,False,False,False,...,True,False,False,False,False,False,False,False,False,True


To keep things simple for now, let's only consider two sensitive groups: *male* and *female*.

In [15]:
sens_cols = ['SEX_Female', 'SEX_Male']
feat = ca_features.drop(columns=sens_cols).to_numpy(dtype="float")
sens = ca_features[sens_cols].to_numpy(dtype="float")
label = ca_labels.to_numpy(dtype="float")

print(sens.mean(axis=0))

[0.47808514 0.52191486]


# A naive PyTorch pipeline

The `fairret` library treats sensitive features in the same way 'normal' features are treated in PyTorch: as (N x D) tensors, where N is the number of samples and D the dimensionality. In contrast to other fairness libraries you may have used, we can therefore just leave categorical sensitive features as one-hot encoded!

In [16]:
import torch
feat, sens, label = torch.tensor(feat).float(), torch.tensor(sens).float(), torch.tensor(label).float().squeeze()
print(f"Shape of the 'normal' features tensor: {feat.shape}")
print(f"Shape of the sensitive features tensor: {sens.shape}")
print(f"Shape of the labels tensor: {label.shape}")

Shape of the 'normal' features tensor: torch.Size([22268, 727])
Shape of the sensitive features tensor: torch.Size([22268, 2])
Shape of the labels tensor: torch.Size([22268])


In typical PyTorch fashion, let's now define a simple neural net with 1 hidden layer, an optimizer, and a DataLoader.

In [17]:
h_layer_dim = 64
lr = 1e-3
batch_size = 1024

model = torch.nn.Sequential(
    torch.nn.Linear(feat.shape[1], h_layer_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_layer_dim, 1)
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

Now, let's train it without doing any fairness adjustment...

In [18]:
import numpy as np

nb_epochs = 10

for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_target in dataloader:
        optimizer.zero_grad()
                
        logit = model(batch_feat).squeeze()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_target)
        loss.backward()
                
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

Epoch: 0, loss: 0.6383606791496277
Epoch: 1, loss: 0.6034089841625907
Epoch: 2, loss: 0.5692692426117983
Epoch: 3, loss: 0.5326909720897675
Epoch: 4, loss: 0.4990630502050573
Epoch: 5, loss: 0.4717546647245234
Epoch: 6, loss: 0.45163395459001715
Epoch: 7, loss: 0.4370973435315219
Epoch: 8, loss: 0.426434107802131
Epoch: 9, loss: 0.4184045513922518


# Bias analysis in fairret

Can we detect any statistical disparities (biases) in the naive model?

The `fairret` library assesses these biases by comparing a (linear-fractional) Statistic computed for each sensitive features. In our example, this is for the 'SEX_Female' and 'SEX_Male' features. For example, let's look at the accuracy.

In [19]:
from fairret.statistic import Accuracy

statistic = Accuracy()

pred = torch.sigmoid(model(feat)).squeeze()
acc_per_group = statistic(pred, feat, sens, label)
absolute_diff = torch.abs(acc_per_group[0] - acc_per_group[1])

print(f"The {statistic.__class__.__name__} for group {sens_cols[0]} is {acc_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group {sens_cols[1]} is {acc_per_group[1]}")
print(f"The absolute difference is {torch.abs(acc_per_group[0] - acc_per_group[1])}")

The Accuracy for group SEX_Female is 0.7232051491737366
The Accuracy for group SEX_Male is 0.6954739093780518
The absolute difference is 0.027731239795684814


# Bias mitigation in fairret

To reduce the statistical disparity we found, we can use one of the fairrets implemented in the library. To quantify bias according to the correct statistic, we need to pass the statistic object to the fairret loss.

In [20]:
from fairret.loss import NormLoss

norm_loss = NormLoss(statistic)

Let's train another model where we now add this loss term to the objective. 

**We only need to add one line of code to the standard PyTorch training loop!**

In [21]:
nb_epochs = 10
model = torch.nn.Sequential(
    torch.nn.Linear(feat.shape[1], h_layer_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_layer_dim, 1)
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(nb_epochs):
    losses = []
    for batch_feat, batch_sens, batch_target in dataloader:
        optimizer.zero_grad()
                
        logit = model(batch_feat).squeeze()
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_target)
        loss += norm_loss(logit, batch_feat, batch_sens, batch_target)
        loss.backward()
                
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch: {epoch}, loss: {np.mean(losses)}")

Epoch: 0, loss: 0.6934518353505568
Epoch: 1, loss: 0.6566014912995425
Epoch: 2, loss: 0.6313163394277747
Epoch: 3, loss: 0.606122615662488
Epoch: 4, loss: 0.5772159641439264
Epoch: 5, loss: 0.5526343502781608
Epoch: 6, loss: 0.534382092681798
Epoch: 7, loss: 0.517937421798706
Epoch: 8, loss: 0.5040952034971931
Epoch: 9, loss: 0.5034685148434206


Let's check the accuracy per group again...

In [22]:
pred = torch.sigmoid(model(feat)).squeeze()
acc_per_group = statistic(pred, feat, sens, label)

print(f"The {statistic.__class__.__name__} for group {sens_cols[0]} is {acc_per_group[0]}")
print(f"The {statistic.__class__.__name__} for group {sens_cols[1]} is {acc_per_group[1]}")
print(f"The absolute difference is {torch.abs(acc_per_group[0] - acc_per_group[1])}")

The Accuracy for group SEX_Female is 0.6595351099967957
The Accuracy for group SEX_Male is 0.6579050421714783
The absolute difference is 0.0016300678253173828


With a simple change, the absolute difference between accuracies went from 2.77% (in the naive model) to 0.16%!

We think this is only a small preview of how powerful this paradigm can be. 

Feel free to go back and try out some other statistics to compare or fairret losses to minimize.

# Further examples

We plan to add more examples, including with multiple sensitive attributes. Check back later!