In [1]:
import torch
import random
from seeker.random import RandomSelectPairSeeker, RandomSelectSeeker, RangeGenSeeker, DistributionGenSeeker
from seeker.gradiant_based import WhiteboxSeeker, BlackboxSeeker, FoolSeeker
from utils import UnfairMetric, load_model, get_L_matrix
from data import adult
from train_dnn import get_data
from models.model import MLP, RandomForest
from distances.normalized_mahalanobis_distances import SquaredEuclideanDistance, ProtectedSEDistances
from distances.sensitive_subspace_distances import LogisticRegSensitiveSubspace
from distances.binary_distances import BinaryDistance
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from IPython.display import display

%load_ext autoreload
%autoreload 2

In [2]:
rand_seed = 0
use_protected_attr = True
protected_vars = ['race_White']
# protected_vars = ['sex_Male', 'race_White']
mark = ''
# mark = 'epoch30'
# mark = 'epoch50'

dataset, train_dl, test_dl = get_data(adult, rand_seed, protected_vars=protected_vars)
dataset.use_protected_attr = use_protected_attr
in_dim = dataset.dim_feature()
out_dim = 2

# prepare data
all_X, all_y = dataset.get_all_data(), dataset.labels

adult_gen = adult.Generator(sensitive_columns=dataset.protected_idxs, include_protected_feature=use_protected_attr)

In [3]:
model_name = 'MLP'
# model_name = 'RandomForest'

if model_name == 'MLP':
    model = MLP(in_dim, out_dim)
    trainer_name = 'STDTrainer'
elif model_name == 'RandomForest':
    model = RandomForest()
    trainer_name = 'RandomForestTrainer'
load_model(model, model_name, 'adult', trainer_name, use_protected_attr=use_protected_attr, \
           protected_vars=protected_vars, id=rand_seed, remark=mark)

In [4]:
# prepare distances
# distance_x_NSE = SquaredEuclideanDistance()
distance_x_Causal = ProtectedSEDistances()
distance_x_LR = LogisticRegSensitiveSubspace()
distance_y = BinaryDistance()

# distance_x_NSE.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen)
if use_protected_attr:
    distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen, protected_idx=dataset.protected_idxs)
    distance_x_LR.fit(all_X, adult_gen, protected_idxs=dataset.protected_idxs)
else:
    sensitive_ = dataset.data[:, dataset.protected_idxs]
    distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen, protected_idx=[])
    distance_x_LR.fit(all_X, adult_gen, data_SensitiveAttrs=sensitive_)

chosen_dx = distance_x_Causal

In [5]:
x = torch.zeros(chosen_dx.num_dims)
pert = 10*torch.diag(torch.ones_like(x))
g = torch.zeros_like(x)
for i in range(g.shape[0]):
    g[i] = chosen_dx(x, x+pert[i])
epsilon = (1/torch.min(g[g!=0])).item()
epsilon


99997992.0

设置之后允许每一维变多少

In [6]:
# epsilon越大，要求不公平样本对dx越小，越严格
# epsilon = 1e10
# epsilon = 9e9
unfair_metric = UnfairMetric(dx=chosen_dx, dy=distance_y, epsilon=epsilon)

In [7]:
def show_result(result):
    pair, n_query = result[0], result[1]
    if len(result) == 3:
        print(f'n_iters = {result[2]}')
    if pair != None:
        display(adult_gen.feature_dataframe(data = pair), n_query)
    else:
        display('not found')

In [8]:
# random.seed(422)
# torch.manual_seed(422)

# select_seeker = RandomSelectPairSeeker(model=model, unfair_metric=unfair_metric, data=all_X)
# for _ in range(3):
#     show_result(select_seeker.seek(dx_constraint=True, max_query=1e6))

In [9]:
random.seed(422)
torch.manual_seed(422)

select_seeker = RandomSelectSeeker(model=model, unfair_metric=unfair_metric, data=all_X, data_gen=adult_gen)
for _ in range(3):
    show_result(select_seeker.seek(max_query=1e6))

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,32,0,0,13,40,1,0,2,2,5,2
1,32,1,0,13,40,1,0,2,2,5,2


1016

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,44,0,1902,14,40,1,1,2,3,0,5
1,44,0,1902,14,40,0,1,2,3,0,5


24

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,46,0,0,12,40,1,1,2,6,0,2
1,46,0,0,12,40,0,1,2,6,0,2


160

In [10]:
random.seed(422)
torch.manual_seed(422)

distribution_seeker = DistributionGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
for _ in range(3):
    show_result(distribution_seeker.seek(max_query=1e6))

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,64,0,0,14,40,0,1,5,9,5,5
1,64,0,0,14,40,1,1,5,9,5,5


240

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,63,0,0,10,50,1,0,2,3,0,1
1,63,0,0,10,50,0,0,2,3,0,1


42

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,52,0,0,12,40,1,1,2,6,0,2
1,52,0,0,12,40,0,1,2,6,0,2


42

In [11]:
# random.seed(422)
# torch.manual_seed(422)

# range_seeker = RangeGenSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
# for _ in range(3):
#     show_result(range_seeker.seek(max_query=1e6))

In [12]:
random.seed(422)
torch.manual_seed(422)

test_seeker = WhiteboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.01, max_query=1e6, lamb=1))

'try: 0'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,40,53,1941,6,6,0,0,5,13,0,3
1,40,53,1941,6,6,1,0,5,13,0,3


142

'try: 1'

restart 482


Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,79,2228,847,2,76,0,0,5,2,4,6
1,79,2228,847,2,76,1,0,5,2,4,6


700

'try: 2'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,82,7100,259,6,99,0,0,4,8,0,2
1,82,7100,259,6,99,1,0,4,8,0,2


202

In [13]:
random.seed(422)
torch.manual_seed(422)

test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen, easy=False)
# show_result(white_seeker.seek())
for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=0.1, max_query=1e6, lamb=1))

'try: 0'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,39,35,2225,6,6,0,0,5,13,0,3
1,39,35,2225,6,6,1,0,5,13,0,3


1177

'try: 1'

restart 747


Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,79,2200,920,2,76,0,0,5,2,4,6
1,79,2200,920,2,76,1,0,5,2,4,6


2174

'try: 2'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,82,7094,260,6,99,0,0,4,8,0,2
1,82,7094,260,6,99,1,0,4,8,0,2


1261

In [14]:
# random.seed(422)
# torch.manual_seed(422)

# test_seeker = BlackboxSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen, easy=True)
# # show_result(white_seeker.seek())
# for i in range(3):
#     display(f'try: {i}')
#     show_result(test_seeker.seek(origin_lr=1, max_query=1e6, lamb=1))

In [15]:
random.seed(422)
torch.manual_seed(422)

test_seeker = FoolSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen, easy=True)

for i in range(3):
    display(f'try: {i}')
    show_result(test_seeker.seek(origin_lr=1, max_query=1e6, lamb=1))

'try: 0'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,17,2319,267,2,1,0,0,5,13,0,3
1,17,2319,267,2,1,1,0,5,13,0,3


107

'try: 1'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,67,7954,0,5,92,0,0,1,4,3,6
1,67,7953,0,5,92,0,0,1,4,3,6


110

'try: 2'

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,17,3317,0,1,1,1,0,2,1,2,3
1,17,3316,0,1,1,1,0,2,1,2,3


114

In [16]:
# random.seed(422)
# torch.manual_seed(422)

# test_seeker = FoolSeeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen, easy=False)

# for i in range(3):
#     display(f'try: {i}')
#     show_result(test_seeker.seek(origin_lr=1, max_query=1e6, lamb=1))