# Reproduction of Adult dataset experiments

In this notebook we reproduce the results from Table 2 of the DECAF paper. We compare various methods for generating debiased data using the DECAF model against synthetic data generated using benchmark models GAN, WGAN-GP and FairGAN. As described in the paper we run all experiments (as implemented in this notebook) 10 times and avarage the results.

In [1]:
from sklearn.metrics import precision_score, recall_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

from data import load_adult, preprocess_adult
from metrics import DP, FTU
from train import train_decaf, train_fairgan, train_vanilla_gan, train_wgan_gp


## Loading data

In [2]:
dataset = load_adult()
dataset.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K
1,50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,13,United-States,<=50K
2,38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,White,Male,0,0,40,United-States,<=50K
3,53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,Black,Male,0,0,40,United-States,<=50K
4,28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,Black,Female,0,0,40,Cuba,<=50K


Preprocess the data next in order to make it suitable for training models on.

In [3]:
dataset = preprocess_adult(dataset)
dataset.head()

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,0.30137,0.833333,0.043338,0.0,0.8,0.333333,0.615385,0.6,0.0,1.0,0.02174,0.0,0.397959,0.0,1.0
1,0.452055,0.166667,0.047277,0.0,0.8,0.0,0.307692,0.4,0.0,1.0,0.0,0.0,0.122449,0.0,1.0
2,0.287671,0.0,0.137244,0.2,0.533333,0.166667,0.461538,0.6,0.0,1.0,0.0,0.0,0.397959,0.0,1.0
3,0.493151,0.0,0.150212,0.133333,0.4,0.0,0.461538,0.4,1.0,1.0,0.0,0.0,0.397959,0.0,1.0
4,0.150685,0.0,0.220703,0.0,0.8,0.0,0.384615,0.0,1.0,0.0,0.0,0.0,0.397959,0.3,1.0


Split the dataset into train and test folds. Test fold size is 2000.

In [4]:
# Split data into train and testing sets
dataset_train, dataset_test = train_test_split(dataset, test_size=2000,
                                               stratify=dataset['income'])

print('Size of train set:', len(dataset_train))
print('Size of test set:', len(dataset_test))

Size of train set: 28162
Size of test set: 2000


### Defining the DAG

We need to define a DAG which captures the biases of the dataset. As described in the DECAF paper normally a causal discovery algorithm is used. In this notebook we simply copy the DAG which as described in the Zhang et al. paper which is the one also used in the DECAF paper.

In [5]:
# Define DAG for Adult dataset
dag = [
    # Edges from race
    ['race', 'occupation'],
    ['race', 'income'],
    ['race', 'hours-per-week'],
    ['race', 'education'],
    ['race', 'marital-status'],

    # Edges from age
    ['age', 'occupation'],
    ['age', 'hours-per-week'],
    ['age', 'income'],
    ['age', 'workclass'],
    ['age', 'marital-status'],
    ['age', 'education'],
    ['age', 'relationship'],
    
    # Edges from sex
    ['sex', 'occupation'],
    ['sex', 'marital-status'],
    ['sex', 'income'],
    ['sex', 'workclass'],
    ['sex', 'education'],
    ['sex', 'relationship'],
    
    # Edges from native country
    ['native-country', 'marital-status'],
    ['native-country', 'hours-per-week'],
    ['native-country', 'education'],
    ['native-country', 'workclass'],
    ['native-country', 'income'],
    ['native-country', 'relationship'],
    
    # Edges from marital status
    ['marital-status', 'occupation'],
    ['marital-status', 'hours-per-week'],
    ['marital-status', 'income'],
    ['marital-status', 'workclass'],
    ['marital-status', 'relationship'],
    ['marital-status', 'education'],
    
    # Edges from education
    ['education', 'occupation'],
    ['education', 'hours-per-week'],
    ['education', 'income'],
    ['education', 'workclass'],
    ['education', 'relationship'],
    
    # All remaining edges
    ['occupation', 'income'],
    ['hours-per-week', 'income'],
    ['workclass', 'income'],
    ['relationship', 'income'],
]

def dag_to_idx(df, dag):
    """Convert columns in a DAG to the corresponding indices."""

    dag_idx = []
    for edge in dag:
        dag_idx.append([df.columns.get_loc(edge[0]), df.columns.get_loc(edge[1])])

    return dag_idx

# Convert the DAG to one that can be provided to the DECAF model
dag_seed = dag_to_idx(dataset, dag)
print(dag_seed)

