In [251]:
import pandas as pd
import numpy as np
dataset = pd.read_csv("list_attr_celeba.txt", sep = ' ', header=1, skiprows = 0)

data = np.vstack((dataset['Male'], dataset["Young"], dataset["Mustache"], (-dataset["No_Beard"]),\
                                                                dataset["Bald"]))

# np.where(data > 0, data, 0)
data=np.maximum(data, 0)
print("0:Male, 1:Young, 2: Mustache, 3:Beard, 4:Bald")
print(data.sum(axis=1))

0:Male, 1:Young, 2: Mustache, 3:Beard, 4:Bald
[ 84434 156734   8417  33441   4547]


In [252]:
data_ = data[:,np.where(data[0,:]==0)[0]]
print(data_.shape)
data_1 = data_[:,np.where(data_[1,:]==0)[0]]
print(data_1.shape)
data_2 = data_1[:,np.where(data_1[2,:]==0)[0]]
print(data_2.shape)
data_3 = data_2[:,np.where(data_2[3,:]==0)[0]]
print(data_3.shape)
print(len(np.where(data_3[4,:]==0)[0]))

(5, 118165)
(5, 14878)
(5, 14876)
(5, 14835)
14831


In [253]:
data.shape

(5, 202599)

In [46]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

## Training

In [60]:


n_steps = int(data.shape[1]/1000)



# clear the param store in case we're in a REPL
pyro.clear_param_store()


