In [1]:
%load_ext autoreload
%autoreload 2

import sys, os
import numpy as np
import math

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data as data

import math
import matplotlib.pyplot as plt

from argparse import Namespace

from fr_train import Generator, DiscriminatorF, DiscriminatorR, weights_init_normal, test_model
import pickle
import warnings
warnings.filterwarnings("ignore")

from aif360.datasets import AdultDataset, GermanDataset, BankDataset, CompasDataset, BinaryLabelDataset, CelebADataset
from aif360.metrics import ClassificationMetric
import pandas

from sklearn.preprocessing import scale, StandardScaler, MaxAbsScaler

min_max_scaler = MaxAbsScaler()
std_scaler = StandardScaler()

In [2]:
def train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    
    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train
    
    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val
    
    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test
    
    # Saves return values here
    test_result = [] 
    
    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch
    
    # Changes the input validation data to an appropriate shape for the training
    XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator(input_size, latent_size)
    discriminator_F = DiscriminatorF()
    discriminator_R = DiscriminatorR(input_size, latent_size)

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    XSY_val_data = XSY_val[:val]

    train_len = XS_train.shape[0]
    val_len = XSY_val.shape[0]

    # Ground truths using in Disriminator_R
    Tensor = torch.FloatTensor
    valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
    generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)
    

    r_weight = torch.ones_like(y_train, requires_grad=False).float()
    r_ones = torch.ones_like(y_train, requires_grad=False).float()

    for epoch in range(n_epochs):

        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 500:
            optimizer_G.zero_grad()

        gen_y = generator(XS_train)
        gen_data = torch.cat([XS_train, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()
        
        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s1_train)
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()
            
            
        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------

        optimizer_D_R.zero_grad()

        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()

        
        # ---------------------
        #  Updates Generator
        # ---------------------


        if epoch < 500 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, (y_train+1)/2)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            r_decision = discriminator_R(gen_data)
            r_gen = bce_loss(r_decision, generated)
            
            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
            if epoch % 100 == 0:
                loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
                a = 1/(1+torch.exp(-(loss_ratio-3)))
                b = 1-a
                r_weight_tmp = r_decision.detach().squeeze()
                r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s1_train, reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y_train.squeeze()+1)/2, reduction="none").squeeze()

            f_gen = torch.mean(f_cost*r_weight)
            g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

        if epoch % 200 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
                )

#     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')
    tmp = test_model(generator, XS_test, y_test, s1_test)
    test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])
    
    
    
    # TEST
    tp_priv, tn_priv, fp_priv, fn_priv, \
    tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

    gen_y = generator(XS_train)
    
    priv_idx = (s1_test==1).squeeze()
    positive_idx = y_val==1

    latent_val = (H(x_val))
    pred_test = W(latent_val)

    h_priv = (H(x_val[priv_idx]))
    h_unpriv = (H(x_val[~priv_idx]))

    h_positive = (H(x_val[positive_idx]))

    test_lb_priv = y_val[priv_idx]
    test_lb_unpriv = y_val[~priv_idx]

    pred_priv = W(h_priv)
    pred_unpriv = W(h_unpriv)

    y_val = y_val.cpu().detach().numpy()
    test_lb_priv = test_lb_priv.cpu().detach().numpy()
    test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

    try:
        pred_priv = pred_priv.argmax(1)
    except:
        pass
    try:
        pred_unpriv = pred_unpriv.argmax(1)
    except:
        pass


    tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
    fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
    tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
    fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

    tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
    fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
    tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
    fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

    tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
    tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
    tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

    fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
    fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
    fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

    acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                              fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
    acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
    acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


    print()
    print('overall TPR : {0:.3f}'.format( tpr_overall))
    print('priv TPR : {0:.3f}'.format( tpr_priv))
    print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
    print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
    print()
    print('overall FPR : {0:.3f}'.format( fpr_overall))
    print('priv FPR : {0:.3f}'.format( fpr_priv))
    print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
    print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
    print()
    print('overall ACC : {0:.3f}'.format( acc_overall))
    print('priv ACC : {0:.3f}'.format( acc_priv))
    print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
    print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))

    valid_pred = data_valid.copy(deepcopy=True)
    feature_size = valid_pred.features.shape[1]
    sens_loc = np.zeros(feature_size).astype(bool)
    sens_loc[sens_idx] = 1

    feature = valid_pred.features[:,~sens_loc] #data without sensitive
    feature = min_max_scaler.fit_transform(feature)

    valid_pred.labels = W((H(torch.tensor(feature).to(device)))).argmax(-1).cpu().numpy().reshape(-1,1)

    classified_metric = ClassificationMetric(data_valid,
                                                     valid_pred,
                                                     unprivileged_groups=unprivileged_groups,
                                                     privileged_groups=privileged_groups)


    print('balanced acc :' ,1/2*(classified_metric.true_positive_rate() + classified_metric.true_negative_rate()))
    print('disparate_impact :' ,classified_metric.disparate_impact())
    print('theil_index :' ,classified_metric.theil_index())
    print('statistical_parity_difference :' ,classified_metric.statistical_parity_difference())
    print('generalized_entropy_index : ', classified_metric.generalized_entropy_index())

    return test_result

In [3]:
np.random.seed(0)
#dataset = {'adult' : AdultDataset(), 'german' : GermanDataset(),'bank': BankDataset(),'compas' : CompasDataset(),'celeb': CelebADataset()}
dataset = {'adult' : AdultDataset(), 'german' : GermanDataset(),'bank': BankDataset(),'compas' : CompasDataset()}
dataset['german'].labels -= 1
sens_attr_dict = {'adult': ['sex', 'race'], 'german' : ['sex', 'age'], 'compas' : ['sex', 'race'], 'bank' : ['age'], 'celeb' : ['gender']}




In [5]:
data_name = 'adult'
protected_attribute_used = 1

In [6]:
if data_name == "adult":
    #dataset_orig = AdultDataset()
    # #dataset_orig = load_preproc_data_adult()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]
        sens_attr = 'race'

elif data_name == "german":
    #dataset_orig = GermanDataset()
    # #dataset_orig.labels = #dataset_orig.labels-1
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'

elif data_name == "compas":
    #dataset_orig = CompasDataset()
    # #dataset_orig = load_preproc_data_compas()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]
        sens_attr = 'race'

elif data_name == "bank":
    #dataset_orig = BankDataset()
    if protected_attribute_used == 1:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'
    else:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'
        
elif data_name == "meps":
    #dataset_orig = MEPSDataset19()
    privileged_groups = [{'RACE': 1}]
    unprivileged_groups = [{'RACE': 0}]
    sens_attr = 'RACE'

In [7]:
min_max_scaler = MaxAbsScaler()
std_scaler = StandardScaler()

In [8]:
sens_idx = dataset[data_name].feature_names.index(sens_attr)
num_sens = len(np.unique(dataset[data_name].features[:, sens_idx]))

#bs = 4000
workers = 16
device = 'cuda:2'
os.environ['CUDA_VISIBLE_DEVICES']='2'
hist = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}


feature_size = dataset[data_name].features.shape[1]
sens_loc = np.zeros(feature_size).astype(bool)
sens_loc[sens_idx] = 1

feature = dataset[data_name].features[:,~sens_loc] #data without sensitive
feature = min_max_scaler.fit_transform(feature)
# feature = std_scaler.fit_transform(feature)
dataset[data_name].features[:,~sens_loc] = feature

#self.sensitive = dataset[data_name].features[:,sens_loc].reshape(-1).astype(int)
#n_values = int(np.max(self.label) + 1)
#self.label = np.eye(n_values)[self.label.astype(int)].squeeze(1)
# dataset[data_name], _ = dataset[data_name].split([0.5], shuffle=True)
data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
data_valid, data_test = data_vt.split([0.5], shuffle=True)

In [9]:
sens_attr

'sex'

In [10]:
class Dataset(data.Dataset):
    def __init__(self, X, s, y):
        self.label = y
        self.feature = torch.cat([X, s], dim = 1)
        self.sensitive = s
        
    def __getitem__(self, idx):
        y = self.label[idx]
        x = self.feature[idx]
        a = self.sensitive[idx]
        
        return x, a, y
    
    def __len__(self):
        return len(self.label)
    


#### Poisoning

In [11]:
idx = np.where(data_train.features[:, sens_idx] == 1)[0]
poison_idx = np.random.permutation(idx)[:int(len(idx)/10)]

#label poisoning
data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

In [12]:
# a namespace object which contains some of the hyperparameters
opt = Namespace(num_train=2000, num_val1=200, num_val2=500, num_test=1000)

# num_train = opt.num_train
# num_val1 = opt.num_val1
# num_val2 = opt.num_val2
# num_test = opt.num_test

# X = np.load('X_synthetic.npy') # Input features
# y = np.load('y_synthetic.npy') # Original labels
# y_poi = np.load('y_poi.npy') # Poisoned train labels
# s1 = np.load('s1_synthetic.npy') # Sensitive features



In [13]:
X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
y_train = torch.FloatTensor(data_train.labels)
s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
y_val = torch.FloatTensor(data_valid.labels)
s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
y_test = torch.FloatTensor(data_test.labels)
s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

In [14]:
XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

In [15]:
data_train = Dataset(X_train, s1_train, y_train)
data_valid = Dataset(X_val, s1_val, y_val)
data_test = Dataset(X_test, s1_test, y_test)

trainloader = torch.utils.data.DataLoader(
    data_train,
    batch_size=128,
    shuffle=True,
    )
validloader = torch.utils.data.DataLoader(
    data_valid,
    batch_size=128,
    shuffle=True,
    drop_last = True
    )
testloader = torch.utils.data.DataLoader(
    data_test,
    batch_size=128,
    shuffle=True,
    drop_last = True
    )

In [16]:
input_size = XS_train.shape[-1]
latent_size = 512
hidden_size = 512