[[8, 6], [8, 14], [8, 12], [8, 3], [8, 5], [0, 6], [0, 12], [0, 14], [0, 1], [0, 5], [0, 3], [0, 7], [9, 6], [9, 5], [9, 14], [9, 1], [9, 3], [9, 7], [13, 5], [13, 12], [13, 3], [13, 1], [13, 14], [13, 7], [5, 6], [5, 12], [5, 14], [5, 1], [5, 7], [5, 3], [3, 6], [3, 12], [3, 14], [3, 1], [3, 7], [6, 14], [12, 14], [1, 14], [7, 14]]


It's also necessary to define edges we want to remove from the DAG in order to meet the various fairness criteria described in the paper.

In [6]:
def create_bias_dict(df, edge_map):
    """
    Convert the given edge tuples to a bias dict used for generating
    debiased synthetic data.
    """
    bias_dict = {}
    for key, val in edge_map.items():
        bias_dict[df.columns.get_loc(key)] = [df.columns.get_loc(f) for f in val]
    
    return bias_dict

# Bias dictionary to satisfy FTU
bias_dict_ftu = create_bias_dict(dataset, {'income': ['sex']})
print('Bias dict FTU:', bias_dict_ftu)

# Bias dictionary to satisfy DP
bias_dict_dp = create_bias_dict(dataset, {'income': [
    'occupation', 'hours-per-week', 'marital-status', 'education', 'sex',
    'workclass', 'relationship']})
print('Bias dict DP:', bias_dict_dp)

# Bias dictionary to satisfy CF
bias_dict_cf = create_bias_dict(dataset, {'income': [
    'marital-status', 'sex']})
print('Bias dict CF:', bias_dict_cf)

Bias dict FTU: {14: [9]}
Bias dict DP: {14: [6, 12, 5, 3, 9, 1, 7]}
Bias dict CF: {14: [5, 9]}


## Experiments

We have loaded and preprocessed the data and we are ready to run the experiments. For each experiment we train a generative model, sample synthetic data from the trained model and then obtain metrics by training and evaluating a downstream multi-layer perceptron using the test fold we generated in the previous section. We use the MLP model from `sklearn` with default parameters which matches the settings described in Appendix D of the paper.

In [7]:
def eval_model(dataset_train, dataset_test):
    """Helper function that prints evaluation metrics."""

    X_train, y_train = dataset_train.drop(columns=['income']), dataset_train['income']
    X_test, y_test = dataset_test.drop(columns=['income']), dataset_test['income']

    clf = MLPClassifier()
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_test)

    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    auroc = roc_auc_score(y_test, y_pred)
    dp = DP(clf, X_test)
    ftu = FTU(clf, X_test)

    return {'precision': precision, 'recall': recall, 'auroc': auroc,
            'dp': dp, 'ftu': ftu}

### Original dataset

As a benchmark we want to first train the downstream model on the original dataset.

In [8]:
eval_model(dataset_train, dataset_test)

{'precision': 0.8734977862112587,
 'recall': 0.9194407456724367,
 'auroc': 0.7589171599848128,
 'dp': 0.187192618718257,
 'ftu': 0.035499999999999976}

In the following sections we train various models in order to reproduce the results from Table 2 of the DECAF paper.

### GAN

In [9]:
synth_data = train_vanilla_gan(dataset_train)
synth_data.head()

2022-01-27 14:44:21.123162: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-01-27 14:44:21.210210: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  0%|          | 0/50 [00:00<?, ?it/s]2022-01-27 14:44:22.126124: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
  2%|▏         | 1/50 [00:09<07:47,  9.55s/it]

0 [D loss: 0.000225, acc.: 100.00%] [G loss: 29.921358]
generated_data


  4%|▍         | 2/50 [00:17<06:52,  8.59s/it]

1 [D loss: 0.000011, acc.: 100.00%] [G loss: 53.890396]


  6%|▌         | 3/50 [00:26<06:49,  8.70s/it]

2 [D loss: 0.000001, acc.: 100.00%] [G loss: 71.640816]


  8%|▊         | 4/50 [00:34<06:23,  8.34s/it]

3 [D loss: 0.000001, acc.: 100.00%] [G loss: 94.739578]


 10%|█         | 5/50 [00:42<06:14,  8.32s/it]

4 [D loss: 0.000007, acc.: 100.00%] [G loss: 102.364059]


 12%|█▏        | 6/50 [00:50<06:01,  8.22s/it]

5 [D loss: 0.000000, acc.: 100.00%] [G loss: 121.709381]


 14%|█▍        | 7/50 [00:59<06:00,  8.39s/it]

6 [D loss: 0.000001, acc.: 100.00%] [G loss: 130.072754]


 16%|█▌        | 8/50 [01:06<05:44,  8.20s/it]

7 [D loss: 0.000000, acc.: 100.00%] [G loss: 143.746796]


 18%|█▊        | 9/50 [01:15<05:42,  8.36s/it]

