# Bias and Constrained Learning Homework

In this homework we'll extend the constrained learning framework we used for mitigating bias in class to handle more complex situations. Specifically, we'll look at the case where the output prediction is not binary. As usual with these homeworks, there are three different levels which build on each other, each one corresponding to an increasing grade:

- The basic version of this homework involves implementing code to measure fairness over multiclass classification then measuring the results when using training a regular, unfair classifier. This version is good for a C.
- The B version of the homework involves training a classifier with some fairness constraints.
- For an A, we'll look at slightly more complicated approach to fair training.

First, we'll generate a dataset for which the sensitive attribute is binary and the output is multiclass.

In [110]:
import numpy as np
import torch
from torch import nn, optim

In [111]:
output_classes = 5

def generate_data():

    dataset_size = 10000
    dimensions = 40

    rng = np.random.default_rng()
    A = np.concatenate((np.zeros(dataset_size // 2), np.ones(dataset_size // 2)))
    rng.shuffle(A)
    X = rng.normal(loc=A[:,np.newaxis], scale=1, size=(dataset_size, dimensions))
    random_linear = np.array([
        -2.28156561, 0.24582547, -2.48926942, -0.02934924, 5.21382855, -1.08613209,
        2.51051602, 1.00773587, -2.10409448, 1.94385103, 0.76013416, -2.94430782,
        0.3289264, -4.35145624, 1.61342623, -1.28433588, -2.07859612, -1.53812125,
        0.51412713, -1.34310334, 4.67174476, 1.67269946, -2.07805413, 3.46667731,
        2.61486654, 1.75418209, -0.06773796, 0.7213423, 2.43896438, 1.79306807,
        -0.74610264, 2.84046827,  1.28779878, 1.84490263, 1.6949681, 0.05814582,
        1.30510732, -0.92332861,  3.00192177, -1.76077192
    ])
    good_score = (X @ random_linear) ** 2 / 2
    qs = np.quantile(good_score, (np.array(range(1, output_classes))) / output_classes)
    Y = np.digitize(good_score, qs)

    return X, A, Y

X, A, Y = generate_data()

In [112]:
print("Total:", [(Y == k).sum() for k in range(output_classes)])
print("A=0:", [((Y == k) & (A == 0)).sum() for k in range(output_classes)])
print("A=1:", [((Y == k) & (A == 1)).sum() for k in range(output_classes)])

Total: [2000, 2000, 2000, 2000, 2000]
A=0: [1376, 1313, 1168, 783, 360]
A=1: [624, 687, 832, 1217, 1640]


This last cell shows the total number of data points in each output category (it should be 2000 each) as well as a breakdown of each output category for the $A=0$ group and the $A=1$ group. Note that the $A=1$ group is much more likely to be assigned to the categories with higher index.

## Fairness Definition (C)

Let's write some code to measure the _demographic parity_ of our classifier: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $0 \le r < K$. In the the function below,

- `R` is a matrix where each row represents a probability distribution over the classes `0` to `K - 1`. That is, `R` is the output of our neural network _after_ a softmax layer.
- `A` is a vector of sensitive attributes. Each element is either `0` or `1`.

These functions should return an array of length `K` where each element of the array represents a measure of bias for _one_ of the output classes. For example, for demographic parity, the value in the output array at index `i` should be $P(R = i \mid A = 1) - P(R = i \mid A = 0)$.

In [119]:
def demographic_parity(R, A):
   return torch.mean(R[A == 0], axis=0) - torch.mean(R[A == 1], axis=0)

def equalized_odds(R, A, Y):
   return torch.mean(R[(A == 0) & (Y == 1)], axis=0) - torch.mean(R[(A == 1) & (Y == 1)], axis=0)

# def predictive_parity(R, A, Y):
#    print("R[A = 0]:", R[A==0].size(), "R[A == 1]:", R[A==1].size())
#    print("R[A = 0][Y == 1]:", R[A==0][Y == 1].size(), "R[A == 1][Y == 1]:", R[A==1][Y == 1].size())
#    return torch.mean(R[A == 0], axis=0) - torch.mean(Y[(A == 1) & (R == 1)], axis=0)

Now we'll train a classifier on this dataset without any fairness constraints for comparison. This code is already complete.

In [120]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(40, 256),
            nn.ReLU(),
            nn.Linear(256, 5)
        )

    def forward(self, x):
        return self.model(x)

In [131]:
def train_unfair(lr=1e-1, epochs=200, fairness=demographic_parity, attributes=A):
    
    network = MLP()
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    attributes = torch.tensor(attributes)
    
    for i in range(epochs):
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        opt.zero_grad()
        loss_val.backward()
        opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)
            if fairness == equalized_odds:
                bias = fairness(probs, attributes, data_out)
            else:
                bias = fairness(probs, attributes)


            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", ['%.4f' % b for b in bias])

    return network

In [132]:
print("Bias measure: demographic parity")
model_unfair_dp = train_unfair(lr=5e-1, epochs=300)

print("\nBias measure: equalized odds")
model_unfair_eq = train_unfair(lr=5e-1, epochs=300, fairness=equalized_odds)

Bias measure: demographic parity
Epoch: 19 Accuracy: 0.3240000009536743 Bias: ['0.1775', '0.1574', '0.1035', '-0.0441', '-0.3943']
Epoch: 39 Accuracy: 0.34940001368522644 Bias: ['0.1919', '0.1696', '0.0997', '-0.0665', '-0.3947']
Epoch: 59 Accuracy: 0.3919999897480011 Bias: ['0.1996', '0.1743', '0.0950', '-0.0703', '-0.3986']
Epoch: 79 Accuracy: 0.45730000734329224 Bias: ['0.1969', '0.1726', '0.0933', '-0.0552', '-0.4075']
Epoch: 99 Accuracy: 0.5371999740600586 Bias: ['0.1890', '0.1672', '0.0904', '-0.0415', '-0.4051']
Epoch: 119 Accuracy: 0.6010000109672546 Bias: ['0.1849', '0.1643', '0.0893', '-0.0446', '-0.3939']
Epoch: 139 Accuracy: 0.637499988079071 Bias: ['0.1830', '0.1612', '0.0907', '-0.0474', '-0.3875']
Epoch: 159 Accuracy: 0.6672000288963318 Bias: ['0.1818', '0.1583', '0.0932', '-0.0504', '-0.3829']
Epoch: 179 Accuracy: 0.7046999931335449 Bias: ['0.1745', '0.1505', '0.0962', '-0.0380', '-0.3831']
Epoch: 199 Accuracy: 0.761900007724762 Bias: ['0.1620', '0.1362', '0.0741', '-0.

In [133]:
p = model_unfair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2568, 1953, 1425, 3076, 978]
A=0: [1472, 1354, 726, 1263, 185]
A=1: [1096, 599, 699, 1813, 793]


This classifier is probably not going to be _extremely_ accurate, but you should be able to see the bias from the dataset reflected here.

## Fair Training (B)

Now we'll extend our fair training approach from the lab to the multiclass setting. Now since we have a bias measure for _each_ possible output class, we essentially have `output_classes` constraints that we need to satisfy. We can handle this within our Lagrange multiplier framework by simply adding extra multipliers for each constraint. That is, our new learning problem is

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i g_i(\beta) \right )
$$

