#### This notebook demonstrates the use of the SenSR algorithm to learn a fair classifier.
[SenSR](https://arxiv.org/pdf/1907.00020.pdf) is an in-processing technique that learns a classifier that is fair in the sense that its performance is invariant under certain perturbations to the features. For example, the performance of a resume screening system should be invariant under changes to the name of the applicant or switching the gender pronouns. This notebook reproduces the Adult experiments in [this paper](https://arxiv.org/pdf/1907.00020.pdf).

In [8]:
!pip install -e git+https://github.com/LisaKouts/AIF360.git@master#egg=aif360

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Obtaining aif360 from git+https://github.com/LisaKouts/AIF360.git@master#egg=aif360
  Updating c:\users\lucp11124\documents\github\sensitive-subspace-robustness\src\aif360 clone (to revision master)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Installing collected packages: aif360
  Attempting uninstall: aif360
    Found existing installation: aif360 0.5.0
    Uninstalling aif360-0.5.0:
      Successfully uninstalled aif360-0.5.0
  Running setup.py develop for aif360
Successfully installed aif360-0.5.0


  Running command git fetch -q --tags
  Running command git reset --hard -q 27a06ab4b09a02283910b970feebacaa77b6dc2e


In [9]:
# Load all necessary packages
from aif360.datasets import BinaryLabelDataset, AdultDataset
from aif360.metrics import BinaryLabelDatasetMetric, ClassificationMetric

from sklearn.preprocessing import StandardScaler

from IPython.display import Markdown, display

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.decomposition import TruncatedSVD

import utils
import SenSR
import numpy as np

Instructions for updating:
non-resource variables are not supported in the long term


#### Load dataset and set options

In [16]:
# Get the dataset and split into train and test
dataset_orig = AdultDataset()

# We do not use these features. Note, we use the continuous version of education, i.e. `education-num`, so we drop the categorical versions of education
drop_features = [
    'education=10th',
    'education=11th',
    'education=12th',
    'education=1st-4th',
    'education=5th-6th',
    'education=7th-8th',
    'education=9th',
    'education=Assoc-acdm',
    'education=Assoc-voc',
    'education=Bachelors',
    'education=Doctorate',
    'education=HS-grad',
    'education=Masters',
    'education=Preschool',
    'education=Prof-school',
    'education=Some-college', 
    'native-country=Cambodia',
    'native-country=Canada',
    'native-country=China',
    'native-country=Columbia',
    'native-country=Cuba',
    'native-country=Dominican-Republic',
    'native-country=Ecuador',
    'native-country=El-Salvador',
    'native-country=England',
    'native-country=France',
    'native-country=Germany',
    'native-country=Greece',
    'native-country=Guatemala',
    'native-country=Haiti',
    'native-country=Holand-Netherlands',
    'native-country=Honduras',
    'native-country=Hong',
    'native-country=Hungary',
    'native-country=India',
    'native-country=Iran',
    'native-country=Ireland',
    'native-country=Italy',
    'native-country=Jamaica',
    'native-country=Japan',
    'native-country=Laos',
    'native-country=Mexico',
    'native-country=Nicaragua',
    'native-country=Outlying-US(Guam-USVI-etc)',
    'native-country=Peru',
    'native-country=Philippines',
    'native-country=Poland',
    'native-country=Portugal',
    'native-country=Puerto-Rico',
    'native-country=Scotland',
    'native-country=South',
    'native-country=Taiwan',
    'native-country=Thailand',
    'native-country=Trinadad&Tobago',
    'native-country=United-States',
    'native-country=Vietnam',
    'native-country=Yugoslavia']

drop_features_indices = [dataset_orig.feature_names.index(feat) for feat in drop_features]

dataset_orig.features = np.delete(dataset_orig.features, drop_features_indices, axis = 1)
dataset_orig.feature_names = [feat for feat in dataset_orig.feature_names if feat not in drop_features]



In [19]:
len(dataset_orig.feature_names)

41

In [10]:


# we will standardize continous features
continous_features = ['age', 'education-num', 'capital-gain', 'capital-loss', 'hours-per-week']
continous_features_indices = [dataset_orig.feature_names.index(feat) for feat in continous_features]

# get a 80%/20% train/test split
dataset_orig_train, dataset_orig_test = dataset_orig.split([0.8], shuffle=True)

X_train = dataset_orig_train.features
# normalize continuous features
SS = StandardScaler().fit(X_train[:, continous_features_indices])
X_train[:, continous_features_indices] = SS.transform(X_train[:, continous_features_indices])
# remove sex and race as predictive features
X_train = np.delete(X_train, [dataset_orig_train.feature_names.index(feat) for feat in ['sex', 'race']], axis = 1)

X_test = dataset_orig_test.features
# normalize continuous features
X_test[:, continous_features_indices] = SS.transform(X_test[:, continous_features_indices])
# remove sex and race as predictive features
X_test = np.delete(X_test, [dataset_orig_test.feature_names.index(feat) for feat in ['sex', 'race']], axis = 1)

y_train = dataset_orig_train.labels
y_test = dataset_orig_test.labels

one_hot = OneHotEncoder(sparse=False)
one_hot.fit(y_train.reshape(-1,1))
y_train = one_hot.transform(y_train.reshape(-1,1))
y_test = one_hot.transform(y_test.reshape(-1,1))

y_sex_train = dataset_orig_train.features[:, dataset_orig_train.feature_names.index('sex')]
y_sex_test = dataset_orig_test.features[:, dataset_orig_test.feature_names.index('sex')]

one_hot.fit(y_sex_train.reshape(-1,1))
y_sex_train = one_hot.transform(y_sex_train.reshape(-1,1))
y_sex_test = one_hot.transform(y_sex_test.reshape(-1,1))

privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]



In [11]:
# print out some labels, names, etc.
display(Markdown("#### Training Dataset shape"))
print(dataset_orig_train.features.shape)
display(Markdown("#### Favorable and unfavorable labels"))
print(dataset_orig_train.favorable_label, dataset_orig_train.unfavorable_label)
display(Markdown("#### Protected attribute names"))
print(dataset_orig_train.protected_attribute_names)
display(Markdown("#### Privileged and unprivileged protected attribute values"))
print(dataset_orig_train.privileged_protected_attributes, 
      dataset_orig_train.unprivileged_protected_attributes)
display(Markdown("#### Dataset feature names"))
print(dataset_orig_train.feature_names)

#### Training Dataset shape

(36176, 41)


#### Favorable and unfavorable labels

1.0 0.0


#### Protected attribute names

['race', 'sex']


#### Privileged and unprivileged protected attribute values

[array([1.]), array([1.])] [array([0.]), array([0.])]


#### Dataset feature names