8 [D loss: 0.000000, acc.: 100.00%] [G loss: 154.271729]


 20%|██        | 10/50 [01:24<05:35,  8.39s/it]

9 [D loss: 0.000036, acc.: 100.00%] [G loss: 152.255554]


 22%|██▏       | 11/50 [01:32<05:29,  8.44s/it]

10 [D loss: 0.000000, acc.: 100.00%] [G loss: 169.953400]


 24%|██▍       | 12/50 [01:40<05:17,  8.37s/it]

11 [D loss: 0.000000, acc.: 100.00%] [G loss: 164.079071]


 26%|██▌       | 13/50 [01:49<05:11,  8.43s/it]

12 [D loss: 0.000000, acc.: 100.00%] [G loss: 186.576538]


 28%|██▊       | 14/50 [01:57<05:02,  8.41s/it]

13 [D loss: 0.000001, acc.: 100.00%] [G loss: 178.147430]


 30%|███       | 15/50 [02:05<04:51,  8.33s/it]

14 [D loss: 0.000002, acc.: 100.00%] [G loss: 176.505493]


 32%|███▏      | 16/50 [02:13<04:39,  8.22s/it]

15 [D loss: 0.000000, acc.: 100.00%] [G loss: 194.578262]


 34%|███▍      | 17/50 [02:22<04:32,  8.26s/it]

16 [D loss: 0.000000, acc.: 100.00%] [G loss: 188.591339]


 36%|███▌      | 18/50 [02:30<04:27,  8.35s/it]

17 [D loss: 0.000000, acc.: 100.00%] [G loss: 208.471817]


 38%|███▊      | 19/50 [02:39<04:19,  8.36s/it]

18 [D loss: 0.000000, acc.: 100.00%] [G loss: 204.191177]


 40%|████      | 20/50 [02:47<04:11,  8.38s/it]

19 [D loss: 0.000000, acc.: 100.00%] [G loss: 225.689285]


 42%|████▏     | 21/50 [02:55<04:00,  8.29s/it]

20 [D loss: 0.000000, acc.: 100.00%] [G loss: 227.029846]


 44%|████▍     | 22/50 [03:04<03:55,  8.42s/it]

21 [D loss: 0.000000, acc.: 100.00%] [G loss: 244.247330]


 46%|████▌     | 23/50 [03:12<03:45,  8.35s/it]

22 [D loss: 0.000656, acc.: 100.00%] [G loss: 229.735611]


 48%|████▊     | 24/50 [03:21<03:39,  8.43s/it]

23 [D loss: 0.000000, acc.: 100.00%] [G loss: 229.075867]


 50%|█████     | 25/50 [03:30<03:34,  8.59s/it]

24 [D loss: 0.000000, acc.: 100.00%] [G loss: 238.755997]


 52%|█████▏    | 26/50 [03:38<03:23,  8.49s/it]

25 [D loss: 0.000000, acc.: 100.00%] [G loss: 243.529373]


 54%|█████▍    | 27/50 [03:46<03:10,  8.27s/it]

26 [D loss: 0.000000, acc.: 100.00%] [G loss: 244.447342]


 56%|█████▌    | 28/50 [03:53<02:58,  8.12s/it]

27 [D loss: 0.000007, acc.: 100.00%] [G loss: 262.647095]


 58%|█████▊    | 29/50 [04:01<02:48,  8.01s/it]

28 [D loss: 0.000000, acc.: 100.00%] [G loss: 271.352539]


 60%|██████    | 30/50 [04:09<02:38,  7.94s/it]

29 [D loss: 0.000000, acc.: 100.00%] [G loss: 269.649933]


 62%|██████▏   | 31/50 [04:17<02:33,  8.10s/it]

30 [D loss: 0.000000, acc.: 100.00%] [G loss: 258.833191]


 64%|██████▍   | 32/50 [04:26<02:27,  8.19s/it]

31 [D loss: 0.000000, acc.: 100.00%] [G loss: 274.367767]


 66%|██████▌   | 33/50 [04:34<02:20,  8.25s/it]

32 [D loss: 0.000000, acc.: 100.00%] [G loss: 293.894714]


 68%|██████▊   | 34/50 [04:43<02:12,  8.30s/it]

33 [D loss: 0.000000, acc.: 100.00%] [G loss: 264.538696]


 70%|███████   | 35/50 [04:51<02:03,  8.22s/it]

34 [D loss: 0.000000, acc.: 100.00%] [G loss: 304.273743]


 72%|███████▏  | 36/50 [04:59<01:55,  8.28s/it]

35 [D loss: 0.000000, acc.: 100.00%] [G loss: 263.536804]


 74%|███████▍  | 37/50 [05:08<01:48,  8.38s/it]

