In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%load_ext autoreload
%autoreload 2
%matplotlib inline

import argparse
import numpy as np
import os
from sklearn.metrics import normalized_mutual_info_score
from sklearn.model_selection import train_test_split
    
    
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import torch
from torch import nn
import torch.nn.functional as F

from dataloader import mnist_usps, mnist_reverse, FaceLandmarksDataset, TabDataset
from eval import predict, cluster_accuracy, balance, calc_FID
from utils import set_seed, AverageMeter, target_distribution, aff, inv_lr_scheduler
import argparse

# from MulticoreTSNE import MulticoreTSNE as TSNE

import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE

from tqdm import tqdm


from aif360.metrics import ClassificationMetric
from aif360.datasets import AdultDataset, GermanDataset, BankDataset, CompasDataset, BinaryLabelDataset, CelebADataset, MEPSDataset19

In [2]:
args = argparse.ArgumentParser(description='Process some integers.')
args.bs = 256
args.test_interval = 200
args.lr = 1e-3

args.num_iter = 1000
args.num_sens = 1
data_name = 'adult'

In [3]:
from sklearn.preprocessing import normalize
from torch.utils import data

class TabDataset(data.Dataset):
    def __init__(self, dataset, sens_idx):
        self.label = dataset.labels.squeeze(-1).astype(int)
        
        self.feature_size = dataset.features.shape[1]
        sens_loc = np.zeros(self.feature_size).astype(bool)
        if isinstance(sens_idx, list):
            for sens in sens_idx:
                sens_loc[sens] = 1
        else:
            sens_loc[sens_idx] = 1

        self.feature = dataset.features[:,~sens_loc] #data without sensitive
        self.feature = normalize(self.feature)
        
        self.sensitive = dataset.features[:,sens_loc]
        self.enc = dict()
        for i, idx in enumerate(np.unique(self.sensitive, axis = 0)):
            self.enc[str(idx)] = i   
            
    def __getitem__(self, idx):
        y = self.label[idx]
        x = self.feature[idx]
        a = self.enc[str(self.sensitive[idx])]
        
        return x, a, y
    
 
    def __len__(self):
        return len(self.label)

In [4]:
dataset_orig = AdultDataset()
sens_idx = dataset_orig.feature_names.index('sex')

# dataset_orig = CompasDataset()
# sens_idx = dataset_orig.feature_names.index('race')

data_train, data_vt = dataset_orig.split([0.7], shuffle=True)
data_valid, data_test = data_vt.split([0.5], shuffle=True)


d_train = TabDataset(data_train, sens_idx)
v_train = TabDataset(data_valid, sens_idx)
t_train = TabDataset(data_test, sens_idx)

trainloader = torch.utils.data.DataLoader(
    d_train,
    batch_size=args.bs,
    shuffle=True,
    num_workers = 16
    )

validloader = torch.utils.data.DataLoader(
    v_train,
    batch_size=args.bs,
    shuffle=True,
    num_workers = 16
    )

testloader = torch.utils.data.DataLoader(
    t_train,
    batch_size=args.bs,
    shuffle=True,
    num_workers = 16
    )




In [5]:
import torch.nn.init as init

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)
        
class Encoder_tab(nn.Module):
    def __init__(self, input_dim, latent_dim=32):
        super(Encoder_tab, self).__init__()
        self.linear1 = nn.Linear(input_dim, 128)
        self.linear2 = nn.Linear(128, 64)
        
        self.linear3_1 = nn.Linear(64, latent_dim)
        self.linear3_2 = nn.Linear(64, latent_dim)

        self.relu = nn.ReLU()
        
        for m in self.children():
            weights_init_kaiming(m)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        
        mu, logvar = self.linear3_1(x), self.linear3_2(x)
        
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

class Decoder_tab(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Decoder_tab, self).__init__()
        self.linear1 = nn.Linear(latent_dim, 64)
        self.linear2 = nn.Linear(64, 64)
        self.linear3 = nn.Linear(64, 128)
        self.linear4 = nn.Linear(128, input_dim)

        self.relu = nn.ReLU()
        
        for m in self.children():
            weights_init_kaiming(m)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.relu(self.linear4(x))
        return x
    
    
class MLP(nn.Module):
    def __init__(self, input_dim = 32, hidden_dim = 128):
        super(MLP, self).__init__()
        self.input_dim = input_dim 
        self.dense1 = nn.Linear(input_dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, hidden_dim)
        
        for m in self.children():
            weights_init_kaiming(m)
        
    def forward(self, x):
        x = torch.relu(self.dense1(x))
        x = self.dense2(x)
        return x

    
class Classifier(nn.Module):
    def __init__(self, input_dim = 32, hidden_dim = 128, output_dim = 1):
        super(Classifier, self).__init__()
        self.input_dim = input_dim
        self.relu = nn.LeakyReLU(0.2)
        self.dense1 = nn.Linear(input_dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, output_dim)
        
        for m in self.children():
            weights_init_kaiming(m)
        
    def forward(self, x):
        x = self.relu(self.dense1(x))
        x = self.dense2(x)
        return x

