In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

from tqdm import tqdm

import fairness_metrics

In [2]:
class Trainer:
    def __init__(self, model, accloss, fairloss, N, Na, tester, regularizer=2, lr=1e-2):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.accloss = accloss
        self.fairloss = fairloss
        self.tester = tester
        self.N = N
        self.Na = Na
        self.X = None
        self.y = None
        self.a = None
        self.regularizer = regularizer
        self.tester(self.model, 1)
    
    def update(self, X, y, a, regime):
        '''
        Perform model update
        X, y, a: torch.Tensor
        
        '''
        # update X, y, a
        if self.X == None:
            self.X = X
            self.y = y
            self.a = a
        else:
            self.X = torch.vstack((self.X, X))
            self.y = torch.hstack((self.y, y))
            self.a = torch.hstack((self.a, a))
            
        # update if enough data
        if (len(self.a) >= self.N) and ((1 - self.a).sum() >= self.Na[0]) and (self.a.sum() >= self.Na[1]):
            # perform training step
            self.optimizer.zero_grad()
            y_hat = self.model(self.X)
            y_hat_1 = y_hat[self.a==1]
            y_hat_0 = y_hat[self.a==0]
            loss = self.accloss(y_hat, self.y) + self.regularizer * self.fairloss(y_hat_1, y_hat_0)
            loss.backward()
            self.optimizer.step()
            
            # reset data
            self.X = None
            self.y = None
            self.a = None
        
        # perform tests
        self.tester(self.model, regime)

class TrainerDebiased:
    def __init__(self, model, accloss, fairloss, N, Na, tester, regularizer=2, lr=1e-2):
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.accloss = accloss
        self.fairloss = fairloss
        self.tester = tester
        self.N = N
        self.Na = Na
        self.X = None
        self.y = None
        self.a = None
        self.regularizer = regularizer
        self.tester(self.model, 1)
    
    def update(self, X, y, a, regime):
        '''
        Perform model update
        X, y, a: torch.Tensor
        '''
        # update X, y, a
        if self.X == None:
            self.X = X
            self.y = y
            self.a = a
        else:
            self.X = torch.vstack((self.X, X))
            self.y = torch.hstack((self.y, y))
            self.a = torch.hstack((self.a, a))
            
        # update if enough data
        if (len(self.a) >= self.N) and ((1 - self.a).sum() >= self.Na[0]) and (self.a.sum() >= self.Na[1]):
            # perform training step
            self.optimizer.zero_grad()
            y_hat = self.model(self.X)
            y_hat_1 = y_hat[self.a==1]
            y_hat_0 = y_hat[self.a==0]
            y_1 = self.y[self.a==1]
            y_0 = self.y[self.a==0]
            delta_1, delta_0 = 1, 1
            N = len(self.a)
            N_1 = self.a.sum()
            N_0 = N-N_1
            if N >= self.N:
                if N_1 == 2:
                    delta_1 = N/(2*(N-1))
                    delta_0 = N/((N-1))
                else:
                    delta_1 = N/((N-1))
                    delta_0 = N/(2*(N-1))
            weight_1 = (delta_1) * N_1/N
            weight_0 = (delta_0) * N_0/N
            accloss1 = self.accloss(y_hat_1, y_1)
            accloss0 = self.accloss(y_hat_0, y_0)
            loss = (weight_0 * accloss0 + weight_1 * accloss1) + self.regularizer * self.fairloss(y_hat_1, y_hat_0)
            loss.backward()
            self.optimizer.step()
            
            # reset data
            self.X = None
            self.y = None
            self.a = None
        
        # perform tests
        self.tester(self.model, regime)

In [3]:
class Tester:
    def __init__(self, X, y1, y2, a, metrics):
        self.X = X
        self.y1 = y1
        self.y2 = y2
        self.a = a
        
        self.metrics = metrics
        self.results = {k:[] for k in metrics.keys()}
        self.MSEs = []
        
    def test(self, model, regime):
        y = self.y1 if regime==1 else self.y2
        y_hat = model(self.X)
        MSE = ((y.flatten()-y_hat.flatten())**2).mean()
        self.MSEs.append(MSE.detach().numpy()[()])
        
        
        y_hat_1 = y_hat[self.a==1]
        y_hat_0 = y_hat[self.a==0]
        for m in self.metrics.keys():
            self.results[m].append(self.metrics[m](y_hat_1, y_hat_0).detach().numpy()[()])
        
    def __call__(self, model, regime):
        self.test(model, regime)

In [4]:
k = 10
N_iter = 4000

In [5]:
torch.manual_seed(0)

<torch._C.Generator at 0x1909babedd0>

In [6]:
slopes1 = 4*torch.rand((k, 5))-2
slopes2 = 4*torch.rand((k, 5))-2

In [7]:
(torch.randint(10,(1,))==0).double()

tensor([0.], dtype=torch.float64)