In [37]:
#COMPAS setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    results_dict[poi_ratio] = {}
    results_dict[poi_ratio]['EqOdds'] =[]
    results_dict[poi_ratio]['Acc'] = []
    results_dict[poi_ratio]['DISP'] = []
    results_dict[poi_ratio]['EqOpp']= []
    results_dict[poi_ratio]['Acc_diff'] = []

    results_dict[poi_ratio]['total_tpr'] =[]
    results_dict[poi_ratio]['total_fpr'] =[]

    results_dict[poi_ratio]['total_tpr_priv'] =[]
    results_dict[poi_ratio]['total_tpr_unpriv'] =[]
    results_dict[poi_ratio]['total_fpr_priv'] =[]
    results_dict[poi_ratio]['total_fpr_unpriv'] =[]
    
    for repeat in range(5):
        data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
        data_valid, data_test = data_vt.split([0.5], shuffle=True)

        idx = np.where(data_train.features[:, sens_idx] == 1)[0]
        poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

        #label poisoning
        data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

        X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
        y_train = torch.FloatTensor(data_train.labels)
        s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

        X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
        y_val = torch.FloatTensor(data_valid.labels)
        s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

        X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
        y_test = torch.FloatTensor(data_test.labels)
        s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

        XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
        XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
        XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

        train_result = []
        train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
        val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
        test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

        train_opt = Namespace(val=len(y_val), n_epochs=50, k=1, lr_g=1e-3, lr_f=1e-4, lr_r=1e-4)
        seed = 1

        lambda_f_set = [0.1] # Lambda value for the fairness discriminator of FR-Train.
        lambda_r = 0.1 # Lambda value for the robustness discriminator of FR-Train.

        data_train = Dataset(X_train, s1_train, y_train)
        data_valid = Dataset(X_val, s1_val, y_val)
        data_test = Dataset(X_test, s1_test, y_test)

        trainloader = torch.utils.data.DataLoader(
            data_train,
            batch_size=128,
            shuffle=True,
            )
        validloader = torch.utils.data.DataLoader(
            data_valid,
            batch_size=128,
            shuffle=True,
            drop_last = True
            )
        testloader = torch.utils.data.DataLoader(
            data_test,
            batch_size=128,
            shuffle=True,
            drop_last = True
            )


        for lambda_f in lambda_f_set:
            train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)



[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.579390] [D_R loss: 0.224440] [G loss: 31.413197]
VALID DATA
overall TPR : 0.005
priv TPR : 0.000
unpriv TPR : 0.017
Eq. Opp : 0.017

overall FPR : 0.010
priv FPR : 0.021
unpriv FPR : 0.003
diff FPR : 0.018

overall ACC : 0.533
priv ACC : 0.608
unpriv ACC : 0.495
diff ACC : 0.113



TEST DATA
overall TPR : 0.010
priv TPR : 0.003
unpriv TPR : 0.031
Eq. Opp : 0.027

overall FPR : 0.010
priv FPR : 0.021
unpriv FPR : 0.003
diff FPR : 0.018

overall ACC : 0.552
priv ACC : 0.655
unpriv ACC : 0.504
diff ACC : 0.151



DIMP : 0.135
[Lambda: 0.100000] [Epoch 20/50] [D_F loss: 0.649528] [D_R loss: 0.812184] [G loss: 3.676873]
VALID DATA
overall TPR : 0.494
priv TPR : 0.532
unpriv TPR : 0.397
Eq. Opp : 0.136

overall FPR : 0.313
priv FPR : 0.291
unpriv FPR : 0.328
diff FPR : 0.037

overall ACC : 0.598
priv ACC : 0.590
unpriv ACC : 0.602
diff ACC : 0.012



TEST DATA
overall TPR : 0.519
priv TPR : 0.545
unpriv TPR : 0.440
Eq. Opp : 0.105

overall FPR : 

overall TPR : 0.532
priv TPR : 0.564
unpriv TPR : 0.454
Eq. Opp : 0.111

overall FPR : 0.237
priv FPR : 0.233
unpriv FPR : 0.239
diff FPR : 0.006

overall ACC : 0.658
priv ACC : 0.646
unpriv ACC : 0.665
diff ACC : 0.019



DIMP : 0.800
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.646568] [D_R loss: 0.375611] [G loss: 9.439654]
VALID DATA
overall TPR : 0.323
priv TPR : 0.286
unpriv TPR : 0.403
Eq. Opp : 0.117

overall FPR : 0.244
priv FPR : 0.379
unpriv FPR : 0.163
diff FPR : 0.216

overall ACC : 0.551
priv ACC : 0.527
unpriv ACC : 0.564
diff ACC : 0.037



TEST DATA
overall TPR : 0.329
priv TPR : 0.257
unpriv TPR : 0.496
Eq. Opp : 0.239

overall FPR : 0.188
priv FPR : 0.279
unpriv FPR : 0.128
diff FPR : 0.151

overall ACC : 0.599
priv ACC : 0.636
unpriv ACC : 0.579
diff ACC : 0.057



DIMP : 0.526
[Lambda: 0.100000] [Epoch 20/50] [D_F loss: 0.693814] [D_R loss: 1.033258] [G loss: 0.202464]
VALID DATA
overall TPR : 0.703
priv TPR : 0.740
unpriv TPR : 0.622
Eq. Opp : 0.118

overall FPR 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.684816] [D_R loss: 0.781322] [G loss: 0.340575]
VALID DATA
overall TPR : 0.417
priv TPR : 0.423
unpriv TPR : 0.405
Eq. Opp : 0.018

overall FPR : 0.276
priv FPR : 0.273
unpriv FPR : 0.279
diff FPR : 0.006

overall ACC : 0.592
priv ACC : 0.609
unpriv ACC : 0.581
diff ACC : 0.028



TEST DATA
overall TPR : 0.441
priv TPR : 0.465
unpriv TPR : 0.375
Eq. Opp : 0.090

overall FPR : 0.269
priv FPR : 0.220
unpriv FPR : 0.300
diff FPR : 0.079

overall ACC : 0.597
priv ACC : 0.628
unpriv ACC : 0.582
diff ACC : 0.046



DIMP : 0.727
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.693330] [D_R loss: 0.461384] [G loss: 13.017981]
VALID DATA
overall TPR : 0.103
priv TPR : 0.114
unpriv TPR : 0.076
Eq. Opp : 0.038

overall FPR : 0.079
priv FPR : 0.045
unpriv FPR : 0.100
diff FPR : 0.055

overall ACC : 0.532
priv ACC : 0.604
unpriv ACC : 0.497
diff ACC : 0.107



TEST DATA
overall TPR : 0.073
priv TPR : 0.091
unpriv TPR : 0.033
Eq. Opp : 0.058

overall FPR : 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.688837] [D_R loss: 0.737252] [G loss: 0.701038]
VALID DATA
overall TPR : 0.525
priv TPR : 0.597
unpriv TPR : 0.347
Eq. Opp : 0.249

overall FPR : 0.301
priv FPR : 0.206
unpriv FPR : 0.361
diff FPR : 0.155

overall ACC : 0.619
priv ACC : 0.622
unpriv ACC : 0.618
diff ACC : 0.004



TEST DATA
overall TPR : 0.515
priv TPR : 0.615
unpriv TPR : 0.316
Eq. Opp : 0.299

overall FPR : 0.269
priv FPR : 0.160
unpriv FPR : 0.338
diff FPR : 0.178

overall ACC : 0.633
priv ACC : 0.620
unpriv ACC : 0.640
diff ACC : 0.019



DIMP : 0.481
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.623766] [D_R loss: 0.677797] [G loss: 12.309790]
VALID DATA
overall TPR : 0.153
priv TPR : 0.099
unpriv TPR : 0.285
Eq. Opp : 0.186

overall FPR : 0.215
priv FPR : 0.398
unpriv FPR : 0.099
diff FPR : 0.299

overall ACC : 0.491
priv ACC : 0.476
unpriv ACC : 0.499
diff ACC : 0.023



TEST DATA
overall TPR : 0.125
priv TPR : 0.062
unpriv TPR : 0.301
Eq. Opp : 0.238

overall FPR : 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.689116] [D_R loss: 0.752594] [G loss: 0.206800]
VALID DATA
overall TPR : 0.670
priv TPR : 0.656
unpriv TPR : 0.708
Eq. Opp : 0.052

overall FPR : 0.467
priv FPR : 0.574
unpriv FPR : 0.395
diff FPR : 0.179

overall ACC : 0.598
priv ACC : 0.532
unpriv ACC : 0.632
diff ACC : 0.100



TEST DATA
overall TPR : 0.663
priv TPR : 0.628
unpriv TPR : 0.744
Eq. Opp : 0.117

overall FPR : 0.520
priv FPR : 0.660
unpriv FPR : 0.429
diff FPR : 0.231

overall ACC : 0.566
priv ACC : 0.505
unpriv ACC : 0.599
diff ACC : 0.095



DIMP : 0.762
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.690773] [D_R loss: 0.319238] [G loss: 32.307858]
VALID DATA
overall TPR : 0.079
priv TPR : 0.110
unpriv TPR : 0.000
Eq. Opp : 0.110

overall FPR : 0.050
priv FPR : 0.000
unpriv FPR : 0.078
diff FPR : 0.078

overall ACC : 0.541
priv ACC : 0.584
unpriv ACC : 0.521
diff ACC : 0.064



TEST DATA
overall TPR : 0.062
priv TPR : 0.091
unpriv TPR : 0.000
Eq. Opp : 0.091

overall FPR : 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.691099] [D_R loss: 0.847971] [G loss: 0.450386]
VALID DATA
overall TPR : 0.603
priv TPR : 0.649
unpriv TPR : 0.500
Eq. Opp : 0.149

overall FPR : 0.404
priv FPR : 0.385
unpriv FPR : 0.414
diff FPR : 0.028

overall ACC : 0.599
priv ACC : 0.568
unpriv ACC : 0.615
diff ACC : 0.047



TEST DATA
overall TPR : 0.610
priv TPR : 0.625
unpriv TPR : 0.569
Eq. Opp : 0.057

overall FPR : 0.363
priv FPR : 0.380
unpriv FPR : 0.352
diff FPR : 0.028