['age', 'education-num', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'workclass=Federal-gov', 'workclass=Local-gov', 'workclass=Private', 'workclass=Self-emp-inc', 'workclass=Self-emp-not-inc', 'workclass=State-gov', 'workclass=Without-pay', 'marital-status=Divorced', 'marital-status=Married-AF-spouse', 'marital-status=Married-civ-spouse', 'marital-status=Married-spouse-absent', 'marital-status=Never-married', 'marital-status=Separated', 'marital-status=Widowed', 'occupation=Adm-clerical', 'occupation=Armed-Forces', 'occupation=Craft-repair', 'occupation=Exec-managerial', 'occupation=Farming-fishing', 'occupation=Handlers-cleaners', 'occupation=Machine-op-inspct', 'occupation=Other-service', 'occupation=Priv-house-serv', 'occupation=Prof-specialty', 'occupation=Protective-serv', 'occupation=Sales', 'occupation=Tech-support', 'occupation=Transport-moving', 'relationship=Husband', 'relationship=Not-in-family', 'relationship=Other-relative', 'relationship=Own-child', 

#### Metric for original training data

In [13]:
# Metric for the original dataset
metric_orig_train = BinaryLabelDatasetMetric(dataset_orig_train, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)
display(Markdown("#### Original training dataset"))
print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_orig_train.mean_difference())
metric_orig_test = BinaryLabelDatasetMetric(dataset_orig_test, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)
print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_orig_test.mean_difference())

#### Original training dataset

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.198379
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.200991


### Learn baseline classifier

In [14]:
weights, train_logits, test_logits  = SenSR.train_nn(X_train, y_train, X_test = X_test, y_test = y_test, n_units=[], l2_reg=0., batch_size=1000, epoch=5000, verbose=True)


Epoch 0 train accuracy 0.683326
Epoch 0 test accuracy 0.687894

Epoch 10 train accuracy 0.683934
Epoch 10 test accuracy 0.687231

Epoch 20 train accuracy 0.686007
Epoch 20 test accuracy 0.68513

Epoch 30 train accuracy 0.687445
Epoch 30 test accuracy 0.686235

Epoch 40 train accuracy 0.689380
Epoch 40 test accuracy 0.687009

Epoch 50 train accuracy 0.691895
Epoch 50 test accuracy 0.688778

Epoch 60 train accuracy 0.695406
Epoch 60 test accuracy 0.692427

Epoch 70 train accuracy 0.699718
Epoch 70 test accuracy 0.694306

Epoch 80 train accuracy 0.704915
Epoch 80 test accuracy 0.699502

Epoch 90 train accuracy 0.710886
Epoch 90 test accuracy 0.706247

Epoch 100 train accuracy 0.716829
Epoch 100 test accuracy 0.712438

Epoch 110 train accuracy 0.722136
Epoch 110 test accuracy 0.717081

Epoch 120 train accuracy 0.728853
Epoch 120 test accuracy 0.721172

Epoch 130 train accuracy 0.735930
Epoch 130 test accuracy 0.729353

Epoch 140 train accuracy 0.741016
Epoch 140 test accuracy 0.734218

Ep

In [8]:
dataset_nodebiasing_train = dataset_orig_train.copy()
dataset_nodebiasing_train.labels = np.argmax(train_logits,axis = 1)

dataset_nodebiasing_test = dataset_orig_test.copy()
dataset_nodebiasing_test.labels = np.argmax(test_logits,axis = 1)

In [9]:
def compute_gap_RMS(data_set):
    TPR = -1*data_set.false_negative_rate_difference()
    TNR = -1*data_set.false_positive_rate_difference()    

    return np.sqrt(1/2*(TPR**2 + TNR**2)), max(np.abs(TPR), np.abs(TNR))

In [10]:
# Metrics for the dataset from plain model (without debiasing)
privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

display(Markdown("#### Plain model - without debiasing - dataset metrics"))
metric_dataset_nodebiasing_train = BinaryLabelDatasetMetric(dataset_nodebiasing_train, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_train.mean_difference())

metric_dataset_nodebiasing_test = BinaryLabelDatasetMetric(dataset_nodebiasing_test, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_test.mean_difference())

display(Markdown("#### Plain model - without debiasing - classification metrics"))
classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)
print("Test set: Classification accuracy = %f" % classified_metric_nodebiasing_test.accuracy())
TPR = classified_metric_nodebiasing_test.true_positive_rate()
TNR = classified_metric_nodebiasing_test.true_negative_rate()
bal_acc_nodebiasing_test = 0.5*(TPR+TNR)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms sex = %f" % gap_rms)
print("Test set: max gap rms sex = %f" % max_gap)
print("Test set: Balanced TPR = %f" % bal_acc_nodebiasing_test)

privileged_groups = [{'race': 1}]
unprivileged_groups = [{'race': 0}]

classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms race = %f" % gap_rms)
print("Test set: max gap rms race = %f" % max_gap)


#### Plain model - without debiasing - dataset metrics

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.300755
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.302490


#### Plain model - without debiasing - classification metrics

Test set: Classification accuracy = 0.807629
Test set: gap rms sex = 0.154314
Test set: max gap rms sex = 0.196815
Test set: Balanced TPR = 0.814855
Test set: gap rms race = 0.057446
Test set: max gap rms race = 0.078686


### Apply in-processing algorithm based on adversarial learning
#### SenSR$_0$

In [11]:
# get sensitive directions
weights, train_logits, test_logits  = SenSR.train_nn(X_train, y_sex_train, X_test = X_test, y_test = y_sex_test, n_units=[], l2_reg=1., batch_size=5000, epoch=5000, verbose=True)

sensitive_directions = []
sensitive_directions.append(weights[0].T)

sensitive_directions = np.vstack(sensitive_directions)
tSVD = TruncatedSVD(n_components=2)
tSVD.fit(sensitive_directions)
sensitive_directions = tSVD.components_


Epoch 0 train accuracy 0.650994
Epoch 0 test accuracy 0.658375

Epoch 10 train accuracy 0.651906
Epoch 10 test accuracy 0.657933

Epoch 20 train accuracy 0.657296
Epoch 20 test accuracy 0.660807

Epoch 30 train accuracy 0.662631
Epoch 30 test accuracy 0.666335

Epoch 40 train accuracy 0.667109
Epoch 40 test accuracy 0.669541

Epoch 50 train accuracy 0.670868
Epoch 50 test accuracy 0.672637

Epoch 60 train accuracy 0.677060
Epoch 60 test accuracy 0.676727

Epoch 70 train accuracy 0.681676
Epoch 70 test accuracy 0.680265

Epoch 80 train accuracy 0.684800
Epoch 80 test accuracy 0.682587

Epoch 90 train accuracy 0.689886
Epoch 90 test accuracy 0.685904

Epoch 100 train accuracy 0.694336
Epoch 100 test accuracy 0.689552

Epoch 110 train accuracy 0.697957
Epoch 110 test accuracy 0.692869

Epoch 120 train accuracy 0.700998
Epoch 120 test accuracy 0.696628

Epoch 130 train accuracy 0.705669
Epoch 130 test accuracy 0.701271

Epoch 140 train accuracy 0.709512
Epoch 140 test accuracy 0.705694

E


Epoch 1230 train accuracy 0.733339
Epoch 1230 test accuracy 0.732338

Epoch 1240 train accuracy 0.732897
Epoch 1240 test accuracy 0.732228

Epoch 1250 train accuracy 0.733090
Epoch 1250 test accuracy 0.732449

Epoch 1260 train accuracy 0.733698
Epoch 1260 test accuracy 0.733002

Epoch 1270 train accuracy 0.733284
Epoch 1270 test accuracy 0.732891

Epoch 1280 train accuracy 0.733145
Epoch 1280 test accuracy 0.732449