def permute_dims(z):
    assert z.dim() == 2
    B, _ = z.size()
    perm_z = []
    for z_j in z.split(1, 1):
        perm = torch.randperm(B).to(z.device)
        perm_z_j = z_j[perm]
        perm_z.append(perm_z_j)

    return torch.cat(perm_z, 1)

### CMI Constraint

In [13]:
input_dim = dataset_orig.features.shape[1] - 1
latent_dim = 6
z_dim = [2,2,2]

def CI_loss_v2(cls, z_y, z_r, s_batch):
    
    z_y_repeat = z_y.unsqueeze(1).expand(-1, z_y.shape[0], -1) # N x N (replica) x D
    z_r_repeat = z_r.unsqueeze(0).expand(z_y.shape[0], -1, -1) # N (replica) x N  x D

    z_yr = torch.cat([z_y_repeat, z_r_repeat], dim = -1).view(z_y.shape[0] **2, -1) # N^2 x 2D
    p_y = torch.sigmoid(cls(z_yr).view(z_y.shape[0], z_y.shape[0], -1)) # N x N x 1
    p_y_agg = p_y.mean(0)  # mean over different z_y
    
    H_y_cond_z = -(p_y_agg * torch.log(p_y_agg + 1e-7) + (1-p_y_agg) * torch.log(1-p_y_agg + 1e-7)).mean()
    
    s_idx = s_batch.view(-1) == 1
    p_a = s_idx.float().mean()

    p_ya1 = p_y[s_idx].mean(0) * torch.log(p_y[s_idx].mean(0) + 1e-7) + \
                (1-p_y[s_idx].mean(0)) * torch.log(1-p_y[s_idx].mean(0) + 1e-7)
    
    p_ya0 = p_y[~s_idx].mean(0) * torch.log(p_y[~s_idx].mean(0) + 1e-7) + \
                (1-p_y[~s_idx].mean(0)) * torch.log(1-p_y[~s_idx].mean(0) + 1e-7)
    
    H_y_cond_za = -(p_a * p_ya1 + (1 - p_a) * p_ya0).mean()
    
    return H_y_cond_z, H_y_cond_za


epochs = 50
criterion = nn.BCEWithLogitsLoss()
lambda_ci = 0.8

encoder = Encoder_tab(input_dim, latent_dim).cuda()
decoder = Decoder_tab(input_dim, latent_dim).cuda()
cls_y = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()
cls_a = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()

param_lst = list()
param_lst += list(encoder.parameters()) + list(decoder.parameters())
param_lst += list(cls_y.parameters()) + list(cls_a.parameters())                                       

optimizer = torch.optim.Adam(param_lst, lr = 5e-4, weight_decay = 1e-5)

H_z_1, H_za = list(), list()
H_z_2, H_zs = list(), list()
                                               
for epoch in (range(epochs + 1)):
    encoder.train()
    decoder.train()
    cls_y.train()
    cls_a.train()

    loss_hist = 0
    recon_loss_hist = 0
    cls_loss_hist = 0
    kld_loss_hist = 0
    loss_ci_hist = 0
    loss_cov_hist = 0
    cnt = 0
    
    for x_batch, s_batch, y_batch in trainloader:
        x_batch = x_batch.cuda().float()
        s_batch, y_batch = s_batch.cuda().view(-1,1).float(), y_batch.cuda().view(-1,1).float()
        
        z, mu, logvar = encoder(x_batch)
        
        z = z.split(2, dim = -1) # [z_x, z_y, z_r, z_a]
        z_y, z_r, z_a = z[0], z[1], z[2]
        recon = decoder(torch.cat(z, -1))
        
        pred_y = cls_y(torch.cat([z_y, z_r], dim =-1))
        pred_a = cls_a(torch.cat([z_a, z_r], dim =-1))

        recon_loss = F.l1_loss(recon, x_batch)
        cls_loss = criterion(pred_y, y_batch) + criterion(pred_a, s_batch)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1))

        loss_elbo = 1e2 * cls_loss + kld_loss
        
        H_y_cond_z_1, H_y_cond_za = CI_loss_v2(cls_y, z_y, z_r, (pred_a>=0).int().detach())
        H_y_cond_z_2, H_y_cond_zs = CI_loss_v2(cls_a, z_a, z_r, (pred_y>=0).int().detach())
        loss_ci = 0.5 * (H_y_cond_z_1 + H_y_cond_z_2 - H_y_cond_za - H_y_cond_zs)
        
        H_z_1.append(H_y_cond_z_1.item())
        H_za.append(H_y_cond_za.item())
        H_z_2.append(H_y_cond_z_2.item())
        H_zs.append(H_y_cond_zs.item())
        

        loss = loss_elbo 
        if epoch > 5:
            loss += lambda_ci * loss_ci
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_hist += loss.item()
        recon_loss_hist += recon_loss.item()
        cls_loss_hist += cls_loss.item()
        kld_loss_hist += kld_loss.item()
        loss_ci_hist += loss_ci.item()
