# Bias and Constrained Learning Homework

In this homework we'll extend the constrained learning framework we used for mitigating bias in class to handle more complex situations. Specifically, we'll look at the case where the output prediction is not binary. As usual with these homeworks, there are three different levels which build on each other, each one corresponding to an increasing grade:

- The basic version of this homework involves implementing code to measure fairness over multiclass classification then measuring the results when using training a regular, unfair classifier. This version is good for a C.
- The B version of the homework involves training a classifier with some fairness constraints.
- For an A, we'll look at slightly more complicated approach to fair training.

First, we'll generate a dataset for which the sensitive attribute is binary and the output is multiclass.

In [1]:
import numpy as np
import torch
from torch import nn, optim

In [2]:
output_classes = 5

def generate_data():

    dataset_size = 10000
    dimensions = 40

    rng = np.random.default_rng()
    A = np.concatenate((np.zeros(dataset_size // 2), np.ones(dataset_size // 2)))
    rng.shuffle(A)
    X = rng.normal(loc=A[:,np.newaxis], scale=1, size=(dataset_size, dimensions))
    random_linear = np.array([
        -2.28156561, 0.24582547, -2.48926942, -0.02934924, 5.21382855, -1.08613209,
        2.51051602, 1.00773587, -2.10409448, 1.94385103, 0.76013416, -2.94430782,
        0.3289264, -4.35145624, 1.61342623, -1.28433588, -2.07859612, -1.53812125,
        0.51412713, -1.34310334, 4.67174476, 1.67269946, -2.07805413, 3.46667731,
        2.61486654, 1.75418209, -0.06773796, 0.7213423, 2.43896438, 1.79306807,
        -0.74610264, 2.84046827,  1.28779878, 1.84490263, 1.6949681, 0.05814582,
        1.30510732, -0.92332861,  3.00192177, -1.76077192
    ])
    good_score = (X @ random_linear) ** 2 / 2
    qs = np.quantile(good_score, (np.array(range(1, output_classes))) / output_classes)
    Y = np.digitize(good_score, qs)

    return X, A, Y

X, A, Y = generate_data()

In [4]:
print("Total:", [(Y == k).sum() for k in range(output_classes)])
print("A=0:", [((Y == k) & (A == 0)).sum() for k in range(output_classes)])
print("A=1:", [((Y == k) & (A == 1)).sum() for k in range(output_classes)])

Total: [2000, 2000, 2000, 2000, 2000]
A=0: [1393, 1301, 1086, 826, 394]
A=1: [607, 699, 914, 1174, 1606]


This last cell shows the total number of data points in each output category (it should be 2000 each) as well as a breakdown of each output category for the $A=0$ group and the $A=1$ group. Note that the $A=1$ group is much more likely to be assigned to the categories with higher index.

## Fairness Definition (C)

Let's write some code to measure the _demographic parity_ of our classifier: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $0 \le r < K$. In the the function below,

- `R` is a matrix where each row represents a probability distribution over the classes `0` to `K - 1`. That is, `R` is the output of our neural network _after_ a softmax layer.
- `A` is a vector of sensitive attributes. Each element is either `0` or `1`.

These functions should return an array of length `K` where each element of the array represents a measure of bias for _one_ of the output classes. For example, for demographic parity, the value in the output array at index `i` should be $P(R = i \mid A = 1) - P(R = i \mid A = 0)$.

In [31]:
# R is output of nn after softmax
# A is sensitive attributes

def demographic_parity(R, A):
   # R0 = R where corresponding index in A is 0
   # R1 same but for 1
   R0 = R[A == 0]
   R1 = R[A == 1]

   R0_av = torch.mean(R0, axis=0)
   R1_av = torch.mean(R1, axis=0)

   output = R1_av - R0_av

   return output

Now we'll train a classifier on this dataset without any fairness constraints for comparison. This code is already complete.

In [32]:
class MLP(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(40, 256),
            nn.ReLU(),
            nn.Linear(256, 5)
        )

    def forward(self, x):
        return self.model(x)

In [36]:
def train_unfair(lr=1e-1, epochs=200):
    
    network = MLP()
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        opt.zero_grad()
        loss_val.backward()
        opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)
            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", demographic_parity(probs, A).tolist())

    return network

In [37]:
model_unfair = train_unfair(lr=5e-1, epochs=300)

Epoch: 19 Accuracy: 0.33660000562667847 Bias: [-0.17264802753925323, -0.15133246779441833, -0.0737253874540329, 0.022097811102867126, 0.37560802698135376]
Epoch: 39 Accuracy: 0.35109999775886536 Bias: [-0.19326931238174438, -0.1664833426475525, -0.07638069987297058, 0.049103572964668274, 0.387029767036438]
Epoch: 59 Accuracy: 0.38769999146461487 Bias: [-0.2048610895872116, -0.17327344417572021, -0.07213155925273895, 0.0608389675617218, 0.38942718505859375]
Epoch: 79 Accuracy: 0.43860000371932983 Bias: [-0.20534522831439972, -0.1728249490261078, -0.06989467144012451, 0.0529661625623703, 0.3950986862182617]
Epoch: 99 Accuracy: 0.5105000138282776 Bias: [-0.1989070326089859, -0.16608558595180511, -0.06473247706890106, 0.042642876505851746, 0.38708221912384033]
Epoch: 119 Accuracy: 0.5820000171661377 Bias: [-0.19589506089687347, -0.16244608163833618, -0.06198650598526001, 0.04614550620317459, 0.3741821050643921]
Epoch: 139 Accuracy: 0.6328999996185303 Bias: [-0.19359979033470154, -0.1592728

In [38]:
p = model_unfair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2765, 2082, 1301, 2706, 1146]
A=0: [1452, 1365, 761, 1191, 231]
A=1: [1313, 717, 540, 1515, 915]


This classifier is probably not going to be _extremely_ accurate, but you should be able to see the bias from the dataset reflected here.

## Fair Training (B)

Now we'll extend our fair training approach from the lab to the multiclass setting. Now since we have a bias measure for _each_ possible output class, we essentially have `output_classes` constraints that we need to satisfy. We can handle this within our Lagrange multiplier framework by simply adding extra multipliers for each constraint. That is, our new learning problem is

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i g_i(\beta) \right )
$$

$$
= \arg\min_\beta \max_\lambda \left ( L(\beta) + \sum_i \lambda_i \left ( P_\beta [ R = i \mid A = 1 ] - P_\beta [ R = i \mid A = 0 ] \right ) \right )
$$

Our `demographic_parity` function gives us a vector representing $g_i(\beta)$, so now all we need to do is replace our single parameter $\lambda$ from the lab with a vector then compute the dot product of $\lambda$ with our demographic parity measure.

In [41]:
def train_fair(lr=1e-1, lam_lr=1, epochs=200):
    
    network = MLP()
    lam = nn.Parameter(torch.zeros(output_classes))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        # Compute the Lagrangian loss L + lam * g
        preds = network(data_in)
        loss_val = loss(preds, data_out)
        probs = nn.functional.softmax(preds, dim=1)
        bias = demographic_parity(probs, A)
        loss_val += (lam * bias).sum()
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            probs = nn.functional.softmax(preds, dim=1)
            print("Epoch:", i, "Accuracy:", acc.item(), "Bias:", demographic_parity(probs, A).tolist(), "Lambda:", lam.max().item())

    return network

In [42]:
model_fair = train_fair(lr=5e-1, lam_lr=3e-1, epochs=300)

Epoch: 19 Accuracy: 0.2777000069618225 Bias: [-0.03208662569522858, -0.02460445463657379, -0.015192016959190369, 0.006254330277442932, 0.06562879681587219] Lambda: 0.43797749280929565
Epoch: 39 Accuracy: 0.3305000066757202 Bias: [-0.044726788997650146, -0.03772194683551788, -0.017514914274215698, 0.015473887324333191, 0.08448979258537292] Lambda: 0.4614896774291992
Epoch: 59 Accuracy: 0.3986999988555908 Bias: [-0.048048824071884155, -0.03726594150066376, -0.010533913969993591, 0.02355434000492096, 0.07229435443878174] Lambda: 0.4849717319011688
Epoch: 79 Accuracy: 0.49149999022483826 Bias: [-0.06287842243909836, -0.046243295073509216, -0.009594187140464783, 0.021553561091423035, 0.09716236591339111] Lambda: 0.5110504031181335
Epoch: 99 Accuracy: 0.5742999911308289 Bias: [-0.0725829005241394, -0.04754588007926941, -0.0022411197423934937, 0.017890378832817078, 0.10447946190834045] Lambda: 0.6065122485160828
Epoch: 119 Accuracy: 0.6227999925613403 Bias: [-0.07398375123739243, -0.041786938

In [43]:
p = model_fair(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [2472, 1988, 936, 3501, 1103]
A=0: [1187, 940, 503, 1705, 665]
A=1: [1285, 1048, 433, 1796, 438]


## Fair Training via KL-Divergence (A)

Let's look back at our definition of demographic parity for the multiclass setting: $P(R = r \mid A = 0) = P(R = r \mid A = 1)$ for all possible output classes $r$. we could also express this by asserting $P(\cdot \mid A = 0)$ and $P(\cdot \mid A = 1)$ should be identical probability distributions. A natural measure of bias then would be to compute the KL-divergence between these two distributions, since KL-divergence is a measure of how "different" two distributions are. That is, we'll now solve the problem

$$
\arg\min_\beta \max_\lambda \left ( L(\beta) + \lambda D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \right )
$$

However, this introduces a new complication. The KL-divergence is never negative and can only be zero if the two distributions are identical (we proved this in our first homework of the semester). That means there's no way for $\lambda$ to ever decrease, and it will just go up forever. We can solve this by allowing a small deviation in our constrained optimization problem:

$$
\begin{align}
\arg\min_\beta &\ L(\beta) \\
\text{s.t.} &\ D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) \le \epsilon
\end{align}
$$

We can still represent this using a Lagrange multiplier:

$$
\arg\min_\beta \max_{\lambda \ge 0} \left ( L(\beta) + \lambda \left ( D_{\textrm{KL}} \left( P(\cdot \mid A = 0) \ \| \ P(\cdot \mid A = 1) \right) - \epsilon \right ) \right )
$$

Your task now is to represent this optimization problem in the code below. I've taken care of clipping $\lambda$ to zero for you since it's not something we've looked at in class.

In [44]:
def train_kl(lr=1e-1, lam_lr=1, epochs=300, epsilon=0.1):
    
    network = MLP()
    lam = nn.Parameter(torch.tensor(0.0))
    loss = nn.CrossEntropyLoss()
    opt = optim.SGD(network.parameters(), lr=lr)
    lam_opt = optim.SGD([lam], lr=lam_lr, maximize=True)
    data_in = torch.tensor(X).float()
    data_out = torch.tensor(Y)
    
    for i in range(epochs):
        
        preds = network(data_in)
        probs = nn.functional.softmax(preds, dim = 1)

        # P = prob((R = r) | (A = 0))
        P = torch.mean(probs[A==0], axis = 0)

        # Q = prob((R = r) | (A = 1))
        Q = torch.mean(probs[A==1], axis = 0)

        # Summation
        kl_log = torch.log(P / Q)
        kl_div = torch.sum(P * kl_log)

        loss_val = loss(preds, data_out) + lam * (kl_div - epsilon) 
        
        opt.zero_grad()
        lam_opt.zero_grad()
        loss_val.backward()
        opt.step()
        lam_opt.step()

        with torch.no_grad():
            lam.clamp_(min=0)

        if (i+1) % 20 == 0:
            acc = (preds.argmax(dim=1) == data_out).float().mean()
            print("Epoch:", i, "Accuracy:", acc.item(), "Divergence:", kl_div.item(), "Lambda:", lam.item())

    return network

In [45]:
model = train_kl(lr=3e-1, lam_lr=1, epsilon=0.02)

Epoch: 19 Accuracy: 0.31310001015663147 Divergence: 0.06018543243408203 Lambda: 0.9450139999389648
Epoch: 39 Accuracy: 0.33880001306533813 Divergence: 0.06900602579116821 Lambda: 1.2576714754104614
Epoch: 59 Accuracy: 0.3749000132083893 Divergence: 0.06540164351463318 Lambda: 1.568528652191162
Epoch: 79 Accuracy: 0.4171999990940094 Divergence: 0.06389471143484116 Lambda: 1.8622554540634155
Epoch: 99 Accuracy: 0.46720001101493835 Divergence: 0.06748935580253601 Lambda: 2.195268392562866
Epoch: 119 Accuracy: 0.5339000225067139 Divergence: 0.07535848766565323 Lambda: 2.6604878902435303
Epoch: 139 Accuracy: 0.5928000211715698 Divergence: 0.0779055804014206 Lambda: 3.2578675746917725
Epoch: 159 Accuracy: 0.6455000042915344 Divergence: 0.06048768013715744 Lambda: 3.7593770027160645
Epoch: 179 Accuracy: 0.6549000144004822 Divergence: 0.06961138546466827 Lambda: 4.4724955558776855
Epoch: 199 Accuracy: 0.6714000105857849 Divergence: 0.06775838136672974 Lambda: 4.896687030792236
Epoch: 219 Accur

In [46]:
p = model(torch.tensor(X).float()).argmax(dim=1)
print("Total:", [(p == k).sum().item() for k in range(output_classes)])
print("A=0:", [((p == k) & (A == 0)).sum().item() for k in range(output_classes)])
print("A=1:", [((p == k) & (A == 1)).sum().item() for k in range(output_classes)])

Total: [3582, 1840, 686, 1943, 1949]
A=0: [1381, 1138, 355, 1015, 1111]
A=1: [2201, 702, 331, 928, 838]