Epoch 1290 train accuracy 0.733284
Epoch 1290 test accuracy 0.732117

Epoch 1300 train accuracy 0.734638
Epoch 1300 test accuracy 0.733997

Epoch 1310 train accuracy 0.734058
Epoch 1310 test accuracy 0.733002

Epoch 1320 train accuracy 0.733505
Epoch 1320 test accuracy 0.732228

Epoch 1330 train accuracy 0.733449
Epoch 1330 test accuracy 0.732228

Epoch 1340 train accuracy 0.733228
Epoch 1340 test accuracy 0.732449

Epoch 1350 train accuracy 0.733201
Epoch 1350 test accuracy 0.732228

Epoch 1360 train accuracy 0.733698
Epoch 1360 test accuracy 0.73267

Epoch 1370 train acc


Epoch 2420 train accuracy 0.735025
Epoch 2420 test accuracy 0.733554

Epoch 2430 train accuracy 0.733947
Epoch 2430 test accuracy 0.733002

Epoch 2440 train accuracy 0.732924
Epoch 2440 test accuracy 0.731012

Epoch 2450 train accuracy 0.733422
Epoch 2450 test accuracy 0.732228

Epoch 2460 train accuracy 0.733477
Epoch 2460 test accuracy 0.732228

Epoch 2470 train accuracy 0.733864
Epoch 2470 test accuracy 0.73267

Epoch 2480 train accuracy 0.733892
Epoch 2480 test accuracy 0.732781

Epoch 2490 train accuracy 0.734058
Epoch 2490 test accuracy 0.733002

Epoch 2500 train accuracy 0.733919
Epoch 2500 test accuracy 0.732228

Epoch 2510 train accuracy 0.733809
Epoch 2510 test accuracy 0.733223

Epoch 2520 train accuracy 0.733228
Epoch 2520 test accuracy 0.731786

Epoch 2530 train accuracy 0.733560
Epoch 2530 test accuracy 0.73267

Epoch 2540 train accuracy 0.734085
Epoch 2540 test accuracy 0.733333

Epoch 2550 train accuracy 0.734140
Epoch 2550 test accuracy 0.733112

Epoch 2560 train accu


Epoch 3610 train accuracy 0.733284
Epoch 3610 test accuracy 0.732228

Epoch 3620 train accuracy 0.731929
Epoch 3620 test accuracy 0.731786

Epoch 3630 train accuracy 0.733781
Epoch 3630 test accuracy 0.732228

Epoch 3640 train accuracy 0.733671
Epoch 3640 test accuracy 0.732449

Epoch 3650 train accuracy 0.733007
Epoch 3650 test accuracy 0.732117

Epoch 3660 train accuracy 0.733145
Epoch 3660 test accuracy 0.732449

Epoch 3670 train accuracy 0.733864
Epoch 3670 test accuracy 0.733223

Epoch 3680 train accuracy 0.733256
Epoch 3680 test accuracy 0.732338

Epoch 3690 train accuracy 0.733836
Epoch 3690 test accuracy 0.731122

Epoch 3700 train accuracy 0.733090
Epoch 3700 test accuracy 0.731122

Epoch 3710 train accuracy 0.733698
Epoch 3710 test accuracy 0.732449

Epoch 3720 train accuracy 0.732952
Epoch 3720 test accuracy 0.732338

Epoch 3730 train accuracy 0.732786
Epoch 3730 test accuracy 0.732559

Epoch 3740 train accuracy 0.733311
Epoch 3740 test accuracy 0.733002

Epoch 3750 train ac


Epoch 4790 train accuracy 0.734085
Epoch 4790 test accuracy 0.733554

Epoch 4800 train accuracy 0.733062
Epoch 4800 test accuracy 0.732117

Epoch 4810 train accuracy 0.734279
Epoch 4810 test accuracy 0.732781

Epoch 4820 train accuracy 0.733754
Epoch 4820 test accuracy 0.732228

Epoch 4830 train accuracy 0.732482
Epoch 4830 test accuracy 0.731564

Epoch 4840 train accuracy 0.732648
Epoch 4840 test accuracy 0.732781

Epoch 4850 train accuracy 0.734970
Epoch 4850 test accuracy 0.733665

Epoch 4860 train accuracy 0.732593
Epoch 4860 test accuracy 0.731233

Epoch 4870 train accuracy 0.734610
Epoch 4870 test accuracy 0.733002

Epoch 4880 train accuracy 0.734776
Epoch 4880 test accuracy 0.734439

Epoch 4890 train accuracy 0.734362
Epoch 4890 test accuracy 0.733665

Epoch 4900 train accuracy 0.733615
Epoch 4900 test accuracy 0.732781

Epoch 4910 train accuracy 0.734334
Epoch 4910 test accuracy 0.732781

Epoch 4920 train accuracy 0.733477
Epoch 4920 test accuracy 0.732117

Epoch 4930 train ac

In [12]:
# apply SenSR_0
weights, train_logits, test_logits  = SenSR.train_fair_nn(
    X_train, 
    y_train, 
    sensitive_directions, 
    X_test=X_test, 
    y_test=y_test, 
    n_units = [], 
    lr=0.001, 
    batch_size=5000, 
    epoch=15000, 
    verbose=True, 
    l2_reg=0., 
    subspace_epoch=10, 
    subspace_step=.1, 
    eps=None, 
    full_step=-1)

Epoch 0 train accuracy 0.591951; lambda is 2.000000
Epoch 0 test accuracy 0.594472
Epoch 10 train accuracy 0.612157; lambda is 2.000000
Epoch 10 test accuracy 0.614483
Epoch 20 train accuracy 0.639854; lambda is 2.000000
Epoch 20 test accuracy 0.645992
Epoch 30 train accuracy 0.672693; lambda is 2.000000
Epoch 30 test accuracy 0.672858
Epoch 40 train accuracy 0.697156; lambda is 2.000000
Epoch 40 test accuracy 0.698397
Epoch 50 train accuracy 0.723084; lambda is 2.000000
Epoch 50 test accuracy 0.721172
Epoch 60 train accuracy 0.739337; lambda is 2.000000
Epoch 60 test accuracy 0.738087
Epoch 70 train accuracy 0.752799; lambda is 2.000000
Epoch 70 test accuracy 0.751465
Epoch 80 train accuracy 0.760566; lambda is 2.000000
Epoch 80 test accuracy 0.759204
Epoch 90 train accuracy 0.766426; lambda is 2.000000
Epoch 90 test accuracy 0.765948
Epoch 100 train accuracy 0.769356; lambda is 2.000000
Epoch 100 test accuracy 0.769486
Epoch 110 train accuracy 0.773337; lambda is 2.000000
Epoch 110 t