overall ACC : 0.625
priv ACC : 0.601
unpriv ACC : 0.637
diff ACC : 0.036



DIMP : 0.923
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.606798] [D_R loss: 2.057556] [G loss: 15.360923]
VALID DATA
overall TPR : 0.213
priv TPR : 0.096
unpriv TPR : 0.518
Eq. Opp : 0.422

overall FPR : 0.201
priv FPR : 0.422
unpriv FPR : 0.081
diff FPR : 0.341

overall ACC : 0.536
priv ACC : 0.554
unpriv ACC : 0.527
diff ACC : 0.027



TEST DATA
overall TPR : 0.221
priv TPR : 0.117
unpriv TPR : 0.491
Eq. Opp : 0.375

overall FPR : 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.688069] [D_R loss: 0.942228] [G loss: 0.229839]
VALID DATA
overall TPR : 0.663
priv TPR : 0.718
unpriv TPR : 0.536
Eq. Opp : 0.182

overall FPR : 0.508
priv FPR : 0.528
unpriv FPR : 0.495
diff FPR : 0.033

overall ACC : 0.571
priv ACC : 0.497
unpriv ACC : 0.613
diff ACC : 0.116



TEST DATA
overall TPR : 0.690
priv TPR : 0.742
unpriv TPR : 0.564
Eq. Opp : 0.178

overall FPR : 0.546
priv FPR : 0.546
unpriv FPR : 0.547
diff FPR : 0.001

overall ACC : 0.559
priv ACC : 0.497
unpriv ACC : 0.591
diff ACC : 0.094



DIMP : 0.864
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.652719] [D_R loss: 0.567049] [G loss: 0.188373]
VALID DATA
overall TPR : 0.273
priv TPR : 0.266
unpriv TPR : 0.294
Eq. Opp : 0.028

overall FPR : 0.130
priv FPR : 0.156
unpriv FPR : 0.115
diff FPR : 0.041

overall ACC : 0.587
priv ACC : 0.627
unpriv ACC : 0.569
diff ACC : 0.057



TEST DATA
overall TPR : 0.266
priv TPR : 0.223
unpriv TPR : 0.354
Eq. Opp : 0.131

overall FPR : 0

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.688653] [D_R loss: 0.860867] [G loss: 0.214796]
VALID DATA
overall TPR : 0.490
priv TPR : 0.492
unpriv TPR : 0.486
Eq. Opp : 0.007

overall FPR : 0.390
priv FPR : 0.495
unpriv FPR : 0.326
diff FPR : 0.169

overall ACC : 0.556
priv ACC : 0.497
unpriv ACC : 0.590
diff ACC : 0.093



TEST DATA
overall TPR : 0.477
priv TPR : 0.480
unpriv TPR : 0.469
Eq. Opp : 0.011

overall FPR : 0.328
priv FPR : 0.438
unpriv FPR : 0.264
diff FPR : 0.174

overall ACC : 0.581
priv ACC : 0.526
unpriv ACC : 0.608
diff ACC : 0.082



DIMP : 0.826
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.691020] [D_R loss: 0.208459] [G loss: 30.432570]
VALID DATA
overall TPR : 0.002
priv TPR : 0.003
unpriv TPR : 0.000
Eq. Opp : 0.003

overall FPR : 0.006
priv FPR : 0.000
unpriv FPR : 0.010
diff FPR : 0.010

overall ACC : 0.537
priv ACC : 0.626
unpriv ACC : 0.492
diff ACC : 0.134



TEST DATA
overall TPR : 0.002
priv TPR : 0.003
unpriv TPR : 0.000
Eq. Opp : 0.003

overall FPR : 

[Lambda: 0.100000] [Epoch 50/50] [D_F loss: 0.688557] [D_R loss: 0.901415] [G loss: 0.533223]
VALID DATA
overall TPR : 0.654
priv TPR : 0.645
unpriv TPR : 0.675
Eq. Opp : 0.030

overall FPR : 0.553
priv FPR : 0.618
unpriv FPR : 0.510
diff FPR : 0.107

overall ACC : 0.544
priv ACC : 0.495
unpriv ACC : 0.569
diff ACC : 0.074



TEST DATA
overall TPR : 0.693
priv TPR : 0.689
unpriv TPR : 0.702
Eq. Opp : 0.012

overall FPR : 0.480
priv FPR : 0.563
unpriv FPR : 0.434
diff FPR : 0.129

overall ACC : 0.598
priv ACC : 0.547
unpriv ACC : 0.624
diff ACC : 0.077



DIMP : 0.892
[Lambda: 0.100000] [Epoch 10/50] [D_F loss: 0.610762] [D_R loss: 1.021801] [G loss: 16.013720]
VALID DATA
overall TPR : 0.247
priv TPR : 0.226
unpriv TPR : 0.299
Eq. Opp : 0.073

overall FPR : 0.224
priv FPR : 0.310
unpriv FPR : 0.170
diff FPR : 0.140

overall ACC : 0.535
priv ACC : 0.539
unpriv ACC : 0.532
diff ACC : 0.007



TEST DATA
overall TPR : 0.236
priv TPR : 0.182
unpriv TPR : 0.383
Eq. Opp : 0.202

overall FPR : 

In [50]:
#Adult setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    results_dict[poi_ratio] = {}
    results_dict[poi_ratio]['EqOdds'] =[]
    results_dict[poi_ratio]['Acc'] = []
    results_dict[poi_ratio]['DISP'] = []
    results_dict[poi_ratio]['EqOpp']= []
    results_dict[poi_ratio]['Acc_diff'] = []

    results_dict[poi_ratio]['total_tpr'] =[]
    results_dict[poi_ratio]['total_fpr'] =[]

    results_dict[poi_ratio]['total_tpr_priv'] =[]
    results_dict[poi_ratio]['total_tpr_unpriv'] =[]
    results_dict[poi_ratio]['total_fpr_priv'] =[]
    results_dict[poi_ratio]['total_fpr_unpriv'] =[]
    
    for repeat in range(5):
        data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
        data_valid, data_test = data_vt.split([0.5], shuffle=True)

        idx = np.where(data_train.features[:, sens_idx] == 1)[0]
        poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

        #label poisoning
        data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

        X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
        y_train = torch.FloatTensor(data_train.labels)
        s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

        X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
        y_val = torch.FloatTensor(data_valid.labels)
        s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

        X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
        y_test = torch.FloatTensor(data_test.labels)
        s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

        XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
        XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
        XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

        train_result = []
        train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
        val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
        test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

        train_opt = Namespace(val=len(y_val), n_epochs=100, k=1, lr_g=5e-4, lr_f=1e-5, lr_r=1e-5)
        seed = 1

        lambda_f_set = [0.3] # Lambda value for the fairness discriminator of FR-Train.
        lambda_r = 0.1 # Lambda value for the robustness discriminator of FR-Train.

        data_train = Dataset(X_train, s1_train, y_train)
        data_valid = Dataset(X_val, s1_val, y_val)
        data_test = Dataset(X_test, s1_test, y_test)

        trainloader = torch.utils.data.DataLoader(
            data_train,
            batch_size=128,
            shuffle=True,
            )
        validloader = torch.utils.data.DataLoader(
            data_valid,
            batch_size=128,
            shuffle=True,
            drop_last = True
            )
        testloader = torch.utils.data.DataLoader(
            data_test,
            batch_size=128,
            shuffle=True,
            drop_last = True
            )


        for lambda_f in lambda_f_set:
            train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)



[Lambda: 0.300000] [Epoch 10/100] [D_F loss: 0.698578] [D_R loss: 0.908454] [G loss: 0.089615]
VALID DATA
overall TPR : 0.500
priv TPR : 0.351
unpriv TPR : 0.526
Eq. Opp : 0.175

overall FPR : 0.123
priv FPR : 0.163
unpriv FPR : 0.061
diff FPR : 0.102

overall ACC : 0.784
priv ACC : 0.739
unpriv ACC : 0.873
diff ACC : 0.134



TEST DATA
overall TPR : 0.503
priv TPR : 0.321
unpriv TPR : 0.534
Eq. Opp : 0.213

overall FPR : 0.117
priv FPR : 0.152
unpriv FPR : 0.062
diff FPR : 0.090

overall ACC : 0.788
priv ACC : 0.748
unpriv ACC : 0.869
diff ACC : 0.121



DIMP : 0.334
[Lambda: 0.300000] [Epoch 20/100] [D_F loss: 0.673877] [D_R loss: 0.866214] [G loss: 0.017246]
VALID DATA
overall TPR : 0.567
priv TPR : 0.442
unpriv TPR : 0.589
Eq. Opp : 0.147

overall FPR : 0.144
priv FPR : 0.203
unpriv FPR : 0.053
diff FPR : 0.150

overall ACC : 0.785
priv ACC : 0.732
unpriv ACC : 0.892
diff ACC : 0.161



TEST DATA
overall TPR : 0.596
priv TPR : 0.431
unpriv TPR : 0.624
Eq. Opp : 0.193

overall FPR :

TEST DATA
overall TPR : 0.616
priv TPR : 0.597
unpriv TPR : 0.620
Eq. Opp : 0.023

overall FPR : 0.146
priv FPR : 0.179
unpriv FPR : 0.091
diff FPR : 0.088

overall ACC : 0.794
priv ACC : 0.757
unpriv ACC : 0.873
diff ACC : 0.116



DIMP : 0.469
[Lambda: 0.300000] [Epoch 60/100] [D_F loss: 0.677994] [D_R loss: 0.931309] [G loss: -0.106376]
VALID DATA
overall TPR : 0.686
priv TPR : 0.612
unpriv TPR : 0.700
Eq. Opp : 0.088

overall FPR : 0.178
priv FPR : 0.232
unpriv FPR : 0.096
diff FPR : 0.135

overall ACC : 0.788
priv ACC : 0.747
unpriv ACC : 0.871
diff ACC : 0.124



TEST DATA
overall TPR : 0.686
priv TPR : 0.656
unpriv TPR : 0.692
Eq. Opp : 0.035

