In [1]:
import torch
import random
from seeker import White_seeker, Random_select_seeker, Random_gen_seeker, Black_seeker
from utils import Unfair_metric, load_model, get_L_matrix
from data import adult
from train_dnn import get_data
from dnn_models.model import MLP, NormMLP
from distances.normalized_mahalanobis_distances import SquaredEuclideanDistance, ProtectedSEDistances
from distances.sensitive_subspace_distances import LogisticRegSensitiveSubspace
from distances.binary_distances import BinaryDistance
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from IPython.display import display

%load_ext autoreload
%autoreload 2

In [2]:
rand_seed = 0
use_protected_attr = True
protected_vars = ['race_White']
# protected_vars = ['sex_Male', 'race_White']

dataset, train_dl, test_dl = get_data(adult, rand_seed, protected_vars=protected_vars)
dataset.use_protected_attr = use_protected_attr
in_dim = dataset.dim_feature()
out_dim = 2

# prepare data
all_X, all_y = dataset.get_all_data(), dataset.labels

adult_gen = adult.Adult_gen(sensitive_columns=dataset.protected_idxs, include_protected_feature=use_protected_attr)

In [3]:
model = MLP(in_dim, out_dim)
load_model(model, 'MLP', 'adult', 'STDTrainer', use_protected_attr=use_protected_attr, \
           protected_vars=protected_vars, id=rand_seed)

# model = NormMLP(in_dim, out_dim, adult_gen.get_range('data'))
# load_model(model, 'NormMLP', 'adult', 'STDTrainer', use_protected_attr=use_protected_attr, \
#            protected_vars=protected_vars, id=rand_seed)

# all_pred = model.get_prediction(all_X)

In [4]:
# prepare distances
distance_x_NSE = SquaredEuclideanDistance()
distance_x_Causal = ProtectedSEDistances()
distance_x_LR = LogisticRegSensitiveSubspace()
distance_y = BinaryDistance()

distance_x_NSE.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen)
distance_x_Causal.fit(num_dims=dataset.dim_feature(), data_gen=adult_gen, protected_idx=dataset.protected_idxs)
distance_x_LR.fit(all_X, adult_gen, protected_idxs=dataset.protected_idxs)

In [5]:
def rand_gen():
    return {
        'age': random.randint(15, 60),
        'capital_gain': 0,
        'capital_loss': 0,
        'education_num': random.randint(1, 15),
        'hours_per_week': random.randint(10, 50),
        'race_white': random.choice([0, 1]),
        'sex_male': random.choice([0, 1]),
        'marital_status': random.choice(list(range(7))),
        'occupation': random.choice(list(range(14))),
        'relationship': random.choice(list(range(6))),
        'workclass': random.choice(list(range(7)))
    }

def perturb_pair(x, pert_features, pert_func):
    pair = dict()
    for k, v in x.items():
        pair[k] = torch.tensor([v, v])
        if k in pert_features:
            pair[k][1] = pert_func(pair[k][0])
    return pair

x = rand_gen()
# pair = adult.generate_from_origin(**perturb_pair(x, ['sex_male', 'race_white'], lambda x:1-x))
# adult.get_original_feature(pair)

In [6]:
pair = adult.generate_from_origin(**perturb_pair(x, ['capital_loss'], lambda x:1+x))
adult.get_original_feature(pair)

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,38.0,0.0,0.0,12.0,27.0,0.0,1.0,4.0,11.0,0.0,1.0
1,38.0,0.0,1.0,12.0,27.0,0.0,1.0,4.0,11.0,0.0,1.0


In [7]:
# dist = distance_x_LR(all_X[0], all_X[0] + torch.eye(all_X.shape[1]), itemwise_dist=False).squeeze()
# print(dist)
# print(dist[dataset.protected_idxs])

In [8]:
# 1/1e-10

In [9]:
dist = distance_x_Causal(all_X[0], all_X[0] + torch.eye(all_X.shape[1]), itemwise_dist=False).squeeze()
print(dist)
print(dist[dataset.protected_idxs])

tensor([1.8765e-04, 1.0001e-10, 5.2702e-08, 4.4444e-03, 1.0412e-04, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00])
tensor([0.])


In [10]:
# x = rand_gen()
# for key in x:
#     pair = adult.generate_from_origin(**perturb_pair(x, key, x[key]+1))
#     print(distance_x_NSE(pair[0], pair[1]), 1 / distance_x_NSE(pair[0], pair[1]))

In [11]:
# epsilon越大，要求不公平样本对dx越小，越严格
epsilon = 1e8
# epsilon = 9e9
unfair_metric = Unfair_metric(dx=distance_x_Causal, dy=distance_y, epsilon=epsilon)

In [12]:
# L = get_L_matrix(all_X, all_pred, distance_x_Causal, distance_y)
# n_unfair = torch.sum(L>epsilon).item()