Epoch 960 train accuracy 0.802913; lambda is 2.000000
Epoch 960 test accuracy 0.802101
Epoch 970 train accuracy 0.802665; lambda is 2.000000
Epoch 970 test accuracy 0.801879
Epoch 980 train accuracy 0.802554; lambda is 2.000000
Epoch 980 test accuracy 0.801879
Epoch 990 train accuracy 0.802665; lambda is 2.000000
Epoch 990 test accuracy 0.802432
Epoch 1000 train accuracy 0.802969; lambda is 2.000000
Epoch 1000 test accuracy 0.802653
Epoch 1010 train accuracy 0.802748; lambda is 2.000000
Epoch 1010 test accuracy 0.801879
Epoch 1020 train accuracy 0.802831; lambda is 2.000000
Epoch 1020 test accuracy 0.801769
Epoch 1030 train accuracy 0.802969; lambda is 2.000000
Epoch 1030 test accuracy 0.802432
Epoch 1040 train accuracy 0.802913; lambda is 2.000000
Epoch 1040 test accuracy 0.802101
Epoch 1050 train accuracy 0.803107; lambda is 2.000000
Epoch 1050 test accuracy 0.802543
Epoch 1060 train accuracy 0.803356; lambda is 2.000000
Epoch 1060 test accuracy 0.803206
Epoch 1070 train accuracy 0.8

Epoch 1910 train accuracy 0.806452; lambda is 2.000000
Epoch 1910 test accuracy 0.804422
Epoch 1920 train accuracy 0.806700; lambda is 2.000000
Epoch 1920 test accuracy 0.804643
Epoch 1930 train accuracy 0.805982; lambda is 2.000000
Epoch 1930 test accuracy 0.804201
Epoch 1940 train accuracy 0.805705; lambda is 2.000000
Epoch 1940 test accuracy 0.803648
Epoch 1950 train accuracy 0.806009; lambda is 2.000000
Epoch 1950 test accuracy 0.803648
Epoch 1960 train accuracy 0.806949; lambda is 2.000000
Epoch 1960 test accuracy 0.805196
Epoch 1970 train accuracy 0.806756; lambda is 2.000000
Epoch 1970 test accuracy 0.804975
Epoch 1980 train accuracy 0.806452; lambda is 2.000000
Epoch 1980 test accuracy 0.804754
Epoch 1990 train accuracy 0.805788; lambda is 2.000000
Epoch 1990 test accuracy 0.804091
Epoch 2000 train accuracy 0.806203; lambda is 2.000000
Epoch 2000 test accuracy 0.804422
Epoch 2010 train accuracy 0.807004; lambda is 2.000000
Epoch 2010 test accuracy 0.804865
Epoch 2020 train accu

Epoch 2860 train accuracy 0.807253; lambda is 2.000000
Epoch 2860 test accuracy 0.804865
Epoch 2870 train accuracy 0.807115; lambda is 2.000000
Epoch 2870 test accuracy 0.804754
Epoch 2880 train accuracy 0.807087; lambda is 2.000000
Epoch 2880 test accuracy 0.804533
Epoch 2890 train accuracy 0.807364; lambda is 2.000000
Epoch 2890 test accuracy 0.804975
Epoch 2900 train accuracy 0.807640; lambda is 2.000000
Epoch 2900 test accuracy 0.805638
Epoch 2910 train accuracy 0.807696; lambda is 2.000000
Epoch 2910 test accuracy 0.805638
Epoch 2920 train accuracy 0.807778; lambda is 2.000000
Epoch 2920 test accuracy 0.80597
Epoch 2930 train accuracy 0.808248; lambda is 2.000000
Epoch 2930 test accuracy 0.806081
Epoch 2940 train accuracy 0.808055; lambda is 2.000000
Epoch 2940 test accuracy 0.80586
Epoch 2950 train accuracy 0.807806; lambda is 2.000000
Epoch 2950 test accuracy 0.805638
Epoch 2960 train accuracy 0.807143; lambda is 2.000000
Epoch 2960 test accuracy 0.804754
Epoch 2970 train accura

Epoch 3810 train accuracy 0.808414; lambda is 2.000000
Epoch 3810 test accuracy 0.806302
Epoch 3820 train accuracy 0.807944; lambda is 2.000000
Epoch 3820 test accuracy 0.805749
Epoch 3830 train accuracy 0.806922; lambda is 2.000000
Epoch 3830 test accuracy 0.804312
Epoch 3840 train accuracy 0.807419; lambda is 2.000000
Epoch 3840 test accuracy 0.805638
Epoch 3850 train accuracy 0.808082; lambda is 2.000000
Epoch 3850 test accuracy 0.806191
Epoch 3860 train accuracy 0.808718; lambda is 2.000000
Epoch 3860 test accuracy 0.806965
Epoch 3870 train accuracy 0.808552; lambda is 2.000000
Epoch 3870 test accuracy 0.806523
Epoch 3880 train accuracy 0.807308; lambda is 2.000000
Epoch 3880 test accuracy 0.805307
Epoch 3890 train accuracy 0.805899; lambda is 2.000000
Epoch 3890 test accuracy 0.803648
Epoch 3900 train accuracy 0.806977; lambda is 2.000000
Epoch 3900 test accuracy 0.804975
Epoch 3910 train accuracy 0.808663; lambda is 2.000000
Epoch 3910 test accuracy 0.806412
Epoch 3920 train accu

Epoch 4740 train accuracy 0.808359; lambda is 2.000000
Epoch 4740 test accuracy 0.806191
Epoch 4750 train accuracy 0.808304; lambda is 2.000000
Epoch 4750 test accuracy 0.806191
Epoch 4760 train accuracy 0.807226; lambda is 2.000000
Epoch 4760 test accuracy 0.80586
Epoch 4770 train accuracy 0.807143; lambda is 2.000000
Epoch 4770 test accuracy 0.805528
Epoch 4780 train accuracy 0.808138; lambda is 2.000000
Epoch 4780 test accuracy 0.806412
Epoch 4790 train accuracy 0.808580; lambda is 2.000000
Epoch 4790 test accuracy 0.806965
Epoch 4800 train accuracy 0.807723; lambda is 2.000000
Epoch 4800 test accuracy 0.80597
Epoch 4810 train accuracy 0.806894; lambda is 2.000000
Epoch 4810 test accuracy 0.804865
Epoch 4820 train accuracy 0.806562; lambda is 2.000000
Epoch 4820 test accuracy 0.804754
Epoch 4830 train accuracy 0.806230; lambda is 2.000000
Epoch 4830 test accuracy 0.804422
Epoch 4840 train accuracy 0.807696; lambda is 2.000000
Epoch 4840 test accuracy 0.805417
Epoch 4850 train accura

Epoch 5670 train accuracy 0.808027; lambda is 2.000000
Epoch 5670 test accuracy 0.806412
Epoch 5680 train accuracy 0.808055; lambda is 2.000000
Epoch 5680 test accuracy 0.806081
Epoch 5690 train accuracy 0.807391; lambda is 2.000000
Epoch 5690 test accuracy 0.805638
Epoch 5700 train accuracy 0.807336; lambda is 2.000000
Epoch 5700 test accuracy 0.805086
Epoch 5710 train accuracy 0.807834; lambda is 2.000000
Epoch 5710 test accuracy 0.805638
Epoch 5720 train accuracy 0.807585; lambda is 2.000000
Epoch 5720 test accuracy 0.804865
Epoch 5730 train accuracy 0.807613; lambda is 2.000000
Epoch 5730 test accuracy 0.805528
Epoch 5740 train accuracy 0.808027; lambda is 2.000000
Epoch 5740 test accuracy 0.805528
Epoch 5750 train accuracy 0.806396; lambda is 2.000000
Epoch 5750 test accuracy 0.804091
Epoch 5760 train accuracy 0.806866; lambda is 2.000000
Epoch 5760 test accuracy 0.804865
Epoch 5770 train accuracy 0.807889; lambda is 2.000000
Epoch 5770 test accuracy 0.805086
Epoch 5780 train accu