overall FPR : 0.187
priv FPR : 0.237
unpriv FPR : 0.105
diff FPR : 0.132

overall ACC : 0.781
priv ACC : 0.740
unpriv ACC : 0.867
diff ACC : 0.127



DIMP : 0.443
[Lambda: 0.300000] [Epoch 70/100] [D_F loss: 0.672008] [D_R loss: 0.715883] [G loss: -0.101417]
VALID DATA
overall TPR : 0.619
priv TPR : 0.638
unpriv TPR : 0.615
Eq. Opp : 0.023

[Lambda: 0.300000] [Epoch 100/100] [D_F loss: 0.662772] [D_R loss: 0.593504] [G loss: -0.035811]
VALID DATA
overall TPR : 0.645
priv TPR : 0.558
unpriv TPR : 0.660
Eq. Opp : 0.102

overall FPR : 0.148
priv FPR : 0.203
unpriv FPR : 0.055
diff FPR : 0.148

overall ACC : 0.800
priv ACC : 0.755
unpriv ACC : 0.899
diff ACC : 0.145



TEST DATA
overall TPR : 0.614
priv TPR : 0.512
unpriv TPR : 0.633
Eq. Opp : 0.121

overall FPR : 0.134
priv FPR : 0.186
unpriv FPR : 0.050
diff FPR : 0.136

overall ACC : 0.805
priv ACC : 0.758
unpriv ACC : 0.900
diff ACC : 0.142



DIMP : 0.318
[Lambda: 0.300000] [Epoch 10/100] [D_F loss: 0.702828] [D_R loss: 1.753486] [G loss: -0.183426]
VALID DATA
overall TPR : 0.597
priv TPR : 0.625
unpriv TPR : 0.592
Eq. Opp : 0.032

overall FPR : 0.084
priv FPR : 0.108
unpriv FPR : 0.046
diff FPR : 0.062

overall ACC : 0.837
priv ACC : 0.798
unpriv ACC : 0.916
diff ACC : 0.118



TEST DATA
overall TPR : 0.599
priv TPR : 0.600
unpriv TPR : 0.599
Eq. Opp : 0.001

overall FP

TEST DATA
overall TPR : 0.681
priv TPR : 0.604
unpriv TPR : 0.694
Eq. Opp : 0.090

overall FPR : 0.162
priv FPR : 0.221
unpriv FPR : 0.067
diff FPR : 0.154

overall ACC : 0.800
priv ACC : 0.752
unpriv ACC : 0.896
diff ACC : 0.143



DIMP : 0.346
[Lambda: 0.300000] [Epoch 50/100] [D_F loss: 0.685944] [D_R loss: 0.860595] [G loss: -0.018658]
VALID DATA
overall TPR : 0.711
priv TPR : 0.525
unpriv TPR : 0.746
Eq. Opp : 0.221

overall FPR : 0.190
priv FPR : 0.267
unpriv FPR : 0.064
diff FPR : 0.204

overall ACC : 0.785
priv ACC : 0.737
unpriv ACC : 0.886
diff ACC : 0.149



TEST DATA
overall TPR : 0.706
priv TPR : 0.528
unpriv TPR : 0.737
Eq. Opp : 0.209

overall FPR : 0.171
priv FPR : 0.248
unpriv FPR : 0.048
diff FPR : 0.201

overall ACC : 0.799
priv ACC : 0.747
unpriv ACC : 0.905
diff ACC : 0.158



DIMP : 0.254
[Lambda: 0.300000] [Epoch 60/100] [D_F loss: 0.687426] [D_R loss: 0.986413] [G loss: -0.018612]
VALID DATA
overall TPR : 0.703
priv TPR : 0.546
unpriv TPR : 0.732
Eq. Opp : 0.187

[Lambda: 0.300000] [Epoch 90/100] [D_F loss: 0.647508] [D_R loss: 0.794022] [G loss: -0.039837]
VALID DATA
overall TPR : 0.704
priv TPR : 0.574
unpriv TPR : 0.728
Eq. Opp : 0.154

overall FPR : 0.217
priv FPR : 0.316
unpriv FPR : 0.054
diff FPR : 0.263

overall ACC : 0.764
priv ACC : 0.698
unpriv ACC : 0.902
diff ACC : 0.204



TEST DATA
overall TPR : 0.727
priv TPR : 0.627
unpriv TPR : 0.745
Eq. Opp : 0.118

overall FPR : 0.219
priv FPR : 0.322
unpriv FPR : 0.054
diff FPR : 0.267

overall ACC : 0.767
priv ACC : 0.699
unpriv ACC : 0.907
diff ACC : 0.207



DIMP : 0.274
[Lambda: 0.300000] [Epoch 100/100] [D_F loss: 0.654826] [D_R loss: 0.788025] [G loss: -0.115382]
VALID DATA
overall TPR : 0.635
priv TPR : 0.595
unpriv TPR : 0.642
Eq. Opp : 0.047

overall FPR : 0.212
priv FPR : 0.302
unpriv FPR : 0.065
diff FPR : 0.237

overall ACC : 0.750
priv ACC : 0.680
unpriv ACC : 0.895
diff ACC : 0.214



TEST DATA
overall TPR : 0.667
priv TPR : 0.601
unpriv TPR : 0.679
Eq. Opp : 0.078

overall FP

TEST DATA
overall TPR : 0.646
priv TPR : 0.498
unpriv TPR : 0.670
Eq. Opp : 0.172

overall FPR : 0.183
priv FPR : 0.268
unpriv FPR : 0.045
diff FPR : 0.223

overall ACC : 0.775
priv ACC : 0.713
unpriv ACC : 0.906
diff ACC : 0.194



DIMP : 0.237
[Lambda: 0.300000] [Epoch 40/100] [D_F loss: 0.670807] [D_R loss: 1.142240] [G loss: -0.015756]
VALID DATA
overall TPR : 0.704
priv TPR : 0.569
unpriv TPR : 0.729
Eq. Opp : 0.160

overall FPR : 0.200
priv FPR : 0.281
unpriv FPR : 0.061
diff FPR : 0.220

overall ACC : 0.776
priv ACC : 0.722
unpriv ACC : 0.892
diff ACC : 0.170



TEST DATA
overall TPR : 0.704
priv TPR : 0.579
unpriv TPR : 0.724
Eq. Opp : 0.145

overall FPR : 0.192
priv FPR : 0.279
unpriv FPR : 0.051
diff FPR : 0.228

overall ACC : 0.783
priv ACC : 0.722
unpriv ACC : 0.910
diff ACC : 0.187



DIMP : 0.257
[Lambda: 0.300000] [Epoch 50/100] [D_F loss: 0.675710] [D_R loss: 0.956856] [G loss: 0.097758]
VALID DATA
overall TPR : 0.655
priv TPR : 0.500
unpriv TPR : 0.682
Eq. Opp : 0.182


[Lambda: 0.300000] [Epoch 80/100] [D_F loss: 0.664997] [D_R loss: 0.753572] [G loss: -0.038445]
VALID DATA
overall TPR : 0.714
priv TPR : 0.598
unpriv TPR : 0.732
Eq. Opp : 0.133

overall FPR : 0.200
priv FPR : 0.282
unpriv FPR : 0.067
diff FPR : 0.215

overall ACC : 0.778
priv ACC : 0.723
unpriv ACC : 0.898
diff ACC : 0.175



TEST DATA
overall TPR : 0.711
priv TPR : 0.640
unpriv TPR : 0.724
Eq. Opp : 0.085

overall FPR : 0.209
priv FPR : 0.293
unpriv FPR : 0.082
diff FPR : 0.211

overall ACC : 0.771
priv ACC : 0.712
unpriv ACC : 0.886
diff ACC : 0.174



DIMP : 0.341
[Lambda: 0.300000] [Epoch 90/100] [D_F loss: 0.672739] [D_R loss: 0.893999] [G loss: 0.091141]
VALID DATA
overall TPR : 0.734
priv TPR : 0.597
unpriv TPR : 0.755
Eq. Opp : 0.158

overall FPR : 0.232
priv FPR : 0.325
unpriv FPR : 0.082
diff FPR : 0.243

overall ACC : 0.759
priv ACC : 0.701
unpriv ACC : 0.884
diff ACC : 0.183



TEST DATA
overall TPR : 0.744
priv TPR : 0.636
unpriv TPR : 0.765
Eq. Opp : 0.129

overall FPR 

TEST DATA
overall TPR : 0.706
priv TPR : 0.500
unpriv TPR : 0.740
Eq. Opp : 0.240

overall FPR : 0.257
priv FPR : 0.392
unpriv FPR : 0.043
diff FPR : 0.349

overall ACC : 0.734
priv ACC : 0.650
unpriv ACC : 0.907
diff ACC : 0.257



DIMP : 0.185
[Lambda: 0.300000] [Epoch 30/100] [D_F loss: 0.683548] [D_R loss: 1.194999] [G loss: 0.020122]
VALID DATA
overall TPR : 0.685
priv TPR : 0.618
unpriv TPR : 0.697
Eq. Opp : 0.079

overall FPR : 0.246
priv FPR : 0.361
unpriv FPR : 0.063
diff FPR : 0.297

overall ACC : 0.737
priv ACC : 0.657
unpriv ACC : 0.901
diff ACC : 0.244



TEST DATA
overall TPR : 0.701
priv TPR : 0.645
unpriv TPR : 0.710
Eq. Opp : 0.066

overall FPR : 0.253
priv FPR : 0.362
unpriv FPR : 0.080
diff FPR : 0.282

overall ACC : 0.736
priv ACC : 0.661
unpriv ACC : 0.890
diff ACC : 0.229



DIMP : 0.300
[Lambda: 0.300000] [Epoch 40/100] [D_F loss: 0.660566] [D_R loss: 1.108666] [G loss: -0.075669]
VALID DATA
overall TPR : 0.699
priv TPR : 0.542
unpriv TPR : 0.726
Eq. Opp : 0.184


[Lambda: 0.300000] [Epoch 70/100] [D_F loss: 0.652026] [D_R loss: 1.008796] [G loss: -0.017855]
VALID DATA
overall TPR : 0.697
priv TPR : 0.646
unpriv TPR : 0.706
Eq. Opp : 0.060

