In [63]:
%load_ext autoreload
%autoreload 2

import sys, os
import numpy as np
import math

from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.utils.data as data

import math
import matplotlib.pyplot as plt

from argparse import Namespace

from fr_train import Generator, DiscriminatorF, DiscriminatorR, weights_init_normal, test_model
import pickle
import warnings
warnings.filterwarnings("ignore")

from aif360.datasets import AdultDataset, GermanDataset, BankDataset, CompasDataset, BinaryLabelDataset, CelebADataset
from aif360.metrics import ClassificationMetric
import pandas

from sklearn.preprocessing import scale, StandardScaler, MaxAbsScaler

min_max_scaler = MaxAbsScaler()
std_scaler = StandardScaler()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [64]:
def train_model(train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    
    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train
    
    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val
    
    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test
    
    # Saves return values here
    test_result = [] 
    
    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch
    
    # Changes the input validation data to an appropriate shape for the training
    XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator(input_size, latent_size)
    discriminator_F = DiscriminatorF()
    discriminator_R = DiscriminatorR(input_size, latent_size)

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    XSY_val_data = XSY_val[:val]

    train_len = XS_train.shape[0]
    val_len = XSY_val.shape[0]

    # Ground truths using in Disriminator_R
    Tensor = torch.FloatTensor
    valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
    generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
    clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)
    

    r_weight = torch.ones_like(y_train, requires_grad=False).float()
    r_ones = torch.ones_like(y_train, requires_grad=False).float()

    for epoch in range(n_epochs):

        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 500:
            optimizer_G.zero_grad()

        gen_y = generator(XS_train)
        gen_data = torch.cat([XS_train, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()
        
        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s1_train)
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()
            
            
        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------

        optimizer_D_R.zero_grad()

        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()

        
        # ---------------------
        #  Updates Generator
        # ---------------------


        if epoch < 500 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, (y_train+1)/2)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            r_decision = discriminator_R(gen_data)
            r_gen = bce_loss(r_decision, generated)
            
            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
            if epoch % 100 == 0:
                loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
                a = 1/(1+torch.exp(-(loss_ratio-3)))
                b = 1-a
                r_weight_tmp = r_decision.detach().squeeze()
                r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s1_train, reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y_train.squeeze()+1)/2, reduction="none").squeeze()

            f_gen = torch.mean(f_cost*r_weight)
            g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

        if epoch % 200 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
                )

#     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')
    tmp = test_model(generator, XS_test, y_test, s1_test)
    test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])
    
    
    
    # TEST
    tp_priv, tn_priv, fp_priv, fn_priv, \
    tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

    gen_y = generator(XS_train)
    
    priv_idx = (s1_test==1).squeeze()
    positive_idx = y_val==1

    latent_val = (H(x_val))
    pred_test = W(latent_val)

    h_priv = (H(x_val[priv_idx]))
    h_unpriv = (H(x_val[~priv_idx]))

    h_positive = (H(x_val[positive_idx]))

    test_lb_priv = y_val[priv_idx]
    test_lb_unpriv = y_val[~priv_idx]

    pred_priv = W(h_priv)
    pred_unpriv = W(h_unpriv)

    y_val = y_val.cpu().detach().numpy()
    test_lb_priv = test_lb_priv.cpu().detach().numpy()
    test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

    try:
        pred_priv = pred_priv.argmax(1)
    except:
        pass
    try:
        pred_unpriv = pred_unpriv.argmax(1)
    except:
        pass


    tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
    fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
    tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
    fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

    tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
    fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
    tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
    fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

    tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
    tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
    tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

    fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
    fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
    fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

    acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                              fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
    acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
    acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


    print()
    print('overall TPR : {0:.3f}'.format( tpr_overall))
    print('priv TPR : {0:.3f}'.format( tpr_priv))
    print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
    print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
    print()
    print('overall FPR : {0:.3f}'.format( fpr_overall))
    print('priv FPR : {0:.3f}'.format( fpr_priv))
    print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
    print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
    print()
    print('overall ACC : {0:.3f}'.format( acc_overall))
    print('priv ACC : {0:.3f}'.format( acc_priv))
    print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
    print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))

    valid_pred = data_valid.copy(deepcopy=True)
    feature_size = valid_pred.features.shape[1]
    sens_loc = np.zeros(feature_size).astype(bool)
    sens_loc[sens_idx] = 1

    feature = valid_pred.features[:,~sens_loc] #data without sensitive
    feature = min_max_scaler.fit_transform(feature)

    valid_pred.labels = W((H(torch.tensor(feature).to(device)))).argmax(-1).cpu().numpy().reshape(-1,1)

    classified_metric = ClassificationMetric(data_valid,
                                                     valid_pred,
                                                     unprivileged_groups=unprivileged_groups,
                                                     privileged_groups=privileged_groups)


    print('balanced acc :' ,1/2*(classified_metric.true_positive_rate() + classified_metric.true_negative_rate()))
    print('disparate_impact :' ,classified_metric.disparate_impact())
    print('theil_index :' ,classified_metric.theil_index())
    print('statistical_parity_difference :' ,classified_metric.statistical_parity_difference())
    print('generalized_entropy_index : ', classified_metric.generalized_entropy_index())

    return test_result

In [65]:
np.random.seed(0)
#dataset = {'adult' : AdultDataset(), 'german' : GermanDataset(),'bank': BankDataset(),'compas' : CompasDataset(),'celeb': CelebADataset()}
dataset = {'adult' : AdultDataset(), 'german' : GermanDataset(),'bank': BankDataset(),'compas' : CompasDataset()}
dataset['german'].labels -= 1
sens_attr_dict = {'adult': ['sex', 'race'], 'german' : ['sex', 'age'], 'compas' : ['sex', 'race'], 'bank' : ['age'], 'celeb' : ['gender']}




In [66]:
data_name = 'german'
protected_attribute_used = 1

In [67]:
if data_name == "adult":
    #dataset_orig = AdultDataset()
    # #dataset_orig = load_preproc_data_adult()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]
        sens_attr = 'race'

elif data_name == "german":
    #dataset_orig = GermanDataset()
    # #dataset_orig.labels = #dataset_orig.labels-1
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'

elif data_name == "compas":
    #dataset_orig = CompasDataset()
    # #dataset_orig = load_preproc_data_compas()
    if protected_attribute_used == 1:
        privileged_groups = [{'sex': 1}]
        unprivileged_groups = [{'sex': 0}]
        sens_attr = 'sex'
    else:
        privileged_groups = [{'race': 1}]
        unprivileged_groups = [{'race': 0}]
        sens_attr = 'race'

elif data_name == "bank":
    #dataset_orig = BankDataset()
    if protected_attribute_used == 1:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'
    else:
        privileged_groups = [{'age': 1}]
        unprivileged_groups = [{'age': 0}]
        sens_attr = 'age'
        
elif data_name == "meps":
    #dataset_orig = MEPSDataset19()
    privileged_groups = [{'RACE': 1}]
    unprivileged_groups = [{'RACE': 0}]
    sens_attr = 'RACE'

In [68]:
min_max_scaler = MaxAbsScaler()
std_scaler = StandardScaler()

In [69]:
sens_idx = dataset[data_name].feature_names.index(sens_attr)
num_sens = len(np.unique(dataset[data_name].features[:, sens_idx]))

#bs = 4000
workers = 16
device = 'cuda:2'
os.environ['CUDA_VISIBLE_DEVICES']='2'
hist = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}


feature_size = dataset[data_name].features.shape[1]
sens_loc = np.zeros(feature_size).astype(bool)
sens_loc[sens_idx] = 1

feature = dataset[data_name].features[:,~sens_loc] #data without sensitive
feature = min_max_scaler.fit_transform(feature)
# feature = std_scaler.fit_transform(feature)
dataset[data_name].features[:,~sens_loc] = feature