36 [D loss: 0.000000, acc.: 100.00%] [G loss: 279.251465]


 76%|███████▌  | 38/50 [05:16<01:40,  8.35s/it]

37 [D loss: 0.000000, acc.: 100.00%] [G loss: 282.881561]


 78%|███████▊  | 39/50 [05:24<01:30,  8.24s/it]

38 [D loss: 0.000000, acc.: 100.00%] [G loss: 274.675812]


 80%|████████  | 40/50 [05:32<01:21,  8.17s/it]

39 [D loss: 0.000009, acc.: 100.00%] [G loss: 325.794495]


 82%|████████▏ | 41/50 [05:40<01:13,  8.12s/it]

40 [D loss: 0.000000, acc.: 100.00%] [G loss: 316.917847]


 84%|████████▍ | 42/50 [05:48<01:05,  8.16s/it]

41 [D loss: 0.000000, acc.: 100.00%] [G loss: 292.149261]


 86%|████████▌ | 43/50 [05:57<00:57,  8.21s/it]

42 [D loss: 0.000000, acc.: 100.00%] [G loss: 306.318481]


 88%|████████▊ | 44/50 [06:05<00:49,  8.22s/it]

43 [D loss: 0.000000, acc.: 100.00%] [G loss: 311.565857]


 90%|█████████ | 45/50 [06:13<00:40,  8.14s/it]

44 [D loss: 0.000000, acc.: 100.00%] [G loss: 313.791809]


 92%|█████████▏| 46/50 [06:21<00:32,  8.12s/it]

45 [D loss: 0.000000, acc.: 100.00%] [G loss: 334.200684]


 94%|█████████▍| 47/50 [06:29<00:24,  8.24s/it]

46 [D loss: 0.000000, acc.: 100.00%] [G loss: 326.521027]


 96%|█████████▌| 48/50 [06:43<00:19,  9.97s/it]

47 [D loss: 0.000000, acc.: 100.00%] [G loss: 316.824432]


 98%|█████████▊| 49/50 [06:52<00:09,  9.60s/it]

48 [D loss: 0.000000, acc.: 100.00%] [G loss: 316.359558]


100%|██████████| 50/50 [07:01<00:00,  8.43s/it]


49 [D loss: 0.000000, acc.: 100.00%] [G loss: 322.851257]


Synthetic data generation: 100%|██████████| 221/221 [00:02<00:00, 96.37it/s]


Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,0.001865,0.333333,0.006739,0.133333,0.533333,0.0,0.153846,0.0,0.25,0.0,-0.005399,0.007568,0.010692,0.775,1.0
1,0.001958,1.0,0.006785,0.4,0.2,0.5,0.769231,0.6,1.0,0.0,-0.005418,0.007586,0.010761,0.3,1.0
2,0.003629,0.5,0.007598,0.933333,0.0,0.333333,1.0,0.2,0.5,1.0,-0.005579,0.008428,0.011905,0.5,0.0
3,-0.005068,0.166667,0.001333,0.666667,0.4,0.833333,0.076923,0.2,1.0,1.0,-0.002291,0.006366,0.005182,0.075,0.0
4,-0.003677,0.5,0.003354,0.266667,0.0,0.666667,0.461538,0.4,0.5,1.0,-0.003649,0.006657,0.006434,0.2,1.0


In [10]:
eval_model(synth_data, dataset_test)

{'precision': 0.7497343251859724,
 'recall': 0.9394141145139814,
 'auroc': 0.49681549099193045,
 'dp': 0.015758018290198073,
 'ftu': 0.027999999999999914}

### WGAN-GP

In [11]:
synth_data = train_wgan_gp(dataset_train)
synth_data.head()

  2%|▏         | 1/50 [00:05<04:18,  5.28s/it]

Epoch: 0 | disc_loss: 0.7107614278793335 | gen_loss: -0.029137199744582176


  4%|▍         | 2/50 [00:08<03:14,  4.05s/it]

Epoch: 1 | disc_loss: 0.3214643895626068 | gen_loss: 0.0006649568094871938


  6%|▌         | 3/50 [00:11<02:47,  3.57s/it]

Epoch: 2 | disc_loss: 0.06803859770298004 | gen_loss: -0.01034797914326191


  8%|▊         | 4/50 [00:14<02:35,  3.39s/it]

Epoch: 3 | disc_loss: 0.3769150376319885 | gen_loss: -0.000423622434027493


 10%|█         | 5/50 [00:17<02:26,  3.25s/it]

Epoch: 4 | disc_loss: 0.6953859329223633 | gen_loss: 0.02071164920926094


 12%|█▏        | 6/50 [00:20<02:20,  3.19s/it]