overall FPR : 0.278
priv FPR : 0.388
unpriv FPR : 0.106
diff FPR : 0.282

overall ACC : 0.716
priv ACC : 0.642
unpriv ACC : 0.867
diff ACC : 0.225



TEST DATA
overall TPR : 0.701
priv TPR : 0.657
unpriv TPR : 0.709
Eq. Opp : 0.051

overall FPR : 0.280
priv FPR : 0.391
unpriv FPR : 0.104
diff FPR : 0.287

overall ACC : 0.715
priv ACC : 0.641
unpriv ACC : 0.868
diff ACC : 0.228



DIMP : 0.341
[Lambda: 0.300000] [Epoch 80/100] [D_F loss: 0.662038] [D_R loss: 0.784797] [G loss: -0.035144]
VALID DATA
overall TPR : 0.678
priv TPR : 0.586
unpriv TPR : 0.694
Eq. Opp : 0.108

overall FPR : 0.235
priv FPR : 0.347
unpriv FPR : 0.059
diff FPR : 0.288

overall ACC : 0.743
priv ACC : 0.666
unpriv ACC : 0.902
diff ACC : 0.236



TEST DATA
overall TPR : 0.684
priv TPR : 0.602
unpriv TPR : 0.698
Eq. Opp : 0.096

overall FPR

TEST DATA
overall TPR : 0.442
priv TPR : 0.340
unpriv TPR : 0.459
Eq. Opp : 0.120

overall FPR : 0.186
priv FPR : 0.266
unpriv FPR : 0.051
diff FPR : 0.215

overall ACC : 0.720
priv ACC : 0.648
unpriv ACC : 0.877
diff ACC : 0.229



DIMP : 0.261
[Lambda: 0.300000] [Epoch 20/100] [D_F loss: 0.669094] [D_R loss: 0.964165] [G loss: 0.135561]
VALID DATA
overall TPR : 0.585
priv TPR : 0.377
unpriv TPR : 0.622
Eq. Opp : 0.246

overall FPR : 0.214
priv FPR : 0.315
unpriv FPR : 0.054
diff FPR : 0.261

overall ACC : 0.736
priv ACC : 0.666
unpriv ACC : 0.881
diff ACC : 0.215



TEST DATA
overall TPR : 0.554
priv TPR : 0.439
unpriv TPR : 0.574
Eq. Opp : 0.135

overall FPR : 0.207
priv FPR : 0.300
unpriv FPR : 0.048
diff FPR : 0.252

overall ACC : 0.733
priv ACC : 0.660
unpriv ACC : 0.891
diff ACC : 0.231



DIMP : 0.244
[Lambda: 0.300000] [Epoch 30/100] [D_F loss: 0.678547] [D_R loss: 1.261220] [G loss: 0.001346]
VALID DATA
overall TPR : 0.681
priv TPR : 0.446
unpriv TPR : 0.723
Eq. Opp : 0.277



[Lambda: 0.300000] [Epoch 60/100] [D_F loss: 0.654119] [D_R loss: 0.997229] [G loss: -0.009756]
VALID DATA
overall TPR : 0.686
priv TPR : 0.698
unpriv TPR : 0.683
Eq. Opp : 0.015

overall FPR : 0.226
priv FPR : 0.322
unpriv FPR : 0.075
diff FPR : 0.247

overall ACC : 0.752
priv ACC : 0.680
unpriv ACC : 0.899
diff ACC : 0.219



TEST DATA
overall TPR : 0.696
priv TPR : 0.640
unpriv TPR : 0.705
Eq. Opp : 0.065

overall FPR : 0.232
priv FPR : 0.329
unpriv FPR : 0.078
diff FPR : 0.251

overall ACC : 0.750
priv ACC : 0.682
unpriv ACC : 0.891
diff ACC : 0.210



DIMP : 0.310
[Lambda: 0.300000] [Epoch 70/100] [D_F loss: 0.654219] [D_R loss: 1.039387] [G loss: 0.025359]
VALID DATA
overall TPR : 0.702
priv TPR : 0.583
unpriv TPR : 0.723
Eq. Opp : 0.139

overall FPR : 0.248
priv FPR : 0.377
unpriv FPR : 0.045
diff FPR : 0.332

overall ACC : 0.740
priv ACC : 0.655
unpriv ACC : 0.912
diff ACC : 0.258



TEST DATA
overall TPR : 0.699
priv TPR : 0.521
unpriv TPR : 0.729
Eq. Opp : 0.208

overall FPR 

TEST DATA
overall TPR : 0.740
priv TPR : 0.644
unpriv TPR : 0.758
Eq. Opp : 0.113

overall FPR : 0.338
priv FPR : 0.480
unpriv FPR : 0.112
diff FPR : 0.368

overall ACC : 0.681
priv ACC : 0.594
unpriv ACC : 0.859
diff ACC : 0.264



DIMP : 0.311
[Lambda: 0.300000] [Epoch 10/100] [D_F loss: 0.679601] [D_R loss: 0.880877] [G loss: 0.145924]
VALID DATA
overall TPR : 0.343
priv TPR : 0.073
unpriv TPR : 0.397
Eq. Opp : 0.324

overall FPR : 0.194
priv FPR : 0.312
unpriv FPR : 0.007
diff FPR : 0.306

overall ACC : 0.687
priv ACC : 0.594
unpriv ACC : 0.874
diff ACC : 0.280



TEST DATA
overall TPR : 0.403
priv TPR : 0.092
unpriv TPR : 0.459
Eq. Opp : 0.367

overall FPR : 0.203
priv FPR : 0.328
unpriv FPR : 0.006
diff FPR : 0.322

overall ACC : 0.700
priv ACC : 0.606
unpriv ACC : 0.893
diff ACC : 0.287



DIMP : 0.042
[Lambda: 0.300000] [Epoch 20/100] [D_F loss: 0.668534] [D_R loss: 0.925050] [G loss: 0.041468]
VALID DATA
overall TPR : 0.446
priv TPR : 0.337
unpriv TPR : 0.468
Eq. Opp : 0.132



[Lambda: 0.300000] [Epoch 50/100] [D_F loss: 0.674340] [D_R loss: 1.104668] [G loss: 0.016734]
VALID DATA
overall TPR : 0.555
priv TPR : 0.687
unpriv TPR : 0.532
Eq. Opp : 0.155

overall FPR : 0.267
priv FPR : 0.381
unpriv FPR : 0.074
diff FPR : 0.307

overall ACC : 0.689
priv ACC : 0.592
unpriv ACC : 0.898
diff ACC : 0.306



TEST DATA
overall TPR : 0.560
priv TPR : 0.665
unpriv TPR : 0.541
Eq. Opp : 0.125

overall FPR : 0.271
priv FPR : 0.394
unpriv FPR : 0.075
diff FPR : 0.319

overall ACC : 0.688
priv ACC : 0.587
unpriv ACC : 0.896
diff ACC : 0.310



DIMP : 0.321
[Lambda: 0.300000] [Epoch 60/100] [D_F loss: 0.670327] [D_R loss: 1.126671] [G loss: -0.077757]
VALID DATA
overall TPR : 0.655
priv TPR : 0.658
unpriv TPR : 0.654
Eq. Opp : 0.004

overall FPR : 0.284
priv FPR : 0.394
unpriv FPR : 0.099
diff FPR : 0.295

overall ACC : 0.701
priv ACC : 0.621
unpriv ACC : 0.873
diff ACC : 0.252



TEST DATA
overall TPR : 0.658
priv TPR : 0.637
unpriv TPR : 0.662
Eq. Opp : 0.024

overall FPR 

TEST DATA
overall TPR : 0.613
priv TPR : 0.630
unpriv TPR : 0.610
Eq. Opp : 0.020

overall FPR : 0.265
priv FPR : 0.382
unpriv FPR : 0.073
diff FPR : 0.309

overall ACC : 0.705
priv ACC : 0.616
unpriv ACC : 0.891
diff ACC : 0.275



DIMP : 0.310
[Lambda: 0.300000] [Epoch 100/100] [D_F loss: 0.639780] [D_R loss: 0.776854] [G loss: 0.000431]
VALID DATA
overall TPR : 0.629
priv TPR : 0.622
unpriv TPR : 0.631
Eq. Opp : 0.009

overall FPR : 0.309
priv FPR : 0.451
unpriv FPR : 0.086
diff FPR : 0.365

overall ACC : 0.676
priv ACC : 0.574
unpriv ACC : 0.880
diff ACC : 0.306



TEST DATA
overall TPR : 0.645
priv TPR : 0.679
unpriv TPR : 0.638
Eq. Opp : 0.041

overall FPR : 0.310
priv FPR : 0.442
unpriv FPR : 0.093
diff FPR : 0.349

overall ACC : 0.679
priv ACC : 0.583
unpriv ACC : 0.879
diff ACC : 0.297



DIMP : 0.326
[Lambda: 0.300000] [Epoch 10/100] [D_F loss: 0.679710] [D_R loss: 1.380780] [G loss: -0.048545]
VALID DATA
overall TPR : 0.427
priv TPR : 0.482
unpriv TPR : 0.418
Eq. Opp : 0.065

[Lambda: 0.300000] [Epoch 40/100] [D_F loss: 0.660264] [D_R loss: 0.801524] [G loss: 0.064935]
VALID DATA
overall TPR : 0.573
priv TPR : 0.605
unpriv TPR : 0.567
Eq. Opp : 0.038

overall FPR : 0.349
priv FPR : 0.514
unpriv FPR : 0.081
diff FPR : 0.432

overall ACC : 0.631
priv ACC : 0.511
unpriv ACC : 0.881
diff ACC : 0.369



TEST DATA
overall TPR : 0.577
priv TPR : 0.628
unpriv TPR : 0.568
Eq. Opp : 0.060

overall FPR : 0.344
priv FPR : 0.508
unpriv FPR : 0.083
diff FPR : 0.424