$$
= \arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i \left ( P_\beta [ R = i \mid A = 1 ] - P_\beta [ R = i \mid A = 0 ] \right ) \right )
$$

Our `demographic_parity` function gives us a vector representing $g_i(\beta)$, so now all we need to do is replace our single parameter $\lambda$ from the lab with a vector then compute the dot product of $\lambda$ with our demographic parity measure.

In [134]:
def train_fair(lr=1e-1, lam_lr=1, epochs=200):
    
    network = MLP()
    lam = nn.Parameter(torch.zeros(output_classes))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        # Compute the Lagrangian loss L + lam * g
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        probs = nn.functional.softmax(preds, dim=1)
        bias = demographic_parity(probs, A)
        loss_val += (lam * bias).sum()
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)

            
            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", ['%.4f' % b for b in demographic_parity(probs, A)], "Lambda:", lam.max().item())

    return network

In [135]:
model_fair = train_fair(lr=5e-1, lam_lr=3e-1, epochs=300)

Epoch: 19 Accuracy: 0.22579999268054962 Bias: ['-0.0419', '-0.0279', '-0.0149', '0.0269', '0.0578'] Lambda: 0.23635146021842957
Epoch: 39 Accuracy: 0.2705000042915344 Bias: ['-0.0286', '-0.0238', '-0.0123', '0.0060', '0.0587'] Lambda: 0.2752276360988617
Epoch: 59 Accuracy: 0.2896000146865845 Bias: ['-0.0368', '-0.0296', '-0.0146', '0.0101', '0.0709'] Lambda: 0.3086240589618683
Epoch: 79 Accuracy: 0.3675999939441681 Bias: ['-0.0449', '-0.0326', '-0.0114', '0.0091', '0.0797'] Lambda: 0.3257610499858856
Epoch: 99 Accuracy: 0.41429999470710754 Bias: ['-0.0432', '-0.0282', '-0.0044', '0.0022', '0.0736'] Lambda: 0.3727899193763733
Epoch: 119 Accuracy: 0.4334999918937683 Bias: ['-0.0429', '-0.0263', '-0.0014', '-0.0021', '0.0728'] Lambda: 0.45134952664375305
Epoch: 139 Accuracy: 0.45019999146461487 Bias: ['-0.0428', '-0.0243', '0.0008', '-0.0046', '0.0709'] Lambda: 0.5334129333496094
Epoch: 159 Accuracy: 0.45989999175071716 Bias: ['-0.0429', '-0.0220', '0.0027', '-0.0065', '0.0687'] Lambda: 0