Epoch: 5 | disc_loss: 1.1991130113601685 | gen_loss: 0.07307153940200806


 14%|█▍        | 7/50 [00:23<02:10,  3.04s/it]

Epoch: 6 | disc_loss: 0.18188615143299103 | gen_loss: 0.027091026306152344


 16%|█▌        | 8/50 [00:26<02:03,  2.95s/it]

Epoch: 7 | disc_loss: 0.8241662979125977 | gen_loss: 0.054655563086271286


 18%|█▊        | 9/50 [00:29<02:02,  2.98s/it]

Epoch: 8 | disc_loss: 0.00047989562153816223 | gen_loss: 0.05843598023056984


 20%|██        | 10/50 [00:32<01:57,  2.94s/it]

Epoch: 9 | disc_loss: 4.592607498168945 | gen_loss: 0.04178837686777115


 22%|██▏       | 11/50 [00:34<01:52,  2.89s/it]

Epoch: 10 | disc_loss: 1.2506260871887207 | gen_loss: 0.04159500077366829


 24%|██▍       | 12/50 [00:37<01:49,  2.89s/it]

Epoch: 11 | disc_loss: 0.1962561011314392 | gen_loss: 0.10931268334388733


 26%|██▌       | 13/50 [00:40<01:48,  2.93s/it]

Epoch: 12 | disc_loss: 0.003349415957927704 | gen_loss: 0.04899400845170021


 28%|██▊       | 14/50 [00:43<01:46,  2.95s/it]

Epoch: 13 | disc_loss: -0.05251096189022064 | gen_loss: 0.06490091979503632


 30%|███       | 15/50 [00:47<01:47,  3.07s/it]

Epoch: 14 | disc_loss: -0.061463210731744766 | gen_loss: 0.0759325847029686


 32%|███▏      | 16/50 [00:50<01:45,  3.10s/it]

Epoch: 15 | disc_loss: -0.03381440043449402 | gen_loss: 0.09221752732992172


 34%|███▍      | 17/50 [00:53<01:43,  3.13s/it]

Epoch: 16 | disc_loss: -0.06201101467013359 | gen_loss: 0.09386336803436279


 36%|███▌      | 18/50 [00:56<01:42,  3.20s/it]

Epoch: 17 | disc_loss: -0.05195339024066925 | gen_loss: 0.06919596344232559


 38%|███▊      | 19/50 [01:00<01:39,  3.21s/it]

Epoch: 18 | disc_loss: -0.07568386942148209 | gen_loss: 0.09850286692380905


 40%|████      | 20/50 [01:02<01:31,  3.04s/it]

Epoch: 19 | disc_loss: -0.07193577289581299 | gen_loss: 0.10080032050609589


 42%|████▏     | 21/50 [01:05<01:23,  2.89s/it]

Epoch: 20 | disc_loss: 0.15693318843841553 | gen_loss: 0.07675142586231232


 44%|████▍     | 22/50 [01:07<01:18,  2.80s/it]

Epoch: 21 | disc_loss: -0.06660167127847672 | gen_loss: 0.027654679492115974


 46%|████▌     | 23/50 [01:10<01:14,  2.75s/it]

Epoch: 22 | disc_loss: -0.051281657069921494 | gen_loss: 0.1011628583073616


 48%|████▊     | 24/50 [01:13<01:11,  2.77s/it]

Epoch: 23 | disc_loss: -0.012521311640739441 | gen_loss: 0.08838582038879395


 50%|█████     | 25/50 [01:16<01:12,  2.88s/it]

Epoch: 24 | disc_loss: 0.03849164396524429 | gen_loss: 0.11807141453027725


 52%|█████▏    | 26/50 [01:19<01:10,  2.94s/it]

Epoch: 25 | disc_loss: -0.018544167280197144 | gen_loss: 0.10829099267721176


 54%|█████▍    | 27/50 [01:22<01:08,  2.97s/it]

Epoch: 26 | disc_loss: -0.08056782931089401 | gen_loss: 0.10898259282112122


 56%|█████▌    | 28/50 [01:25<01:04,  2.92s/it]

Epoch: 27 | disc_loss: -0.0723012164235115 | gen_loss: 0.12350492924451828


 58%|█████▊    | 29/50 [01:28<01:00,  2.90s/it]

Epoch: 28 | disc_loss: 0.04208127409219742 | gen_loss: 0.14741067588329315


 60%|██████    | 30/50 [01:30<00:56,  2.81s/it]

Epoch: 29 | disc_loss: 0.5813738703727722 | gen_loss: 0.10455050319433212


 62%|██████▏   | 31/50 [01:33<00:52,  2.77s/it]

