In [4]:
import numpy as np
import pandas as pd
import random
import estimator as est

np.set_printoptions(precision=3)

d = 10
dist_type = "uni"
est_type = "sw"
N=100000
total_repeat = 1

filename = 'data/eps_grid.csv'
temp = pd.read_csv(filename)
eps_grid = np.asarray(temp['eps_grid'])

filename = 'data/%s_q_%d.csv' %(dist_type, d)
temp = pd.read_csv(filename)
x_grid = np.asarray(temp['x_grid'])
x_q = np.asarray(temp['x_q'])

def WassDist(p1,p2):
    print(p1)
    cmf1 = Pmf2Cmf(p1)
    print(p1)
    cmf2 = Pmf2Cmf(p2)
    dist = np.sum(np.abs(cmf1-cmf2))
    return dist

def Pmf2Cmf(pmf):
    cmf = pmf
    for i in range(len(pmf)):
        cmf[i] = cmf[i-1]+cmf[i]
    return cmf

# generated data samples
idx_original = random.choices(np.arange(d),x_q,k=N)
sample_original = x_grid[idx_original]
print('%d data samples generated with %s distribution.\n' %(N, dist_type))

var_aaa = np.zeros(len(eps_grid))
var_est = np.zeros(len(eps_grid))

distance_aaa = np.zeros(len(eps_grid))
distance_est = np.zeros(len(eps_grid))

for i in range(len(eps_grid)):
    eps = eps_grid[i]
    print('eps=%.2f'%(eps))

    filename = 'data/%s_%s_M_%.2f_%d.csv'%(est_type,dist_type,eps,d)
    temp = pd.read_csv(filename)
    M_aaa = np.asarray(temp)[:,1:]

    filename = 'data/%s_M_%.2f_%d.csv'%(est_type,eps,d)
    temp = pd.read_csv(filename)
    M_est = np.asarray(temp)[:,1:]

    filename = 'data/%s_a_%.2f_%d.csv'%(est_type, eps,d)
    temp = pd.read_csv(filename)
    a_grid = np.asarray(temp['a_grid'])

    # compute the variances
    var_aaa[i] = est.M2Var(M_aaa,x_grid,x_grid,x_q)
    var_est[i] = est.M2Var(M_est,a_grid,x_grid,x_q)
    print('compute variances for eps=%.2f.'%(eps))

    # generate random pools for each value
    rand_num_aaa = np.zeros((N,d))
    rand_num_est = np.zeros((N,d))

    for j in range(d):
        rand_num_aaa[:,j] = random.choices(x_grid,M_aaa[:,j],k=N)
        rand_num_est[:,j] = random.choices(a_grid,M_est[:,j],k=N)
    print("random num pools generated.")

    temp1,temp2 = 0,0
    for k in range(total_repeat):
        idx_noise = random.choices(range(N),k=N)
        idx_pair = list(zip(idx_noise,idx_original))

        sample_perturbed = rand_num_aaa[tuple(zip(*idx_pair))]
        _,counts = np.unique(sample_perturbed,return_counts=True)
        x_q_aaa = counts/N
        temp1 = temp1 + WassDist(x_q_aaa,x_q)
        print('q(x)=', x_q)

        sample_perturbed = rand_num_est[tuple(zip(*idx_pair))]
        _,counts = np.unique(sample_perturbed,return_counts=True)
        x_q_est = est.EM(M_est,counts,eps)
        temp2 = temp2 + WassDist(x_q_est,x_q)

    distance_aaa[i] = temp1/total_repeat
    distance_est[i] = temp2/total_repeat
    print('compute average WassDist, total repeat=%d.\n' %(total_repeat))

print(distance_aaa)
print(distance_est)
results ={'eps_grid': eps_grid,
          'aaa_dist': distance_aaa, 
          '%s_dist'%(est_type):distance_est,
          'aaa_var': var_aaa,
          '%s_var'%(est_type): var_est} 
filename = 'data/results_%s_%s_%d' %(est_type,dist_type,d)
pd.DataFrame(results).to_csv(filename)
print('complete.')

100000 data samples generated with uni distribution.

eps=0.50
compute variances for eps=0.50.
random num pools generated.
[0.104 0.103 0.091 0.102 0.093 0.101 0.102 0.095 0.098 0.111]
[0.215 0.318 0.408 0.511 0.604 0.705 0.807 0.902 1.    1.111]
q(x)= [0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.  1.1]
[[0.082 0.097 0.101 0.101 0.105 0.116 0.087 0.092 0.1   0.118]]
[[0.164 0.195 0.202 0.203 0.21  0.232 0.175 0.183 0.2   0.236]]
compute average WassDist, total repeat=1.

eps=1.00
compute variances for eps=1.00.
random num pools generated.
[0.095 0.096 0.108 0.101 0.095 0.101 0.108 0.102 0.095 0.099]
[0.194 0.29  0.398 0.499 0.595 0.695 0.803 0.905 1.    1.099]
q(x)= [ 8.9 10.5 12.5 15.  18.1 21.9 26.5 32.  38.5 46.1]
[[0.1   0.103 0.097 0.092 0.103 0.102 0.1   0.107 0.096 0.1  ]]
[[0.199 0.207 0.194 0.185 0.206 0.204 0.2   0.215 0.191 0.199]]
compute average WassDist, total repeat=1.

eps=1.50
compute variances for eps=1.50.
random num pools generated.
[0.098 0.103 0.097 0.1   0.1   0.107 0.097 