#         loss_cov_hist += loss_cov.item()
        cnt += 1

    diag = "[TRAIN] epoch : [{}/{}]\n".format(epoch, epochs)
    diag += "Loss : {:.3f}\n".format(loss_hist/cnt)
    diag += "recon_loss : {:.3f}, cls_loss : {:.3f}, kld_loss : {:.3f}\n".format(recon_loss_hist/cnt, cls_loss_hist/cnt, kld_loss_hist/cnt)
    diag += "CI_loss : {:.3f}, COV_loss : {:.3f}\n".format(loss_ci_hist/cnt, loss_cov_hist/cnt)
    print(diag)

    
    
encoder.eval()
args.lr = 1e-4
epochs = 10

for z_input in ['yr', 'y', 'a', 'r', 'yra']:
    print("INPUT : ", z_input)
    z_idx = []
    
    if 'y' in z_input:
        z_idx.append(0)
    if 'a' in z_input:
        z_idx.append(2)
    if 'r' in z_input:
        z_idx.append(1)
        
    print("index : ", z_idx)
    
    cls = Classifier(len(z_input) * 2).cuda()
    optimizer = torch.optim.Adam(cls.parameters(), lr = args.lr, weight_decay = 1e-5)

    for epoch in (range(epochs + 1)):
        cls.train()
        for x_batch, s_batch, y_batch in trainloader:
            x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float(), y_batch.cuda().float().view(-1,1)

            z, _, _ = encoder(x_batch)
            z = z.split(2, dim = 1)
            pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))
            loss = criterion(pred, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if epoch % 5 == 0:
        cls.eval()
        with torch.no_grad():
            pred_lst, y_lst, a_lst = [], [], []
            for x_batch, s_batch, y_batch in validloader:
                x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float().view(-1,1), y_batch.cuda().float().view(-1,1)

                z, _, _ = encoder(x_batch)
                z = z.split(2, dim = 1)
                z_y, z_r, z_a = z[0], z[1], z[2]
                pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))

                pred[pred>=0] = 1
                pred[pred<0] = 0

                pred_lst.append(pred.detach())
                y_lst.append(y_batch.detach())
                a_lst.append(s_batch.detach())

            pred_lst = torch.cat(pred_lst).cpu().numpy()
            y_lst = torch.cat(y_lst).cpu().numpy()
            a_lst = torch.cat(a_lst).cpu().numpy()

            acc = (pred_lst == y_lst).mean()
            acc_priv = (pred_lst[a_lst == 1] == y_lst[a_lst == 1]).mean()
            acc_unpriv = (pred_lst[a_lst == 0] == y_lst[a_lst == 0]).mean()
            priv_idx = a_lst == 1
            pos_idx = y_lst == 1

            tpr = (pred_lst[pos_idx] == 1).mean()
            tpr_priv = (pred_lst[pos_idx*priv_idx] == 1).mean()
            tpr_unpriv = (pred_lst[pos_idx*~priv_idx] == 1).mean()

            fpr = (pred_lst[~pos_idx] == 1).mean()
            fpr_priv = (pred_lst[~pos_idx*priv_idx] == 1).mean()
            fpr_unpriv = (pred_lst[~pos_idx*~priv_idx] == 1).mean()

            DP = abs((pred_lst[priv_idx]==1).mean() - (pred_lst[~priv_idx]==1).mean())
            EOP =abs(tpr_priv-tpr_unpriv)
            EOD =abs(tpr_priv-tpr_unpriv) + abs(fpr_priv-fpr_unpriv)

            diag = "ACC : {:.3f}, ACC_priv : {:.3f}, ACC_unpriv : {:.3f}".format(acc, acc_priv, acc_unpriv)
            print(diag)

            diag = "TPR : {:.3f}, TPR_priv : {:.3f}, TPR_unpriv : {:.3f}".format(tpr, tpr_priv, tpr_unpriv)
            print(diag)

            diag = "FPR : {:.3f}, FPR_priv : {:.3f}, FPR_unpriv : {:.3f}".format(fpr, fpr_priv, fpr_unpriv)
            print(diag)

            diag = "DP : {:.3f}, EOP : {:.3f}, EOd : {:.3f}".format(DP, EOP, EOD)
            print(diag)

            
            

[TRAIN] epoch : [0/50]
Loss : 130.708
recon_loss : 0.314, cls_loss : 1.289, kld_loss : 1.785
CI_loss : 0.000, COV_loss : 0.000

[TRAIN] epoch : [1/50]
Loss : 118.870
recon_loss : 0.290, cls_loss : 1.159, kld_loss : 2.974
CI_loss : nan, COV_loss : 0.000

[TRAIN] epoch : [2/50]
Loss : 114.381
recon_loss : 0.291, cls_loss : 1.112, kld_loss : 3.174
CI_loss : 0.001, COV_loss : 0.000

[TRAIN] epoch : [3/50]
Loss : 104.916
recon_loss : 0.316, cls_loss : 1.009, kld_loss : 4.062
CI_loss : 0.005, COV_loss : 0.000

[TRAIN] epoch : [4/50]
Loss : 92.870
recon_loss : 0.364, cls_loss : 0.878, kld_loss : 5.091
CI_loss : 0.017, COV_loss : 0.000

[TRAIN] epoch : [5/50]
Loss : 87.241
recon_loss : 0.364, cls_loss : 0.822, kld_loss : 5.078
CI_loss : 0.024, COV_loss : 0.000