Epoch 6600 train accuracy 0.807143; lambda is 2.000000
Epoch 6600 test accuracy 0.804422
Epoch 6610 train accuracy 0.808110; lambda is 2.000000
Epoch 6610 test accuracy 0.806081
Epoch 6620 train accuracy 0.808359; lambda is 2.000000
Epoch 6620 test accuracy 0.806191
Epoch 6630 train accuracy 0.807585; lambda is 2.000000
Epoch 6630 test accuracy 0.805528
Epoch 6640 train accuracy 0.807419; lambda is 2.000000
Epoch 6640 test accuracy 0.805417
Epoch 6650 train accuracy 0.807474; lambda is 2.000000
Epoch 6650 test accuracy 0.805528
Epoch 6660 train accuracy 0.807530; lambda is 2.000000
Epoch 6660 test accuracy 0.805417
Epoch 6670 train accuracy 0.807696; lambda is 2.000000
Epoch 6670 test accuracy 0.805528
Epoch 6680 train accuracy 0.808359; lambda is 2.000000
Epoch 6680 test accuracy 0.807297
Epoch 6690 train accuracy 0.808331; lambda is 2.000000
Epoch 6690 test accuracy 0.806744
Epoch 6700 train accuracy 0.807308; lambda is 2.000000
Epoch 6700 test accuracy 0.804865
Epoch 6710 train accu

Epoch 7520 train accuracy 0.808829; lambda is 2.000000
Epoch 7520 test accuracy 0.807739
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.398301; subspace 0.866774; full 0.866774
Epoch 7530 train accuracy 0.808635; lambda is 2.000000
Epoch 7530 test accuracy 0.808071
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.387676; subspace 0.832573; full 0.832573
Epoch 7540 train accuracy 0.807861; lambda is 2.000000
Epoch 7540 test accuracy 0.806744
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.405326; subspace 0.832988; full 0.832988
Epoch 7550 train accuracy 0.808635; lambda is 2.000000
Epoch 7550 test accuracy 0.806633
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.404415; subspace 0.807641; full 0.807641
Epoch 7560 train accuracy 0.808221; lambda is 2.000000
Epoch 7560 test accuracy 0.80597
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.391372; subspace 0.762855; full 0.762855
Epoch 

Epoch 7940 train accuracy 0.782984; lambda is 2.000000
Epoch 7940 test accuracy 0.779878
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.465308; subspace 0.465420; full 0.465420
Epoch 7950 train accuracy 0.783896; lambda is 2.000000
Epoch 7950 test accuracy 0.780431
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.471151; subspace 0.471264; full 0.471264
Epoch 7960 train accuracy 0.784228; lambda is 2.000000
Epoch 7960 test accuracy 0.78021
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.449065; subspace 0.449175; full 0.449175
Epoch 7970 train accuracy 0.783979; lambda is 2.000000
Epoch 7970 test accuracy 0.779989
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.456607; subspace 0.456744; full 0.456744
Epoch 7980 train accuracy 0.783260; lambda is 2.000000
Epoch 7980 test accuracy 0.779989
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432273; subspace 0.432371; full 0.432371
Epoch 

Epoch 8360 train accuracy 0.789120; lambda is 2.000000
Epoch 8360 test accuracy 0.783085
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.449412; subspace 0.449495; full 0.449495
Epoch 8370 train accuracy 0.789175; lambda is 2.000000
Epoch 8370 test accuracy 0.783416
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.440781; subspace 0.440864; full 0.440864
Epoch 8380 train accuracy 0.788982; lambda is 2.000000
Epoch 8380 test accuracy 0.783527
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.438150; subspace 0.438234; full 0.438234
Epoch 8390 train accuracy 0.788899; lambda is 2.000000
Epoch 8390 test accuracy 0.783858
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.429501; subspace 0.429593; full 0.429593
Epoch 8400 train accuracy 0.789037; lambda is 2.000000
Epoch 8400 test accuracy 0.78419
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.448537; subspace 0.448608; full 0.448608
Epoch 

Epoch 8780 train accuracy 0.792161; lambda is 2.000000
Epoch 8780 test accuracy 0.789939
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432912; subspace 0.432940; full 0.432940
Epoch 8790 train accuracy 0.792133; lambda is 2.000000
Epoch 8790 test accuracy 0.789718
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.433434; subspace 0.433487; full 0.433487
Epoch 8800 train accuracy 0.792548; lambda is 2.000000
Epoch 8800 test accuracy 0.790492
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.437198; subspace 0.437249; full 0.437249
Epoch 8810 train accuracy 0.792327; lambda is 2.000000
Epoch 8810 test accuracy 0.789829
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.446145; subspace 0.446193; full 0.446193
Epoch 8820 train accuracy 0.792492; lambda is 2.000000
Epoch 8820 test accuracy 0.79016
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.436007; subspace 0.436074; full 0.436074
Epoch 

Epoch 9200 train accuracy 0.793792; lambda is 2.000000
Epoch 9200 test accuracy 0.793256
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432997; subspace 0.433044; full 0.433044
Epoch 9210 train accuracy 0.794289; lambda is 2.000000
Epoch 9210 test accuracy 0.793919
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.423560; subspace 0.423587; full 0.423587
Epoch 9220 train accuracy 0.794953; lambda is 2.000000
Epoch 9220 test accuracy 0.794362
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.436946; subspace 0.437025; full 0.437025
Epoch 9230 train accuracy 0.795920; lambda is 2.000000
Epoch 9230 test accuracy 0.794362
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431908; subspace 0.431947; full 0.431947
Epoch 9240 train accuracy 0.796031; lambda is 2.000000
Epoch 9240 test accuracy 0.79414
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.427820; subspace 0.427884; full 0.427884
Epoch 

Epoch 9620 train accuracy 0.797496; lambda is 2.000000
Epoch 9620 test accuracy 0.794914
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431021; subspace 0.431051; full 0.431051
Epoch 9630 train accuracy 0.797855; lambda is 2.000000
Epoch 9630 test accuracy 0.794804
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.434372; subspace 0.434421; full 0.434421
Epoch 9640 train accuracy 0.797496; lambda is 2.000000
Epoch 9640 test accuracy 0.794583
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.436929; subspace 0.436945; full 0.436945
Epoch 9650 train accuracy 0.797661; lambda is 2.000000
Epoch 9650 test accuracy 0.794693
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.430822; subspace 0.430866; full 0.430866
Epoch 9660 train accuracy 0.797192; lambda is 2.000000
Epoch 9660 test accuracy 0.793809
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.438269; subspace 0.438314; full 0.438314
Epoch

Epoch 10040 train accuracy 0.797772; lambda is 2.000000
Epoch 10040 test accuracy 0.794362
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.424464; subspace 0.424521; full 0.424521
Epoch 10050 train accuracy 0.797910; lambda is 2.000000
Epoch 10050 test accuracy 0.794362
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.424722; subspace 0.424771; full 0.424771
Epoch 10060 train accuracy 0.798574; lambda is 2.000000
Epoch 10060 test accuracy 0.794914
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.429268; subspace 0.429312; full 0.429312
Epoch 10070 train accuracy 0.798601; lambda is 2.000000
Epoch 10070 test accuracy 0.794914
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.420350; subspace 0.420370; full 0.420370
Epoch 10080 train accuracy 0.798297; lambda is 2.000000
Epoch 10080 test accuracy 0.794693
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.428564; subspace 0.428589; full 0.42

