In [11]:
# Load the data and libraries
import pandas as pd
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt

def laplace_mech(v, sensitivity, epsilon):
    return v + np.random.laplace(loc=0, scale=sensitivity / epsilon)

def gaussian_mech(v, sensitivity, epsilon, delta):
    return v + np.random.normal(loc=0, scale=sensitivity * np.sqrt(2*np.log(1.25/delta)) / epsilon)

def gaussian_mech_vec(vec, sensitivity, epsilon, delta):
    return [v + np.random.normal(loc=0, scale=sensitivity * np.sqrt(2*np.log(1.25/delta)) / epsilon)
            for v in vec]

def pct_error(orig, priv):
    return np.abs(orig - priv)/orig * 100.0

adult = pd.read_csv('https://github.com/jnear/cs3110-data-privacy/raw/main/homework/adult_with_pii.csv')
adult = adult.drop(columns=['Name', 'DOB', 'SSN', 'Zip'])

In [12]:
def two_way_marginal(col1, col2, epsilon):
    
    syn_rep = adult[[col1, col2]].value_counts()
    dp_syn = syn_rep.apply(lambda x: laplace_mech(x, 1, epsilon))
    dp_syn = dp_syn.clip(lower=0)
    marginal = dp_syn / dp_syn.sum()
    return marginal.to_frame(name='probability').reset_index()

In [13]:
def general_syn_dat(n, cols, epsilon):
    syn_dat = {}

    # Related Marginals (rather than basing everything off of column 0, I chose to
    # create them sequentially so you have m1=col[0]|col[1], m2=col[1]|col[2]... and so on)
    marginal_1 = two_way_marginal(cols[0], cols[1], epsilon)
    marginal_2 = two_way_marginal(cols[1], cols[2], epsilon)
    marginal_3 = two_way_marginal(cols[2], cols[3], epsilon)

    # As above, the synthetic data is created sequentially based on the input columns instead of
    # all being related to the first column
    syn_dat[cols[0]] = np.random.choice(marginal_1[cols[0]], size=n, p=marginal_1['probability'])
    syn_dat[cols[1]] = np.random.choice(marginal_1[cols[1]], size=n, p=marginal_1['probability'])
    syn_dat[cols[2]] = np.random.choice(marginal_2[cols[2]], size=n, p=marginal_2['probability'])
    syn_dat[cols[3]] = np.random.choice(marginal_3[cols[3]], size=n, p=marginal_3['probability'])
    
    syn_dat_df = pd.DataFrame(syn_dat)

    return syn_dat_df

In [14]:
adult.columns # Printing out columns to help choose test columns

Index(['Age', 'Workclass', 'fnlwgt', 'Education', 'Education-Num',
       'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex',
       'Capital Gain', 'Capital Loss', 'Hours per week', 'Country', 'Target'],
      dtype='object')

In [15]:
adult['Target'].value_counts() # Printing out categories of chosen column to plug into test cases

<=50K    24720
>50K      7841
Name: Target, dtype: int64

In [16]:
adult['Education-Num'].value_counts() # Printing out categories of chosen column to plug into test cases

9     10501
10     7291
13     5355
14     1723
11     1382
7      1175
12     1067
6       933
4       646
15      576
5       514
8       433
16      413
3       333
2       168
1        51
Name: Education-Num, dtype: int64

In [17]:
adult['Marital Status'].value_counts() # Printing out categories of chosen column to plug into test cases

Married-civ-spouse       14976
Never-married            10683
Divorced                  4443
Separated                 1025
Widowed                    993
Married-spouse-absent      418
Married-AF-spouse           23
Name: Marital Status, dtype: int64

In [30]:
# TEST CASES

#cols = ['Age', 'Workclass', 'Occupation', 'Education']
#cats = ['State-gov', 'Sales', 'Bachelors']

#cols = ['Age', 'Race', 'Hours per week', 'Relationship']
#cats = ['Black', 40, 'Husband']

cols = ['Age', 'Target', 'Education-Num', 'Marital Status']
cats = ['>50K', 15, 'Divorced']

SIZE = 32561 # Size of adult dataset
EPSILON = 1.0

synthetic_data = general_syn_dat(SIZE, cols, EPSILON)

synth1 = synthetic_data['Age'].mean()
synth2 = len(synthetic_data[synthetic_data[cols[1]] == cats[0]])
synth3 = len(synthetic_data[synthetic_data[cols[2]] == cats[1]])
synth4 = len(synthetic_data[synthetic_data[cols[3]] == cats[2]])

real1 = adult['Age'].mean()
real2 = len(adult[adult[cols[1]] == cats[0]])
real3 = len(adult[adult[cols[2]] == cats[1]])
real4 = len(adult[adult[cols[3]] == cats[2]])

print(f'{real1:.2f}', real2, real3, real4)

print(f'{synth1:.2f}', synth2, synth3, synth4)

assert synth1/real1 > 0.95 and synth1/real1 < 1.05
assert synth2/real2 > 0.9 and synth2/real2 < 1.1
assert synth3/real3 > 0.9 and synth3/real3 < 1.1
assert synth4/real4 > 0.9 and synth4/real4 < 1.1

38.58 3124 15217 13193
38.45 3202 15373 13250