#self.sensitive = dataset[data_name].features[:,sens_loc].reshape(-1).astype(int)
#n_values = int(np.max(self.label) + 1)
#self.label = np.eye(n_values)[self.label.astype(int)].squeeze(1)
# dataset[data_name], _ = dataset[data_name].split([0.5], shuffle=True)
data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
data_valid, data_test = data_vt.split([0.5], shuffle=True)

In [70]:
sens_attr

'sex'

#### Poisoning

In [71]:
idx = np.where(data_train.features[:, sens_idx] == 1)[0]
poison_idx = np.random.permutation(idx)[:int(len(idx)/10)]

#label poisoning
data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

In [72]:
# a namespace object which contains some of the hyperparameters
opt = Namespace(num_train=2000, num_val1=200, num_val2=500, num_test=1000)

# num_train = opt.num_train
# num_val1 = opt.num_val1
# num_val2 = opt.num_val2
# num_test = opt.num_test

# X = np.load('X_synthetic.npy') # Input features
# y = np.load('y_synthetic.npy') # Original labels
# y_poi = np.load('y_poi.npy') # Poisoned train labels
# s1 = np.load('s1_synthetic.npy') # Sensitive features



In [73]:
X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
y_train = torch.FloatTensor(data_train.labels)
s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
y_val = torch.FloatTensor(data_valid.labels)
s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
y_test = torch.FloatTensor(data_test.labels)
s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

In [74]:
XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

In [75]:
input_size = XS_train.shape[-1]
latent_size = 512
hidden_size = 512

In [12]:
# Orig code

X = torch.FloatTensor(X)
y = torch.FloatTensor(y)
y_poi = torch.FloatTensor(y_poi)
s1 = torch.FloatTensor(s1)

X_train = X[:num_train - num_val1]
y_train = y_poi[:num_train - num_val1] # Poisoned label
s1_train = s1[:num_train - num_val1]

X_val = X[num_train: num_train + num_val1]
y_val = y[num_train: num_train + num_val1]
s1_val = s1[num_train: num_train + num_val1]

# Currently not used
# X_val2 = X[num_train + num_val1 : num_train + num_val1 + num_val2]
# y_val2 = y[num_train + num_val1 : num_train + num_val1 + num_val2]
# s1_val2 = s1[num_train + num_val1 : num_train + num_val1 + num_val2]

X_test = X[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
y_test = y[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]
s1_test = s1[num_train + num_val1 + num_val2 : num_train + num_val1 + num_val2 + num_test]

XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

NameError: name 'X' is not defined

In [76]:
class Dataset(data.Dataset):
    def __init__(self, X, s, y):
        self.label = y
        self.feature = torch.cat([X, s], dim = 1)
        self.sensitive = s
        
    def __getitem__(self, idx):
        y = self.label[idx]
        x = self.feature[idx]
        a = self.sensitive[idx]
        
        return x, a, y
    
    def __len__(self):
        return len(self.label)
    
data_train = Dataset(X_train, s1_train, y_train)
data_valid = Dataset(X_val, s1_val, y_val)
data_test = Dataset(X_test, s1_test, y_test)

trainloader = torch.utils.data.DataLoader(
    data_train,
    batch_size=128,
    shuffle=True,
    )
validloader = torch.utils.data.DataLoader(
    data_valid,
    batch_size=128,
    shuffle=True,
    drop_last = True
    )
testloader = torch.utils.data.DataLoader(
    data_test,
    batch_size=128,
    shuffle=True,
    drop_last = True
    )

In [59]:
train_result = []
train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

train_opt = Namespace(val=len(y_val), n_epochs=100, k=2, lr_g=1e-5, lr_f=1e-5, lr_r=1e-5)
seed = 1

lambda_f_set = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.52] # Lambda value for the fairness discriminator of FR-Train.
lambda_r = 0.2 # Lambda value for the robustness discriminator of FR-Train.

In [None]:
results_dict, train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed

In [20]:
train_opt.n_epochs = 50

for lambda_f in lambda_f_set:
    train_model_adult(results_dict, m train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)
    


TypeError: train_model_adult() missing 1 required positional argument: 'train_opt'

In [44]:
#COMPAS setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
    data_valid, data_test = data_vt.split([0.5], shuffle=True)

    idx = np.where(data_train.features[:, sens_idx] == 1)[0]
    poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

    #label poisoning
    data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

    X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
    y_train = torch.FloatTensor(data_train.labels)
    s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

    X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
    y_val = torch.FloatTensor(data_valid.labels)
    s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

    X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
    y_test = torch.FloatTensor(data_test.labels)
    s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

    XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
    XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
    XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

    train_result = []
    train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
    val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
    test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

    train_opt = Namespace(val=len(y_val), n_epochs=200, k=1, lr_g=1e-3, lr_f=1e-4, lr_r=1e-4)
    seed = 1

    lambda_f_set = [0.1] # Lambda value for the fairness discriminator of FR-Train.
    lambda_r = 0.1 # Lambda value for the robustness discriminator of FR-Train.

    results_dict[poi_ratio] = {}
    
    data_train = Dataset(X_train, s1_train, y_train)
    data_valid = Dataset(X_val, s1_val, y_val)
    data_test = Dataset(X_test, s1_test, y_test)

    trainloader = torch.utils.data.DataLoader(
        data_train,
        batch_size=128,
        shuffle=True,
        )
    validloader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )
    testloader = torch.utils.data.DataLoader(
        data_test,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )


    for lambda_f in lambda_f_set:
        train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)
    


[Lambda: 0.100000] [Epoch 10/200] [D_F loss: 0.694809] [D_R loss: 0.588161] [G loss: 4.766986]
VALID DATA
overall TPR : 0.224
priv TPR : 0.258
unpriv TPR : 0.148
Eq. Opp : 0.110

overall FPR : 0.139
priv FPR : 0.129
unpriv FPR : 0.145
diff FPR : 0.016

overall ACC : 0.571
priv ACC : 0.576
unpriv ACC : 0.569
diff ACC : 0.008



TEST DATA
overall TPR : 0.202
priv TPR : 0.231
unpriv TPR : 0.145
Eq. Opp : 0.086

overall FPR : 0.103
priv FPR : 0.060
unpriv FPR : 0.131
diff FPR : 0.071

overall ACC : 0.562
priv ACC : 0.588
unpriv ACC : 0.548
diff ACC : 0.041



DIMP : 0.538
[Lambda: 0.100000] [Epoch 20/200] [D_F loss: 0.690723] [D_R loss: 1.004347] [G loss: 0.204525]
VALID DATA
overall TPR : 0.716
priv TPR : 0.729
unpriv TPR : 0.690
Eq. Opp : 0.039

overall FPR : 0.495
priv FPR : 0.508
unpriv FPR : 0.487
diff FPR : 0.021

overall ACC : 0.602
priv ACC : 0.573
unpriv ACC : 0.617
diff ACC : 0.044



TEST DATA
overall TPR : 0.705
priv TPR : 0.715
unpriv TPR : 0.685
Eq. Opp : 0.030

overall FPR :

[Lambda: 0.100000] [Epoch 160/200] [D_F loss: 0.680010] [D_R loss: 0.543620] [G loss: 0.279422]
VALID DATA
overall TPR : 0.577
priv TPR : 0.616
unpriv TPR : 0.496
Eq. Opp : 0.119

overall FPR : 0.334
priv FPR : 0.324
unpriv FPR : 0.340
diff FPR : 0.016

overall ACC : 0.625
priv ACC : 0.600
unpriv ACC : 0.639
diff ACC : 0.039



TEST DATA
overall TPR : 0.594
priv TPR : 0.603
unpriv TPR : 0.577
Eq. Opp : 0.025

overall FPR : 0.325
priv FPR : 0.291
unpriv FPR : 0.347
diff FPR : 0.057

overall ACC : 0.636
priv ACC : 0.651
unpriv ACC : 0.628
diff ACC : 0.023



DIMP : 0.879
[Lambda: 0.100000] [Epoch 170/200] [D_F loss: 0.678357] [D_R loss: 0.649717] [G loss: 0.227475]
VALID DATA
overall TPR : 0.528
priv TPR : 0.583
unpriv TPR : 0.411
Eq. Opp : 0.172