[TRAIN] epoch : [6/50]
Loss : 82.836
recon_loss : 0.378, cls_loss : 0.777, kld_loss : 5.104
CI_loss : 0.025, COV_loss : 0.000

[TRAIN] epoch : [7/50]
Loss : 80.585
recon_loss : 0.388, cls_loss : 0.755, kld_loss : 5.102
CI_loss : 0.026, 

### CI Constraint

In [22]:
input_dim = dataset_orig.features.shape[1] - 1
z_dim = [2,2,2]
latent_dim = sum(z_dim)

epochs = 100
criterion = nn.BCEWithLogitsLoss()
lambda_fair = 2e1

# encoder = [Encoder_tab(input_dim, sub_dim).cuda() for sub_dim in  z_dim]
encoder = Encoder_tab(input_dim, latent_dim).cuda()
decoder = Decoder_tab(input_dim, latent_dim).cuda()

cls_y = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()
cls_a = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()

param_lst = list()
param_lst += list(encoder.parameters()) + list(decoder.parameters())
param_lst += list(cls_y.parameters()) + list(cls_a.parameters())                                       

optimizer = torch.optim.Adam(param_lst, lr = 5e-4, weight_decay = 1e-5)
                                               
dis = Classifier(input_dim = 2, hidden_dim = 16, output_dim =1).cuda()
optimizer_dis = torch.optim.Adam(dis.parameters(), lr = 1e-4, weight_decay = 1e-5)
    
for epoch in (range(epochs + 1)):
    encoder.train()
    decoder.train()
    cls_y.train()
    cls_a.train()
#     [mlp.train() for mlp in encoder]
    for x_batch, s_batch, y_batch in trainloader:
        x_batch = x_batch.cuda().float()
        s_batch, y_batch = s_batch.cuda().view(-1,1).float(), y_batch.cuda().view(-1,1).float()
        
        z, mu, logvar = encoder(x_batch)
        z_y, z_r, z_a = z.split(2, dim = -1)
        recon = decoder(z)
        pred_y = cls_y(torch.cat([z_y, z_r], dim =-1))
        pred_a = cls_a(torch.cat([z_a, z_r], dim =-1))

        recon_loss = F.l1_loss(recon, x_batch)
        cls_loss = criterion(pred_y, y_batch) + criterion(pred_a, s_batch)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1))

        loss_elbo = 1e2 * cls_loss + kld_loss
        
        ### Agg prob
        z_y_repeat = z_y.unsqueeze(1).expand(-1, z_a.shape[0], -1)
        z_a_repeat = z_a.unsqueeze(1).expand(-1, z_a.shape[0], -1)
        z_r_repeat = z_r.unsqueeze(0).expand(z_a.shape[0], -1, -1)

        z_yr = torch.cat([z_y_repeat, z_r_repeat], dim = -1).view(z_a.shape[0] **2, -1)
        z_ar = torch.cat([z_a_repeat, z_r_repeat], dim = -1).view(z_a.shape[0] **2, -1)
        p_y_agg = cls_y(z_yr).view(z_a.shape[0], z_a.shape[0], -1).mean(0)
        p_a_agg = cls_a(z_ar).view(z_a.shape[0], z_a.shape[0], -1).mean(0)
        
        loss_ci = criterion(dis(torch.cat([p_y_agg, p_a_agg], dim = 1)), torch.zeros_like(y_batch))
        
        loss = loss_elbo + lambda_fair * (loss_ci)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        z, mu, logvar = encoder(x_batch)
        z_y, z_r, z_a = z.split(2, dim = -1)

        ### Agg prob
        z_y_repeat = z_y.unsqueeze(1).expand(-1, z_a.shape[0], -1)
        z_a_repeat = z_a.unsqueeze(1).expand(-1, z_a.shape[0], -1)
        z_r_repeat = z_r.unsqueeze(0).expand(z_a.shape[0], -1, -1)

        z_yr = torch.cat([z_y_repeat, z_r_repeat], dim = -1).view(z_a.shape[0] **2, -1)
        z_ar = torch.cat([z_a_repeat, z_r_repeat], dim = -1).view(z_a.shape[0] **2, -1)
        p_y_agg = cls_y(z_yr).view(z_a.shape[0], z_a.shape[0], -1).mean(0)
        p_a_agg = cls_a(z_ar).view(z_a.shape[0], z_a.shape[0], -1).mean(0)
        
        dis_loss = 0.5* (criterion(dis(torch.cat([p_y_agg, p_a_agg], dim = 1)), torch.ones_like(y_batch))\
                    + criterion(dis(torch.cat([p_y_agg, permute_dims(p_a_agg)], dim = 1)), torch.zeros_like(y_batch)))

        optimizer_dis.zero_grad()
        dis_loss.backward()
        optimizer_dis.step()
        
        
    diag = "[TRAIN] epoch : [{}/{}]\n".format(epoch, epochs)
    diag += "Loss : {:.3f}\n".format(loss.item())
    diag += "recon_loss : {:.3f}, cls_loss : {:.3f}, kld_loss : {:.3f}\n".format(recon_loss.item(), cls_loss.item(), kld_loss.item())
    diag += "CI_loss : {:.3f}, \n".format(loss_ci.item())
    print(diag)