Epoch: 30 | disc_loss: -0.08876527100801468 | gen_loss: 0.09489351511001587


 64%|██████▍   | 32/50 [01:36<00:50,  2.78s/it]

Epoch: 31 | disc_loss: -0.07724887132644653 | gen_loss: 0.11249829828739166


 66%|██████▌   | 33/50 [01:39<00:48,  2.87s/it]

Epoch: 32 | disc_loss: -0.09586035460233688 | gen_loss: 0.13413295149803162


 68%|██████▊   | 34/50 [01:41<00:44,  2.77s/it]

Epoch: 33 | disc_loss: -0.08317115902900696 | gen_loss: 0.13778020441532135


 70%|███████   | 35/50 [01:44<00:42,  2.81s/it]

Epoch: 34 | disc_loss: -0.09284806996583939 | gen_loss: 0.1256261169910431


 72%|███████▏  | 36/50 [01:47<00:40,  2.87s/it]

Epoch: 35 | disc_loss: -0.09428723156452179 | gen_loss: 0.14041867852210999


 74%|███████▍  | 37/50 [01:51<00:39,  3.06s/it]

Epoch: 36 | disc_loss: -0.08539609611034393 | gen_loss: 0.12367858737707138


 76%|███████▌  | 38/50 [01:54<00:36,  3.04s/it]

Epoch: 37 | disc_loss: -0.09341955184936523 | gen_loss: 0.14552739262580872


 78%|███████▊  | 39/50 [01:57<00:32,  2.97s/it]

Epoch: 38 | disc_loss: -0.09714009612798691 | gen_loss: 0.12519752979278564


 80%|████████  | 40/50 [02:00<00:29,  2.98s/it]

Epoch: 39 | disc_loss: -0.1005224883556366 | gen_loss: 0.11282118409872055


 82%|████████▏ | 41/50 [02:02<00:25,  2.87s/it]

Epoch: 40 | disc_loss: 0.2775387167930603 | gen_loss: 0.11147604882717133


 84%|████████▍ | 42/50 [02:05<00:21,  2.74s/it]

Epoch: 41 | disc_loss: -0.09395670890808105 | gen_loss: 0.13231906294822693


 86%|████████▌ | 43/50 [02:07<00:18,  2.65s/it]

Epoch: 42 | disc_loss: -0.056405067443847656 | gen_loss: 0.13196110725402832


 88%|████████▊ | 44/50 [02:10<00:15,  2.62s/it]

Epoch: 43 | disc_loss: -0.04155673086643219 | gen_loss: 0.09076317399740219


 90%|█████████ | 45/50 [02:12<00:13,  2.64s/it]

Epoch: 44 | disc_loss: -0.10144925117492676 | gen_loss: 0.13050000369548798


 92%|█████████▏| 46/50 [02:15<00:10,  2.70s/it]

Epoch: 45 | disc_loss: -0.09317844361066818 | gen_loss: 0.12488552927970886


 94%|█████████▍| 47/50 [02:18<00:08,  2.68s/it]

Epoch: 46 | disc_loss: -0.10353392362594604 | gen_loss: 0.13082721829414368


 96%|█████████▌| 48/50 [02:20<00:05,  2.68s/it]

Epoch: 47 | disc_loss: -0.09945110976696014 | gen_loss: 0.14057742059230804


 98%|█████████▊| 49/50 [02:24<00:02,  2.82s/it]

Epoch: 48 | disc_loss: -0.08630061894655228 | gen_loss: 0.010011428967118263


100%|██████████| 50/50 [02:27<00:00,  2.94s/it]


Epoch: 49 | disc_loss: -0.10573866218328476 | gen_loss: 0.1532163769006729


Synthetic data generation: 100%|██████████| 57/57 [00:00<00:00, 64.69it/s]


Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,0.532502,0.0,0.117296,0.133333,0.6,0.0,0.384615,0.2,0.5,0.0,0.043905,0.015043,0.410408,0.4,0.0
1,0.463003,1.0,0.093717,0.666667,0.933333,0.833333,0.692308,0.8,0.5,1.0,0.039477,0.023471,0.378923,0.225,1.0
2,0.526523,0.5,0.11604,0.266667,0.466667,0.333333,0.076923,0.4,0.0,0.0,0.062467,0.068638,0.400123,0.575,0.0
3,0.281195,1.0,0.074402,0.0,1.0,0.0,0.076923,0.6,1.0,0.0,0.012299,0.037618,0.311223,0.225,1.0
4,0.326141,0.166667,0.075756,0.133333,0.2,0.5,0.076923,0.6,0.0,0.0,0.024199,0.040427,0.323382,0.85,1.0


In [12]:
eval_model(synth_data, dataset_test)