overall FPR : 0.282
priv FPR : 0.214
unpriv FPR : 0.325
diff FPR : 0.111

overall ACC : 0.632
priv ACC : 0.633
unpriv ACC : 0.631
diff ACC : 0.002



TEST DATA
overall TPR : 0.551
priv TPR : 0.615
unpriv TPR : 0.423
Eq. Opp : 0.192

overall FPR

[Lambda: 0.100000] [Epoch 110/200] [D_F loss: 0.684285] [D_R loss: 0.628715] [G loss: 0.234564]
VALID DATA
overall TPR : 0.505
priv TPR : 0.582
unpriv TPR : 0.328
Eq. Opp : 0.254

overall FPR : 0.302
priv FPR : 0.326
unpriv FPR : 0.288
diff FPR : 0.037

overall ACC : 0.612
priv ACC : 0.532
unpriv ACC : 0.651
diff ACC : 0.119



TEST DATA
overall TPR : 0.484
priv TPR : 0.539
unpriv TPR : 0.362
Eq. Opp : 0.177

overall FPR : 0.249
priv FPR : 0.289
unpriv FPR : 0.224
diff FPR : 0.065

overall ACC : 0.629
priv ACC : 0.571
unpriv ACC : 0.661
diff ACC : 0.091



DIMP : 0.846
[Lambda: 0.100000] [Epoch 120/200] [D_F loss: 0.683957] [D_R loss: 0.640857] [G loss: 0.236264]
VALID DATA
overall TPR : 0.706
priv TPR : 0.730
unpriv TPR : 0.647
Eq. Opp : 0.084

overall FPR : 0.506
priv FPR : 0.617
unpriv FPR : 0.446
diff FPR : 0.171

overall ACC : 0.588
priv ACC : 0.488
unpriv ACC : 0.636
diff ACC : 0.148



TEST DATA
overall TPR : 0.706
priv TPR : 0.710
unpriv TPR : 0.696
Eq. Opp : 0.014

overall FPR

[Lambda: 0.100000] [Epoch 60/200] [D_F loss: 0.688775] [D_R loss: 0.720247] [G loss: 0.958846]
VALID DATA
overall TPR : 0.543
priv TPR : 0.588
unpriv TPR : 0.427
Eq. Opp : 0.161

overall FPR : 0.309
priv FPR : 0.314
unpriv FPR : 0.306
diff FPR : 0.008

overall ACC : 0.623
priv ACC : 0.582
unpriv ACC : 0.642
diff ACC : 0.060



TEST DATA
overall TPR : 0.482
priv TPR : 0.507
unpriv TPR : 0.424
Eq. Opp : 0.083

overall FPR : 0.314
priv FPR : 0.279
unpriv FPR : 0.337
diff FPR : 0.058

overall ACC : 0.590
priv ACC : 0.603
unpriv ACC : 0.583
diff ACC : 0.020



DIMP : 0.795
[Lambda: 0.100000] [Epoch 70/200] [D_F loss: 0.688516] [D_R loss: 0.917322] [G loss: 0.227753]
VALID DATA
overall TPR : 0.592
priv TPR : 0.627
unpriv TPR : 0.504
Eq. Opp : 0.123

overall FPR : 0.374
priv FPR : 0.398
unpriv FPR : 0.360
diff FPR : 0.037

overall ACC : 0.610
priv ACC : 0.563
unpriv ACC : 0.633
diff ACC : 0.070



TEST DATA
overall TPR : 0.567
priv TPR : 0.593
unpriv TPR : 0.504
Eq. Opp : 0.089

overall FPR :

[Lambda: 0.100000] [Epoch 10/200] [D_F loss: 0.583197] [D_R loss: 1.624131] [G loss: 8.654215]
VALID DATA
overall TPR : 0.185
priv TPR : 0.132
unpriv TPR : 0.304
Eq. Opp : 0.172

overall FPR : 0.173
priv FPR : 0.296
unpriv FPR : 0.092
diff FPR : 0.204

overall ACC : 0.537
priv ACC : 0.548
unpriv ACC : 0.530
diff ACC : 0.018



TEST DATA
overall TPR : 0.237
priv TPR : 0.183
unpriv TPR : 0.366
Eq. Opp : 0.183

overall FPR : 0.174
priv FPR : 0.301
unpriv FPR : 0.092
diff FPR : 0.209

overall ACC : 0.551
priv ACC : 0.566
unpriv ACC : 0.543
diff ACC : 0.023



DIMP : 0.422
[Lambda: 0.100000] [Epoch 20/200] [D_F loss: 0.693887] [D_R loss: 0.961105] [G loss: 0.965591]
VALID DATA
overall TPR : 0.552
priv TPR : 0.542
unpriv TPR : 0.574
Eq. Opp : 0.032

overall FPR : 0.361
priv FPR : 0.416
unpriv FPR : 0.324
diff FPR : 0.092

overall ACC : 0.599
priv ACC : 0.580
unpriv ACC : 0.610
diff ACC : 0.030



TEST DATA
overall TPR : 0.539
priv TPR : 0.572
unpriv TPR : 0.460
Eq. Opp : 0.113

overall FPR :

[Lambda: 0.100000] [Epoch 160/200] [D_F loss: 0.679706] [D_R loss: 0.676911] [G loss: 0.211069]
VALID DATA
overall TPR : 0.517
priv TPR : 0.587
unpriv TPR : 0.358
Eq. Opp : 0.229

overall FPR : 0.337
priv FPR : 0.286
unpriv FPR : 0.371
diff FPR : 0.085

overall ACC : 0.597
priv ACC : 0.577
unpriv ACC : 0.608
diff ACC : 0.032



TEST DATA
overall TPR : 0.521
priv TPR : 0.603
unpriv TPR : 0.323
Eq. Opp : 0.281

overall FPR : 0.305
priv FPR : 0.270
unpriv FPR : 0.328
diff FPR : 0.057

overall ACC : 0.613
priv ACC : 0.566
unpriv ACC : 0.637
diff ACC : 0.071



DIMP : 0.622
[Lambda: 0.100000] [Epoch 170/200] [D_F loss: 0.678166] [D_R loss: 0.616668] [G loss: 0.188007]
VALID DATA
overall TPR : 0.656
priv TPR : 0.708
unpriv TPR : 0.537
Eq. Opp : 0.171

overall FPR : 0.448
priv FPR : 0.410
unpriv FPR : 0.473
diff FPR : 0.063

overall ACC : 0.599
priv ACC : 0.569
unpriv ACC : 0.616
diff ACC : 0.047



TEST DATA
overall TPR : 0.654
priv TPR : 0.723
unpriv TPR : 0.488
Eq. Opp : 0.235

overall FPR

TEST DATA
overall TPR : 0.569
priv TPR : 0.550
unpriv TPR : 0.612
Eq. Opp : 0.062

overall FPR : 0.384
priv FPR : 0.538
unpriv FPR : 0.281
diff FPR : 0.257

overall ACC : 0.594
priv ACC : 0.525
unpriv ACC : 0.632
diff ACC : 0.107



DIMP : 0.739
[Lambda: 0.100000] [Epoch 110/200] [D_F loss: 0.681453] [D_R loss: 0.706545] [G loss: 0.189696]
VALID DATA
overall TPR : 0.725
priv TPR : 0.722
unpriv TPR : 0.732
Eq. Opp : 0.010

overall FPR : 0.577
priv FPR : 0.696
unpriv FPR : 0.492
diff FPR : 0.204

overall ACC : 0.550
priv ACC : 0.436
unpriv ACC : 0.612
diff ACC : 0.175



TEST DATA
overall TPR : 0.705
priv TPR : 0.700
unpriv TPR : 0.715
Eq. Opp : 0.015

overall FPR : 0.591
priv FPR : 0.628
unpriv FPR : 0.566
diff FPR : 0.062