Epoch 10450 train accuracy 0.799348; lambda is 2.000000
Epoch 10450 test accuracy 0.795688
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.435992; subspace 0.436025; full 0.436025
Epoch 10460 train accuracy 0.799541; lambda is 2.000000
Epoch 10460 test accuracy 0.795688
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.430533; subspace 0.430581; full 0.430581
Epoch 10470 train accuracy 0.799320; lambda is 2.000000
Epoch 10470 test accuracy 0.795799
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432900; subspace 0.432950; full 0.432950
Epoch 10480 train accuracy 0.799237; lambda is 2.000000
Epoch 10480 test accuracy 0.795135
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.426628; subspace 0.426685; full 0.426685
Epoch 10490 train accuracy 0.800592; lambda is 2.000000
Epoch 10490 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.428498; subspace 0.428577; full 0.42

Epoch 10860 train accuracy 0.800509; lambda is 2.000000
Epoch 10860 test accuracy 0.795688
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.423742; subspace 0.423795; full 0.423795
Epoch 10870 train accuracy 0.799735; lambda is 2.000000
Epoch 10870 test accuracy 0.795688
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.430403; subspace 0.430454; full 0.430454
Epoch 10880 train accuracy 0.800398; lambda is 2.000000
Epoch 10880 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.439519; subspace 0.439554; full 0.439554
Epoch 10890 train accuracy 0.800592; lambda is 2.000000
Epoch 10890 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.437445; subspace 0.437478; full 0.437478
Epoch 10900 train accuracy 0.799099; lambda is 2.000000
Epoch 10900 test accuracy 0.795357
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.429896; subspace 0.429966; full 0.42

Epoch 11270 train accuracy 0.800094; lambda is 2.000000
Epoch 11270 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.418313; subspace 0.418365; full 0.418365
Epoch 11280 train accuracy 0.799873; lambda is 2.000000
Epoch 11280 test accuracy 0.795357
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.439921; subspace 0.439978; full 0.439978
Epoch 11290 train accuracy 0.800149; lambda is 2.000000
Epoch 11290 test accuracy 0.795467
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.440088; subspace 0.440110; full 0.440110
Epoch 11300 train accuracy 0.799458; lambda is 2.000000
Epoch 11300 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.434910; subspace 0.434958; full 0.434958
Epoch 11310 train accuracy 0.798933; lambda is 2.000000
Epoch 11310 test accuracy 0.795357
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.426814; subspace 0.426842; full 0.42

Epoch 11680 train accuracy 0.801006; lambda is 2.000000
Epoch 11680 test accuracy 0.796352
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.435350; subspace 0.435354; full 0.435354
Epoch 11690 train accuracy 0.800149; lambda is 2.000000
Epoch 11690 test accuracy 0.795688
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.438087; subspace 0.438213; full 0.438213
Epoch 11700 train accuracy 0.800370; lambda is 2.000000
Epoch 11700 test accuracy 0.796241
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431470; subspace 0.431492; full 0.431492
Epoch 11710 train accuracy 0.800453; lambda is 2.000000
Epoch 11710 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.429668; subspace 0.429717; full 0.429717
Epoch 11720 train accuracy 0.800592; lambda is 2.000000
Epoch 11720 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.437613; subspace 0.437644; full 0.4376

Epoch 12090 train accuracy 0.800094; lambda is 2.000000
Epoch 12090 test accuracy 0.796241
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.424363; subspace 0.424412; full 0.424412
Epoch 12100 train accuracy 0.800840; lambda is 2.000000
Epoch 12100 test accuracy 0.796462
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.435725; subspace 0.435745; full 0.435745
Epoch 12110 train accuracy 0.801006; lambda is 2.000000
Epoch 12110 test accuracy 0.796683
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.425741; subspace 0.425788; full 0.425788
Epoch 12120 train accuracy 0.799873; lambda is 2.000000
Epoch 12120 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432666; subspace 0.432704; full 0.432704
Epoch 12130 train accuracy 0.799569; lambda is 2.000000
Epoch 12130 test accuracy 0.79602
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.439677; subspace 0.439690; full 0.4396

Epoch 12500 train accuracy 0.800066; lambda is 2.000000
Epoch 12500 test accuracy 0.797789
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.424877; subspace 0.424895; full 0.424895
Epoch 12510 train accuracy 0.799569; lambda is 2.000000
Epoch 12510 test accuracy 0.796904
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.433021; subspace 0.433039; full 0.433039
Epoch 12520 train accuracy 0.800785; lambda is 2.000000
Epoch 12520 test accuracy 0.797789
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.443605; subspace 0.443645; full 0.443645
Epoch 12530 train accuracy 0.801697; lambda is 2.000000
Epoch 12530 test accuracy 0.797678
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.440261; subspace 0.440303; full 0.440303
Epoch 12540 train accuracy 0.800177; lambda is 2.000000
Epoch 12540 test accuracy 0.797125
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.438239; subspace 0.438276; full 0.43

Epoch 12910 train accuracy 0.800094; lambda is 2.000000
Epoch 12910 test accuracy 0.796573
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.432944; subspace 0.433021; full 0.433021
Epoch 12920 train accuracy 0.799845; lambda is 2.000000
Epoch 12920 test accuracy 0.796573
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.447388; subspace 0.447417; full 0.447417
Epoch 12930 train accuracy 0.799762; lambda is 2.000000
Epoch 12930 test accuracy 0.796352
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.416196; subspace 0.416217; full 0.416217
Epoch 12940 train accuracy 0.800260; lambda is 2.000000
Epoch 12940 test accuracy 0.796683
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431044; subspace 0.431105; full 0.431105
Epoch 12950 train accuracy 0.801227; lambda is 2.000000
Epoch 12950 test accuracy 0.797678
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.428965; subspace 0.428992; full 0.42

Epoch 13320 train accuracy 0.799348; lambda is 2.000000
Epoch 13320 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.449172; subspace 0.449206; full 0.449206
Epoch 13330 train accuracy 0.799762; lambda is 2.000000
Epoch 13330 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.439811; subspace 0.439857; full 0.439857
Epoch 13340 train accuracy 0.800177; lambda is 2.000000
Epoch 13340 test accuracy 0.796683
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.426031; subspace 0.426063; full 0.426063
Epoch 13350 train accuracy 0.799928; lambda is 2.000000
Epoch 13350 test accuracy 0.796683
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.448991; subspace 0.449042; full 0.449042
Epoch 13360 train accuracy 0.800149; lambda is 2.000000
Epoch 13360 test accuracy 0.796904
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.421534; subspace 0.421574; full 0.4215

