# Bias and Constrained Learning Homework

In this homework we'll extend the constrained learning framework we used for mitigating bias in class to handle more complex situations. Specifically, we'll look at the case where the output prediction is not binary. As usual with these homeworks, there are three different levels which build on each other, each one corresponding to an increasing grade:

- The basic version of this homework involves implementing code to measure fairness over multiclass classification then measuring the results when using training a regular, unfair classifier. This version is good for a C.
- The B version of the homework involves training a classifier with some fairness constraints.
- For an A, we'll look at slightly more complicated approach to fair training.

First, we'll generate a dataset for which the sensitive attribute is binary and the output is multiclass.

In [121]:
import numpy as np
import torch
from torch import nn, optim

In [122]:
output_classes = 5

def generate_data():

    dataset_size = 10000
    dimensions = 40

    rng = np.random.default_rng()
    A = np.concatenate((np.zeros(dataset_size // 2), np.ones(dataset_size // 2)))
    rng.shuffle(A)
    X = rng.normal(loc=A[:,np.newaxis], scale=1, size=(dataset_size, dimensions))
    random_linear = np.array([
        -2.28156561, 0.24582547, -2.48926942, -0.02934924, 5.21382855, -1.08613209,
        2.51051602, 1.00773587, -2.10409448, 1.94385103, 0.76013416, -2.94430782,
        0.3289264, -4.35145624, 1.61342623, -1.28433588, -2.07859612, -1.53812125,
        0.51412713, -1.34310334, 4.67174476, 1.67269946, -2.07805413, 3.46667731,
        2.61486654, 1.75418209, -0.06773796, 0.7213423, 2.43896438, 1.79306807,
        -0.74610264, 2.84046827,  1.28779878, 1.84490263, 1.6949681, 0.05814582,
        1.30510732, -0.92332861,  3.00192177, -1.76077192
    ])
    good_score = (X @ random_linear) ** 2 / 2
    qs = np.quantile(good_score, (np.array(range(1, output_classes))) / output_classes)
    Y = np.digitize(good_score, qs)

    return X, A, Y

X, A, Y = generate_data()

In [123]:
# A: 10,000 observations with 0 or 1 for the sensitive attribute
# X: 10,000 observations with 40 features
# Y: 10,000 observations with outcomes

In [124]:
print("Total:", [(Y == k).sum() for k in range(output_classes)])
print("A=0:", [((Y == k) & (A == 0)).sum() for k in range(output_classes)])
print("A=1:", [((Y == k) & (A == 1)).sum() for k in range(output_classes)])

Total: [2000, 2000, 2000, 2000, 2000]
A=0: [1412, 1295, 1119, 805, 369]
A=1: [588, 705, 881, 1195, 1631]


This last cell shows the total number of data points in each output category (it should be 2000 each) as well as a breakdown of each output category for the $A=0$ group and the $A=1$ group. Note that the $A=1$ group is much more likely to be assigned to the categories with higher index.

## Fairness Definition (C)

Let's write some code to measure the _demographic parity_ of our classifier: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $0 \le r < K$. In the the function below,

- `R` is a matrix where each row represents a probability distribution over the classes `0` to `K - 1`. That is, `R` is the output of our neural network _after_ a softmax layer.
- `A` is a vector of sensitive attributes. Each element is either `0` or `1`.

These functions should return an array of length `K` where each element of the array represents a measure of bias for _one_ of the output classes. For example, for demographic parity, the value in the output array at index `i` should be $P(R = i \mid A = 1) - P(R = i \mid A = 0)$.

In [125]:
# R is output of nn after softmax
# A is sensitive attributes

def demographic_parity(R, A):
   # R0 = R where corresponding index in A is 0
   # R1 same but for 1
   R0 = R[A == 0]
   R1 = R[A == 1]

   R0_av = torch.mean(R0, axis=0)
   R1_av = torch.mean(R1, axis=0)

   output = R1_av - R0_av

   return output

Now we'll train a classifier on this dataset without any fairness constraints for comparison. This code is already complete.

In [126]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(40, 256),
            nn.ReLU(),
            nn.Linear(256, 5)
        )

    def forward(self, x):
        return self.model(x)

In [127]:
def train_unfair(lr=1e-1, epochs=200):
    
    network = MLP()
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        opt.zero_grad()
        loss_val.backward()
        opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)
            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", demographic_parity(probs, A))

    return network

In [128]:
model_unfair = train_unfair(lr=5e-1, epochs=300)

Epoch: 19 Accuracy: 0.3312000036239624 Bias: tensor([-0.1802, -0.1489, -0.0848,  0.0227,  0.3911], grad_fn=<SubBackward0>)
Epoch: 39 Accuracy: 0.35249999165534973 Bias: tensor([-0.1962, -0.1617, -0.0908,  0.0460,  0.4027], grad_fn=<SubBackward0>)
Epoch: 59 Accuracy: 0.4052000045776367 Bias: tensor([-0.2005, -0.1661, -0.0885,  0.0488,  0.4063], grad_fn=<SubBackward0>)
Epoch: 79 Accuracy: 0.5095999836921692 Bias: tensor([-0.1952, -0.1621, -0.0844,  0.0367,  0.4050], grad_fn=<SubBackward0>)
Epoch: 99 Accuracy: 0.5853000283241272 Bias: tensor([-0.1943, -0.1599, -0.0809,  0.0367,  0.3984], grad_fn=<SubBackward0>)
Epoch: 119 Accuracy: 0.6283000111579895 Bias: tensor([-0.1953, -0.1574, -0.0780,  0.0405,  0.3903], grad_fn=<SubBackward0>)
Epoch: 139 Accuracy: 0.6560999751091003 Bias: tensor([-0.1960, -0.1546, -0.0768,  0.0435,  0.3839], grad_fn=<SubBackward0>)
Epoch: 159 Accuracy: 0.6787999868392944 Bias: tensor([-0.1967, -0.1517, -0.0767,  0.0464,  0.3788], grad_fn=<SubBackward0>)
Epoch: 179 A

In [129]:
p = model_unfair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2659, 2056, 1557, 2626, 1102]
A=0: [1472, 1314, 923, 1078, 213]
A=1: [1187, 742, 634, 1548, 889]