overall ACC : 0.550
priv ACC : 0.513
unpriv ACC : 0.571
diff ACC : 0.058



DIMP : 0.957
[Lambda: 0.100000] [Epoch 120/200] [D_F loss: 0.686425] [D_R loss: 0.847187] [G loss: 0.233886]
VALID DATA
overall TPR : 0.536
priv TPR : 0.553
unpriv TPR : 0.490
Eq. Opp : 0.063

In [81]:
#Adult setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
    data_valid, data_test = data_vt.split([0.5], shuffle=True)

    idx = np.where(data_train.features[:, sens_idx] == 1)[0]
    poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

    #label poisoning
    data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

    X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
    y_train = torch.FloatTensor(data_train.labels)
    s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

    X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
    y_val = torch.FloatTensor(data_valid.labels)
    s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

    X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
    y_test = torch.FloatTensor(data_test.labels)
    s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

    XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
    XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
    XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

    train_result = []
    train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
    val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
    test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

    train_opt = Namespace(val=len(y_val), n_epochs=100, k=1, lr_g=1e-4, lr_f=1e-5, lr_r=1e-5)
    seed = 1

    lambda_f_set = [0.2] # Lambda value for the fairness discriminator of FR-Train.
    lambda_r = 0.2 # Lambda value for the robustness discriminator of FR-Train.

    results_dict[poi_ratio] = {}
    
    data_train = Dataset(X_train, s1_train, y_train)
    data_valid = Dataset(X_val, s1_val, y_val)
    data_test = Dataset(X_test, s1_test, y_test)

    trainloader = torch.utils.data.DataLoader(
        data_train,
        batch_size=128,
        shuffle=True,
        )
    validloader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )
    testloader = torch.utils.data.DataLoader(
        data_test,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )


    for lambda_f in lambda_f_set:
        train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)
    


TypeError: __init__() takes 1 positional argument but 3 were given

In [84]:
#German setup
results_dict = {}

for poi_ratio in np.linspace(0.1, 0.5, 5):
    data_train, data_vt = dataset[data_name].split([0.7], shuffle=True)
    data_valid, data_test = data_vt.split([0.5], shuffle=True)

    idx = np.where(data_train.features[:, sens_idx] == 1)[0]
    poison_idx = np.random.permutation(idx)[:int(len(idx) * poi_ratio)]

    #label poisoning
    data_train.labels[poison_idx] = 1 - data_train.labels[poison_idx]

    X_train = torch.FloatTensor(data_train.features[:, ~sens_loc])
    y_train = torch.FloatTensor(data_train.labels)
    s1_train = torch.FloatTensor(data_train.features[:, sens_idx]).view(-1,1)

    X_val = torch.FloatTensor(data_valid.features[:, ~sens_loc])
    y_val = torch.FloatTensor(data_valid.labels)
    s1_val = torch.FloatTensor(data_valid.features[:, sens_idx]).view(-1,1)

    X_test = torch.FloatTensor(data_test.features[:, ~sens_loc])
    y_test = torch.FloatTensor(data_test.labels)
    s1_test = torch.FloatTensor(data_test.features[:, sens_idx]).view(-1,1)

    XS_train = torch.cat([X_train, s1_train.reshape((s1_train.shape[0], 1))], dim=1)
    XS_val = torch.cat([X_val, s1_val.reshape((s1_val.shape[0], 1))], dim=1)
    XS_test = torch.cat([X_test, s1_test.reshape((s1_test.shape[0], 1))], dim=1)

    train_result = []
    train_tensors = Namespace(XS_train = XS_train, y_train = y_train, s1_train = s1_train)
    val_tensors = Namespace(XS_val = XS_val, y_val = y_val, s1_val = s1_val) 
    test_tensors = Namespace(XS_test = XS_test, y_test = y_test, s1_test = s1_test)

    train_opt = Namespace(val=len(y_val), n_epochs=200, k=1, lr_g=1e-3, lr_f=1e-4, lr_r=1e-4)
    seed = 1

    lambda_f_set = [0.1] # Lambda value for the fairness discriminator of FR-Train.
    lambda_r = 0.1 # Lambda value for the robustness discriminator of FR-Train.

    results_dict[poi_ratio] = {}
    
    data_train = Dataset(X_train, s1_train, y_train)
    data_valid = Dataset(X_val, s1_val, y_val)
    data_test = Dataset(X_test, s1_test, y_test)

    trainloader = torch.utils.data.DataLoader(
        data_train,
        batch_size=128,
        shuffle=True,
        )
    validloader = torch.utils.data.DataLoader(
        data_valid,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )
    testloader = torch.utils.data.DataLoader(
        data_test,
        batch_size=128,
        shuffle=True,
        drop_last = True
        )


    for lambda_f in lambda_f_set:
        train_model_adult(results_dict[poi_ratio], train_tensors, val_tensors, test_tensors, train_opt, lambda_f = lambda_f, lambda_r = lambda_r, seed = seed)
    


[Lambda: 0.100000] [Epoch 10/200] [D_F loss: 30.000002] [D_R loss: 52.477028] [G loss: 33.262245]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.680
priv ACC : 0.678
unpriv ACC : 0.684
diff ACC : 0.006



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.711
priv ACC : 0.719
unpriv ACC : 0.692
diff ACC : 0.027



DIMP : nan
[Lambda: 0.100000] [Epoch 20/200] [D_F loss: 1.978784] [D_R loss: 1.269961] [G loss: 0.630586]
VALID DATA
overall TPR : 0.385
priv TPR : 0.500
unpriv TPR : 0.345
Eq. Opp : 0.155

overall FPR : 0.292
priv FPR : 0.266
unpriv FPR : 0.360
diff FPR : 0.094

overall ACC : 0.609
priv ACC : 0.613
unpriv ACC : 0.600
diff ACC : 0.013



TEST DATA
overall TPR : 0.371
priv TPR : 0.333
unpriv TPR : 0.391
Eq. Opp : 0.058

overall FPR 

[Lambda: 0.100000] [Epoch 160/200] [D_F loss: 0.735269] [D_R loss: 1.075231] [G loss: -0.163821]
VALID DATA
overall TPR : 0.537
priv TPR : 0.333
unpriv TPR : 0.621
Eq. Opp : 0.287

overall FPR : 0.322
priv FPR : 0.367
unpriv FPR : 0.222
diff FPR : 0.144

overall ACC : 0.633
priv ACC : 0.629
unpriv ACC : 0.641
diff ACC : 0.012



TEST DATA
overall TPR : 0.667
priv TPR : 0.846
unpriv TPR : 0.565
Eq. Opp : 0.281

overall FPR : 0.304
priv FPR : 0.258
unpriv FPR : 0.400
diff FPR : 0.142

overall ACC : 0.688
priv ACC : 0.694
unpriv ACC : 0.674
diff ACC : 0.020



DIMP : 0.638
[Lambda: 0.100000] [Epoch 170/200] [D_F loss: 0.674397] [D_R loss: 0.850940] [G loss: -0.124231]
VALID DATA
overall TPR : 0.450
priv TPR : 0.385
unpriv TPR : 0.481
Eq. Opp : 0.097

overall FPR : 0.352
priv FPR : 0.400
unpriv FPR : 0.250
diff FPR : 0.150

overall ACC : 0.586
priv ACC : 0.563
unpriv ACC : 0.634
diff ACC : 0.071



TEST DATA
overall TPR : 0.629
priv TPR : 0.714
unpriv TPR : 0.571
Eq. Opp : 0.143

overall F

[Lambda: 0.100000] [Epoch 110/200] [D_F loss: 0.660626] [D_R loss: 0.555011] [G loss: 0.091700]
VALID DATA
overall TPR : 0.196
priv TPR : 0.143
unpriv TPR : 0.219
Eq. Opp : 0.076

overall FPR : 0.268
priv FPR : 0.315
unpriv FPR : 0.179
diff FPR : 0.136

overall ACC : 0.539
priv ACC : 0.512
unpriv ACC : 0.595
diff ACC : 0.084