[TRAIN] epoch : [0/100]
Loss : 157.003
recon_loss : 0.763, cls_loss : 1.291, kld_loss : 2.402
CI_loss : 1.276, 

[TRAIN] epoch : [1/100]
Loss : 147.563
recon_loss : 0.767, cls_loss : 1.206, kld_loss : 2.950
CI_loss : 1.201, 

[TRAIN] epoch : [2/100]
Loss : 136.528
recon_loss : 0.699, cls_loss : 1.120, kld_loss : 3.676
CI_loss : 1.041, 

[TRAIN] epoch : [3/100]
Loss : 118.428
recon_loss : 0.672, cls_loss : 0.939, kld_loss : 5.676
CI_loss : 0.941, 

[TRAIN] epoch : [4/100]
Loss : 111.171
recon_loss : 0.687, cls_loss : 0.891, kld_loss : 6.849
CI_loss : 0.763, 

[TRAIN] epoch : [5/100]
Loss : 94.987
recon_loss : 0.659, cls_loss : 0.744, kld_loss : 6.915
CI_loss : 0.684, 

[TRAIN] epoch : [6/100]
Loss : 112.136
recon_loss : 0.597, cls_loss : 0.903, kld_loss : 6.572
CI_loss : 0.764, 

[TRAIN] epoch : [7/100]
Loss : 97.201
recon_loss : 0.587, cls_loss : 0.764, kld_loss : 6.462
CI_loss : 0.717, 

[TRAIN] epoch : [8/100]
Loss : 93.904
recon_loss : 0.613, cls_loss : 0.736, kld_loss : 6.788
CI_lo

[TRAIN] epoch : [73/100]
Loss : 83.402
recon_loss : 0.604, cls_loss : 0.643, kld_loss : 5.309
CI_loss : 0.688, 

[TRAIN] epoch : [74/100]
Loss : 79.167
recon_loss : 0.640, cls_loss : 0.601, kld_loss : 5.361
CI_loss : 0.687, 

[TRAIN] epoch : [75/100]
Loss : 82.444
recon_loss : 0.615, cls_loss : 0.632, kld_loss : 5.311
CI_loss : 0.696, 

[TRAIN] epoch : [76/100]
Loss : 88.425
recon_loss : 0.622, cls_loss : 0.691, kld_loss : 5.329
CI_loss : 0.698, 

[TRAIN] epoch : [77/100]
Loss : 90.786
recon_loss : 0.608, cls_loss : 0.714, kld_loss : 5.394
CI_loss : 0.702, 

[TRAIN] epoch : [78/100]
Loss : 79.849
recon_loss : 0.604, cls_loss : 0.608, kld_loss : 5.360
CI_loss : 0.685, 

[TRAIN] epoch : [79/100]
Loss : 85.796
recon_loss : 0.617, cls_loss : 0.668, kld_loss : 5.165
CI_loss : 0.693, 

[TRAIN] epoch : [80/100]
Loss : 81.612
recon_loss : 0.632, cls_loss : 0.626, kld_loss : 5.313
CI_loss : 0.684, 

[TRAIN] epoch : [81/100]
Loss : 76.477
recon_loss : 0.607, cls_loss : 0.574, kld_loss : 5.231
CI

In [23]:
encoder.eval()
# [mlp.eval() for mlp in encoder]

args.lr = 1e-3
epochs = 20

for z_input in ['yr', 'y', 'a', 'r', 'yra']:
    print("INPUT : ", z_input)
    z_idx = []
    
    if 'y' in z_input:
        z_idx.append(0)
    if 'a' in z_input:
        z_idx.append(2)
    if 'r' in z_input:
        z_idx.append(1)
        
    print("index : ", z_idx)
    
    cls = Classifier(len(z_input) * 2).cuda()
    optimizer = torch.optim.Adam(cls.parameters(), lr = args.lr, weight_decay = 1e-5)

    for epoch in (range(epochs + 1)):
        cls.train()
        for x_batch, s_batch, y_batch in trainloader:
            x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float(), y_batch.cuda().float().view(-1,1)

            z, _, _ = encoder(x_batch)
            z = z.split(2, dim = -1)
            pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))
            loss = criterion(pred, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if epoch % 5 == 0:
        cls.eval()
        with torch.no_grad():
            pred_lst, y_lst, a_lst = [], [], []
            for x_batch, s_batch, y_batch in validloader:
                x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float().view(-1,1), y_batch.cuda().float().view(-1,1)

                z, _, _ = encoder(x_batch)
                z = z.split(2, dim = -1)
                z_y, z_r, z_a = z[0], z[1], z[2]
                pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))

                pred[pred>=0] = 1
                pred[pred<0] = 0

                pred_lst.append(pred.detach())
                y_lst.append(y_batch.detach())
                a_lst.append(s_batch.detach())

            pred_lst = torch.cat(pred_lst).cpu().numpy()
            y_lst = torch.cat(y_lst).cpu().numpy()
            a_lst = torch.cat(a_lst).cpu().numpy()

            acc = (pred_lst == y_lst).mean()
            acc_priv = (pred_lst[a_lst == 1] == y_lst[a_lst == 1]).mean()
            acc_unpriv = (pred_lst[a_lst == 0] == y_lst[a_lst == 0]).mean()
            priv_idx = a_lst == 1
            pos_idx = y_lst == 1

            tpr = (pred_lst[pos_idx] == 1).mean()
            tpr_priv = (pred_lst[pos_idx*priv_idx] == 1).mean()
            tpr_unpriv = (pred_lst[pos_idx*~priv_idx] == 1).mean()

            fpr = (pred_lst[~pos_idx] == 1).mean()
            fpr_priv = (pred_lst[~pos_idx*priv_idx] == 1).mean()
            fpr_unpriv = (pred_lst[~pos_idx*~priv_idx] == 1).mean()

            DP = abs((pred_lst[priv_idx]==1).mean() - (pred_lst[~priv_idx]==1).mean())
            EOP =abs(tpr_priv-tpr_unpriv)
            EOD =abs(tpr_priv-tpr_unpriv) + abs(fpr_priv-fpr_unpriv)

            diag = "ACC : {:.3f}, ACC_priv : {:.3f}, ACC_unpriv : {:.3f}".format(acc, acc_priv, acc_unpriv)
            print(diag)

            diag = "TPR : {:.3f}, TPR_priv : {:.3f}, TPR_unpriv : {:.3f}".format(tpr, tpr_priv, tpr_unpriv)
            print(diag)

            diag = "FPR : {:.3f}, FPR_priv : {:.3f}, FPR_unpriv : {:.3f}".format(fpr, fpr_priv, fpr_unpriv)
            print(diag)

            diag = "DP : {:.3f}, EOP : {:.3f}, EOd : {:.3f}".format(DP, EOP, EOD)
            print(diag)

            