overall ACC : 0.637
priv ACC : 0.516
unpriv ACC : 0.884
diff ACC : 0.368



DIMP : 0.275
[Lambda: 0.300000] [Epoch 50/100] [D_F loss: 0.650732] [D_R loss: 1.093930] [G loss: 0.010090]
VALID DATA
overall TPR : 0.607
priv TPR : 0.650
unpriv TPR : 0.599
Eq. Opp : 0.051

overall FPR : 0.283
priv FPR : 0.409
unpriv FPR : 0.078
diff FPR : 0.332

overall ACC : 0.689
priv ACC : 0.593
unpriv ACC : 0.889
diff ACC : 0.296



TEST DATA
overall TPR : 0.587
priv TPR : 0.672
unpriv TPR : 0.573
Eq. Opp : 0.099

overall FPR :

TEST DATA
overall TPR : 0.691
priv TPR : 0.623
unpriv TPR : 0.704
Eq. Opp : 0.081

overall FPR : 0.386
priv FPR : 0.572
unpriv FPR : 0.065
diff FPR : 0.507

overall ACC : 0.634
priv ACC : 0.512
unpriv ACC : 0.898
diff ACC : 0.386



DIMP : 0.214
[Lambda: 0.300000] [Epoch 90/100] [D_F loss: 0.644574] [D_R loss: 0.815761] [G loss: -0.006734]
VALID DATA
overall TPR : 0.672
priv TPR : 0.658
unpriv TPR : 0.674
Eq. Opp : 0.016

overall FPR : 0.364
priv FPR : 0.556
unpriv FPR : 0.064
diff FPR : 0.492

overall ACC : 0.645
priv ACC : 0.516
unpriv ACC : 0.907
diff ACC : 0.391



TEST DATA
overall TPR : 0.718
priv TPR : 0.643
unpriv TPR : 0.731
Eq. Opp : 0.088

overall FPR : 0.402
priv FPR : 0.591
unpriv FPR : 0.077
diff FPR : 0.514

overall ACC : 0.628
priv ACC : 0.508
unpriv ACC : 0.890
diff ACC : 0.382



DIMP : 0.227
[Lambda: 0.300000] [Epoch 100/100] [D_F loss: 0.650681] [D_R loss: 0.722989] [G loss: -0.050138]
VALID DATA
overall TPR : 0.642
priv TPR : 0.608
unpriv TPR : 0.648
Eq. Opp : 0.04

[Lambda: 0.300000] [Epoch 30/100] [D_F loss: 0.677369] [D_R loss: 1.327799] [G loss: 0.163861]
VALID DATA
overall TPR : 0.567
priv TPR : 0.468
unpriv TPR : 0.584
Eq. Opp : 0.116

overall FPR : 0.338
priv FPR : 0.512
unpriv FPR : 0.055
diff FPR : 0.457

overall ACC : 0.638
priv ACC : 0.518
unpriv ACC : 0.893
diff ACC : 0.375



TEST DATA
overall TPR : 0.629
priv TPR : 0.476
unpriv TPR : 0.654
Eq. Opp : 0.179

overall FPR : 0.358
priv FPR : 0.532
unpriv FPR : 0.065
diff FPR : 0.466

overall ACC : 0.639
priv ACC : 0.525
unpriv ACC : 0.886
diff ACC : 0.362



DIMP : 0.191
[Lambda: 0.300000] [Epoch 40/100] [D_F loss: 0.667059] [D_R loss: 1.199313] [G loss: 0.040598]
VALID DATA
overall TPR : 0.650
priv TPR : 0.560
unpriv TPR : 0.665
Eq. Opp : 0.106

overall FPR : 0.378
priv FPR : 0.574
unpriv FPR : 0.058
diff FPR : 0.516

overall ACC : 0.629
priv ACC : 0.500
unpriv ACC : 0.900
diff ACC : 0.400



TEST DATA
overall TPR : 0.700
priv TPR : 0.546
unpriv TPR : 0.725
Eq. Opp : 0.179

overall FPR :

TEST DATA
overall TPR : 0.573
priv TPR : 0.570
unpriv TPR : 0.574
Eq. Opp : 0.004

overall FPR : 0.377
priv FPR : 0.572
unpriv FPR : 0.049
diff FPR : 0.523

overall ACC : 0.611
priv ACC : 0.474
unpriv ACC : 0.906
diff ACC : 0.432



DIMP : 0.194
[Lambda: 0.300000] [Epoch 80/100] [D_F loss: 0.647027] [D_R loss: 1.054087] [G loss: -0.006262]
VALID DATA
overall TPR : 0.629
priv TPR : 0.686
unpriv TPR : 0.620
Eq. Opp : 0.067

overall FPR : 0.353
priv FPR : 0.547
unpriv FPR : 0.050
diff FPR : 0.497

overall ACC : 0.642
priv ACC : 0.506
unpriv ACC : 0.921
diff ACC : 0.415



TEST DATA
overall TPR : 0.626
priv TPR : 0.609
unpriv TPR : 0.629
Eq. Opp : 0.019

overall FPR : 0.363
priv FPR : 0.555
unpriv FPR : 0.041
diff FPR : 0.514

overall ACC : 0.634
priv ACC : 0.503
unpriv ACC : 0.917
diff ACC : 0.415



DIMP : 0.188
[Lambda: 0.300000] [Epoch 90/100] [D_F loss: 0.664828] [D_R loss: 0.985620] [G loss: 0.019174]
VALID DATA
overall TPR : 0.643
priv TPR : 0.642
unpriv TPR : 0.644
Eq. Opp : 0.001


In [45]:
lambda_r

0.1

In [44]:
lambda_f_set

[0.3]

In [43]:
train_opt

Namespace(k=1, lr_f=1e-05, lr_g=0.005, lr_r=1e-05, n_epochs=100, val=6783)

In [51]:
results_dict