{'precision': 0.7603550295857988,
 'recall': 0.855525965379494,
 'auroc': 0.5211364766656507,
 'dp': 0.13635587174308428,
 'ftu': 0.07699999999999996}

### FairGAN

In [13]:
synth_data = train_fairgan(dataset_train)
synth_data.head()

Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
cache/adult.npy


2022-01-27 14:54:35.908215: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:196] None of the MLIR optimization passes are enabled (registered 0 passes)


Pretrain_Epoch:0, trainLoss:0.036737, validLoss:0.019772, validReverseLoss:0.000000
Pretrain_Epoch:1, trainLoss:0.019831, validLoss:0.019710, validReverseLoss:0.000000
Pretrain_Epoch:2, trainLoss:0.019880, validLoss:0.019720, validReverseLoss:0.000000
Pretrain_Epoch:3, trainLoss:0.019934, validLoss:0.019821, validReverseLoss:0.000000
Pretrain_Epoch:4, trainLoss:0.019971, validLoss:0.019784, validReverseLoss:0.000000
Pretrain_Epoch:5, trainLoss:0.019995, validLoss:0.019787, validReverseLoss:0.000000
Pretrain_Epoch:6, trainLoss:0.019959, validLoss:0.019801, validReverseLoss:0.000000
Pretrain_Epoch:7, trainLoss:0.019956, validLoss:0.019740, validReverseLoss:0.000000
Pretrain_Epoch:8, trainLoss:0.013213, validLoss:0.007593, validReverseLoss:0.000000
Pretrain_Epoch:9, trainLoss:0.007500, validLoss:0.007036, validReverseLoss:0.000000
Pretrain_Epoch:10, trainLoss:0.003250, validLoss:0.002439, validReverseLoss:0.000000
Pretrain_Epoch:11, trainLoss:0.002496, validLoss:0.002376, validReverseLoss



Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
0,0.0,0.369636,0.429304,0.568245,0.522931,0.404446,0.595719,0.0,0.528495,0.0,0.0,0.691437,0.0,0.582464,1.0
1,0.513192,0.168188,0.50007,0.927547,0.0,0.549738,0.56506,0.0,0.711314,0.0,0.219476,0.433434,0.0,0.0,1.0
2,1.287684,0.025203,0.42249,1.111596,0.0,0.45821,0.021143,0.0,1.612983,0.0,0.0,0.689238,0.0,0.0,1.0
3,0.0,0.041467,0.41066,0.737548,0.0,0.0,0.61455,0.0,0.71277,0.0,0.0,0.52003,0.0,0.229288,1.0
4,0.0,0.175688,0.0,0.529809,0.0,0.21812,0.360827,0.300408,0.614899,0.0,0.277231,0.262351,0.107554,0.506268,1.0


In [14]:
eval_model(synth_data, dataset_test)

{'precision': 0.7911561001598295,
 'recall': 0.988681757656458,
 'auroc': 0.600766581639474,
 'dp': 0.00811793616182721,
 'ftu': 0.04049999999999998}

### DECAF

#### DECAF-ND

In [15]:
synth_data = train_decaf(dataset_train, dag_seed)
synth_data.head()

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name          | Type             | Params
---------------------------------------------------
0 | generator     | Generator_causal | 134 K 
1 | discriminator | Discriminator    | 43.6 K
---------------------------------------------------
178 K     Trainable params
225       Non-trainable params
178 K     Total params
0.713     Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
  rank_zero_warn(


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
26022,0.234127,6.765107e-07,0.139488,0.177453,0.53994,0.1179136,0.298036,0.441479,1.06363e-11,1.0,2.539986e-06,2.790788e-14,0.541281,2.434887e-15,1.0
22685,0.350065,5.415779e-06,0.19456,0.133822,0.510079,0.0269873,0.514205,0.15822,2.9201129999999997e-20,0.0,1.781848e-09,2.888478e-12,0.604431,9.084453e-10,1.0
1786,0.107019,2.445179e-07,0.180215,0.237115,0.548,0.2801915,0.295019,0.134571,0.0,0.0,2.228014e-13,3.178187e-08,0.350311,2.274254e-21,0.0
16922,0.586856,0.5697811,0.02312,0.191823,0.550194,0.005259033,0.18994,0.436645,6.327144e-37,1.0,0.1146319,3.676456e-10,0.270717,6.91213e-21,0.0
7506,0.419983,0.005354655,0.131415,0.865792,0.541115,3.556299e-08,0.743885,0.452908,2.516105e-19,1.0,4.340899e-13,2.103943e-10,0.61936,2.585483e-24,1.0


In [16]:
eval_model(synth_data, dataset_test)

{'precision': 0.8512110726643599,
 'recall': 0.8189081225033289,
 'auroc': 0.6935906074364033,
 'dp': 0.28775632,
 'ftu': 0.16600001}

#### DECAF-FTU

In [17]:
synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_ftu)
synth_data.head()

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
26022,0.060165,8.846032e-09,0.082856,0.119036,0.543828,0.50617,0.590801,0.623375,0.9983974,0.0,5.064419e-05,0.0002998538,0.262476,0.825726,1.0
22685,0.550186,0.002414228,0.181615,0.422031,0.512948,0.001664,0.087264,0.442757,0.9984341,1.0,4.15655e-07,2.2156350000000003e-17,0.25191,6.448075e-05,1.0
1786,0.0644,5.409464e-08,0.023076,0.104712,0.944537,0.000219,0.092931,0.424156,7.557856e-21,1.0,0.0001682411,2.4581020000000002e-18,0.628548,1.151245e-19,0.0
16922,0.100146,0.002944645,0.027178,0.205918,0.543854,0.325511,0.838662,0.639171,4.482559e-31,1.0,2.821697e-11,1.048198e-16,0.391728,4.622066e-15,1.0
7506,0.406931,0.003207172,0.077683,0.123296,0.541885,0.009738,0.60041,0.417331,2.2430619999999997e-38,1.0,0.0007695793,0.0005475842,0.670021,3.697422e-16,1.0


