In [1]:
from train_dnn import get_data
from data import adult, compas, bank, german
from models.trainer import STDTrainer
from models.model import MLP
from seeker.random import RandomSelectPairSeeker, RandomSelectSeeker, RangeGenSeeker, DistributionGenSeeker
from seeker.gradiant_based import WhiteboxSeeker, BlackboxSeeker, FoolSeeker
from distances.normalized_mahalanobis_distances import SquaredEuclideanDistance, ProtectedSEDistances
from distances.sensitive_subspace_distances import LogisticRegSensitiveSubspace
from distances.binary_distances import BinaryDistance
from utils import UnfairMetric
import torch
import random

%load_ext autoreload
%autoreload 2

In [2]:
data = bank
use_protected_attr = True

dataset, train_dl, test_dl = get_data(data, 0, ['age'])
dataset.use_protected_attr = use_protected_attr
feature_dim = dataset.dim_feature()
output_dim = 2

# data_gen = data.Generator(dataset.protected_idxs, args.use_all_features)

model = MLP(input_size=feature_dim, output_size=output_dim)

trainer = STDTrainer(model, train_dl, test_dl, device='cuda:1', epochs=50, lr=1e-3)
trainer.train()

100%|██████████| 50/50 [00:54<00:00,  1.09s/it]


Train Accuracy: 0.8905385732650757
Test Accuracy: 0.8893066048622131


In [3]:
all_X = dataset.data
all_X_conter = all_X.clone()
all_X_conter[:, dataset.protected_idxs[0]] = 15 + all_X_conter[:, dataset.protected_idxs[0]]

model.to('cpu')
all_pred = model.get_prediction(all_X)
all_pred_conter = model.get_prediction(all_X_conter)

d_len = len(all_pred)
n_unfair = (all_pred != all_pred_conter).sum().item()
print(f'unfair ratio: {n_unfair/d_len} ({n_unfair}/{d_len})')

unfair ratio: 0.008205967574262901 (371/45211)


In [4]:
data_gen = data.Generator(sensitive_columns=dataset.protected_idxs, include_protected_feature=use_protected_attr)

In [5]:
distance_x_Causal = ProtectedSEDistances()
# distance_x_LR = LogisticRegSensitiveSubspace()
distance_y = BinaryDistance()

# distance_x_NSE.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen)
if use_protected_attr:
    distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=data_gen, protected_idx=dataset.protected_idxs)
    # distance_x_LR.fit(all_X, data_gen=data_gen, protected_idxs=dataset.protected_idxs)
else:
    sensitive_ = dataset.data[:, dataset.protected_idxs]
    distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=data_gen, protected_idx=[])
    # distance_x_LR.fit(all_X, data_gen=data_gen, data_SensitiveAttrs=sensitive_)

chosen_dx = distance_x_Causal

In [6]:
x = torch.zeros(chosen_dx.num_dims)
pert = 10*torch.diag(torch.ones_like(x))
g = torch.zeros_like(x)
for i in range(g.shape[0]):
    g[i] = chosen_dx(x, x+pert[i])
epsilon = (1/torch.min(g[g!=0])).item()
print(g, epsilon)

tensor([0.0000e+00, 8.2420e-09, 2.6015e-02, 1.0000e+02, 1.0000e+02, 1.0000e+02,
        1.1111e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02,
        1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02,
        1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02,
        1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 8.2645e-01,
        1.3151e-04, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.3223e-03]) 121330144.0


In [7]:
epsilon = 1e8
unfair_metric = UnfairMetric(dx=chosen_dx, dy=distance_y, epsilon=epsilon)

In [8]:
def show_result(result):
    pair, n_query = result[0], result[1]
    if len(result) == 3:
        print(f'n_iters = {result[2]}')
    if pair != None:
        display(data_gen.feature_dataframe(data = pair), n_query)
    else:
        display('not found')

In [9]:
random.seed(422)
torch.manual_seed(422)

select_seeker = RandomSelectSeeker(model=model, unfair_metric=unfair_metric, data=all_X, data_gen=data_gen)
for _ in range(3):
    show_result(select_seeker.seek(max_query=1e6))

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,46,429,2,26,0,1,0,4,369,2,0,0,0,0,2
1,45,429,2,26,0,1,0,4,369,2,0,0,0,0,2


3996

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,81,122,1,23,0,0,0,4,-1,0,1,0,5,1,3
1,82,122,1,23,0,0,0,4,-1,0,1,0,5,1,3


6238

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,30,255,1,27,0,0,0,1,-1,0,0,1,0,2,3
1,29,255,1,27,0,0,0,1,-1,0,0,1,0,2,3


10274

In [10]:
random.seed(422)
torch.manual_seed(422)