INPUT :  yr
index :  [0, 1]
ACC : 0.849, ACC_priv : 0.817, ACC_unpriv : 0.915
TPR : 0.619, TPR_priv : 0.639, TPR_unpriv : 0.506
FPR : 0.075, FPR_priv : 0.102, FPR_unpriv : 0.031
DP : 0.184, EOP : 0.133, EOd : 0.203
INPUT :  y
index :  [0]
ACC : 0.798, ACC_priv : 0.757, ACC_unpriv : 0.884
TPR : 0.303, TPR_priv : 0.311, TPR_unpriv : 0.259
FPR : 0.038, FPR_priv : 0.040, FPR_unpriv : 0.035
DP : 0.065, EOP : 0.053, EOd : 0.058
INPUT :  a
index :  [2]
ACC : 0.812, ACC_priv : 0.772, ACC_unpriv : 0.896
TPR : 0.551, TPR_priv : 0.610, TPR_unpriv : 0.224
FPR : 0.102, FPR_priv : 0.154, FPR_unpriv : 0.017
DP : 0.256, EOP : 0.386, EOd : 0.524
INPUT :  r
index :  [1]
ACC : 0.848, ACC_priv : 0.813, ACC_unpriv : 0.918
TPR : 0.644, TPR_priv : 0.663, TPR_unpriv : 0.537
FPR : 0.085, FPR_priv : 0.118, FPR_unpriv : 0.032
DP : 0.198, EOP : 0.126, EOd : 0.212
INPUT :  yra
index :  [0, 2, 1]
ACC : 0.848, ACC_priv : 0.816, ACC_unpriv : 0.915
TPR : 0.636, TPR_priv : 0.659, TPR_unpriv : 0.510
FPR : 0.082, FPR_pri

### COV approximation constraint

In [34]:
# New version

input_dim = dataset_orig.features.shape[1] - 1
latent_dim = 6
z_dim = [2,2,2]

# Take cls_ent as both cls_y and cls_a
# New version
def cov_loss(pred_y, pred_a):
    p_y = torch.sigmoid(pred_y)
    p_a = torch.sigmoid(pred_a)
#     return((p_y * p_a).mean() - p_y.mean() * p_a.mean()) ** 2
    return abs((p_y * p_a).mean() - p_y.mean() * p_a.mean())


def CI_loss_v2(cls, z_y, z_r, s_batch):
    
    z_y_repeat = z_y.unsqueeze(1).expand(-1, z_y.shape[0], -1) # N x N (replica) x D
    z_r_repeat = z_r.unsqueeze(0).expand(z_y.shape[0], -1, -1) # N (replica) x N  x D

    z_yr = torch.cat([z_y_repeat, z_r_repeat], dim = -1).view(z_y.shape[0] **2, -1) # N^2 x 2D
    p_y = torch.sigmoid(cls(z_yr).view(z_y.shape[0], z_y.shape[0], -1)) # N x N x 1
    p_y_agg = p_y.mean(0)  # mean over different z_y
    
    H_y_cond_z = -(p_y_agg * torch.log(p_y_agg + 1e-7) + (1-p_y_agg) * torch.log(1-p_y_agg + 1e-7)).mean()
    
    s_idx = s_batch.view(-1) == 1
    p_a = s_idx.float().mean()

    p_ya1 = p_y[s_idx].mean(0) * torch.log(p_y[s_idx].mean(0) + 1e-7) + \
                (1-p_y[s_idx].mean(0)) * torch.log(1-p_y[s_idx].mean(0) + 1e-7)
    
    p_ya0 = p_y[~s_idx].mean(0) * torch.log(p_y[~s_idx].mean(0) + 1e-7) + \
                (1-p_y[~s_idx].mean(0)) * torch.log(1-p_y[~s_idx].mean(0) + 1e-7)
    
    H_y_cond_za = -(p_a * p_ya1 + (1 - p_a) * p_ya0).mean()
    