TEST DATA
overall TPR : 0.326
priv TPR : 0.400
unpriv TPR : 0.286
Eq. Opp : 0.114

overall FPR : 0.176
priv FPR : 0.161
unpriv FPR : 0.217
diff FPR : 0.056

overall ACC : 0.656
priv ACC : 0.667
unpriv ACC : 0.632
diff ACC : 0.035



DIMP : 0.691
[Lambda: 0.100000] [Epoch 120/200] [D_F loss: 0.676087] [D_R loss: 0.636975] [G loss: 7.919653]
VALID DATA
overall TPR : 0.317
priv TPR : 0.308
unpriv TPR : 0.321
Eq. Opp : 0.014

overall FPR : 0.253
priv FPR : 0.281
unpriv FPR : 0.200
diff FPR : 0.081

overall ACC : 0.609
priv ACC : 0.588
unpriv ACC : 0.651
diff ACC : 0.063



TEST DATA
overall TPR : 0.342
priv TPR : 0.385
unpriv TPR : 0.320
Eq. Opp : 0.065

overall FPR

[Lambda: 0.100000] [Epoch 60/200] [D_F loss: 0.980054] [D_R loss: 0.720563] [G loss: 0.472730]
VALID DATA
overall TPR : 0.343
priv TPR : 0.263
unpriv TPR : 0.438
Eq. Opp : 0.174

overall FPR : 0.226
priv FPR : 0.246
unpriv FPR : 0.179
diff FPR : 0.068

overall ACC : 0.656
priv ACC : 0.691
unpriv ACC : 0.596
diff ACC : 0.096



TEST DATA
overall TPR : 0.475
priv TPR : 0.200
unpriv TPR : 0.640
Eq. Opp : 0.440

overall FPR : 0.227
priv FPR : 0.217
unpriv FPR : 0.250
diff FPR : 0.033

overall ACC : 0.680
priv ACC : 0.741
unpriv ACC : 0.558
diff ACC : 0.183



DIMP : 0.682
[Lambda: 0.100000] [Epoch 70/200] [D_F loss: 0.833967] [D_R loss: 0.863072] [G loss: 1.720533]
VALID DATA
overall TPR : 0.441
priv TPR : 0.333
unpriv TPR : 0.562
Eq. Opp : 0.229

overall FPR : 0.277
priv FPR : 0.281
unpriv FPR : 0.267
diff FPR : 0.015

overall ACC : 0.648
priv ACC : 0.688
unpriv ACC : 0.583
diff ACC : 0.104



TEST DATA
overall TPR : 0.475
priv TPR : 0.278
unpriv TPR : 0.636
Eq. Opp : 0.359

overall FPR :

[Lambda: 0.100000] [Epoch 10/200] [D_F loss: 28.566610] [D_R loss: 52.603184] [G loss: 43.777225]
VALID DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.727
priv ACC : 0.741
unpriv ACC : 0.702
diff ACC : 0.039



TEST DATA
overall TPR : 0.000
priv TPR : 0.000
unpriv TPR : 0.000
Eq. Opp : 0.000

overall FPR : 0.000
priv FPR : 0.000
unpriv FPR : 0.000
diff FPR : 0.000

overall ACC : 0.656
priv ACC : 0.652
unpriv ACC : 0.667
diff ACC : 0.014



DIMP : nan
[Lambda: 0.100000] [Epoch 20/200] [D_F loss: 1.697839] [D_R loss: 1.206487] [G loss: 2.122984]
VALID DATA
overall TPR : 0.333
priv TPR : 0.500
unpriv TPR : 0.200
Eq. Opp : 0.300

overall FPR : 0.250
priv FPR : 0.250
unpriv FPR : 0.250
diff FPR : 0.000

overall ACC : 0.633
priv ACC : 0.613
unpriv ACC : 0.667
diff ACC : 0.054



TEST DATA
overall TPR : 0.349
priv TPR : 0.385
unpriv TPR : 0.333
Eq. Opp : 0.051

overall FPR 

[Lambda: 0.100000] [Epoch 160/200] [D_F loss: 0.804397] [D_R loss: 1.000725] [G loss: -0.140533]
VALID DATA
overall TPR : 0.541
priv TPR : 0.571
unpriv TPR : 0.522
Eq. Opp : 0.050

overall FPR : 0.440
priv FPR : 0.491
unpriv FPR : 0.353
diff FPR : 0.138

overall ACC : 0.555
priv ACC : 0.512
unpriv ACC : 0.625
diff ACC : 0.113



TEST DATA
overall TPR : 0.413
priv TPR : 0.467
unpriv TPR : 0.387
Eq. Opp : 0.080

overall FPR : 0.390
priv FPR : 0.387
unpriv FPR : 0.400
diff FPR : 0.013

overall ACC : 0.539
priv ACC : 0.538
unpriv ACC : 0.543
diff ACC : 0.005



DIMP : 0.903
[Lambda: 0.100000] [Epoch 170/200] [D_F loss: 0.712114] [D_R loss: 0.595027] [G loss: -0.090835]
VALID DATA
overall TPR : 0.514
priv TPR : 0.500
unpriv TPR : 0.524
Eq. Opp : 0.024

overall FPR : 0.398
priv FPR : 0.459
unpriv FPR : 0.281
diff FPR : 0.178

overall ACC : 0.578
priv ACC : 0.537
unpriv ACC : 0.652
diff ACC : 0.116



TEST DATA
overall TPR : 0.455
priv TPR : 0.667
unpriv TPR : 0.375
Eq. Opp : 0.292

overall F

[Lambda: 0.100000] [Epoch 110/200] [D_F loss: 0.692701] [D_R loss: 0.867965] [G loss: 0.779713]
VALID DATA
overall TPR : 0.531
priv TPR : 0.375
unpriv TPR : 0.583
Eq. Opp : 0.208

overall FPR : 0.375
priv FPR : 0.424
unpriv FPR : 0.267
diff FPR : 0.158

overall ACC : 0.602
priv ACC : 0.578
unpriv ACC : 0.658
diff ACC : 0.080



TEST DATA
overall TPR : 0.550
priv TPR : 0.571
unpriv TPR : 0.538
Eq. Opp : 0.033

overall FPR : 0.398
priv FPR : 0.450
unpriv FPR : 0.286
diff FPR : 0.164

overall ACC : 0.586
priv ACC : 0.547
unpriv ACC : 0.667
diff ACC : 0.120



DIMP : 0.799
[Lambda: 0.100000] [Epoch 120/200] [D_F loss: 0.690243] [D_R loss: 0.930588] [G loss: 1.159359]
VALID DATA
overall TPR : 0.545
priv TPR : 0.429
unpriv TPR : 0.577
Eq. Opp : 0.148

overall FPR : 0.463
priv FPR : 0.500
unpriv FPR : 0.387
diff FPR : 0.113

overall ACC : 0.539
priv ACC : 0.522
unpriv ACC : 0.579
diff ACC : 0.057



TEST DATA
overall TPR : 0.591
priv TPR : 0.600
unpriv TPR : 0.586
Eq. Opp : 0.014

overall FPR

In [23]:
results_dict