This classifier is probably not going to be _extremely_ accurate, but you should be able to see the bias from the dataset reflected here.

## Fair Training (B)

Now we'll extend our fair training approach from the lab to the multiclass setting. Now since we have a bias measure for _each_ possible output class, we essentially have `output_classes` constraints that we need to satisfy. We can handle this within our Lagrange multiplier framework by simply adding extra multipliers for each constraint. That is, our new learning problem is

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i g_i(\beta) \right )
$$

$$
= \arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i \left ( P_\beta [ R = i \mid A = 1 ] - P_\beta [ R = i \mid A = 0 ] \right ) \right )
$$

Our `demographic_parity` function gives us a vector representing $g_i(\beta)$, so now all we need to do is replace our single parameter $\lambda$ from the lab with a vector then compute the dot product of $\lambda$ with our demographic parity measure.

In [130]:
def train_fair(lr=1e-1, lam_lr=1, epochs=200):
    
    network = MLP()
    lam = nn.Parameter(torch.zeros(output_classes))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        # Compute the Lagrangian loss L + lam * g
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        probs = nn.functional.softmax(preds, dim=1) # specify axis? why would it be dim 1?
        bias = demographic_parity(probs, A)
        loss_val += (lam * bias).sum()
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)
            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", demographic_parity(probs, A), "Lambda:", lam.max().item())

    return network

In [131]:
model_fair = train_fair(lr=5e-1, lam_lr=3e-1, epochs=300)

Epoch: 19 Accuracy: 0.2897000014781952 Bias: tensor([-0.0191, -0.0090, -0.0130,  0.0196,  0.0215], grad_fn=<SubBackward0>) Lambda: 0.4671611785888672
Epoch: 39 Accuracy: 0.3239000141620636 Bias: tensor([-0.0452, -0.0360, -0.0182,  0.0100,  0.0894], grad_fn=<SubBackward0>) Lambda: 0.48258352279663086
Epoch: 59 Accuracy: 0.383899986743927 Bias: tensor([-0.0471, -0.0348, -0.0120,  0.0163,  0.0775], grad_fn=<SubBackward0>) Lambda: 0.5010895133018494
Epoch: 79 Accuracy: 0.47350001335144043 Bias: tensor([-0.0577, -0.0418, -0.0107,  0.0151,  0.0952], grad_fn=<SubBackward0>) Lambda: 0.5216808319091797
Epoch: 99 Accuracy: 0.5566999912261963 Bias: tensor([-0.0679, -0.0447, -0.0041,  0.0120,  0.1048], grad_fn=<SubBackward0>) Lambda: 0.6017270088195801
Epoch: 119 Accuracy: 0.6018000245094299 Bias: tensor([-0.0702, -0.0403,  0.0014,  0.0085,  0.1006], grad_fn=<SubBackward0>) Lambda: 0.6871916651725769
Epoch: 139 Accuracy: 0.6258000135421753 Bias: tensor([-0.0698, -0.0346,  0.0043,  0.0065,  0.0937]

In [102]:
p = model_fair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2084, 1611, 2614, 523, 3168]
A=0: [1408, 1199, 1484, 278, 631]
A=1: [676, 412, 1130, 245, 2537]


## Fair Training via KL-Divergence (A)

Let's look back at our definition of demographic parity for the multiclass setting: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $r$. we could also express this by asserting $P(\cdot \mid A = 0)$ and $P(\cdot \mid A = 1)$ should be identical probability distributions. A natural measure of bias then would be to compute the KL-divergence between these two distributions, since KL-divergence is a measure of how "different" two distributions are. That is, we'll now solve the problem

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \lambda D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \right )
$$

However, this introduces a new complication. The KL-divergence is never negative and can only be zero if the two distributions are identical (we proved this in our first homework of the semester). That means there's no way for $\lambda$ to ever decrease, and it will just go up forever. We can solve this by allowing a small deviation in our constrained optimization problem:

$$
\begin{align}
\arg\min_\beta &\ L(\beta) \\
\text{s.t.} &\ D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \le \epsilon
\end{align}
$$

We can still represent this using a Lagrange multiplier:

$$
\arg\min_\beta \max_{\lambda \ge 0} \left ( L(\beta) + \lambda \left ( D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) - \epsilon \right ) \right )
$$

Your task now is to represent this optimization problem in the code below. I've taken care of clipping $\lambda$ to zero for you since it's not something we've looked at in class.

In [None]:
def train_kl(lr=1e-1, lam_lr=1, epochs=300, epsilon=0.1):
    
    network = MLP()
    lam = nn.Parameter(torch.tensor(0.0))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):

        # Implement the loss function above here.
        loss_val = ???
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()
        with torch.no_grad():
            lam.clamp_(min=0)

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            print("Epoch:", i, "Accuracy:", acc.item(), "Divergence:", kl_div.item(), "Lambda:", lam.item())

    return network

In [None]:
model = train_kl(lr=3e-1, lam_lr=1, epsilon=0.02)

In [None]:
p = model(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])