distribution_seeker = DistributionGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
for _ in range(3):
    show_result(distribution_seeker.seek(max_query=1e6))

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,34,149,1,15,0,0,0,5,-1,9,0,2,4,2,3
1,35,149,1,15,0,0,0,5,-1,9,0,2,4,2,3


41354

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,56,378,1,20,0,0,0,2,91,0,0,2,4,2,3
1,55,378,1,20,0,0,0,2,91,0,0,2,4,2,3


9908

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,46,-19,2,30,0,0,0,11,343,0,0,1,7,1,3
1,45,-19,2,30,0,0,0,11,343,0,0,1,7,1,3


15776

In [11]:
random.seed(422)
torch.manual_seed(422)

range_seeker = RangeGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
for _ in range(3):
    show_result(range_seeker.seek(max_query=1e6))

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,34,14195,32,24,1,1,1,7,656,232,0,2,9,2,2
1,33,14195,32,24,1,1,1,7,656,232,0,2,9,2,2


7872

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,27,8662,32,3,0,0,1,4,342,221,0,0,4,0,1
1,28,8662,32,3,0,0,1,4,342,221,0,0,4,0,1


472

Unnamed: 0,age,balance,campaign,day,default,housing,loan,month,pdays,previous,contact,education,job,marital,poutcome
0,36,13924,27,31,1,1,1,4,395,236,1,1,0,2,2
1,35,13924,27,31,1,1,1,4,395,236,1,1,0,2,2


17398

In [15]:
random.seed(42)
torch.manual_seed(42)

test_seeker = WhiteboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen)
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.01, max_query=1e6, lamb=1))

'try: 0'

tensor([[0.8831, 0.9150, 0.3871, 0.0000, 0.0000, 1.0000, 0.9667, 0.0000, 0.0000,
         0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         1.0000, 0.0000, 0.8182, 0.9415, 0.0000, 0.0000, 1.0000, 0.0000, 0.1309]])
------------g----------------
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
-----------------------------
tensor([[nan, nan, nan, 1., 0., 0., nan, nan, 1., 0., 0., 0., nan, 1., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., nan, 1., 0., 0., nan, nan, 1., 0., 0., 0., nan]],
       grad_fn=<DivBackward0>)
------------g----------------
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]])
---------------

KeyboardInterrupt: 

In [12]:
random.seed(422)
torch.manual_seed(422)

test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen, easy=False)
# show_result(white_seeker.seek())
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.1, max_query=1e6, lamb=1))

'try: 0'

tensor([[0.2987, 0.3713, 0.5484, 0.0000, 1.0000, 0.0000, 0.3333, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 1.0000, 0.7273, 0.9966, 0.0000, 1.0000, 0.0000, 0.0000, 0.0327]])
new g
tensor([[4.8700e+01, 3.2873e+04, 3.5000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         1.1000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 9.0000e+00,
         8.6800e+02, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 9.0000e+00]],
       grad_fn=<AddBackward0>)
tensor([[3.3300e+01, 3.2873e+04, 3.5000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         1.1000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,

tensor([[4.1000e+01, 4.3888e+04, 3.5000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         1.1000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 9.0000e+00,
         8.6800e+02, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 9.0000e+00]],
       grad_fn=<AddBackward0>)
tensor([[4.1000e+01, 2.1858e+04, 3.5000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00,
         1.1000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 9.0000e+00,
         8.6800e+02, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 9.0000e

KeyboardInterrupt: Interrupted by user

In [None]:
random.seed(422)
torch.manual_seed(422)

test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=data_gen, easy=True)
# show_result(white_seeker.seek())
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.1, max_query=1e6, lamb=1))

'try: 0'

restart 63


Unnamed: 0,juv_fel_count,juv_misd_count,juv_other_count,priors_count,race_Black,score_text,sex_Male,age_catagory,crime_charge_degree
0,0,0,6,1,1,0,1,2,0
1,0,0,6,1,0,0,1,2,0


131

'try: 1'

restart 59
restart 115
restart 172
restart 225
restart 280
restart 342
restart 412
restart 480


Unnamed: 0,juv_fel_count,juv_misd_count,juv_other_count,priors_count,race_Black,score_text,sex_Male,age_catagory,crime_charge_degree
0,3,0,5,8,0,0,1,2,1
1,3,0,5,8,1,0,1,2,1


541

'try: 2'

restart 70
restart 137
restart 208
restart 264
restart 317
restart 379
restart 433
restart 491
restart 558
restart 619
restart 676
restart 738


Unnamed: 0,juv_fel_count,juv_misd_count,juv_other_count,priors_count,race_Black,score_text,sex_Male,age_catagory,crime_charge_degree
0,2,12,0,13,0,0,0,2,0
1,2,12,0,13,1,0,0,2,0


795