# total_pairs = all_X.shape[0]**2

# unfair_ratio = n_unfair/total_pairs
# unfair_ratio

In [13]:
# 1/unfair_ratio

In [14]:
# p_success = lambda n: 1 - (1 - unfair_ratio)**n
# for i in range(10):
#     if p_success(10**i) > 1/100:
#         print(10**i, p_success(10**i))

In [15]:
def show_result(result):
    pair, n_query = result
    if pair != None:
        display(adult.get_original_feature(pair), n_query)
    else:
        print('not found')

In [16]:
# random.seed(422)
# torch.manual_seed(422)

# select_seeker = Random_select_seeker(model=model, unfair_metric=unfair_metric, data=all_X)
# for _ in range(5):
#     show_result(select_seeker.seek(dx_constraint=True, max_query=1e6))

In [17]:
# random.seed(422)
# torch.manual_seed(422)

# gen_seeker = Random_gen_seeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
# for _ in range(3):
#     show_result(gen_seeker.seek(by_range=True, max_query=1e6))

In [18]:
# random.seed(422)
# torch.manual_seed(422)

# for _ in range(3):
#     show_result(gen_seeker.seek(by_range=False, max_query=1e6))

In [19]:
adult_gen.get_range('data')

tensor([[1.7000e+01, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.0000e+01, 9.9999e+04, 4.3560e+03, 1.6000e+01, 9.9000e+01, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
         1.0000e+00, 1.0000e+00, 1.0000

In [27]:
random.seed(422)
torch.manual_seed(422)

white_seeker = White_seeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
# show_result(white_seeker.seek())
for _ in range(5):
    show_result(white_seeker.seek(origin_lr=100, max_query=1e6))

step change 1 tensor([[ 1.1462, -0.0491, -0.1528,  0.8518,  1.1826,  4.1826, -0.5449, -1.8152,
          1.1397,  5.9937,  2.5178,  2.4288,  3.3635, -0.0302,  1.4144, -1.5988,
          2.2114,  2.7636,  1.8868,  4.1073,  1.5916, -1.0322, -0.1536,  0.4494,
          0.2265,  1.7411,  1.5950, -2.0282,  5.4223,  2.4362,  4.5005,  4.3263,
         -0.4180,  0.3718,  0.1715,  0.4385,  2.8141, -1.1200,  1.7192,  0.8406,
          0.9724]]) tensor(1.5950) tensor(0.3718)
step change 2 tensor([[ 8.3672e+01, -4.9054e+03, -6.6566e+02,  1.2777e+01,  1.1589e+02,
          4.1826e+00, -5.4490e-01, -1.8152e+00,  1.1397e+00,  5.9937e+00,
          2.5178e+00,  2.4288e+00,  3.3635e+00, -3.0197e-02,  1.4144e+00,
         -1.5988e+00,  2.2114e+00,  2.7636e+00,  1.8868e+00,  4.1073e+00,
          1.5916e+00, -1.0322e+00, -1.5357e-01,  4.4942e-01,  2.2651e-01,
          1.7411e+00,  1.5950e+00, -2.0282e+00,  5.4223e+00,  2.4362e+00,
          4.5005e+00,  4.3263e+00, -4.1802e-01,  3.7176e-01,  1.7146e-01,

Traceback (most recent call last):
  File "/data/liuyuanhao/anaconda3/envs/torch1.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_15319/590663850.py", line 7, in <module>
    show_result(white_seeker.seek(origin_lr=100, max_query=1e6))
  File "/data/liuyuanhao/projects/Unfairness_prove/seeker.py", line 223, in seek
    raise Exception('debug checkpoint')
Exception: debug checkpoint

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/data/liuyuanhao/anaconda3/envs/torch1.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 2120, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
  File "/data/liuyuanhao/anaconda3/envs/torch1.9/lib/python3.9/site-packages/IPython/core/ultratb.py", line 1435, in structured_traceback
    return FormattedTB.structured_traceback(
  File "/data/liuyu

In [30]:
random.seed(422)
torch.manual_seed(422)

black_seeker = Black_seeker(model=model, unfair_metric=unfair_metric, data_gen=adult_gen)
for _ in range(5):
    show_result(black_seeker.seek(max_query=1e6))

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,90.0,9449.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0
1,90.0,9446.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0


1251

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,90.0,9450.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0
1,90.0,9447.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0


1086

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,90.0,9450.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0
1,90.0,9447.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0


754

Unnamed: 0,age,capital-gain,capital-loss,education-num,hours-per-week,race_White,sex_Male,marital-status,occupation,relationship,workclass
0,90.0,9449.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0
1,90.0,9446.0,0.0,16.0,99.0,1.0,1.0,4.0,7.0,1.0,2.0


1418

KeyboardInterrupt: 