#     return H_y_cond_z - H_y_cond_za
    return H_y_cond_z, H_y_cond_za


epochs = 50
criterion = nn.BCEWithLogitsLoss()
lambda_ci = 0.5

encoder = Encoder_tab(input_dim, latent_dim).cuda()
decoder = Decoder_tab(input_dim, latent_dim).cuda()
cls_y = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()
cls_a = Classifier(input_dim = 4, hidden_dim = 16, output_dim =1).cuda()

param_lst = list()
param_lst += list(encoder.parameters()) + list(decoder.parameters())
param_lst += list(cls_y.parameters()) + list(cls_a.parameters())                                       

optimizer = torch.optim.Adam(param_lst, lr = 5e-4, weight_decay = 1e-4)

H_z_1, H_za = list(), list()
H_z_2, H_zs = list(), list()
                                               
for epoch in (range(epochs + 1)):
    encoder.train()
    decoder.train()
    cls_y.train()
    cls_a.train()

    loss_hist = 0
    recon_loss_hist = 0
    cls_loss_hist = 0
    kld_loss_hist = 0
    loss_ci_hist = 0
    loss_cov_hist = 0
    cnt = 0
    
    for x_batch, s_batch, y_batch in trainloader:
        x_batch = x_batch.cuda().float()
        s_batch, y_batch = s_batch.cuda().view(-1,1).float(), y_batch.cuda().view(-1,1).float()
        
        z, mu, logvar = encoder(x_batch)
        
        z = z.split(2, dim = -1) # [z_x, z_y, z_r, z_a]
        z_y, z_r, z_a = z[0], z[1], z[2]
        recon = decoder(torch.cat(z, -1))
        
        pred_y = cls_y(torch.cat([z_y, z_r], dim =-1))
        pred_a = cls_a(torch.cat([z_a, z_r], dim =-1))

        recon_loss = F.l1_loss(recon, x_batch)
        cls_loss = criterion(pred_y, y_batch) + criterion(pred_a, s_batch)
        kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1))

#         loss_elbo = recon_loss + 1e2 * cls_loss + kld_loss
        loss_elbo = 1e2 * cls_loss + kld_loss
        
        z_y_repeat = z_y.unsqueeze(1).expand(-1, z_y.shape[0], -1) # N x N (replica) x D
        z_a_repeat = z_a.unsqueeze(1).expand(-1, z_y.shape[0], -1) # N x N (replica) x D
        z_r_repeat = z_r.unsqueeze(0).expand(z_y.shape[0], -1, -1) # N (replica) x N  x D

        # N^2 x 2D
        z_yr = torch.cat([z_y_repeat, z_r_repeat], dim = -1).view(z_y.shape[0] **2, -1)
        z_ar = torch.cat([z_a_repeat, z_r_repeat], dim = -1).view(z_y.shape[0] **2, -1)

        p_y = (cls_y(z_yr).view(z_y.shape[0], z_y.shape[0], -1)) # N x N x 1
        p_y_agg = p_y.mean(0)  # mean over different z_y

        p_a = (cls_a(z_ar).view(z_y.shape[0], z_y.shape[0], -1)) # N x N x 1
        p_a_agg = p_a.mean(0)  # mean over different z_y
        loss = loss_elbo 

        optimizer.zero_grad()
        loss.backward(retain_graph = True)
        
        for param in cls_a.parameters():
            param.requires_grad=False
        for param in cls_y.parameters():
            param.requires_grad=False
            
        loss_fair = lambda_ci * cov_loss(pred_y, pred_a)
        loss_fair.backward()
        
        for param in cls_a.parameters():
            param.requires_grad=True
        for param in cls_y.parameters():
            param.requires_grad=True
            
        optimizer.step()
        
        loss_hist += loss.item()
        recon_loss_hist += recon_loss.item()
        cls_loss_hist += cls_loss.item()
        kld_loss_hist += kld_loss.item()
        loss_ci_hist += loss_fair.item()
#         loss_cov_hist += loss_cov.item()
        cnt += 1

    diag = "[TRAIN] epoch : [{}/{}]\n".format(epoch, epochs)
    diag += "Loss : {:.3f}\n".format(loss_hist/cnt)
    diag += "recon_loss : {:.3f}, cls_loss : {:.3f}, kld_loss : {:.3f}\n".format(recon_loss_hist/cnt, cls_loss_hist/cnt, kld_loss_hist/cnt)
    diag += "CI_loss : {:.3f}".format(loss_ci_hist/cnt)
    print(diag)

    
    
encoder.eval()
args.lr = 1e-4
epochs = 10