In [75]:
p = model_fair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [1731, 1843, 1449, 1939, 3038]
A=0: [1140, 971, 625, 803, 1461]
A=1: [591, 872, 824, 1136, 1577]


## Fair Training via KL-Divergence (A)

Let's look back at our definition of demographic parity for the multiclass setting: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $r$. we could also express this by asserting $P(\cdot \mid A = 0)$ and $P(\cdot \mid A = 1)$ should be identical probability distributions. A natural measure of bias then would be to compute the KL-divergence between these two distributions, since KL-divergence is a measure of how "different" two distributions are. That is, we'll now solve the problem

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \lambda D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \right )
$$

However, this introduces a new complication. The KL-divergence is never negative and can only be zero if the two distributions are identical (we proved this in our first homework of the semester). That means there's no way for $\lambda$ to ever decrease, and it will just go up forever. We can solve this by allowing a small deviation in our constrained optimization problem:

$$
\begin{align}
\arg\min_\beta &\ L(\beta) \\
\text{s.t.} &\ D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \le \epsilon
\end{align}
$$

We can still represent this using a Lagrange multiplier:

$$
\arg\min_\beta \max_{\lambda \ge 0} \left ( L(\beta) + \lambda \left ( D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) - \epsilon \right ) \right )
$$

Your task now is to represent this optimization problem in the code below. I've taken care of clipping $\lambda$ to zero for you since it's not something we've looked at in class.

In [137]:
def train_kl(lr=1e-1, lam_lr=1, epochs=300, epsilon=0.1):
    
    network = MLP()
    lam = nn.Parameter(torch.tensor(0.0))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        
        preds = network(data_in)
        probs = nn.functional.softmax(preds, dim = 1)

        # P = prob((R = r) | (A = 0))
        P = torch.mean(probs[A==0], axis = 0)

        # Q = prob((R = r) | (A = 1))
        Q = torch.mean(probs[A==1], axis = 0)

        # Summation
        kl_log = torch.log(P / Q)
        kl_div = torch.sum(P * kl_log)

        loss_val = loss(preds, data_out) + lam * (kl_div - epsilon) 
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()

        with torch.no_grad():
            lam.clamp_(min=0)

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            print("Epoch:", i, "Accuracy:", acc.item(), "Divergence:", kl_div.item(), "Lambda:", lam.item())

    return network

In [138]:
model = train_kl(lr=3e-1, lam_lr=1, epsilon=0.02)

Epoch: 19 Accuracy: 0.3160000145435333 Divergence: 0.0017289668321609497 Lambda: 0.8826927542686462
Epoch: 39 Accuracy: 0.3174000084400177 Divergence: 0.0015972619876265526 Lambda: 1.212470531463623
Epoch: 59 Accuracy: 0.32179999351501465 Divergence: 0.0022437660954892635 Lambda: 1.5176353454589844
Epoch: 79 Accuracy: 0.328000009059906 Divergence: 0.003057121764868498 Lambda: 1.7916920185089111
Epoch: 99 Accuracy: 0.34700000286102295 Divergence: 0.004950245842337608 Lambda: 2.0657811164855957
Epoch: 119 Accuracy: 0.3714999854564667 Divergence: 0.01033196784555912 Lambda: 2.398014545440674
Epoch: 139 Accuracy: 0.38269999623298645 Divergence: 0.016513975337147713 Lambda: 2.826582193374634
Epoch: 159 Accuracy: 0.41290000081062317 Divergence: 0.024347763508558273 Lambda: 3.4310302734375
Epoch: 179 Accuracy: 0.428600013256073 Divergence: 0.019872542470693588 Lambda: 4.006704807281494
Epoch: 199 Accuracy: 0.45649999380111694 Divergence: 0.013929116539657116 Lambda: 4.508157730102539
Epoch: 2

In [139]:
p = model(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2311, 1872, 1897, 1118, 2802]
A=0: [1399, 961, 1089, 594, 957]
A=1: [912, 911, 808, 524, 1845]