{0.1: {'EqOdds': [tensor(0.1152, device='cuda:0'),
   tensor(0.2263, device='cuda:0'),
   tensor(0.2571, device='cuda:0'),
   tensor(0.1932, device='cuda:0'),
   tensor(0.2559, device='cuda:0')],
  'Acc': [tensor(0.8044, device='cuda:0'),
   tensor(0.7897, device='cuda:0'),
   tensor(0.8047, device='cuda:0'),
   tensor(0.7879, device='cuda:0'),
   tensor(0.8045, device='cuda:0')],
  'DISP': [],
  'EqOpp': [tensor(0.0251, device='cuda:0'),
   tensor(0.0721, device='cuda:0'),
   tensor(0.1209, device='cuda:0'),
   tensor(0.0549, device='cuda:0'),
   tensor(0.1198, device='cuda:0')],
  'Acc_diff': [tensor(0.1301, device='cuda:0'),
   tensor(0.1388, device='cuda:0'),
   tensor(0.1418, device='cuda:0'),
   tensor(0.1355, device='cuda:0'),
   tensor(0.1319, device='cuda:0')],
  'total_tpr': [tensor(0.6086, device='cuda:0'),
   tensor(0.6979, device='cuda:0'),
   tensor(0.6140, device='cuda:0'),
   tensor(0.6653, device='cuda:0'),
   tensor(0.6579, device='cuda:0')],
  'total_fpr': [tensor(0.

In [100]:
#German setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
    data_valid, data_test = data_vt.split([0.5], shuffle=True)

    idx = np.where(data_train.features[:, sens_idx] == 1)[0]
    poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

    #label poisoning
    data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

    X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
    y_train = torch.FloatTensor(data_train.labels)
    s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

    X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
    y_val = torch.FloatTensor(data_valid.labels)
    s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

    X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
    y_test = torch.FloatTensor(data_test.labels)
    s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

    XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
    XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
    XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

    train_result = []
    train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
    val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
    test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

    train_opt = Namespace(val=len(y_val), n_epochs=100, k=1, lr_g=5e-4, lr_f=1e-5, lr_r=1e-5)
    seed = 1

    lambda_f_set = [0.1] # Lambda value for the fairness discriminator of FR-Train.
    lambda_r = 0.2 # Lambda value for the robustness discriminator of FR-Train.

    results_dict[poi_ratio] = {}
    
    data_train = Dataset(X_train, s1_train, y_train)
    data_valid = Dataset(X_val, s1_val, y_val)
    data_test = Dataset(X_test, s1_test, y_test)

    trainloader = torch.utils.data.DataLoader(
        data_train,
        batch_size=128,
        shuffle=True,
        )
    validloader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )
    testloader = torch.utils.data.DataLoader(
        data_test,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )


    for lambda_f in lambda_f_set:
        train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)
    


[Lambda: 0.100000] [Epoch 10/100] [D_F loss: 35.000000] [D_R loss: 56.391727] [G loss: 13.647821]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.648
priv ACC : 0.718
unpriv ACC : 0.512
diff ACC : 0.206



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.719
priv ACC : 0.765
unpriv ACC : 0.628
diff ACC : 0.137



DIMP : nan
[Lambda: 0.100000] [Epoch 20/100] [D_F loss: 21.882446] [D_R loss: 55.871990] [G loss: 12.997374]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.641
priv ACC : 0.701
unpriv ACC : 0.512
diff ACC : 0.189



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall F

[Lambda: 0.100000] [Epoch 60/100] [D_F loss: 4.862841] [D_R loss: 4.169364] [G loss: 1.947477]
VALID DATA
overall TPR : 0.475
priv TPR : 0.562
unpriv TPR : 0.417
Eq. Opp : 0.146

overall FPR : 0.261
priv FPR : 0.277
unpriv FPR : 0.217
diff FPR : 0.060

overall ACC : 0.656
priv ACC : 0.640
unpriv ACC : 0.692
diff ACC : 0.052



TEST DATA
overall TPR : 0.450
priv TPR : 0.429
unpriv TPR : 0.462
Eq. Opp : 0.033

overall FPR : 0.205
priv FPR : 0.212
unpriv FPR : 0.182
diff FPR : 0.030

overall ACC : 0.688
priv ACC : 0.696
unpriv ACC : 0.667
diff ACC : 0.029



DIMP : 0.983
[Lambda: 0.100000] [Epoch 70/100] [D_F loss: 3.324680] [D_R loss: 3.711500] [G loss: 0.429208]
VALID DATA
overall TPR : 0.513
priv TPR : 0.562
unpriv TPR : 0.478
Eq. Opp : 0.084

overall FPR : 0.191
priv FPR : 0.206
unpriv FPR : 0.143
diff FPR : 0.063

overall ACC : 0.719
priv ACC : 0.714
unpriv ACC : 0.730
diff ACC : 0.015



TEST DATA
overall TPR : 0.429
priv TPR : 0.462
unpriv TPR : 0.409
Eq. Opp : 0.052

overall FPR :

[Lambda: 0.100000] [Epoch 10/100] [D_F loss: 30.000002] [D_R loss: 56.355026] [G loss: 22.871462]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.695
priv ACC : 0.717
unpriv ACC : 0.639
diff ACC : 0.079



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.664
priv ACC : 0.681
unpriv ACC : 0.618
diff ACC : 0.063



DIMP : nan
[Lambda: 0.100000] [Epoch 20/100] [D_F loss: 28.959572] [D_R loss: 55.172455] [G loss: 7.361261]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.664
priv ACC : 0.703
unpriv ACC : 0.568
diff ACC : 0.136



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FP

[Lambda: 0.100000] [Epoch 60/100] [D_F loss: 1.207181] [D_R loss: 4.210495] [G loss: 2.987177]
VALID DATA
overall TPR : 0.333
priv TPR : 0.417
unpriv TPR : 0.296
Eq. Opp : 0.120

overall FPR : 0.404
priv FPR : 0.437
unpriv FPR : 0.278
diff FPR : 0.159

overall ACC : 0.516
priv ACC : 0.490
unpriv ACC : 0.600
diff ACC : 0.110



TEST DATA
overall TPR : 0.256
priv TPR : 0.312
unpriv TPR : 0.222
Eq. Opp : 0.090

overall FPR : 0.306
priv FPR : 0.393
unpriv FPR : 0.138
diff FPR : 0.255

overall ACC : 0.547
priv ACC : 0.482
unpriv ACC : 0.667
diff ACC : 0.185



DIMP : 0.593
[Lambda: 0.100000] [Epoch 70/100] [D_F loss: 1.535565] [D_R loss: 3.740830] [G loss: 4.173036]
VALID DATA
overall TPR : 0.366
priv TPR : 0.467
unpriv TPR : 0.308
Eq. Opp : 0.159

overall FPR : 0.379
priv FPR : 0.455
unpriv FPR : 0.143
diff FPR : 0.312

overall ACC : 0.539
priv ACC : 0.478
unpriv ACC : 0.694
diff ACC : 0.216



TEST DATA
overall TPR : 0.275
priv TPR : 0.294
unpriv TPR : 0.261
Eq. Opp : 0.033

overall FPR :

In [101]:
results_dict

{0.1: {'EqOdds': tensor(0.1168, device='cuda:0'),
  'Acc': tensor(0.7578, device='cuda:0'),
  'DISP': tensor(0.9191, device='cuda:0', grad_fn=<DivBackward0>),
  'EqOpp': tensor(0.0471, device='cuda:0'),
  'Acc_diff': tensor(0.0544, device='cuda:0'),
  'total_tpr': tensor(0.3750, device='cuda:0'),
  'total_fpr': tensor(0.1146, device='cuda:0'),
  'total_tpr_priv': tensor(0.3529, device='cuda:0'),
  'total_tpr_unpriv': tensor(0.4000, device='cuda:0'),
  'total_fpr_priv': tensor(0.1364, device='cuda:0'),
  'total_fpr_unpriv': tensor(0.0667, device='cuda:0')},
 0.2: {'EqOdds': tensor(0.0835, device='cuda:0'),
  'Acc': tensor(0.6641, device='cuda:0'),
  'DISP': tensor(0.8909, device='cuda:0', grad_fn=<DivBackward0>),
  'EqOpp': tensor(0.0217, device='cuda:0'),
  'Acc_diff': tensor(0.0271, device='cuda:0'),
  'total_tpr': tensor(0.5135, device='cuda:0'),
  'total_fpr': tensor(0.2747, device='cuda:0'),
  'total_tpr_priv': tensor(0.5000, device='cuda:0'),
  'total_tpr_unpriv': tensor(0.5217, d

In [46]:
data_name

'adult'

In [52]:
with open(f'results/{data_name}/robust_fr_dict-3.pkl', 'wb') as f:
    pickle.dump(results_dict, f)

In [18]:
def train_model_adult(results_dict, train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    # Initializes generator and discriminator
    Tensor = torch.FloatTensor

    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train

    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val

    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test

    # Saves return values here
    test_result = [] 

    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator(input_size, latent_size).cuda()
    discriminator_F = DiscriminatorF().cuda()
    discriminator_R = DiscriminatorR(input_size, latent_size).cuda()

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    for epoch in range(1, n_epochs + 1):
        cnt = 0
    #     for x, s, y, x_val, s_val, y_val in trainloader:
        generator.train()
        for x, s, y in trainloader:
            cnt += 1
            x, s, y = x.cuda(), s.cuda(), y.cuda()
            x_val, s_val, y_val = iter(validloader).next()
            x_val, s_val, y_val = x_val.cuda(), s_val.cuda(), y_val.cuda()
            # -------------------
            #  Forwards Generator
            # -------------------
            if cnt % k == 0 or epoch < 10:
                optimizer_G.zero_grad()

            gen_y = generator(x)
            gen_data = torch.cat([x, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


            # -------------------------------
            #  Trains Fairness Discriminator
            # -------------------------------

            optimizer_D_F.zero_grad()

            # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
            d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s)
            d_f_loss.backward()
            d_f_losses.append(d_f_loss)
            optimizer_D_F.step()

            # ---------------------------------
            #  Trains Robustness Discriminator
            # ---------------------------------
            optimizer_D_R.zero_grad()

            XSY_val_data = torch.cat([x_val, y_val.view(-1,1)], dim = 1).cuda()
            fake = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
            clean = Variable(Tensor(len(x_val), 1).fill_(1.0), requires_grad=False).cuda()

            # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
            clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
            poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
            d_r_loss = 0.5 * (clean_loss + poison_loss)

            d_r_loss.backward()
            d_r_losses.append(d_r_loss)
            optimizer_D_R.step()

            # ---------------------
            #  Updates Generator
            # ---------------------
            if epoch < 10 :
                g_loss = bce_loss((F.tanh(gen_y)+1)/2, y)
                g_loss.backward()
                g_losses.append(g_loss)
                optimizer_G.step()
            elif cnt % k == 0:
                r_decision = discriminator_R(gen_data)
                generated = Variable(Tensor(len(gen_data), 1).fill_(0.0), requires_grad=False).cuda()
                
                r_gen = bce_loss(r_decision, generated)

                # ---------------------------------
                #  Re-weights using output of D_R
                # ---------------------------------
#                 if epoch % 10 == 0:
                
                if epoch < 10:
                    r_weight = torch.ones_like(y, requires_grad=False).float().cuda()
                else:
                    r_ones = torch.ones_like(y, requires_grad=False).float().cuda()
                    loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
                    a = 1/(1+torch.exp(-(loss_ratio-3)))
                    b = 1-a
                    r_weight_tmp = r_decision.detach().squeeze()
                    r_weight = a * r_weight_tmp + b * r_ones

                f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s, reduction="none").squeeze()
                g_cost = F.binary_cross_entropy_with_logits(gen_y, y, reduction="none").squeeze()

                f_gen = torch.mean(f_cost*r_weight)
                g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

                g_loss.backward()
                optimizer_G.step()

            g_losses.append(g_loss)

        if epoch % 10 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
                )

            tp_priv, tn_priv, fp_priv, fn_priv, \
            tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

            for x_val, a_val, y_val in validloader:
                x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

                priv_idx = (a_val==1).squeeze()
                positive_idx = y_val==1

                pred_test = generator(x_val).view(-1)

                pred_test[pred_test>0] = 1
                pred_test[pred_test<=0] = 0

                test_lb_priv = y_val[priv_idx].view(-1)
                test_lb_unpriv = y_val[~priv_idx].view(-1)

                pred_priv = pred_test[priv_idx]
                pred_unpriv = pred_test[~priv_idx]

                y_val = y_val.cpu().detach().numpy()
                test_lb_priv = test_lb_priv.cpu().detach().numpy()
                test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

                try:
                    pred_priv = pred_priv.argmax(1)
                except:
                    pass
                try:
                    pred_unpriv = pred_unpriv.argmax(1)
                except:
                    pass


                tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
                fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
                tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
                fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

                tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
                fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
                tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
                fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

            tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
            tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
            tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

            fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
            fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
            fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

            acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                      fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
            acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
            acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


            print("VALID DATA")
            print('overall TPR : {0:.3f}'.format( tpr_overall))
            print('priv TPR : {0:.3f}'.format( tpr_priv))
            print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
            print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
            print()
            print('overall FPR : {0:.3f}'.format( fpr_overall))
            print('priv FPR : {0:.3f}'.format( fpr_priv))
            print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
            print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
            print()
            print('overall ACC : {0:.3f}'.format( acc_overall))
            print('priv ACC : {0:.3f}'.format( acc_priv))
            print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
            print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))
            
            tp_priv, tn_priv, fp_priv, fn_priv, \
            tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0
            
            y_hat1_priv, y_hat1_unpriv, priv, unpriv = [], [],[], []

            
            generator.eval()
            for x_val, a_val, y_val in testloader:
                x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

                priv_idx = (a_val==1).squeeze()
                positive_idx = y_val==1

                pred_test = generator(x_val).view(-1)

                pred_test[pred_test>0] = 1
                pred_test[pred_test<=0] = 0

                test_lb_priv = y_val[priv_idx].view(-1)
                test_lb_unpriv = y_val[~priv_idx].view(-1)

                pred_priv = pred_test[priv_idx]
                pred_unpriv = pred_test[~priv_idx]

                y_val = y_val.cpu().detach().numpy()
                test_lb_priv = test_lb_priv.cpu().detach().numpy()
                test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

                try:
                    pred_priv = pred_priv.argmax(1)
                except:
                    pass
                try:
                    pred_unpriv = pred_unpriv.argmax(1)
                except:
                    pass


                tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
                fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
                tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
                fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

                tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
                fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
                tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
                fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)
                
                y_hat1_priv.append(pred_priv)
                y_hat1_unpriv.append(pred_unpriv)
                priv.append(priv_idx)
                unpriv.append(~priv_idx)

            tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
            tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
            tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

            fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
            fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
            fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

            acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                      fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
            acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
            acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()
            
            y_hat1_priv = sum(torch.cat(y_hat1_priv)).float()/sum(torch.cat(priv))
            y_hat1_unpriv = sum(torch.cat(y_hat1_unpriv)).float()/sum(torch.cat(unpriv))
            
            dimp = min(y_hat1_priv/y_hat1_unpriv, y_hat1_unpriv/y_hat1_priv)

            print("TEST DATA")
            print('overall TPR : {0:.3f}'.format( tpr_overall))
            print('priv TPR : {0:.3f}'.format( tpr_priv))
            print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
            print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
            print()
            print('overall FPR : {0:.3f}'.format( fpr_overall))
            print('priv FPR : {0:.3f}'.format( fpr_priv))
            print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
            print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
            print()
            print('overall ACC : {0:.3f}'.format( acc_overall))
            print('priv ACC : {0:.3f}'.format( acc_priv))
            print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
            print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))
            
            print('DIMP : {0:.3f}'.format( dimp)) 
            
        

    results_dict['EqOdds'].append(abs(fpr_unpriv-fpr_priv) + abs(tpr_unpriv - tpr_priv))
    results_dict['Acc'].append( acc_overall)
    results_dict['EqOpp'].append( abs(tpr_unpriv - tpr_priv))
    results_dict['Acc_diff'].append( abs(acc_unpriv-acc_priv))

    results_dict['total_tpr'].append(tpr_overall)
    results_dict['total_fpr'].append(fpr_overall)

    results_dict['total_tpr_priv'].append(tpr_priv)
    results_dict['total_tpr_unpriv'].append(tpr_unpriv)
    results_dict['total_fpr_priv'].append(fpr_priv)
    results_dict['total_fpr_unpriv'].append(fpr_unpriv)


        #     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')