{0.1: {'EqOdds': tensor(0.0716, device='cuda:0'),
  'Acc': tensor(0.5312, device='cuda:0'),
  'DISP': tensor(0.7317, device='cuda:0', grad_fn=<DivBackward0>),
  'EqOpp': tensor(0.0345, device='cuda:0'),
  'Acc_diff': tensor(0.0959, device='cuda:0'),
  'total_tpr': tensor(0.1098, device='cuda:0'),
  'total_fpr': tensor(0.0985, device='cuda:0'),
  'total_tpr_priv': tensor(0.1006, device='cuda:0'),
  'total_tpr_unpriv': tensor(0.1351, device='cuda:0'),
  'total_fpr_priv': tensor(0.1215, device='cuda:0'),
  'total_fpr_unpriv': tensor(0.0845, device='cuda:0')},
 0.2: {'EqOdds': tensor(0.1850, device='cuda:0'),
  'Acc': tensor(0.5279, device='cuda:0'),
  'DISP': tensor(0.4840, device='cuda:0', grad_fn=<DivBackward0>),
  'EqOpp': tensor(0.0830, device='cuda:0'),
  'Acc_diff': tensor(0.0441, device='cuda:0'),
  'total_tpr': tensor(0.1214, device='cuda:0'),
  'total_fpr': tensor(0.1134, device='cuda:0'),
  'total_tpr_priv': tensor(0.0973, device='cuda:0'),
  'total_tpr_unpriv': tensor(0.1803, d

In [45]:
with open(f'results/{data_name}/robust_fr_dict.pkl', 'wb') as f:
    pickle.dump(results_dict, f)

In [82]:
def train_model_adult(results_dict, train_tensors, val_tensors, test_tensors, train_opt, lambda_f, lambda_r, seed):
    """
      Trains FR-Train by using the classes in FRTrain_arch.py.
      
      Args:
        train_tensors: Training data.
        val_tensors: Clean validation data.
        test_tensors: Test data.
        train_opt: Options for the training. It currently contains size of validation set, 
                number of epochs, generator/discriminator update ratio, and learning rates.
        lambda_f: The tuning knob for L_2 (ref: FR-Train paper, Section 3.3).
        lambda_r: The tuning knob for L_3 (ref: FR-Train paper, Section 3.3).
        seed: An integer value for specifying torch random seed.
        
      Returns:
        Information about the tuning knobs (lambda_f, lambda_r),
        the test accuracy of the trained model, and disparate impact of the trained model.
    """
    # Initializes generator and discriminator
    Tensor = torch.FloatTensor

    XS_train = train_tensors.XS_train
    y_train = train_tensors.y_train
    s1_train = train_tensors.s1_train

    XS_val = val_tensors.XS_val
    y_val = val_tensors.y_val
    s1_val = val_tensors.s1_val

    XS_test = test_tensors.XS_test
    y_test = test_tensors.y_test
    s1_test = test_tensors.s1_test

    # Saves return values here
    test_result = [] 

    val = train_opt.val # Number of data points in validation set
    k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
    n_epochs = train_opt.n_epochs  # Number of training epoch

    # The loss values of each component will be saved in the following lists. 
    # We can draw epoch-loss graph by the following lists, if necessary.
    g_losses =[]
    d_f_losses = []
    d_r_losses = []
    clean_test_result = []

    bce_loss = torch.nn.BCELoss()

    # Initializes generator and discriminator
    generator = Generator(input_size, latent_size).cuda()
    discriminator_F = DiscriminatorF().cuda()
    discriminator_R = DiscriminatorR(input_size, latent_size).cuda()

    # Initializes weights
    torch.manual_seed(seed)
    generator.apply(weights_init_normal)
    discriminator_F.apply(weights_init_normal)
    discriminator_R.apply(weights_init_normal)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
    optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
    optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

    for epoch in range(1, n_epochs + 1):
        cnt = 0
    #     for x, s, y, x_val, s_val, y_val in trainloader:
        generator.train()
        for x, s, y in trainloader:
            cnt += 1
            x, s, y = x.cuda(), s.cuda(), y.cuda()
            x_val, s_val, y_val = iter(validloader).next()
            x_val, s_val, y_val = x_val.cuda(), s_val.cuda(), y_val.cuda()
            # -------------------
            #  Forwards Generator
            # -------------------
            if cnt % k == 0 or epoch < 10:
                optimizer_G.zero_grad()

            gen_y = generator(x)
            gen_data = torch.cat([x, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


            # -------------------------------
            #  Trains Fairness Discriminator
            # -------------------------------

            optimizer_D_F.zero_grad()

            # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
            d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s)
            d_f_loss.backward()
            d_f_losses.append(d_f_loss)
            optimizer_D_F.step()

            # ---------------------------------
            #  Trains Robustness Discriminator
            # ---------------------------------
            optimizer_D_R.zero_grad()

            XSY_val_data = torch.cat([x_val, y_val.view(-1,1)], dim = 1).cuda()
            fake = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
            clean = Variable(Tensor(len(x_val), 1).fill_(1.0), requires_grad=False).cuda()

            # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
            clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
            poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
            d_r_loss = 0.5 * (clean_loss + poison_loss)

            d_r_loss.backward()
            d_r_losses.append(d_r_loss)
            optimizer_D_R.step()

            # ---------------------
            #  Updates Generator
            # ---------------------
            if epoch < 10 :
                g_loss = bce_loss((F.tanh(gen_y)+1)/2, y)
                g_loss.backward()
                g_losses.append(g_loss)
                optimizer_G.step()
            elif cnt % k == 0:
                r_decision = discriminator_R(gen_data)
                generated = Variable(Tensor(len(gen_data), 1).fill_(0.0), requires_grad=False).cuda()
                
                r_gen = bce_loss(r_decision, generated)

                # ---------------------------------
                #  Re-weights using output of D_R
                # ---------------------------------
#                 if epoch % 10 == 0:
                
                if epoch < 10:
                    r_weight = torch.ones_like(y, requires_grad=False).float().cuda()
                else:
                    r_ones = torch.ones_like(y, requires_grad=False).float().cuda()
                    loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
                    a = 1/(1+torch.exp(-(loss_ratio-3)))
                    b = 1-a
                    r_weight_tmp = r_decision.detach().squeeze()
                    r_weight = a * r_weight_tmp + b * r_ones

                f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s, reduction="none").squeeze()
                g_cost = F.binary_cross_entropy_with_logits(gen_y, y, reduction="none").squeeze()

                f_gen = torch.mean(f_cost*r_weight)
                g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

                g_loss.backward()
                optimizer_G.step()

            g_losses.append(g_loss)

        if epoch % 10 == 0:
            print(
                    "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                    % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
                )

            tp_priv, tn_priv, fp_priv, fn_priv, \
            tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

            for x_val, a_val, y_val in validloader:
                x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

                priv_idx = (a_val==1).squeeze()
                positive_idx = y_val==1

                pred_test = generator(x_val).view(-1)

                pred_test[pred_test>0] = 1
                pred_test[pred_test<=0] = 0

                test_lb_priv = y_val[priv_idx].view(-1)
                test_lb_unpriv = y_val[~priv_idx].view(-1)

                pred_priv = pred_test[priv_idx]
                pred_unpriv = pred_test[~priv_idx]

                y_val = y_val.cpu().detach().numpy()
                test_lb_priv = test_lb_priv.cpu().detach().numpy()
                test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

                try:
                    pred_priv = pred_priv.argmax(1)
                except:
                    pass
                try:
                    pred_unpriv = pred_unpriv.argmax(1)
                except:
                    pass


                tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
                fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
                tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
                fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

                tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
                fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
                tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
                fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

            tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
            tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
            tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

            fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
            fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
            fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

            acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                      fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
            acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
            acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


            print("VALID DATA")
            print('overall TPR : {0:.3f}'.format( tpr_overall))
            print('priv TPR : {0:.3f}'.format( tpr_priv))
            print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
            print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
            print()
            print('overall FPR : {0:.3f}'.format( fpr_overall))
            print('priv FPR : {0:.3f}'.format( fpr_priv))
            print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
            print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
            print()
            print('overall ACC : {0:.3f}'.format( acc_overall))
            print('priv ACC : {0:.3f}'.format( acc_priv))
            print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
            print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))
            
            tp_priv, tn_priv, fp_priv, fn_priv, \
            tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0
            
            y_hat1_priv, y_hat1_unpriv, priv, unpriv = [], [],[], []

            
            generator.eval()
            for x_val, a_val, y_val in testloader:
                x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

                priv_idx = (a_val==1).squeeze()
                positive_idx = y_val==1

                pred_test = generator(x_val).view(-1)

                pred_test[pred_test>0] = 1
                pred_test[pred_test<=0] = 0

                test_lb_priv = y_val[priv_idx].view(-1)
                test_lb_unpriv = y_val[~priv_idx].view(-1)

                pred_priv = pred_test[priv_idx]
                pred_unpriv = pred_test[~priv_idx]

                y_val = y_val.cpu().detach().numpy()
                test_lb_priv = test_lb_priv.cpu().detach().numpy()
                test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

                try:
                    pred_priv = pred_priv.argmax(1)
                except:
                    pass
                try:
                    pred_unpriv = pred_unpriv.argmax(1)
                except:
                    pass


                tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
                fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
                tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
                fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

                tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
                fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
                tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
                fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)
                
                y_hat1_priv.append(pred_priv)
                y_hat1_unpriv.append(pred_unpriv)
                priv.append(priv_idx)
                unpriv.append(~priv_idx)

            tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
            tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
            tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

            fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
            fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
            fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

            acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                      fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
            acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
            acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()
            
            y_hat1_priv = sum(torch.cat(y_hat1_priv)).float()/sum(torch.cat(priv))
            y_hat1_unpriv = sum(torch.cat(y_hat1_unpriv)).float()/sum(torch.cat(unpriv))
            
            dimp = min(y_hat1_priv/y_hat1_unpriv, y_hat1_unpriv/y_hat1_priv)

            print("TEST DATA")
            print('overall TPR : {0:.3f}'.format( tpr_overall))
            print('priv TPR : {0:.3f}'.format( tpr_priv))
            print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
            print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
            print()
            print('overall FPR : {0:.3f}'.format( fpr_overall))
            print('priv FPR : {0:.3f}'.format( fpr_priv))
            print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
            print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
            print()
            print('overall ACC : {0:.3f}'.format( acc_overall))
            print('priv ACC : {0:.3f}'.format( acc_priv))
            print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
            print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))
            
            print('DIMP : {0:.3f}'.format( dimp)) 
            
        

            results_dict['EqOdds'] =abs(fpr_unpriv-fpr_priv) + abs(tpr_unpriv - tpr_priv)
            results_dict['Acc'] = acc_overall
            results_dict['DISP'] = dimp
            results_dict['EqOpp']= abs(tpr_unpriv - tpr_priv)
            results_dict['Acc_diff'] = abs(acc_unpriv-acc_priv)

            results_dict['total_tpr'] =tpr_overall
            results_dict['total_fpr'] =fpr_overall

            results_dict['total_tpr_priv'] =tpr_priv
            results_dict['total_tpr_unpriv'] =tpr_unpriv
            results_dict['total_fpr_priv'] =fpr_priv
            results_dict['total_fpr_unpriv'] =fpr_unpriv


        #     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')

