In [1]:
import torch
import random
from seeker.random import RandomSelectPairSeeker, RandomSelectSeeker, RangeGenSeeker, DistributionGenSeeker
from seeker.gradiant_based import WhiteboxSeeker, BlackboxSeeker, FoolSeeker
from utils import UnfairMetric, load_model
from data import adult, german, loans_default
from train_dnn import get_data
from models.model import MLP
from distances.normalized_mahalanobis_distances import ProtectedSEDistances
from distances.sensitive_subspace_distances import LogisticRegSensitiveSubspace
from distances.binary_distances import BinaryDistance
from IPython.display import display

%load_ext autoreload
%autoreload 2

In [2]:
# data
data_name = 'adult'
use_sensitive_attr = True
sensitive_vars = ['sex_Male']
# model
rand_seed = 0
trainer_name = 'std'
rho=0.1
_note = ''
note=_note if trainer_name == 'std' else _note + f'rho={rho}'
# others
device = 'cpu'

In [3]:
data_choices = {
    'adult': adult,
    'german': german,
    'loans_default': loans_default
}
data = data_choices[data_name]
data_gen = data.Generator(use_sensitive_attr, sensitive_vars, device)

dataset, train_dl, test_dl = get_data(data, rand_seed, sensitive_vars=sensitive_vars)
dataset.use_sensitive_attr = use_sensitive_attr
in_dim = dataset.dim_feature()
out_dim = 2

all_X, all_y = dataset.get_all_data(), dataset.labels

In [4]:
model = MLP(in_dim, out_dim, data_gen=data_gen, n_layers=4, norm=False)
load_model(model, data_name, trainer_name, use_sensitive_attr=use_sensitive_attr, \
           sensitive_vars=sensitive_vars, id=rand_seed, note=note)

In [5]:
distance_x_Causal = ProtectedSEDistances()
distance_x_LR = LogisticRegSensitiveSubspace()
distance_y = BinaryDistance()

if use_sensitive_attr:
    distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=data_gen, sensitive_idx=dataset.sensitive_idxs)
    chosen_dx = distance_x_Causal
else:
    sensitive_ = dataset.data[:, dataset.sensitive_idxs]
    distance_x_LR.fit(dataset.get_all_data(), data_gen=data_gen, data_SensitiveAttrs=sensitive_)
    chosen_dx = distance_x_LR

In [6]:
x = torch.zeros(chosen_dx.num_dims)
pert = 10*torch.diag(torch.ones_like(x))
g = torch.zeros_like(x)
for i in range(g.shape[0]):
    g[i] = chosen_dx(x, x+pert[i])
epsilon = (1/torch.min(g[g!=0])).item()
epsilon

99997992.0

In [7]:
epsilon = 1e8
unfair_metric = UnfairMetric(dx=chosen_dx, dy=distance_y, epsilon=epsilon)

In [8]:
def show_result(result):
    pair, n_query = result[0], result[1]
    if len(result) == 3:
        print(f'n_iters = {result[2]}')
    if pair != None:
        display(data_gen.feature_dataframe(data = pair), n_query)
    else:
        display('not found')

In [9]:
# random.seed(422)
# torch.manual_seed(422)

# select_seeker = RandomSelectSeeker(model=model, unfair_metric=unfair_metric, data=dataset.get_all_data(), data_gen=data_gen)
# for _ in range(3):
#     show_result(select_seeker.seek(max_query=1e5))

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,18,0,0,10,15,1,0,4,11,2,2
1,18,1,0,10,15,1,0,4,11,2,2


6314

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,18,0,0,9,10,1,0,4,12,3,2
1,18,1,0,9,10,1,0,4,12,3,2


4928

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,22,0,0,13,16,1,0,4,11,1,2
1,22,1,0,13,16,1,0,4,11,1,2


4030

In [10]:
# random.seed(422)
# torch.manual_seed(422)

# distribution_seeker = DistributionGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
# for _ in range(3):
#     show_result(distribution_seeker.seek(max_query=1e5))

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,20,0,0,12,13,1,1,2,3,1,5
1,20,1,0,12,13,1,1,2,3,1,5


2

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,17,0,0,13,10,1,1,2,10,3,2
1,17,1,0,13,10,1,1,2,10,3,2


512

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,19,0,0,12,15,1,1,2,11,0,2
1,19,1,0,12,15,1,1,2,11,0,2


2372

In [11]:
random.seed(422)
torch.manual_seed(422)