def model(data):
    # sample f from the beta prior
    laten_sex = pyro.param('latent_sex', torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    
    laten_mus_1 = pyro.param("latent_female_mustache", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_mus_2 = pyro.param("latent_male_mustache", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_mus = [laten_mus_1,laten_mus_2]
    
    laten_beard_1 = pyro.param("latent_no_mustache_beard", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_beard_2 = pyro.param("latent_mustache_beard", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_beard = [laten_beard_1,laten_beard_2]

    
    laten_young = pyro.param('latent_young', torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    
    laten_bald_1 = pyro.param("latent_female_old_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_2 = pyro.param("latent_female_young_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_3 = pyro.param("latent_male_old_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_4 = pyro.param("latent_male_young_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald = [[laten_bald_1,laten_bald_2],[laten_bald_3, laten_bald_4]]
    
    
    # loop over the observed data
    for i in pyro.plate("data_loop", data.shape[1]):
        # observe datapoint i using the bernoulli likelihood
        sex = pyro.sample('sex_{}'.format(i), dist.Bernoulli(laten_sex),obs = data[0,i])        
        young = pyro.sample('young_{}'.format(i), dist.Bernoulli(laten_young),obs = data[1,i])        

        mus = pyro.sample("mustache_{}".format(i), dist.Bernoulli(laten_mus[sex.long()]), obs=data[2,i])
        beard = pyro.sample("beard_{}".format(i), dist.Bernoulli(laten_beard[mus.long()]), obs=data[3,i])
        
        
        bald = pyro.sample("bald_{}".format(i), dist.Bernoulli(laten_bald[sex.long()][young.long()]), \
                                                             obs=data[4,i])
        


def guide(data):
    pass

store_p = []                       
# setup the optimizer
adam_params = {"lr": 0.1, "betas": (0.80, 0.999)}
optimizer = Adam(adam_params)
pyro.clear_param_store()
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    rand_list = np.random.randint(low = 0, high=data.shape[1], size=1000)
#     svi.step(data[:,rand_list.tolist()])
    svi.step(torch.tensor(data[:,rand_list.tolist()]).float())
    if(step == 50):
        adam_params = {"lr": 0.01, "betas": (0.80, 0.999)}
        optimizer = Adam(adam_params)
    if(step == 100):
        adam_params = {"lr": 0.0005, "betas": (0.80, 0.999)}
        optimizer = Adam(adam_params)
    if(step %1 ==0):
        print('step: {}'.format(step))
#         for i in pyro.params.param_store.keys():
#             store_p[i] = pyro.params(i)
        p = pyro.get_param_store()
        store_p.append(p) 
        for (k,v) in p.items():
            print(k+": {:.3f}".format(v.item())) 




step: 0
latent_sex: 0.475
latent_female_mustache: 0.475
latent_male_mustache: 0.475
latent_no_mustache_beard: 0.475
latent_mustache_beard: 0.525
latent_young: 0.525
latent_female_old_bald: 0.475
latent_female_young_bald: 0.475
latent_male_old_bald: 0.475
latent_male_young_bald: 0.475
step: 1
latent_sex: 0.451
latent_female_mustache: 0.450
latent_male_mustache: 0.450
latent_no_mustache_beard: 0.450
latent_mustache_beard: 0.550
latent_young: 0.550
latent_female_old_bald: 0.450
latent_female_young_bald: 0.450
latent_male_old_bald: 0.450
latent_male_young_bald: 0.450
step: 2
latent_sex: 0.431
latent_female_mustache: 0.426
latent_male_mustache: 0.426
latent_no_mustache_beard: 0.426
latent_mustache_beard: 0.574
latent_young: 0.574
latent_female_old_bald: 0.425
latent_female_young_bald: 0.426
latent_male_old_bald: 0.426
latent_male_young_bald: 0.426
step: 3
latent_sex: 0.415
latent_female_mustache: 0.402
latent_male_mustache: 0.402
latent_no_mustache_beard: 0.402
latent_mustache_beard: 0.597


step: 29
latent_sex: 0.424
latent_female_mustache: 0.095
latent_male_mustache: 0.128
latent_no_mustache_beard: 0.137
latent_mustache_beard: 0.898
latent_young: 0.785
latent_female_old_bald: 0.094
latent_female_young_bald: 0.095
latent_male_old_bald: 0.139
latent_male_young_bald: 0.098
step: 30
latent_sex: 0.426
latent_female_mustache: 0.091
latent_male_mustache: 0.125
latent_no_mustache_beard: 0.136
latent_mustache_beard: 0.902
latent_young: 0.784
latent_female_old_bald: 0.091
latent_female_young_bald: 0.091
latent_male_old_bald: 0.136
latent_male_young_bald: 0.095
step: 31
latent_sex: 0.424
latent_female_mustache: 0.087
latent_male_mustache: 0.123
latent_no_mustache_beard: 0.135
latent_mustache_beard: 0.905
latent_young: 0.783
latent_female_old_bald: 0.087
latent_female_young_bald: 0.087
latent_male_old_bald: 0.133
latent_male_young_bald: 0.091
step: 32
latent_sex: 0.421
latent_female_mustache: 0.084
latent_male_mustache: 0.120
latent_no_mustache_beard: 0.134
latent_mustache_beard: 0.

step: 58
latent_sex: 0.423
latent_female_mustache: 0.040
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.130
latent_mustache_beard: 0.946
latent_young: 0.775
latent_female_old_bald: 0.040
latent_female_young_bald: 0.040
latent_male_old_bald: 0.113
latent_male_young_bald: 0.047
step: 59
latent_sex: 0.426
latent_female_mustache: 0.039
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.130
latent_mustache_beard: 0.947
latent_young: 0.775
latent_female_old_bald: 0.039
latent_female_young_bald: 0.039
latent_male_old_bald: 0.113
latent_male_young_bald: 0.047
step: 60
latent_sex: 0.434
latent_female_mustache: 0.039
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.949
latent_young: 0.775
latent_female_old_bald: 0.038
latent_female_young_bald: 0.039
latent_male_old_bald: 0.112
latent_male_young_bald: 0.046
step: 61
latent_sex: 0.433
latent_female_mustache: 0.038
latent_male_mustache: 0.099
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.

step: 87
latent_sex: 0.427
latent_female_mustache: 0.024
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.127
latent_mustache_beard: 0.963
latent_young: 0.773
latent_female_old_bald: 0.024
latent_female_young_bald: 0.024
latent_male_old_bald: 0.109
latent_male_young_bald: 0.032
step: 88
latent_sex: 0.428
latent_female_mustache: 0.023
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.127
latent_mustache_beard: 0.963
latent_young: 0.773
latent_female_old_bald: 0.023
latent_female_young_bald: 0.023
latent_male_old_bald: 0.108
latent_male_young_bald: 0.032
step: 89
latent_sex: 0.424
latent_female_mustache: 0.023
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.963
latent_young: 0.776
latent_female_old_bald: 0.023
latent_female_young_bald: 0.023
latent_male_old_bald: 0.108
latent_male_young_bald: 0.031
step: 90
latent_sex: 0.415
latent_female_mustache: 0.023
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.

step: 116
latent_sex: 0.420
latent_female_mustache: 0.016
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.124
latent_mustache_beard: 0.965
latent_young: 0.771
latent_female_old_bald: 0.016
latent_female_young_bald: 0.016
latent_male_old_bald: 0.120
latent_male_young_bald: 0.025
step: 117
latent_sex: 0.420
latent_female_mustache: 0.016
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.125
latent_mustache_beard: 0.965
latent_young: 0.769
latent_female_old_bald: 0.016
latent_female_young_bald: 0.016
latent_male_old_bald: 0.120
latent_male_young_bald: 0.025
step: 118
latent_sex: 0.424
latent_female_mustache: 0.016
latent_male_mustache: 0.101
latent_no_mustache_beard: 0.125
latent_mustache_beard: 0.965
latent_young: 0.768
latent_female_old_bald: 0.016
latent_female_young_bald: 0.016
latent_male_old_bald: 0.121
latent_male_young_bald: 0.024
step: 119
latent_sex: 0.420
latent_female_mustache: 0.015
latent_male_mustache: 0.101
latent_no_mustache_beard: 0.125
latent_mustache_beard

step: 145
latent_sex: 0.419
latent_female_mustache: 0.012
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.132
latent_mustache_beard: 0.972
latent_young: 0.771
latent_female_old_bald: 0.012
latent_female_young_bald: 0.012
latent_male_old_bald: 0.111
latent_male_young_bald: 0.024
step: 146
latent_sex: 0.413
latent_female_mustache: 0.012
latent_male_mustache: 0.099
latent_no_mustache_beard: 0.132
latent_mustache_beard: 0.972
latent_young: 0.769
latent_female_old_bald: 0.012
latent_female_young_bald: 0.012
latent_male_old_bald: 0.110
latent_male_young_bald: 0.024
step: 147
latent_sex: 0.406
latent_female_mustache: 0.011
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.132
latent_mustache_beard: 0.972
latent_young: 0.770
latent_female_old_bald: 0.012
latent_female_young_bald: 0.012
latent_male_old_bald: 0.110
latent_male_young_bald: 0.024
step: 148
latent_sex: 0.405
latent_female_mustache: 0.011
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.131
latent_mustache_beard

step: 174
latent_sex: 0.406
latent_female_mustache: 0.009
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.130
latent_mustache_beard: 0.973
latent_young: 0.772
latent_female_old_bald: 0.009
latent_female_young_bald: 0.009
latent_male_old_bald: 0.108
latent_male_young_bald: 0.022
step: 175
latent_sex: 0.404
latent_female_mustache: 0.009
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.130
latent_mustache_beard: 0.974
latent_young: 0.772
latent_female_old_bald: 0.009
latent_female_young_bald: 0.009
latent_male_old_bald: 0.109
latent_male_young_bald: 0.022
step: 176
latent_sex: 0.404
latent_female_mustache: 0.009
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.973
latent_young: 0.772
latent_female_old_bald: 0.009
latent_female_young_bald: 0.009
latent_male_old_bald: 0.111
latent_male_young_bald: 0.022
step: 177
latent_sex: 0.411
latent_female_mustache: 0.009
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.131
latent_mustache_beard

In [61]:

all_param = pyro.get_param_store()
print(all_param)
all_param.save('./pyro_params')

<pyro.params.param_store.ParamStoreDict object at 0x7fdb16885130>


## Testing

In [62]:
graph_params = pyro.params.param_store.ParamStoreDict()
graph_params.load('./pyro_params')

In [63]:
for (k,v) in graph_params.items():
    print(k+": {:.3f}".format(v.item())) 

latent_sex: 0.402
latent_female_mustache: 0.007
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.133
latent_mustache_beard: 0.971
latent_young: 0.776
latent_female_old_bald: 0.008
latent_female_young_bald: 0.007
latent_male_old_bald: 0.121
latent_male_young_bald: 0.021


In [116]:
from pyro import poutine
from pyro.infer import config_enumerate, infer_discrete

def determined_model():
    laten_sex = graph_params.get_param('latent_sex')

    laten_mus_1 = graph_params.get_param("latent_female_mustache")
    laten_mus_2 = graph_params.get_param("latent_male_mustache")
    laten_mus = [laten_mus_1, laten_mus_2]

    laten_beard_1 = graph_params.get_param("latent_no_mustache_beard")
    laten_beard_2 = graph_params.get_param("latent_mustache_beard")
    laten_beard = [laten_beard_1, laten_beard_2]

    laten_young = graph_params.get_param('latent_young')

    laten_bald_1 = graph_params.get_param("latent_female_old_bald")
    laten_bald_2 = graph_params.get_param("latent_female_young_bald")
    laten_bald_3 = graph_params.get_param("latent_male_old_bald")
    laten_bald_4 = graph_params.get_param("latent_male_young_bald")
    laten_bald = [[laten_bald_1, laten_bald_2], [laten_bald_3, laten_bald_4]]

    with pyro.plate("a_plate", size=1, dim=-2):
        sex = pyro.sample('sex', dist.Bernoulli(laten_sex))
        young = pyro.sample('young', dist.Bernoulli(laten_young))
        #         with pyro.plate("b_plate", size=1, dim=-1):

        mus = pyro.sample("mustache", dist.Bernoulli(laten_mus[sex.long()]))
        beard = pyro.sample("beard", dist.Bernoulli(laten_beard[mus.long()]))

        bald = pyro.sample("bald", dist.Bernoulli(laten_bald[sex.long()][young.long()]))

        
def make_log_joint(model):

    def _log_joint(data, *args, **kwargs):

        conditioned_model = poutine.condition(model, data=data)

        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)

        return trace.log_prob_sum()

    return _log_joint

 

scale_log_joint = make_log_joint(determined_model)

print(scale_log_joint({\
                       "sex": torch.tensor(0.0),
                       "young": torch.tensor(1.),
                       "mustache": torch.tensor(0.0),
                       "beard": torch.tensor(1.0),
                       "bald":torch.tensor(.0)\
}).exp())


tensor(0.0608, grad_fn=<ExpBackward>)


### Compute joint probability learned by Pyro

In [65]:
from itertools import product

In [66]:
perm_array = np.array(list(product([0, 1], repeat=5))).astype(float)

def joint_prob_graph(arr):
    prob = scale_log_joint({\
                       "sex": torch.tensor(arr[0]),
                       "young": torch.tensor(arr[1]),
                       "mustache": torch.tensor(arr[2]),
                       "beard": torch.tensor(arr[3]),
                       "bald":torch.tensor(arr[4])\
                           }).exp()
    return prob.data.numpy()

In [67]:
perm_array

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 1.],
       [0., 0., 1., 1., 0.],
       [0., 0., 1., 1., 1.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 1.],
       [0., 1., 0., 1., 0.],
       [0., 1., 0., 1., 1.],
       [0., 1., 1., 0., 0.],
       [0., 1., 1., 0., 1.],
       [0., 1., 1., 1., 0.],
       [0., 1., 1., 1., 1.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 1.],
       [1., 0., 0., 1., 0.],
       [1., 0., 0., 1., 1.],
       [1., 0., 1., 0., 0.],
       [1., 0., 1., 0., 1.],
       [1., 0., 1., 1., 0.],
       [1., 0., 1., 1., 1.],
       [1., 1., 0., 0., 0.],
       [1., 1., 0., 0., 1.],
       [1., 1., 0., 1., 0.],
       [1., 1., 0., 1., 1.],
       [1., 1., 1., 0., 0.],
       [1., 1., 1., 0., 1.],
       [1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1.]])

In [68]:
learned_joint_prob_array = []

for arr in perm_array:
    p = joint_prob_graph(arr)
    learned_joint_prob_array.append(p)

learned_joint_prob_array = np.array(learned_joint_prob_array)
learned_joint_prob_array

array([1.14329420e-01, 8.71653105e-04, 1.75486776e-02, 1.33791979e-04,
       2.77513555e-05, 2.11577695e-07, 9.42341962e-04, 7.18446135e-06,
       3.96290158e-01, 2.93919126e-03, 6.08274600e-02, 4.51143020e-04,
       9.61921180e-05, 7.13434404e-07, 3.26635827e-03, 2.42258140e-05,
       6.17586492e-02, 8.53769816e-03, 9.47947277e-03, 1.31047033e-03,
       2.25985611e-04, 3.12409186e-05, 7.67370531e-03, 1.06083570e-03,
       2.38604020e-01, 5.00798401e-03, 3.66238630e-02, 7.68686632e-04,
       8.73093497e-04, 1.83250822e-05, 2.96472956e-02, 6.22257675e-04])

### Joint probability derived from dataset 

In [256]:

print(data.shape)
ori_joint_prob_array = []

for arr in perm_array:
    p = ((data.T == arr).sum(axis=1) == 5).mean()
    ori_joint_prob_array.append(p)

ori_joint_prob_array = np.array(ori_joint_prob_array)
ori_joint_prob_array

(5, 202599)


array([7.32037177e-02, 1.97434341e-05, 2.02370199e-04, 0.00000000e+00,
       4.93585852e-06, 0.00000000e+00, 4.93585852e-06, 0.00000000e+00,
       5.09262139e-01, 6.41661607e-05, 4.78778276e-04, 0.00000000e+00,
       4.93585852e-06, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
       8.62985503e-02, 1.04344049e-02, 3.10317425e-02, 4.33861964e-03,
       5.57752013e-04, 6.91020193e-05, 1.78332568e-02, 2.38401966e-03,
       1.52745078e-01, 1.75222977e-03, 8.60764367e-02, 2.54690300e-03,
       4.83714135e-04, 3.94868681e-05, 1.93683088e-02, 7.94673221e-04])

In [257]:
np.mean((ori_joint_prob_array - learned_joint_prob_array)**2)**0.5

0.030382503760488736

In [265]:
data_idx = np.array([21539, 181360, 83908, 181011, 114272, 161427, 199746, 10031, 49592, 157897, 38620, 22196, 86403, 197756, 161481, 9002, 85259, 104188, 196177, 141496, 16843, 73087, 4804, 94515, 19522, 152652, 115100, 70592, 70578, 68415, 12675, 133427, 3151, 135102, 114110, 135382, 95740, 140859, 81399, 86284, 38691, 101071, 134440, 32192, 122815, 110197, 115163, 20599, 173023, 84235, 83572, 144698, 36380, 21443, 172706, 35416, 49685, 98247, 198043, 2056, 149020, 60144, 18131, 163224, 132132, 71205, 92310, 70023, 66795, 72539, 73993, 7068, 82911, 40982, 124610, 2081, 155809, 56274, 54108, 6079, 6626, 47890, 112368, 24497, 62146, 157042, 190450, 46127, 4325, 38917, 172109, 34773, 91410, 179918, 16196, 73302, 35816, 202003, 178307, 96677, 64352, 35413, 69707, 50007, 58694, 10980, 53950, 119681, 61876, 18611, 138406, 102704, 192599, 23571, 179517, 165757, 189680, 187864, 93090, 176340, 97407, 77200, 63449, 110102, 132454, 19095, 22905, 48899, 176957, 160693, 3227, 195579, 2685, 171494, 147815, 39802, 158382, 162579, 169245, 81481, 110847, 41481, 95272, 21859, 33339, 19524, 102600, 89032, 42824, 108359, 14826, 191968, 5898, 97611, 20204, 49843, 183820, 146558, 161901, 63851, 169444, 180797, 103866, 89113, 200329, 135682, 120079, 168734, 158342, 30020, 112569, 176595, 106439, 147433, 177094, 8351, 158338, 82116, 137144, 65549, 149, 170654, 22912, 67141, 177730, 166608, 146454, 49624, 156186, 192849, 28845, 55732, 191800, 182427, 118991, 5354, 86976, 54022, 196323, 51717, 179005, 143484, 168164, 71254, 143524, 111918, 112604, 87430, 89956, 173529, 96549, 81757, 136462, 86539, 98169, 174233, 85220, 187089, 3318, 184071, 118731, 45358, 46335, 59164, 149529, 196200, 165574, 14287, 167086, 82770, 80299, 46710, 202564, 103927, 170894, 175381, 47169, 149139, 136400, 54921, 123457, 88074, 154662, 179170, 115389, 24816, 176803, 77192, 57601, 84959, 181524, 114881, 41041, 2213, 168550, 150623, 185164, 107555, 127818, 98730, 154693, 183876, 132467, 86714, 201642, 123159, 33618, 182085, 18770, 110178, 67712, 133876, 121568, 148456, 119559, 182076, 69138, 27851, 21238, 180492, 14264, 164472, 120201, 172330, 19203, 30539, 24828, 148707, 191196, 26873, 174138, 44288, 12769, 14462, 100529, 20338, 15574, 56350, 120361, 109870, 129212, 96511, 65695, 184940, 35737, 45567, 41037, 189776, 126533, 171305, 71365, 129825, 810, 174784, 14805, 143004, 95049, 147460, 11425, 5977, 72303, 33141, 121656, 88165, 7664, 126751, 115293, 184446, 190754, 144306, 194823, 98668, 16934, 47406, 141502, 75770, 83890, 125673, 36265, 102105, 130543, 201240, 83075, 155446, 169892, 193042, 84897, 106273, 152741, 968, 56995, 81246, 82688, 193764, 109834, 188245, 138403, 157939, 107529, 123156, 160631, 177280, 135429, 150521, 181613, 21186, 127417, 108386, 3179, 49229, 130064, 127774, 9992, 107809, 184363, 12668, 60031, 115176, 162698, 37445, 28045, 66728, 41301, 113402, 94151, 154582, 184672, 127703, 127777, 94748, 154428, 143721, 120889, 17460, 105160, 200238, 48760, 153916, 154587, 49636, 159870, 79559, 122870, 156862, 88399, 33335, 112468, 199535, 35909, 181262, 94978, 57736, 118260, 41059, 124797, 70375, 80077, 16972, 46931, 170694, 195221, 62568, 93324, 166380, 112115, 109720, 13299, 84853, 123889, 131104, 15445, 30163, 30551, 14506, 120820, 201418, 117097, 37613, 77712, 130428, 87121, 23923, 123665, 50986, 132961, 101187, 20496, 197054, 10300, 2325, 155264, 48543, 101560, 188814, 176550, 149088, 91143, 110917, 172466, 131650, 148782, 200681, 15471, 20968, 74521, 129253, 62964, 61640, 146384, 127719, 81422, 37412, 104021, 129474, 184792, 143219, 65961, 181632, 52879, 10256, 86975, 170096, 45678, 57816, 15027, 148154, 22703, 147108, 21318, 9124, 23093, 93067, 132064, 102302, 11180, 124403, 46662, 2045, 78063, 161104, 174124, 195846, 35758, 155471, 11182, 184258, 67621, 139012, 79125, 62695, 172380, 12891, 189138, 119310, 167162, 8518, 46575, 69677, 183860, 72308, 55931, 2628, 13090, 150721, 134847, 157642, 18652, 190635, 70386, 70828, 132802, 29995, 95859, 91129, 77539, 122995, 16253, 13610, 101683, 67890, 78268, 15953, 20249, 186992, 152836, 22263, 158978, 185524, 170146, 65750, 133589, 28574, 72509, 73457, 115196, 22023, 76739, 172958, 30575, 25595, 114143, 74363, 116644, 159784, 164254, 92803, 46686, 33041, 31223, 94687, 26475, 141529, 52249, 36643, 89058, 27050, 100696, 12670, 157137, 192660, 198155, 68189, 139838, 41399, 27964, 143128, 72166, 138882, 116957, 73398, 97526, 75224, 172477, 36482, 142827, 6437, 193092, 29338, 135043, 158579, 103204, 116119, 126485, 75344, 160612, 180627, 177720, 86946, 107091, 14880, 47093, 61884, 190753, 111831, 48198, 110277, 125619, 123065, 17459, 110051, 148026, 115789, 178140, 62409, 23900, 9430, 108284, 96014, 40852, 58229, 87511, 79141, 4530, 118430, 98578, 165141, 47474, 96746, 78635, 100422, 119960, 145751, 158022, 169564, 111467, 177399, 102572, 11247, 122539, 21684, 78069, 149657, 48073, 61868, 26799, 11883, 118926, 69067, 138389, 164806, 15818, 23119, 197995, 146170, 94737, 132346, 92392, 57747, 154938, 34965, 186974, 88863, 67047, 171082, 77792, 69194, 94836, 191022, 37399, 15330, 52420, 201158, 23691, 117771, 149730, 94017, 46226, 47186, 164579, 189552, 193148, 11207, 148541, 183172, 157900, 175947, 21722, 12931, 136280, 118171, 116795, 23705, 146962, 55815, 81405, 130295, 200187, 167629, 8120, 37321, 53913, 117807, 201273, 87667, 102515, 171221, 153686, 79456, 23170, 144408, 118432, 156262, 174457, 58871, 44630, 159059, 159033, 198327, 161747, 188828, 105211, 116878, 201161, 163931, 94446, 30750, 25992, 44657, 135495, 119619, 93350, 49480, 29073, 140680, 129789, 73766, 12411, 176916, 87627, 77229, 1969, 61759, 183091, 181256, 49190, 187665, 85448, 93247, 100676, 113960, 48958, 39724, 200720, 137799, 179954, 200645, 175698, 87655, 29371, 22309, 182684, 74541, 56067, 68356, 169060, 80782, 22495, 100515, 5154, 155055, 7922, 82249, 42351, 181246, 175180, 111484, 134274, 78121, 32558, 60221, 63398, 30199, 195246, 163153, 940, 98560, 128195, 111775, 32495, 42514, 185558, 61942, 53547, 58346, 42923, 67654, 9350, 137034, 90299, 35046, 157235, 107073, 42096, 132069, 22310, 66320, 163402, 80630, 163372, 181458, 194625, 193908, 101572, 85541, 121226, 200525, 80176, 112310, 15875, 181720, 30532, 16522, 20799, 199375, 14407, 96321, 5949, 202452, 74658, 63078, 165905, 112, 4054, 27094, 172154, 94441, 178801, 191503, 152385, 108224, 44798, 141255, 99739, 83018, 49333, 191459, 186963, 20975, 110996, 8527, 28107, 81883, 84237, 79681, 119794, 112865, 188053, 149417, 119612, 49348, 183604, 7524, 142047, 91937, 197068, 162978, 9498, 152793, 60575, 179715, 132062, 42080, 115089, 120580, 28808, 50975, 96924, 167948, 168584, 97492, 106527, 198241, 90830, 149293, 38301, 27950, 73608, 130802, 189464, 5263, 146120, 148576, 179933, 83536, 58715, 65355, 192947, 132782, 25185, 110008, 84537, 108584, 176066, 78141, 77994, 50122, 59743, 29231, 190607, 18143, 153768, 13825, 185543, 32110, 154267, 113565, 189936, 117786, 186184, 93085, 110878, 57902, 152756, 63598, 55715, 33336, 173602, 3009, 32128, 172025, 74404, 21596, 188760, 194435, 26959, 167864, 139811, 50771, 28402, 174428, 4346, 99548, 39234, 153756, 165957, 87775, 82207, 127675, 86387, 62110, 100738, 120739, 183457, 168730, 114921, 183286, 457, 148296, 171077, 11936, 63075, 128208, 72406, 106310, 194269, 182833, 121817, 133687, 39066, 96681, 90240, 138587, 36516, 138074, 39382, 113396, 62535, 36261, 101511, 165005, 113723, 191906, 91397, 9061, 73232, 170262, 173904, 111738, 66429, 15255])
data_select = data[:,data_idx]

data_select[0,:] = 1
data_select[1,:] = 1
data_select[2,:] = 1


print(data_select.shape)
select_ori_joint_prob_array = []

for arr in perm_array:
    p = ((data_select.T == arr).sum(axis=1) == 5).mean()
    select_ori_joint_prob_array.append(p)

select_ori_joint_prob_array = np.array(select_ori_joint_prob_array)

def prob_edited_attr(prob_a):
#     return prob_a[np.where(perm_array[:,1] == 1 )[0]]
    return prob_a[np.where(np.logical_and(perm_array[:,0] == 1 , perm_array[:,1] == 1, perm_array[:,2] == 1))[0]]
# print(prob_edited_attr(generated_joint_prob_array))
# prob_edited_attr(generated_joint_prob_array).sum()

select_ori_conditioned_prob_array = prob_edited_attr(select_ori_joint_prob_array) / prob_edited_attr(select_ori_joint_prob_array).sum()
ori_conditoned_prob_array = prob_edited_attr(ori_joint_prob_array) / prob_edited_attr(ori_joint_prob_array).sum()

np.mean((ori_conditoned_prob_array - select_ori_conditioned_prob_array)**2)**0.5

(5, 1000)


0.37114874211652576

In [None]:


ori_joint_prob_array = np.array(ori_joint_prob_array)

[np.where(np.logical_and(perm_array[:,0] == 0 , perm_array[:,1] == 0))[0]]

### Compare the conditioned probability from the generated dataset

In [238]:
res = np.load("./ori_list.npy").T
print(res.shape)
# att_name = np.array(["Bald",
# "Bangs",
# "Black_Hair",
# "Blond_Hair",
# "Brown_Hair",
# "Bushy_Eyebrows",
# "Eyeglasses",
# "Male",
# "Mouth_Slightly_Open",
# "Mustache",
# "No_Beard",
# "Pale_Skin",
# "Young"])
# att_name_list = np.array([7,12,9,10,0])
# print(att_name[att_name_list])

# res = res.squeeze()
# data = []
# for i in range(att_name_list.shape[0]):
#     if(att_name[att_name_list[i]] == "No_Beard"):
#         data.append(-1*res[:,i])
#     else:
#         data.append(res[:,i])
data = np.array(res).astype(int)
data = np.maximum(data, 0)
print(data)

perm_array = np.array(list(product([0, 1], repeat=5)))
generated_joint_prob_array = []

for arr in perm_array:
    p = ((data.T == arr).sum(axis=1) == 5).mean()
    generated_joint_prob_array.append(p)

generated_joint_prob_array = np.array(generated_joint_prob_array)
generated_joint_prob_array

(5, 5000)
[[0 0 0 ... 0 0 0]
 [1 1 1 ... 1 1 1]
 [0 0 0 ... 0 0 0]
 [1 1 1 ... 1 1 1]
 [0 0 0 ... 0 0 0]]


array([0.   , 0.   , 0.077, 0.   , 0.   , 0.   , 0.   , 0.   , 0.   ,
       0.   , 0.477, 0.   , 0.   , 0.   , 0.001, 0.   , 0.   , 0.   ,
       0.111, 0.022, 0.   , 0.   , 0.019, 0.002, 0.   , 0.   , 0.262,
       0.006, 0.   , 0.   , 0.023, 0.   ])

In [241]:
def prob_edited_attr(prob_a):
#     return prob_a[np.where(perm_array[:,1] == 1 )[0]]
    return prob_a[np.where(np.logical_and(perm_array[:,0] == 0 , perm_array[:,1] == 0))[0]]
print(prob_edited_attr(generated_joint_prob_array))
prob_edited_attr(generated_joint_prob_array).sum()

[0.    0.    0.077 0.    0.    0.    0.    0.   ]


0.077

In [227]:
generated_conditoned_prob_array = prob_edited_attr(generated_joint_prob_array) / prob_edited_attr(generated_joint_prob_array).sum()
ori_conditoned_prob_array = prob_edited_attr(ori_joint_prob_array) / prob_edited_attr(ori_joint_prob_array).sum()

In [231]:
ori_conditoned_prob_array

array([9.96840973e-01, 2.68853341e-04, 2.75574674e-03, 0.00000000e+00,
       6.72133351e-05, 0.00000000e+00, 6.72133351e-05, 0.00000000e+00])

In [229]:
ori_conditoned_prob_array 

array([9.96840973e-01, 2.68853341e-04, 2.75574674e-03, 0.00000000e+00,
       6.72133351e-05, 0.00000000e+00, 6.72133351e-05, 0.00000000e+00])

In [230]:
np.mean((generated_conditoned_prob_array - ori_conditoned_prob_array)**2)**0.5

0.08296898687815281