In [7]:
import os
import csv
import math
import itertools
import pickle

import pandas as pd
pd.set_option('display.max_columns', None)
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import preprocessing
from scipy import sparse
import networkx as nx

from mbi import (
    Dataset,
    FactoredInference,
    Domain,
    LocalInference,
    MixtureInference,
    PublicInference,
)

In [8]:
data = Dataset.load("./data/adult_processed.csv", "./data/adult_processed.json")
domain = data.domain
total = data.df.shape[0]

In [None]:
# adapted from https://github.com/ryan112358/private-pgm/blob/master/examples/adult_example.py

cliques = [('workclass', 'occupation'), ('workclass', 'fnlwgt'), ('marital.status', 'income'), ('fnlwgt', 'occupation'), ('fnlwgt', 'race'), ('fnlwgt', 'income'), ('education', 'education.num'), ('fnlwgt', 'capital.loss'), ('race', 'native.country'), ('relationship', 'sex'), ('marital.status', 'relationship'), ('occupation', 'sex'), ('fnlwgt', 'relationship'), ('fnlwgt', 'marital.status'), ('fnlwgt', 'sex'), ('fnlwgt', 'education'), ('age', 'capital.gain'), ('capital.gain', 'income'), ('relationship', 'income'), ('fnlwgt', 'native.country'), ('fnlwgt', 'education.num'), ('fnlwgt', 'hours.per.week'), ('marital.status', 'sex'), ('age', 'fnlwgt'), ('age', 'capital.loss'), ('fnlwgt', 'capital.gain')]

# spend half of privacy budget to measure all 1 way marginals
np.random.seed(0)

epsilon = 1.0
epsilon_split = epsilon / (len(data.domain) + len(cliques))
sigma = 2.0 / epsilon_split

measurements = []
for col in data.domain:
    x = data.project(col).datavector()
    y = x + np.random.laplace(loc=0, scale=sigma, size=x.size)
    I = sparse.eye(x.size)
    measurements.append((I, y, sigma, (col,)))

# spend half of privacy budget to measure some more 2 and 3 way marginals
for cl in cliques:
    x = data.project(cl).datavector()
    y = x + np.random.laplace(loc=0, scale=sigma, size=x.size)
    I = sparse.eye(x.size)
    measurements.append((I, y, sigma, cl))

In [None]:
# THIS TAKES 924m15s MINUTES!
engine = FactoredInference(domain, log=True, iters=500)
model = engine.estimate(measurements, total=total)

Total clique size: 355979712
iteration		time		l1_loss		l2_loss		feasibility
0.00		0.00		10720602.35		11210028.55		0.00
50.00		4179.25		10711458.22		10714903.93		0.00
100.00		9873.43		10710113.53		10710655.33		0.00
150.00		15668.32		10709301.19		10708655.66		0.00
200.00		21444.27		10708738.82		10707145.16		0.00
250.00		27191.46		10708346.61		10705960.43		0.00
300.00		32957.98		10708067.44		10704995.58		0.00
350.00		38627.04		10707870.26		10704198.27		0.00
400.00		44133.61		10707734.69		10703559.30		0.00
450.00		49872.76		10707623.52		10703052.92		0.00


In [7]:
# Save the model to a file
# this is about 5GB
with open('./model/adult_synth.pkl', 'wb') as f:
    pickle.dump(model, f)

In [None]:
# this takes around 45 seconds
synth = model.synthetic_data(rows=30000)
sdf = synth.df
sdf

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,6,4,12445,15,9,0,1,3,4,1,0,0,9,39,0
1,12,4,19268,11,8,0,10,1,4,1,0,0,44,39,0
2,3,4,14421,7,11,0,7,4,1,1,0,0,39,30,0
3,13,6,9822,11,8,0,12,1,4,0,0,0,39,39,0
4,13,4,6249,11,8,0,10,4,4,0,0,0,76,39,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,25,7,7321,12,13,1,1,5,4,1,0,0,49,39,1
29996,13,2,9056,15,9,1,3,0,4,0,0,0,54,39,0
29997,18,4,7533,11,8,0,3,1,4,0,0,0,39,39,1
29998,36,4,7284,12,13,1,1,0,4,0,0,0,0,39,1


# Without heuristic

In [3]:
cliques = [('age', 'fnlwgt'), ('workclass', 'fnlwgt'), ('fnlwgt', 'hours.per.week'), ('fnlwgt', 'occupation'), ('fnlwgt', 'education'), ('fnlwgt', 'relationship'), ('fnlwgt', 'sex'), ('fnlwgt', 'native.country'), ('fnlwgt', 'race'), ('fnlwgt', 'capital.gain'), ('fnlwgt', 'income'), ('fnlwgt', 'capital.loss'), ('education', 'education.num'), ('marital.status', 'relationship')]

# spend half of privacy budget to measure all 1 way marginals
np.random.seed(0)

epsilon = 1.0
epsilon_split = epsilon / (len(data.domain) + len(cliques))
sigma = 2.0 / epsilon_split

measurements = []
for col in data.domain:
    x = data.project(col).datavector()
    y = x + np.random.laplace(loc=0, scale=sigma, size=x.size)
    I = sparse.eye(x.size)
    measurements.append((I, y, sigma, (col,)))

# spend half of privacy budget to measure some more 2 and 3 way marginals
for cl in cliques:
    x = data.project(cl).datavector()
    y = x + np.random.laplace(loc=0, scale=sigma, size=x.size)
    I = sparse.eye(x.size)
    measurements.append((I, y, sigma, cl))

In [None]:
# THIS TAKES 17m1s
engine = FactoredInference(domain, log=True, iters=500)
model = engine.estimate(measurements, total=total)

Total clique size: 10283068
iteration		time		l1_loss		l2_loss		feasibility
0.00		0.00		10315003.89		11036888.71		0.00
50.00		78.15		10305806.70		10309525.21		0.00
100.00		184.64		10304066.27		10304786.66		0.00
150.00		292.16		10303055.12		10302005.12		0.00
200.00		394.16		10302460.60		10299956.93		0.00
250.00		499.22		10302081.52		10298375.76		0.00
300.00		604.60		10301856.07		10297234.22		0.00
350.00		708.26		10301713.65		10296451.32		0.00
400.00		808.18		10301621.99		10295921.65		0.00
450.00		915.11		10301553.73		10295560.93		0.00


In [None]:
# Save the model to a file
# this is about
with open('./model/adult_synth_mst.pkl', 'wb') as f:
    pickle.dump(model, f)

In [None]:
# this takes around 13 seconds
synth = model.synthetic_data(rows=30000)
sdf = synth.df
sdf

Unnamed: 0,age,workclass,fnlwgt,education,education.num,marital.status,occupation,relationship,race,sex,capital.gain,capital.loss,hours.per.week,native.country,income
0,56,4,677,15,9,0,1,3,4,1,0,15,14,39,1
1,13,6,2215,15,9,0,0,1,4,0,0,0,39,39,0
2,10,4,4434,15,9,1,3,0,4,0,0,0,49,39,0
3,29,4,16475,15,9,0,5,4,3,0,0,0,39,30,0
4,10,4,9546,15,9,1,4,0,4,1,0,0,49,39,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29995,50,4,20118,15,9,0,10,4,4,0,0,0,39,39,0
29996,18,4,20671,15,9,1,3,0,4,1,0,0,44,39,1
29997,20,4,12861,9,12,1,4,0,4,0,0,0,59,39,0
29998,19,4,21145,8,9,0,10,1,4,0,0,0,34,36,0