#     tmp = test_model(generator, XS_test, y_test, s1_test)
#     test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])
    
#     return test_result

In [77]:
Tensor = torch.FloatTensor

XS_train = train_tensors.XS_train
y_train = train_tensors.y_train
s1_train = train_tensors.s1_train

XS_val = val_tensors.XS_val
y_val = val_tensors.y_val
s1_val = val_tensors.s1_val

XS_test = test_tensors.XS_test
y_test = test_tensors.y_test
s1_test = test_tensors.s1_test

# Saves return values here
test_result = [] 

val = train_opt.val # Number of data points in validation set
k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
n_epochs = train_opt.n_epochs  # Number of training epoch

# Changes the input validation data to an appropriate shape for the training
XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

# The loss values of each component will be saved in the following lists. 
# We can draw epoch-loss graph by the following lists, if necessary.
g_losses =[]
d_f_losses = []
d_r_losses = []
clean_test_result = []

bce_loss = torch.nn.BCELoss()

# Initializes generator and discriminator
generator = Generator(input_size, latent_size).cuda()
discriminator_F = DiscriminatorF().cuda()
discriminator_R = DiscriminatorR(input_size, latent_size).cuda()

# Initializes weights
torch.manual_seed(seed)
generator.apply(weights_init_normal)
discriminator_F.apply(weights_init_normal)
discriminator_R.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

In [78]:

# valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
# generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
# fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
# clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)


# r_weight = torch.ones_like(y_train, requires_grad=False).float()
# r_ones = torch.ones_like(y_train, requires_grad=False).float()

for epoch in range(1, n_epochs):
#     for x, s, y, x_val, s_val, y_val in trainloader:
    for x, s, y in trainloader:
        x, s, y = x.cuda(), s.cuda(), y.cuda()
        x_val, s_val, y_val = iter(validloader).next()
        x_val, s_val, y_val = x_val.cuda(), s_val.cuda(), y_val.cuda()
        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 10:
            optimizer_G.zero_grad()

        gen_y = generator(x)
        gen_data = torch.cat([x, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()

        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s)
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()

        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------
        optimizer_D_R.zero_grad()

        XSY_val_data = torch.cat([x_val, y_val.view(-1,1)], dim = 1).cuda()
        fake = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
        clean = Variable(Tensor(len(x_val), 1).fill_(1.0), requires_grad=False).cuda()
        
        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()

        # ---------------------
        #  Updates Generator
        # ---------------------
        if epoch < 10 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, y)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            generated = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
            
            r_decision = discriminator_R(gen_data)
            r_gen = bce_loss(r_decision, generated)

            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
#             if epoch % 10 == 0:
            r_ones = torch.ones_like(y, requires_grad=False).float().cuda()

            loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
            a = 1/(1+torch.exp(-(loss_ratio-3)))
            b = 1-a
            r_weight_tmp = r_decision.detach().squeeze()
            r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s, reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y.squeeze()+1)/2, reduction="none").squeeze()

            f_gen = torch.mean(f_cost*r_weight)
            g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

    if epoch % 5 == 0:
        print(
                "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
            )

        tp_priv, tn_priv, fp_priv, fn_priv, \
        tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

        for x_val, a_val, y_val in validloader:
            x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

            priv_idx = (a_val==1).squeeze()
            positive_idx = y_val==1

            pred_test = generator(x_val).view(-1)
            
            pred_test[pred_test>0] = 1
            pred_test[pred_test<=0] = 0

            test_lb_priv = y_val[priv_idx].view(-1)
            test_lb_unpriv = y_val[~priv_idx].view(-1)

            pred_priv = pred_test[priv_idx]
            pred_unpriv = pred_test[~priv_idx]

            y_val = y_val.cpu().detach().numpy()
            test_lb_priv = test_lb_priv.cpu().detach().numpy()
            test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

            try:
                pred_priv = pred_priv.argmax(1)
            except:
                pass
            try:
                pred_unpriv = pred_unpriv.argmax(1)
            except:
                pass


            tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
            fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
            tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
            fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

            tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
            fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
            tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
            fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

        tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
        tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
        tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

        fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
        fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
        fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

        acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                  fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
        acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
        acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


        print()
        print('overall TPR : {0:.3f}'.format( tpr_overall))
        print('priv TPR : {0:.3f}'.format( tpr_priv))
        print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
        print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
        print()
        print('overall FPR : {0:.3f}'.format( fpr_overall))
        print('priv FPR : {0:.3f}'.format( fpr_priv))
        print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
        print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
        print()
        print('overall ACC : {0:.3f}'.format( acc_overall))
        print('priv ACC : {0:.3f}'.format( acc_priv))
        print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
        print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))

    #     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')

    
tmp = test_model(generator, XS_test, y_test, s1_test)
test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])


[Lambda: 0.150000] [Epoch 5/50] [D_F loss: 0.522253] [D_R loss: 50.000000] [G loss: 3.333333]

overall TPR : 0.004
priv TPR : 0.000
unpriv TPR : 0.004
Eq. Opp : 0.004

overall FPR : 0.004
priv FPR : 0.006
unpriv FPR : 0.002
diff FPR : 0.004

overall ACC : 0.748
priv ACC : 0.684
unpriv ACC : 0.882
diff ACC : 0.198



[Lambda: 0.150000] [Epoch 10/50] [D_F loss: 0.877271] [D_R loss: 48.840309] [G loss: -30.585560]

overall TPR : 0.005
priv TPR : 0.004
unpriv TPR : 0.006
Eq. Opp : 0.002

overall FPR : 0.009
priv FPR : 0.010
unpriv FPR : 0.007
diff FPR : 0.003

overall ACC : 0.748
priv ACC : 0.684
unpriv ACC : 0.880
diff ACC : 0.196





KeyboardInterrupt: 

In [49]:
class Generator(nn.Module):
    """FR-Train generator (classifier).
    
    This class is for defining the structure of FR-Train generator (classifier). 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self, input_size, latent_size):
        """Initializes Generator with torch components."""
        
        super(Generator, self).__init__()

    
        def block(in_feat, out_feat, normalize=True):
            """Defines a block with torch components.
            
                Args:
                    in_feat: An integer value for the size of the input feature.
                    out_feat: An integer value for the size of the output feature.
                    normalize: A boolean indicating whether normalization is needed.
                    
                Returns:
                    The stacked layer.
            """
            
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(input_size, latent_size, normalize=True),
            *block(latent_size, latent_size, normalize=True),
            *block(latent_size, latent_size, normalize=True),
#             nn.Linear(32, 1),
#             nn.Tanh()
            nn.Linear(latent_size,1),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
            
        Returns:
            The predicted label (y_hat) for the given input data.
        """
        
        output = self.model(input_data)
        return output


class DiscriminatorF(nn.Module):
    """FR-Train fairness discriminator.
    
    This class is for defining structure of FR-Train fairness discriminator. 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self):
        """Initializes DiscriminatorF with torch components."""
        
        super(DiscriminatorF, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1,1),
            nn.Sigmoid(),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
            
        Returns:
            The predicted sensitive attribute for the given input data.
        """
        
        predicted_z = self.model(input_data)
        return predicted_z
    

class DiscriminatorR(nn.Module):
    """FR-Train robustness discriminator.
    
    This class is for defining the structure of FR-Train robustness discriminator. 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self, input_size, latent_size):
        """Initializes DiscriminatorR with torch components."""
        
        super(DiscriminatorR, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size + 1, latent_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(latent_size, latent_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(latent_size, 1),
            nn.Sigmoid(),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
        
        Returns:
            The predicted indicator (whether the input data is clean or poisoned) 
            for the given input data.
        """
        
        validity = self.model(input_data)
        return validity