#     tmp = test_model(generator, XS_test, y_test, s1_test)
#     test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])
    
#     return test_result

In [77]:
Tensor = torch.FloatTensor

XS_train = train_tensors.XS_train
y_train = train_tensors.y_train
s1_train = train_tensors.s1_train

XS_val = val_tensors.XS_val
y_val = val_tensors.y_val
s1_val = val_tensors.s1_val

XS_test = test_tensors.XS_test
y_test = test_tensors.y_test
s1_test = test_tensors.s1_test

# Saves return values here
test_result = [] 

val = train_opt.val # Number of data points in validation set
k = train_opt.k     # Update ratio of generator and discriminator (1:k training).
n_epochs = train_opt.n_epochs  # Number of training epoch

# Changes the input validation data to an appropriate shape for the training
XSY_val = torch.cat([XS_val, y_val.reshape((y_val.shape[0], 1))], dim=1)  

# The loss values of each component will be saved in the following lists. 
# We can draw epoch-loss graph by the following lists, if necessary.
g_losses =[]
d_f_losses = []
d_r_losses = []
clean_test_result = []

bce_loss = torch.nn.BCELoss()

# Initializes generator and discriminator
generator = Generator(input_size, latent_size).cuda()
discriminator_F = DiscriminatorF().cuda()
discriminator_R = DiscriminatorR(input_size, latent_size).cuda()

# Initializes weights
torch.manual_seed(seed)
generator.apply(weights_init_normal)
discriminator_F.apply(weights_init_normal)
discriminator_R.apply(weights_init_normal)

optimizer_G = torch.optim.Adam(generator.parameters(), lr=train_opt.lr_g)
optimizer_D_F = torch.optim.SGD(discriminator_F.parameters(), lr=train_opt.lr_f)
optimizer_D_R = torch.optim.SGD(discriminator_R.parameters(), lr=train_opt.lr_r)

In [78]:

# valid = Variable(Tensor(train_len, 1).fill_(1.0), requires_grad=False)
# generated = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
# fake = Variable(Tensor(train_len, 1).fill_(0.0), requires_grad=False)
# clean = Variable(Tensor(val_len, 1).fill_(1.0), requires_grad=False)


# r_weight = torch.ones_like(y_train, requires_grad=False).float()
# r_ones = torch.ones_like(y_train, requires_grad=False).float()

for epoch in range(1, n_epochs):
#     for x, s, y, x_val, s_val, y_val in trainloader:
    for x, s, y in trainloader:
        x, s, y = x.cuda(), s.cuda(), y.cuda()
        x_val, s_val, y_val = iter(validloader).next()
        x_val, s_val, y_val = x_val.cuda(), s_val.cuda(), y_val.cuda()
        # -------------------
        #  Forwards Generator
        # -------------------
        if epoch % k == 0 or epoch < 10:
            optimizer_G.zero_grad()

        gen_y = generator(x)
        gen_data = torch.cat([x, gen_y.reshape((gen_y.shape[0], 1))], dim=1)


        # -------------------------------
        #  Trains Fairness Discriminator
        # -------------------------------

        optimizer_D_F.zero_grad()

        # Discriminator_F tries to distinguish the sensitive groups by using the output of the generator.
        d_f_loss = bce_loss(discriminator_F(gen_y.detach()), s)
        d_f_loss.backward()
        d_f_losses.append(d_f_loss)
        optimizer_D_F.step()

        # ---------------------------------
        #  Trains Robustness Discriminator
        # ---------------------------------
        optimizer_D_R.zero_grad()

        XSY_val_data = torch.cat([x_val, y_val.view(-1,1)], dim = 1).cuda()
        fake = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
        clean = Variable(Tensor(len(x_val), 1).fill_(1.0), requires_grad=False).cuda()
        
        # Discriminator_R tries to distinguish whether the input is from the validation data or the generated data from generator.
        clean_loss =  bce_loss(discriminator_R(XSY_val_data), clean)
        poison_loss = bce_loss(discriminator_R(gen_data.detach()), fake)
        d_r_loss = 0.5 * (clean_loss + poison_loss)

        d_r_loss.backward()
        d_r_losses.append(d_r_loss)
        optimizer_D_R.step()

        # ---------------------
        #  Updates Generator
        # ---------------------
        if epoch < 10 :
            g_loss = 0.1 * bce_loss((F.tanh(gen_y)+1)/2, y)
            g_loss.backward()
            g_losses.append(g_loss)
            optimizer_G.step()
        elif epoch % k == 0:
            generated = Variable(Tensor(len(x), 1).fill_(0.0), requires_grad=False).cuda()
            
            r_decision = discriminator_R(gen_data)
            r_gen = bce_loss(r_decision, generated)

            # ---------------------------------
            #  Re-weights using output of D_R
            # ---------------------------------