In [18]:
eval_model(synth_data, dataset_test)

{'precision': 0.827037773359841,
 'recall': 0.8308921438082557,
 'auroc': 0.6533978791330435,
 'dp': 0.12550032,
 'ftu': 0.038999975}

#### DECAF-CF

In [19]:
synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_cf)
synth_data.head()

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
26022,0.376979,0.04417319,0.107903,0.090297,0.888707,0.000257,0.435016,0.445857,0.01714575,1.0,2.6016e-12,2.7170280000000003e-22,0.381479,2.89304e-25,0.0
22685,0.183172,3.874241e-05,0.251086,0.167791,0.551216,0.317544,0.56624,0.879151,7.999757999999999e-21,0.0,2.024597e-14,1.63542e-18,0.286153,0.03298963,1.0
1786,0.228797,0.002407223,0.113237,0.215229,0.54073,0.002845,0.101234,0.450681,1.984082e-25,1.0,1.606261e-08,3.207013e-16,0.376381,6.985810000000001e-17,1.0
16922,0.206276,3.136597e-08,0.067481,0.16858,0.54568,0.334844,0.283045,0.868826,0.0,0.0,1.503925e-09,9.618725999999998e-19,0.320836,1.6472360000000003e-17,0.0
7506,0.122383,4.905114e-07,0.17216,0.244046,0.536079,5.7e-05,0.446719,0.315273,1.85129e-06,1.0,1.990528e-06,6.445155999999999e-19,0.286599,0.6302889,0.0


In [20]:
eval_model(synth_data, dataset_test)

{'precision': 0.7640449438202247,
 'recall': 0.9507323568575233,
 'auroc': 0.5325950940914123,
 'dp': 0.003698945,
 'ftu': 0.014999986}

#### DECAF-DP

In [21]:
synth_data = train_decaf(dataset_train, dag_seed, biased_edges=bias_dict_dp)
synth_data.head()

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")
  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Initialised adjacency matrix as parsed:
 Parameter containing:
tensor([[0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 1.],
        [0., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 1., 0.

Unnamed: 0,age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income
26022,0.219547,8.501428e-05,0.091582,0.222513,0.547538,0.322064,0.255209,0.234307,1.432031e-21,0.0,2.047043e-11,2.8756380000000003e-22,0.367917,3.855668e-06,1.0
22685,0.072794,4.437362e-09,0.135453,0.261167,0.52264,0.289771,0.316758,0.631737,2.910277e-16,0.0,3.202769e-13,1.715337e-08,0.43741,4.89496e-06,1.0
1786,0.581513,0.003996564,0.065573,0.264535,0.97089,0.005916,0.275685,0.426376,0.0,1.0,0.2092007,1.484834e-09,0.358271,5.869144e-11,1.0
16922,0.255456,1.248933e-08,0.038852,0.0768,0.539664,0.344737,0.704153,0.893334,0.000100593,0.0,7.739571e-08,1.118384e-15,0.527627,5.437344e-11,1.0
7506,0.071769,0.0004970749,0.050992,0.157355,0.539611,0.281271,0.706125,0.113897,8.705056e-19,0.0,0.0003678382,1.118564e-15,0.474221,0.08006825,1.0


In [22]:
eval_model(synth_data, dataset_test)

{'precision': 0.7598488936859147,
 'recall': 0.9374167776298269,
 'auroc': 0.521921240220536,
 'dp': 0.0094688535,
 'ftu': 0.036499977}