In [8]:
pa_inverse=50

In [9]:
def get_sample(k, regime):
    X = torch.rand((1,k-1))
    a = (torch.randint(pa_inverse,(1,))==0).float()
    X = torch.hstack((X,torch.unsqueeze(a, 1)))
    y1 = (X @ slopes1).max()
    y2 = (X @ slopes2).max()
    if regime==1:
        return X, y1, a
    else:
        return X, y2, a

In [10]:
def build_test_set(k, N_test=1000):
    X = torch.rand((N_test,k-1))
    a = (torch.randint(pa_inverse,(N_test,))==0).float()
    X = torch.hstack((X,torch.unsqueeze(a, 1)))
    y1 = (X @ slopes1).max(dim=1)[0]
    y2 = (X @ slopes2).max(dim=1)[0]
    return X, y1, y2, a

In [11]:
class NeuralNetwork(nn.Module):
    def __init__(self, k):
        super(NeuralNetwork, self).__init__()
        self.linear1 = torch.nn.Linear(k, 20, bias=True)
        self.linear2 = torch.nn.Linear(20, 1, bias=True)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        self.output = self.linear2(x)
        return self.output

In [12]:
def mse(yhat, y):
    return ((yhat.flatten()-y.flatten())**2).mean()

In [13]:
metrics = {
    'SPD': lambda y1, y2: fairness_metrics.statistical_parity(y1.flatten(), y2.flatten(), None, None),
    'ED': fairness_metrics.energy_distance,
    'WD': fairness_metrics.wasserstein_distance
}

In [14]:
def test(N=4, seed=0):
    test_set = build_test_set(k)
    tester_ED = Tester(*test_set, metrics)
    tester_ED_db = Tester(*test_set, metrics)
    tester_ED_fullbias = Tester(*test_set, metrics)

    torch.manual_seed(seed)
    trainer_ED = Trainer(NeuralNetwork(k), mse, fairness_metrics.energy_distance, N, [2,2], tester_ED)
    torch.manual_seed(seed)
    trainer_ED_db = TrainerDebiased(NeuralNetwork(k), mse, fairness_metrics.energy_distance, N, [2,2], tester_ED_db)
    torch.manual_seed(seed)
    trainer_ED_fullbias = TrainerDebiased(NeuralNetwork(k), mse, fairness_metrics.energy_distance_biased, N, [2,2], tester_ED_fullbias)
    
    regime = 1
    for i in tqdm(range(N_iter)):
        sample = get_sample(k, regime)
        trainer_ED.update(*sample, regime)
        trainer_ED_db.update(*sample, regime)
        trainer_ED_fullbias.update(*sample, regime)
        if i==N_iter/2:
            regime = 2

    return (trainer_ED.tester, trainer_ED_db.tester, trainer_ED_fullbias.tester)

In [None]:
EDs, dbs, fullbiases = [],[],[]
dbs, dbbigns = [],[]
for i in range(5):
    ED, db, fullbias = test(seed = i)
    EDs.append(ED)
    dbs.append(db)
    fullbiases.append(fullbias)

100%|██████████████████████████████████████████████████████████████████████████████| 4000/4000 [09:09<00:00,  7.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4000/4000 [09:01<00:00,  7.38it/s]
 46%|███████████████████████████████████▋                                          | 1831/4000 [04:18<05:00,  7.22it/s]

In [None]:
import seaborn as sns
import pandas as pd
import numpy as np

In [None]:

from matplotlib import rcParams
rcParams['font.family'] = 'serif'
rcParams['font.sans-serif'] = ['Times']

In [None]:
import pickle

In [None]:
np.save('dumps/slope1.npy', slopes1.numpy())
np.save('dumps/slope2.npy', slopes2.numpy())

In [None]:
plt.rcParams['text.usetex'] = False

In [None]:
df = pd.melt(pd.DataFrame([np.array(tester.MSEs)+ 2*np.array(tester.results['ED']) for tester in EDs]).T.reset_index(), id_vars='index')
df.columns = ['Time horizon', 'b', 'Loss']
sns.lineplot(data = df, 
             x="Time horizon", y="Loss", label='Debiased ED')

df = pd.melt(pd.DataFrame([np.array(tester.MSEs)+ 2*np.array(tester.results['ED']) for tester in dbs]).T.reset_index(), id_vars='index')
df.columns = ['Time horizon', 'b', 'Loss']
sns.lineplot(data = df, 
             x="Time horizon", y="Loss", label='Full Debias')

df = pd.melt(pd.DataFrame([np.array(tester.MSEs)+ 2*np.array(tester.results['ED']) for tester in fullbiases]).T.reset_index(), id_vars='index')
df.columns = ['Time horizon', 'b', 'Loss']
sns.lineplot(data = df, 
             x="Time horizon", y="Loss", label='Biased Loss and ED')


plt.yscale('log')
#plt.savefig('loss-comparison-bias.pdf')