Epoch 13730 train accuracy 0.799541; lambda is 2.000000
Epoch 13730 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.424236; subspace 0.424258; full 0.424258
Epoch 13740 train accuracy 0.800343; lambda is 2.000000
Epoch 13740 test accuracy 0.797347
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.436952; subspace 0.437005; full 0.437005
Epoch 13750 train accuracy 0.800370; lambda is 2.000000
Epoch 13750 test accuracy 0.797347
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.441808; subspace 0.441820; full 0.441820
Epoch 13760 train accuracy 0.801200; lambda is 2.000000
Epoch 13760 test accuracy 0.797789
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.433878; subspace 0.433890; full 0.433890
Epoch 13770 train accuracy 0.800702; lambda is 2.000000
Epoch 13770 test accuracy 0.797457
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.437587; subspace 0.437633; full 0.437

Epoch 14140 train accuracy 0.799707; lambda is 2.000000
Epoch 14140 test accuracy 0.796683
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.422496; subspace 0.422516; full 0.422516
Epoch 14150 train accuracy 0.799818; lambda is 2.000000
Epoch 14150 test accuracy 0.796573
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.433313; subspace 0.433340; full 0.433340
Epoch 14160 train accuracy 0.800094; lambda is 2.000000
Epoch 14160 test accuracy 0.796573
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.448082; subspace 0.448095; full 0.448095
Epoch 14170 train accuracy 0.800011; lambda is 2.000000
Epoch 14170 test accuracy 0.797015
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431093; subspace 0.431180; full 0.431180
Epoch 14180 train accuracy 0.799514; lambda is 2.000000
Epoch 14180 test accuracy 0.796904
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.429309; subspace 0.429319; full 0.42

Epoch 14550 train accuracy 0.799458; lambda is 2.000000
Epoch 14550 test accuracy 0.795799
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.423278; subspace 0.423306; full 0.423306
Epoch 14560 train accuracy 0.799265; lambda is 2.000000
Epoch 14560 test accuracy 0.795578
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.423559; subspace 0.423601; full 0.423601
Epoch 14570 train accuracy 0.799237; lambda is 2.000000
Epoch 14570 test accuracy 0.795467
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431735; subspace 0.431781; full 0.431781
Epoch 14580 train accuracy 0.799596; lambda is 2.000000
Epoch 14580 test accuracy 0.79613
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.425087; subspace 0.425104; full 0.425104
Epoch 14590 train accuracy 0.800232; lambda is 2.000000
Epoch 14590 test accuracy 0.797236
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.437555; subspace 0.437601; full 0.437

Epoch 14960 train accuracy 0.800813; lambda is 2.000000
Epoch 14960 test accuracy 0.797678
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.425050; subspace 0.425083; full 0.425083
Epoch 14970 train accuracy 0.800232; lambda is 2.000000
Epoch 14970 test accuracy 0.796904
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.447379; subspace 0.447432; full 0.447432
Epoch 14980 train accuracy 0.798574; lambda is 2.000000
Epoch 14980 test accuracy 0.795025
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.431755; subspace 0.431763; full 0.431763
Epoch 14990 train accuracy 0.798740; lambda is 2.000000
Epoch 14990 test accuracy 0.795467
FAILED attacks: subspace 0; full 0; Nans after attack 0
Loss clean 0.422409; subspace 0.422587; full 0.422587

Final train accuracy 0.800343
Final test accuracy 0.796573


In [13]:
dataset_debiasing_train = dataset_orig_train.copy()
dataset_debiasing_train.labels = np.argmax(train_logits,axis = 1)

dataset_debiasing_test = dataset_orig_test.copy()
dataset_debiasing_test.labels = np.argmax(test_logits,axis = 1)

In [14]:
# Metrics for the dataset from plain model (without debiasing)
# parameters from paper but with more epochs
privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

display(Markdown("#### Plain model - without debiasing - dataset metrics"))
print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_train.mean_difference())
print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_test.mean_difference())

# Metrics for the dataset from model with debiasing
display(Markdown("#### Model - with debiasing - dataset metrics"))
metric_dataset_debiasing_train = BinaryLabelDatasetMetric(dataset_debiasing_train, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_debiasing_train.mean_difference())

metric_dataset_debiasing_test = BinaryLabelDatasetMetric(dataset_debiasing_test, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_debiasing_test.mean_difference())



display(Markdown("#### Plain model - without debiasing - classification metrics"))
print("Test set: Classification accuracy = %f" % classified_metric_nodebiasing_test.accuracy())

privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

TPR = classified_metric_nodebiasing_test.true_positive_rate()
TNR = classified_metric_nodebiasing_test.true_negative_rate()
bal_acc_nodebiasing_test = 0.5*(TPR+TNR)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms sex = %f" % gap_rms)
print("Test set: max gap rms sex = %f" % max_gap)
print("Test set: Balanced TPR = %f" % bal_acc_nodebiasing_test)

privileged_groups = [{'race': 1}]
unprivileged_groups = [{'race': 0}]

classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms race = %f" % gap_rms)
print("Test set: max gap rms race = %f" % max_gap)




display(Markdown("#### Model - with debiasing - classification metrics"))
privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

classified_metric_debiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_debiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)
print("Test set: Classification accuracy = %f" % classified_metric_debiasing_test.accuracy())
TPR = classified_metric_debiasing_test.true_positive_rate()
TNR = classified_metric_debiasing_test.true_negative_rate()
bal_acc_debiasing_test = 0.5*(TPR+TNR)

gap_rms, max_gap = compute_gap_RMS(classified_metric_debiasing_test)
print("Test set: gap rms sex = %f" % gap_rms)
print("Test set: max gap rms sex = %f" % max_gap)
print("Test set: Balanced TPR = %f" % bal_acc_debiasing_test)

privileged_groups = [{'race': 1}]
unprivileged_groups = [{'race': 0}]
classified_metric_debiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_debiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)
gap_rms, max_gap = compute_gap_RMS(classified_metric_debiasing_test)
print("Test set: gap rms race = %f" % gap_rms)
print("Test set: max gap rms race = %f" % max_gap)

#### Plain model - without debiasing - dataset metrics

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.300755
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.302490


#### Model - with debiasing - dataset metrics

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.146503
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.140729


#### Plain model - without debiasing - classification metrics

Test set: Classification accuracy = 0.807629
Test set: gap rms sex = 0.154314
Test set: max gap rms sex = 0.196815
Test set: Balanced TPR = 0.814855
Test set: gap rms race = 0.057446
Test set: max gap rms race = 0.078686


#### Model - with debiasing - classification metrics

Test set: Classification accuracy = 0.796573
Test set: gap rms sex = 0.049541
Test set: max gap rms sex = 0.060443
Test set: Balanced TPR = 0.792238
Test set: gap rms race = 0.033212
Test set: max gap rms race = 0.046837


#### SenSR

In [14]:
#get sensitive directions
weights, train_logits, test_logits  = SenSR.train_nn(X_train, y_sex_train, X_test = X_test, y_test = y_sex_test, n_units=[], l2_reg=1., batch_size=5000, epoch=5000, verbose=True)

sensitive_directions = []
sensitive_directions.append(weights[0].T)

sensitive_directions = np.vstack(sensitive_directions)
tSVD = TruncatedSVD(n_components=2)
tSVD.fit(sensitive_directions)
sensitive_directions = tSVD.components_


Epoch 0 train accuracy 0.354701
Epoch 0 test accuracy 0.352018