for z_input in ['yr', 'y', 'a', 'r', 'yra']:
    print("INPUT : ", z_input)
    z_idx = []
    
    if 'y' in z_input:
        z_idx.append(0)
    if 'a' in z_input:
        z_idx.append(2)
    if 'r' in z_input:
        z_idx.append(1)
        
    print("index : ", z_idx)
    
    cls = Classifier(len(z_input) * 2).cuda()
    optimizer = torch.optim.Adam(cls.parameters(), lr = args.lr, weight_decay = 1e-5)

    for epoch in (range(epochs + 1)):
        cls.train()
        for x_batch, s_batch, y_batch in trainloader:
            x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float(), y_batch.cuda().float().view(-1,1)

            z, _, _ = encoder(x_batch)
            z = z.split(2, dim = 1)
            pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))
            loss = criterion(pred, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if epoch % 5 == 0:
        cls.eval()
        with torch.no_grad():
            pred_lst, y_lst, a_lst = [], [], []
            for x_batch, s_batch, y_batch in validloader:
                x_batch, s_batch, y_batch = x_batch.cuda().float(), s_batch.cuda().float().view(-1,1), y_batch.cuda().float().view(-1,1)

                z, _, _ = encoder(x_batch)
                z = z.split(2, dim = 1)
                z_y, z_r, z_a = z[0], z[1], z[2]
                pred = cls(torch.cat([z[idx] for idx in z_idx], dim = -1))

                pred[pred>=0] = 1
                pred[pred<0] = 0

                pred_lst.append(pred.detach())
                y_lst.append(y_batch.detach())
                a_lst.append(s_batch.detach())

            pred_lst = torch.cat(pred_lst).cpu().numpy()
            y_lst = torch.cat(y_lst).cpu().numpy()
            a_lst = torch.cat(a_lst).cpu().numpy()

            acc = (pred_lst == y_lst).mean()
            acc_priv = (pred_lst[a_lst == 1] == y_lst[a_lst == 1]).mean()
            acc_unpriv = (pred_lst[a_lst == 0] == y_lst[a_lst == 0]).mean()
            priv_idx = a_lst == 1
            pos_idx = y_lst == 1

            tpr = (pred_lst[pos_idx] == 1).mean()
            tpr_priv = (pred_lst[pos_idx*priv_idx] == 1).mean()
            tpr_unpriv = (pred_lst[pos_idx*~priv_idx] == 1).mean()

            fpr = (pred_lst[~pos_idx] == 1).mean()
            fpr_priv = (pred_lst[~pos_idx*priv_idx] == 1).mean()
            fpr_unpriv = (pred_lst[~pos_idx*~priv_idx] == 1).mean()

            DP = abs((pred_lst[priv_idx]==1).mean() - (pred_lst[~priv_idx]==1).mean())
            EOP =abs(tpr_priv-tpr_unpriv)
            EOD =abs(tpr_priv-tpr_unpriv) + abs(fpr_priv-fpr_unpriv)

            diag = "ACC : {:.3f}, ACC_priv : {:.3f}, ACC_unpriv : {:.3f}".format(acc, acc_priv, acc_unpriv)
            print(diag)

            diag = "TPR : {:.3f}, TPR_priv : {:.3f}, TPR_unpriv : {:.3f}".format(tpr, tpr_priv, tpr_unpriv)
            print(diag)

            diag = "FPR : {:.3f}, FPR_priv : {:.3f}, FPR_unpriv : {:.3f}".format(fpr, fpr_priv, fpr_unpriv)
            print(diag)

            diag = "DP : {:.3f}, EOP : {:.3f}, EOd : {:.3f}".format(DP, EOP, EOD)
            print(diag)

            
            

[TRAIN] epoch : [0/50]
Loss : 126.927
recon_loss : 0.556, cls_loss : 1.263, kld_loss : 0.608
CI_loss : 0.001
[TRAIN] epoch : [1/50]
Loss : 117.475
recon_loss : 0.572, cls_loss : 1.154, kld_loss : 2.079
CI_loss : 0.002
[TRAIN] epoch : [2/50]
Loss : 107.933
recon_loss : 0.600, cls_loss : 1.048, kld_loss : 3.152
CI_loss : 0.007
[TRAIN] epoch : [3/50]
Loss : 97.850
recon_loss : 0.617, cls_loss : 0.938, kld_loss : 4.054
CI_loss : 0.015
[TRAIN] epoch : [4/50]
Loss : 93.313
recon_loss : 0.601, cls_loss : 0.893, kld_loss : 3.995
CI_loss : 0.020
[TRAIN] epoch : [5/50]
Loss : 90.768
recon_loss : 0.577, cls_loss : 0.870, kld_loss : 3.796
CI_loss : 0.022
[TRAIN] epoch : [6/50]
Loss : 88.510
recon_loss : 0.556, cls_loss : 0.846, kld_loss : 3.879
CI_loss : 0.022
[TRAIN] epoch : [7/50]
Loss : 86.661
recon_loss : 0.544, cls_loss : 0.828, kld_loss : 3.891
CI_loss : 0.022
[TRAIN] epoch : [8/50]
Loss : 84.562
recon_loss : 0.537, cls_loss : 0.805, kld_loss : 4.099
CI_loss : 0.022
[TRAIN] epoch : [9/50]
Lo