range_seeker = RangeGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
for _ in range(3):
    show_result(range_seeker.seek(max_query=1e5))

In [None]:
random.seed(42)
torch.manual_seed(42)

test_seeker = WhiteboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.01, max_query=1e5, lamb=1))

'try: 0'

restart 57
restart 194
restart 289
restart 372
restart 476
restart 601
restart 800
restart 1000
restart 1114
restart 1259
restart 1338
restart 1562
restart 1651
restart 1796
restart 1964
restart 2050
restart 2178
restart 2236
restart 2405
restart 2582
restart 2636
restart 2788
restart 2916
restart 3063
restart 3155
restart 3250
restart 3317
restart 3495
restart 3611
restart 3740
restart 3934
restart 4032
restart 4116
restart 4168
restart 4373
restart 4492
restart 4618
restart 4829
restart 4962
restart 5124
restart 5235
restart 5293
restart 5469
restart 5599
restart 5738
restart 5921
restart 5977
restart 6128
restart 6189
restart 6241
restart 6326
restart 6507
restart 6669
restart 6824
restart 6940
restart 6984
restart 7082
restart 7221
restart 7372
restart 7592
restart 7677
restart 7929
restart 7980
restart 8278
restart 8464
restart 8739
restart 8829
restart 8890
restart 9001
restart 9172
restart 9385
restart 9516
restart 9727
restart 9785
restart 9978
restart 10124
restart 10214
resta

KeyboardInterrupt: 

In [None]:
random.seed(422)
torch.manual_seed(422)

test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen, easy=False)
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.1, max_query=1e5, lamb=1))

'try: 0'

Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,39,343103,615390,490990,-140161,58479,127880,5,959925,0,...,8,2,4,458887,0,621449,235788,371211,459875,1
1,39,343103,615390,490990,-140161,58479,127880,5,959925,0,...,8,2,4,458887,0,621449,235788,371211,459875,0


771

'try: 1'

restart 500
restart 954
restart 1829
restart 2847
restart 3537
restart 4318
restart 5145
restart 6070
restart 6994
restart 7962
restart 8791
restart 9853


Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,73,148434,853021,1232103,101027,560519,-9642,3,73263,3,...,6,7,0,0,866764,530701,334426,385705,11913,1
1,73,148434,853021,1232103,101027,560519,-9642,3,73263,3,...,6,7,0,0,866764,530701,334426,385705,11913,0


10625

'try: 2'

Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,35,917845,-17755,717130,74667,593015,-156212,5,579146,1,...,1,-1,3,299253,22619,786796,278335,73552,350130,1
1,35,917845,-17755,717130,74667,593015,-156212,5,579146,1,...,1,-1,3,299253,22619,786796,278335,73552,350130,0


586

In [None]:
# random.seed(422)
# torch.manual_seed(422)

# test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen, easy=True)
# for i in range(3):
#     display(f'try: {i}')
#     show_result(test_seeker.seek(origin_lr=0.1, max_query=1e5, lamb=1))

'try: 0'

Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,39,342241,616685,494856,-137462,54623,133511,5,954246,0,...,8,2,4,458949,0,621126,235486,370927,459574,1
1,39,342241,616685,494856,-137462,54623,133511,5,954246,0,...,8,2,4,458949,0,621126,235486,370927,459574,0


122

'try: 1'

restart 134
restart 263
restart 404
restart 551
restart 683
restart 821


Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,22,685407,351786,1377380,828498,513092,556237,4,147035,0,...,7,2,5,269479,916466,0,532021,76678,103760,1
1,22,685407,351786,1377380,828498,513092,556237,4,147035,0,...,7,2,5,269479,916466,0,532021,76678,103760,0


950

'try: 2'

restart 139
restart 280
restart 420
restart 557
restart 697


Unnamed: 0,AGE,BILL_AMT1,BILL_AMT2,BILL_AMT3,BILL_AMT4,BILL_AMT5,BILL_AMT6,EDUCATION,LIMIT_BAL,MARRIAGE,...,PAY_4,PAY_5,PAY_6,PAY_AMT1,PAY_AMT2,PAY_AMT3,PAY_AMT4,PAY_AMT5,PAY_AMT6,SEX
0,73,124237,840553,1221937,86514,543569,-8541,3,65594,3,...,6,8,0,6366,893519,530794,343000,388366,4826,1
1,73,124237,840553,1221937,86514,543569,-8541,3,65594,3,...,6,8,0,6366,893519,530794,343000,388366,4826,0


822