Epoch 10 train accuracy 0.356138
Epoch 10 test accuracy 0.355998

Epoch 20 train accuracy 0.356912
Epoch 20 test accuracy 0.357656

Epoch 30 train accuracy 0.358460
Epoch 30 test accuracy 0.358209

Epoch 40 train accuracy 0.359731
Epoch 40 test accuracy 0.359536

Epoch 50 train accuracy 0.362136
Epoch 50 test accuracy 0.361747

Epoch 60 train accuracy 0.364265
Epoch 60 test accuracy 0.364179

Epoch 70 train accuracy 0.367582
Epoch 70 test accuracy 0.368049

Epoch 80 train accuracy 0.370816
Epoch 80 test accuracy 0.371144

Epoch 90 train accuracy 0.374077
Epoch 90 test accuracy 0.375567

Epoch 100 train accuracy 0.378804
Epoch 100 test accuracy 0.378662

Epoch 110 train accuracy 0.383946
Epoch 110 test accuracy 0.384411

Epoch 120 train accuracy 0.388728
Epoch 120 test accuracy 0.388944

Epoch 130 train accuracy 0.394947
Epoch 130 test accuracy 0.393809

Epoch 140 train accuracy 0.402604
Epoch 140 test accuracy 0.400663

E

In [15]:
# apply SenSR

weights, train_logits, test_logits  = SenSR.train_fair_nn(
    X_train, 
    y_train, 
    sensitive_directions, 
    X_test=X_test, 
    y_test=y_test, 
    n_units = [], 
    lr=0.001, 
    batch_size=5000, 
    epoch=15000, 
    verbose=True, 
    l2_reg=0., 
    lamb_init=2., 
    subspace_epoch=15, 
    subspace_step=1, 
    eps=.001, 
    full_step=.0001, 
    full_epoch=25)

Epoch 0 train accuracy 0.503884; lambda is 2.000000
Epoch 0 test accuracy 0.50901
Epoch 10 train accuracy 0.535368; lambda is 2.000000
Epoch 10 test accuracy 0.537092
Epoch 20 train accuracy 0.562678; lambda is 2.000000
Epoch 20 test accuracy 0.56628
Epoch 30 train accuracy 0.588744; lambda is 2.000000
Epoch 30 test accuracy 0.593256
Epoch 40 train accuracy 0.613014; lambda is 2.000000
Epoch 40 test accuracy 0.616915
Epoch 50 train accuracy 0.638583; lambda is 2.000000
Epoch 50 test accuracy 0.641238
Epoch 60 train accuracy 0.664981; lambda is 2.000000
Epoch 60 test accuracy 0.667993
Epoch 70 train accuracy 0.685297; lambda is 2.000000
Epoch 70 test accuracy 0.685462
Epoch 80 train accuracy 0.699312; lambda is 2.000000
Epoch 80 test accuracy 0.701935
Epoch 90 train accuracy 0.710590; lambda is 2.000000
Epoch 90 test accuracy 0.713433
Epoch 100 train accuracy 0.718633; lambda is 2.000000
Epoch 100 test accuracy 0.720619
Epoch 110 train accuracy 0.724936; lambda is 2.000000
Epoch 110 tes

In [15]:
dataset_debiasing_train = dataset_orig_train.copy()
dataset_debiasing_train.labels = np.argmax(train_logits,axis = 1)

dataset_debiasing_test = dataset_orig_test.copy()
dataset_debiasing_test.labels = np.argmax(test_logits,axis = 1)

In [16]:
# Metrics for the dataset from plain model (without debiasing)
privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

display(Markdown("#### Plain model - without debiasing - dataset metrics"))
print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_train.mean_difference())
print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_nodebiasing_test.mean_difference())

# Metrics for the dataset from model with debiasing
display(Markdown("#### Model - with debiasing - dataset metrics"))
metric_dataset_debiasing_train = BinaryLabelDatasetMetric(dataset_debiasing_train, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Train set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_debiasing_train.mean_difference())

metric_dataset_debiasing_test = BinaryLabelDatasetMetric(dataset_debiasing_test, 
                                             unprivileged_groups=unprivileged_groups,
                                             privileged_groups=privileged_groups)

print("Test set: Difference in mean outcomes between unprivileged and privileged groups = %f" % metric_dataset_debiasing_test.mean_difference())



display(Markdown("#### Plain model - without debiasing - classification metrics"))
print("Test set: Classification accuracy = %f" % classified_metric_nodebiasing_test.accuracy())
privileged_groups = [{'sex': 1}]
unprivileged_groups = [{'sex': 0}]

classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

TPR = classified_metric_nodebiasing_test.true_positive_rate()
TNR = classified_metric_nodebiasing_test.true_negative_rate()
bal_acc_nodebiasing_test = 0.5*(TPR+TNR)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms sex = %f" % gap_rms)
print("Test set: max gap rms sex = %f" % max_gap)
print("Test set: Balanced TPR = %f" % bal_acc_nodebiasing_test)

privileged_groups = [{'race': 1}]
unprivileged_groups = [{'race': 0}]

classified_metric_nodebiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_nodebiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

gap_rms, max_gap = compute_gap_RMS(classified_metric_nodebiasing_test)
print("Test set: gap rms race = %f" % gap_rms)
print("Test set: max gap rms race = %f" % max_gap)



display(Markdown("#### Model - with debiasing - classification metrics"))
classified_metric_debiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_debiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)
print("Test set: Classification accuracy = %f" % classified_metric_debiasing_test.accuracy())
TPR = classified_metric_debiasing_test.true_positive_rate()
TNR = classified_metric_debiasing_test.true_negative_rate()
bal_acc_debiasing_test = 0.5*(TPR+TNR)

gap_rms, max_gap = compute_gap_RMS(classified_metric_debiasing_test)
print("Test set: gap rms sex = %f" % gap_rms)
print("Test set: max gap rms sex = %f" % max_gap)
print("Test set: Balanced TPR = %f" % bal_acc_debiasing_test)

privileged_groups = [{'race': 1}]
unprivileged_groups = [{'race': 0}]
classified_metric_debiasing_test = ClassificationMetric(dataset_orig_test, 
                                                 dataset_debiasing_test,
                                                 unprivileged_groups=unprivileged_groups,
                                                 privileged_groups=privileged_groups)

print("Test set: gap rms race = %f" % gap_rms)
print("Test set: max gap rms race = %f" % max_gap)

#### Plain model - without debiasing - dataset metrics

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.300755
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.302490


#### Model - with debiasing - dataset metrics

Train set: Difference in mean outcomes between unprivileged and privileged groups = -0.146503
Test set: Difference in mean outcomes between unprivileged and privileged groups = -0.140729


#### Plain model - without debiasing - classification metrics

Test set: Classification accuracy = 0.807629
Test set: gap rms sex = 0.154314
Test set: max gap rms sex = 0.196815
Test set: Balanced TPR = 0.814855
Test set: gap rms race = 0.057446
Test set: max gap rms race = 0.078686


#### Model - with debiasing - classification metrics

Test set: Classification accuracy = 0.796573
Test set: gap rms sex = 0.033212
Test set: max gap rms sex = 0.046837
Test set: Balanced TPR = 0.792238
Test set: gap rms race = 0.033212
Test set: max gap rms race = 0.046837