#             if epoch % 10 == 0:
            r_ones = torch.ones_like(y, requires_grad=False).float().cuda()

            loss_ratio = (g_losses[-1]/d_r_losses[-1]).detach()
            a = 1/(1+torch.exp(-(loss_ratio-3)))
            b = 1-a
            r_weight_tmp = r_decision.detach().squeeze()
            r_weight = a * r_weight_tmp + b * r_ones

            f_cost = F.binary_cross_entropy(discriminator_F(gen_y), s, reduction="none").squeeze()
            g_cost = F.binary_cross_entropy_with_logits(gen_y.squeeze(), (y.squeeze()+1)/2, reduction="none").squeeze()

            f_gen = torch.mean(f_cost*r_weight)
            g_loss = (1-lambda_f-lambda_r) * torch.mean(g_cost*r_weight) - lambda_f * f_gen -  lambda_r * r_gen 

            g_loss.backward()
            optimizer_G.step()


        g_losses.append(g_loss)

    if epoch % 5 == 0:
        print(
                "[Lambda: %1f] [Epoch %d/%d] [D_F loss: %f] [D_R loss: %f] [G loss: %f]"
                % (lambda_f, epoch, n_epochs, d_f_losses[-1], d_r_losses[-1], g_losses[-1])
            )

        tp_priv, tn_priv, fp_priv, fn_priv, \
        tp_unpriv, tn_unpriv, fp_unpriv, fn_unpriv = 0, 0, 0, 0, 0, 0, 0, 0

        for x_val, a_val, y_val in validloader:
            x_val, a_val, y_val = x_val.cuda(), a_val.cuda(), y_val.cuda()

            priv_idx = (a_val==1).squeeze()
            positive_idx = y_val==1

            pred_test = generator(x_val).view(-1)
            
            pred_test[pred_test>0] = 1
            pred_test[pred_test<=0] = 0

            test_lb_priv = y_val[priv_idx].view(-1)
            test_lb_unpriv = y_val[~priv_idx].view(-1)

            pred_priv = pred_test[priv_idx]
            pred_unpriv = pred_test[~priv_idx]

            y_val = y_val.cpu().detach().numpy()
            test_lb_priv = test_lb_priv.cpu().detach().numpy()
            test_lb_unpriv = test_lb_unpriv.cpu().detach().numpy()

            try:
                pred_priv = pred_priv.argmax(1)
            except:
                pass
            try:
                pred_unpriv = pred_unpriv.argmax(1)
            except:
                pass


            tp_priv += sum(pred_priv[test_lb_priv == 1] == 1)
            fp_priv += sum(pred_priv[test_lb_priv == 0] == 1)
            tn_priv += sum(pred_priv[test_lb_priv == 0] == 0)
            fn_priv += sum(pred_priv[test_lb_priv == 1] == 0)

            tp_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 1)
            fp_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 1)
            tn_unpriv += sum(pred_unpriv[test_lb_unpriv == 0] == 0)
            fn_unpriv += sum(pred_unpriv[test_lb_unpriv == 1] == 0)

        tpr_overall = (tp_priv + tp_unpriv)/(tp_priv + tp_unpriv + fn_priv + fn_unpriv).float().item()
        tpr_priv = (tp_unpriv)/(tp_unpriv + fn_unpriv).float().item()
        tpr_unpriv = (tp_priv)/(tp_priv + fn_priv).float().item()

        fpr_overall = (fp_priv + fp_unpriv)/(tn_priv + tn_unpriv + fp_priv + fp_unpriv).float().item()
        fpr_unpriv = (fp_unpriv)/(tn_unpriv + fp_unpriv).float().item()
        fpr_priv = (fp_priv)/(tn_priv + fp_priv).float().item()

        acc_overall = (tp_priv + tn_priv + tp_unpriv + tn_unpriv)/(tp_priv + tn_priv + tp_unpriv + tn_unpriv + \
                                                                  fp_priv + fn_priv + fp_unpriv + fn_unpriv).float().item()
        acc_priv = (tp_priv + tn_priv)/(tp_priv + tn_priv + fp_priv + fn_priv).float().item()
        acc_unpriv = (tp_unpriv + tn_unpriv)/(tp_unpriv + tn_unpriv + fp_unpriv + fn_unpriv).float().item()


        print()
        print('overall TPR : {0:.3f}'.format( tpr_overall))
        print('priv TPR : {0:.3f}'.format( tpr_priv))
        print('unpriv TPR : {0:.3f}'.format( tpr_unpriv))
        print('Eq. Opp : {0:.3f}'.format( abs(tpr_unpriv - tpr_priv)))
        print()
        print('overall FPR : {0:.3f}'.format( fpr_overall))
        print('priv FPR : {0:.3f}'.format( fpr_priv))
        print('unpriv FPR : {0:.3f}'.format( fpr_unpriv))
        print('diff FPR : {0:.3f}'.format( abs(fpr_unpriv-fpr_priv)))
        print()
        print('overall ACC : {0:.3f}'.format( acc_overall))
        print('priv ACC : {0:.3f}'.format( acc_priv))
        print('unpriv ACC : {0:.3f}'.format( acc_unpriv)) 
        print('diff ACC : {0:.3f}\n\n\n'.format( abs(acc_unpriv-acc_priv)))

    #     torch.save(generator.state_dict(), './FR-Train_on_poi_synthetic.pth')

    
tmp = test_model(generator, XS_test, y_test, s1_test)
test_result.append([lambda_f, lambda_r, tmp[0].item(), tmp[1]])


[Lambda: 0.150000] [Epoch 5/50] [D_F loss: 0.522253] [D_R loss: 50.000000] [G loss: 3.333333]

overall TPR : 0.004
priv TPR : 0.000
unpriv TPR : 0.004
Eq. Opp : 0.004

overall FPR : 0.004
priv FPR : 0.006
unpriv FPR : 0.002
diff FPR : 0.004

overall ACC : 0.748
priv ACC : 0.684
unpriv ACC : 0.882
diff ACC : 0.198



[Lambda: 0.150000] [Epoch 10/50] [D_F loss: 0.877271] [D_R loss: 48.840309] [G loss: -30.585560]

overall TPR : 0.005
priv TPR : 0.004
unpriv TPR : 0.006
Eq. Opp : 0.002

overall FPR : 0.009
priv FPR : 0.010
unpriv FPR : 0.007
diff FPR : 0.003

overall ACC : 0.748
priv ACC : 0.684
unpriv ACC : 0.880
diff ACC : 0.196





KeyboardInterrupt: 

In [83]:
class Generator(nn.Module):
    """FR-Train generator (classifier).
    
    This class is for defining the structure of FR-Train generator (classifier). 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self, input_size, latent_size):
        """Initializes Generator with torch components."""
        
        super(Generator, self).__init__()

    
        def block(in_feat, out_feat, normalize=True):
            """Defines a block with torch components.
            
                Args:
                    in_feat: An integer value for the size of the input feature.
                    out_feat: An integer value for the size of the output feature.
                    normalize: A boolean indicating whether normalization is needed.
                    
                Returns:
                    The stacked layer.
            """
            
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(input_size, latent_size, normalize=False),
            *block(latent_size, latent_size, normalize=False),
            *block(latent_size, latent_size, normalize=False),
#             nn.Linear(32, 1),
#             nn.Tanh()
            nn.Linear(latent_size,1),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
            
        Returns:
            The predicted label (y_hat) for the given input data.
        """
        
        output = self.model(input_data)
        return output


class DiscriminatorF(nn.Module):
    """FR-Train fairness discriminator.
    
    This class is for defining structure of FR-Train fairness discriminator. 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self):
        """Initializes DiscriminatorF with torch components."""
        
        super(DiscriminatorF, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1,1),
            nn.Sigmoid(),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
            
        Returns:
            The predicted sensitive attribute for the given input data.
        """
        
        predicted_z = self.model(input_data)
        return predicted_z
    

class DiscriminatorR(nn.Module):
    """FR-Train robustness discriminator.
    
    This class is for defining the structure of FR-Train robustness discriminator. 
    (ref: FR-Train paper, Section 3)

    Attributes:
        model: A model consisting of torch components.
    """
    
    def __init__(self, input_size, latent_size):
        """Initializes DiscriminatorR with torch components."""
        
        super(DiscriminatorR, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size + 1, latent_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(latent_size, latent_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(latent_size, 1),
            nn.Sigmoid(),
        )

    def forward(self, input_data):
        """Defines a forward operation of the model.
        
        Args: 
            input_data: The input data.
        
        Returns:
            The predicted indicator (whether the input data is clean or poisoned) 
            for the given input data.
        """
        
        validity = self.model(input_data)
        return validity