In [1]:
import pandas as pd
import numpy as np
dataset = pd.read_csv("list_attr_celeba.txt", sep = ' ', header=1, skiprows = 0)

data = np.vstack((dataset['Male'], dataset["Young"], dataset["Mustache"], (-dataset["No_Beard"]),\
                                                                dataset["Bald"]))

# np.where(data > 0, data, 0)
data=np.maximum(data, 0)
print("0:Male, 1:Young, 2: Mustache, 3:Beard, 4:Bald")
print(data.sum(axis=1))

0:Male, 1:Young, 2: Mustache, 3:Beard, 4:Bald
[ 84434 156734   8417  33441   4547]


In [2]:
data_ = data[:,np.where(data[0,:]==0)[0]]
print(data_.shape)
data_1 = data_[:,np.where(data_[1,:]==0)[0]]
print(data_1.shape)
data_2 = data_1[:,np.where(data_1[2,:]==0)[0]]
print(data_2.shape)
data_3 = data_2[:,np.where(data_2[3,:]==0)[0]]
print(data_3.shape)
print(len(np.where(data_3[4,:]==0)[0]))

(5, 118165)
(5, 14878)
(5, 14876)
(5, 14835)
14831


In [3]:
data.shape

(5, 202599)

In [4]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

n_steps = 200



# clear the param store in case we're in a REPL
pyro.clear_param_store()


def model(data):
    # sample f from the beta prior
    laten_sex = pyro.param('latent_sex', torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    
    laten_mus_1 = pyro.param("latent_female_mustache", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_mus_2 = pyro.param("latent_male_mustache", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_mus = [laten_mus_1,laten_mus_2]
    
    laten_beard_1 = pyro.param("latent_no_mustache_beard", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_beard_2 = pyro.param("latent_mustache_beard", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_beard = [laten_beard_1,laten_beard_2]

    
    laten_young = pyro.param('latent_young', torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    
    laten_bald_1 = pyro.param("latent_female_old_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_2 = pyro.param("latent_female_young_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_3 = pyro.param("latent_male_old_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald_4 = pyro.param("latent_male_young_bald", torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    laten_bald = [[laten_bald_1,laten_bald_2],[laten_bald_3, laten_bald_4]]
    
    
    # loop over the observed data
    for i in pyro.plate("data_loop", data.shape[1]):
        # observe datapoint i using the bernoulli likelihood
        sex = pyro.sample('sex_{}'.format(i), dist.Bernoulli(laten_sex),obs = data[0,i])        
        young = pyro.sample('young_{}'.format(i), dist.Bernoulli(laten_young),obs = data[1,i])        

        mus = pyro.sample("mustache_{}".format(i), dist.Bernoulli(laten_mus[sex.long()]), obs=data[2,i])
        beard = pyro.sample("beard_{}".format(i), dist.Bernoulli(laten_beard[mus.long()]), obs=data[3,i])
        
        
        bald = pyro.sample("bald_{}".format(i), dist.Bernoulli(laten_bald[sex.long()][young.long()]), \
                                                             obs=data[4,i])
        


def guide(data):
    pass

store_p = []                       
# setup the optimizer
adam_params = {"lr": 0.1, "betas": (0.80, 0.999)}
optimizer = Adam(adam_params)
pyro.clear_param_store()
# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    rand_list = np.random.randint(low = 0, high=data.shape[1], size=1000)
#     svi.step(data[:,rand_list.tolist()])
    svi.step(torch.tensor(data[:,rand_list.tolist()]).float())
    if(step == 60):
        adam_params = {"lr": 0.01, "betas": (0.80, 0.999)}
        optimizer = Adam(adam_params)
    if(step == 120):
        adam_params = {"lr": 0.001, "betas": (0.80, 0.999)}
        optimizer = Adam(adam_params)
    if(step %1 ==0):
        print('step: {}'.format(step))
#         for i in pyro.params.param_store.keys():
#             store_p[i] = pyro.params(i)
        p = pyro.get_param_store()
        store_p.append(p) 
        for (k,v) in p.items():
            print(k+": {:.3f}".format(v.item())) 




step: 0
latent_sex: 0.475
latent_female_mustache: 0.475
latent_male_mustache: 0.475
latent_no_mustache_beard: 0.475
latent_mustache_beard: 0.525
latent_young: 0.525
latent_female_old_bald: 0.475
latent_female_young_bald: 0.475
latent_male_old_bald: 0.475
latent_male_young_bald: 0.475
step: 1
latent_sex: 0.452
latent_female_mustache: 0.450
latent_male_mustache: 0.450
latent_no_mustache_beard: 0.450
latent_mustache_beard: 0.550
latent_young: 0.550
latent_female_old_bald: 0.450
latent_female_young_bald: 0.450
latent_male_old_bald: 0.450
latent_male_young_bald: 0.450
step: 2
latent_sex: 0.433
latent_female_mustache: 0.426
latent_male_mustache: 0.426
latent_no_mustache_beard: 0.426
latent_mustache_beard: 0.574
latent_young: 0.574
latent_female_old_bald: 0.426
latent_female_young_bald: 0.426
latent_male_old_bald: 0.425
latent_male_young_bald: 0.426
step: 3
latent_sex: 0.417
latent_female_mustache: 0.402
latent_male_mustache: 0.402
latent_no_mustache_beard: 0.403
latent_mustache_beard: 0.597


step: 29
latent_sex: 0.430
latent_female_mustache: 0.095
latent_male_mustache: 0.125
latent_no_mustache_beard: 0.139
latent_mustache_beard: 0.902
latent_young: 0.780
latent_female_old_bald: 0.098
latent_female_young_bald: 0.095
latent_male_old_bald: 0.130
latent_male_young_bald: 0.100
step: 30
latent_sex: 0.416
latent_female_mustache: 0.091
latent_male_mustache: 0.122
latent_no_mustache_beard: 0.138
latent_mustache_beard: 0.906
latent_young: 0.780
latent_female_old_bald: 0.094
latent_female_young_bald: 0.091
latent_male_old_bald: 0.128
latent_male_young_bald: 0.096
step: 31
latent_sex: 0.408
latent_female_mustache: 0.088
latent_male_mustache: 0.120
latent_no_mustache_beard: 0.136
latent_mustache_beard: 0.910
latent_young: 0.778
latent_female_old_bald: 0.090
latent_female_young_bald: 0.088
latent_male_old_bald: 0.127
latent_male_young_bald: 0.093
step: 32
latent_sex: 0.400
latent_female_mustache: 0.085
latent_male_mustache: 0.119
latent_no_mustache_beard: 0.135
latent_mustache_beard: 0.

step: 58
latent_sex: 0.403
latent_female_mustache: 0.040
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.952
latent_young: 0.776
latent_female_old_bald: 0.043
latent_female_young_bald: 0.040
latent_male_old_bald: 0.102
latent_male_young_bald: 0.049
step: 59
latent_sex: 0.395
latent_female_mustache: 0.039
latent_male_mustache: 0.100
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.952
latent_young: 0.775
latent_female_old_bald: 0.042
latent_female_young_bald: 0.039
latent_male_old_bald: 0.102
latent_male_young_bald: 0.048
step: 60
latent_sex: 0.391
latent_female_mustache: 0.039
latent_male_mustache: 0.099
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.953
latent_young: 0.773
latent_female_old_bald: 0.041
latent_female_young_bald: 0.038
latent_male_old_bald: 0.102
latent_male_young_bald: 0.047
step: 61
latent_sex: 0.388
latent_female_mustache: 0.038
latent_male_mustache: 0.098
latent_no_mustache_beard: 0.128
latent_mustache_beard: 0.

step: 87
latent_sex: 0.419
latent_female_mustache: 0.024
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.963
latent_young: 0.778
latent_female_old_bald: 0.025
latent_female_young_bald: 0.023
latent_male_old_bald: 0.104
latent_male_young_bald: 0.032
step: 88
latent_sex: 0.427
latent_female_mustache: 0.023
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.963
latent_young: 0.779
latent_female_old_bald: 0.024
latent_female_young_bald: 0.023
latent_male_old_bald: 0.104
latent_male_young_bald: 0.032
step: 89
latent_sex: 0.433
latent_female_mustache: 0.023
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.963
latent_young: 0.778
latent_female_old_bald: 0.024
latent_female_young_bald: 0.023
latent_male_old_bald: 0.104
latent_male_young_bald: 0.032
step: 90
latent_sex: 0.438
latent_female_mustache: 0.023
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.131
latent_mustache_beard: 0.

step: 116
latent_sex: 0.443
latent_female_mustache: 0.016
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.141
latent_mustache_beard: 0.965
latent_young: 0.768
latent_female_old_bald: 0.017
latent_female_young_bald: 0.016
latent_male_old_bald: 0.117
latent_male_young_bald: 0.027
step: 117
latent_sex: 0.441
latent_female_mustache: 0.016
latent_male_mustache: 0.097
latent_no_mustache_beard: 0.141
latent_mustache_beard: 0.965
latent_young: 0.770
latent_female_old_bald: 0.016
latent_female_young_bald: 0.016
latent_male_old_bald: 0.116
latent_male_young_bald: 0.027
step: 118
latent_sex: 0.436
latent_female_mustache: 0.016
latent_male_mustache: 0.096
latent_no_mustache_beard: 0.140
latent_mustache_beard: 0.965
latent_young: 0.771
latent_female_old_bald: 0.016
latent_female_young_bald: 0.015
latent_male_old_bald: 0.116
latent_male_young_bald: 0.026
step: 119
latent_sex: 0.429
latent_female_mustache: 0.015
latent_male_mustache: 0.096
latent_no_mustache_beard: 0.139
latent_mustache_beard

step: 145
latent_sex: 0.446
latent_female_mustache: 0.012
latent_male_mustache: 0.101
latent_no_mustache_beard: 0.126
latent_mustache_beard: 0.973
latent_young: 0.773
latent_female_old_bald: 0.012
latent_female_young_bald: 0.012
latent_male_old_bald: 0.108
latent_male_young_bald: 0.023
step: 146
latent_sex: 0.442
latent_female_mustache: 0.012
latent_male_mustache: 0.102
latent_no_mustache_beard: 0.127
latent_mustache_beard: 0.973
latent_young: 0.772
latent_female_old_bald: 0.012
latent_female_young_bald: 0.012
latent_male_old_bald: 0.108
latent_male_young_bald: 0.023
step: 147
latent_sex: 0.430
latent_female_mustache: 0.011
latent_male_mustache: 0.103
latent_no_mustache_beard: 0.129
latent_mustache_beard: 0.973
latent_young: 0.771
latent_female_old_bald: 0.012
latent_female_young_bald: 0.011
latent_male_old_bald: 0.109
latent_male_young_bald: 0.023
step: 148
latent_sex: 0.412
latent_female_mustache: 0.011
latent_male_mustache: 0.104
latent_no_mustache_beard: 0.130
latent_mustache_beard

step: 174
latent_sex: 0.415
latent_female_mustache: 0.009
latent_male_mustache: 0.103
latent_no_mustache_beard: 0.125
latent_mustache_beard: 0.974
latent_young: 0.769
latent_female_old_bald: 0.010
latent_female_young_bald: 0.009
latent_male_old_bald: 0.103
latent_male_young_bald: 0.021
step: 175
latent_sex: 0.417
latent_female_mustache: 0.009
latent_male_mustache: 0.103
latent_no_mustache_beard: 0.126
latent_mustache_beard: 0.974
latent_young: 0.772
latent_female_old_bald: 0.010
latent_female_young_bald: 0.009
latent_male_old_bald: 0.103
latent_male_young_bald: 0.021
step: 176
latent_sex: 0.419
latent_female_mustache: 0.009
latent_male_mustache: 0.102
latent_no_mustache_beard: 0.126
latent_mustache_beard: 0.974
latent_young: 0.775
latent_female_old_bald: 0.010
latent_female_young_bald: 0.009
latent_male_old_bald: 0.103
latent_male_young_bald: 0.021
step: 177
latent_sex: 0.418
latent_female_mustache: 0.009
latent_male_mustache: 0.102
latent_no_mustache_beard: 0.126
latent_mustache_beard

In [5]:
from pyro import poutine
from pyro.infer import config_enumerate, infer_discrete

def determined_model():
    
    laten_sex = pyro.param('latent_sex')
    
    laten_mus_1 = pyro.param("latent_female_mustache")
    laten_mus_2 = pyro.param("latent_male_mustache")
    laten_mus = [laten_mus_1,laten_mus_2]
    
    laten_beard_1 = pyro.param("latent_no_mustache_beard")
    laten_beard_2 = pyro.param("latent_mustache_beard")
    laten_beard = [laten_beard_1,laten_beard_2]

    
    laten_young = pyro.param('latent_young')
    
    laten_bald_1 = pyro.param("latent_female_old_bald")
    laten_bald_2 = pyro.param("latent_female_young_bald")
    laten_bald_3 = pyro.param("latent_male_old_bald")
    laten_bald_4 = pyro.param("latent_male_young_bald")
    laten_bald = [[laten_bald_1,laten_bald_2],[laten_bald_3, laten_bald_4]]
    

    with pyro.plate("a_plate", size=1, dim=-2):

        
        sex = pyro.sample('sex', dist.Bernoulli(laten_sex))        
        young = pyro.sample('young', dist.Bernoulli(laten_young))        
#         with pyro.plate("b_plate", size=1, dim=-1):

        mus = pyro.sample("mustache", dist.Bernoulli(laten_mus[sex.long()]))
        beard = pyro.sample("beard", dist.Bernoulli(laten_beard[mus.long()]))


        bald = pyro.sample("bald", dist.Bernoulli(laten_bald[sex.long()][young.long()]))

        
def make_log_joint(model):

    def _log_joint(data, *args, **kwargs):

        conditioned_model = poutine.condition(model, data=data)

        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)

        return trace.log_prob_sum()

    return _log_joint

 

scale_log_joint = make_log_joint(determined_model)

print(scale_log_joint({\
                       "sex": torch.tensor(0.0),
                       "young": torch.tensor(0.),
                       "mustache": torch.tensor(0.0),
                       "beard": torch.tensor(.0),
                       "bald":torch.tensor(.0)\
}).exp())


tensor(0.1136, grad_fn=<ExpBackward>)


In [6]:
from itertools import product

In [9]:
perm_array = np.array(list(product([0, 1], repeat=5))).astype(float)

def joint_prob_graph(arr):
    prob = scale_log_joint({\
                       "sex": torch.tensor(arr[0]),
                       "young": torch.tensor(arr[1]),
                       "mustache": torch.tensor(arr[2]),
                       "beard": torch.tensor(arr[3]),
                       "bald":torch.tensor(arr[4])\
                           }).exp()
    return prob.data.numpy()

In [10]:
prob_array = []

for arr in perm_array:
    p = joint_prob_graph(arr)
    prob_array.append(p)

prob_array = np.array(prob_array)
prob_array

array([1.13584166e-01, 9.29917576e-04, 1.75871010e-02, 1.43986216e-04,
       2.63495161e-05, 2.15724418e-07, 9.49200690e-04, 7.77113955e-06,
       3.79276066e-01, 2.82775077e-03, 5.87262001e-02, 4.37842175e-04,
       8.79853341e-05, 6.55988128e-07, 3.16953600e-03, 2.36309609e-05,
       6.72165909e-02, 9.18747971e-03, 1.04076564e-02, 1.42256741e-03,
       2.16826589e-04, 2.96368778e-05, 7.81084357e-03, 1.06762283e-03,
       2.49527842e-01, 5.41273994e-03, 3.86362950e-02, 8.38095724e-04,
       8.04924354e-04, 1.74603610e-05, 2.89961588e-02, 6.28982583e-04])

In [11]:

all_param = pyro.get_param_store()
print(all_param)
all_param.save('./pyro_params')

<pyro.params.param_store.ParamStoreDict object at 0x0000029F48F84160>


In [12]:

pyro.clear_param_store()
graph_params = pyro.params.param_store.ParamStoreDict()
graph_params.load('./pyro_params')


In [13]:
for (k,v) in graph_params.items():
    print(k+": {:.3f}".format(v.item())) 

latent_sex: 0.422
latent_female_mustache: 0.007
latent_male_mustache: 0.094
latent_no_mustache_beard: 0.134
latent_mustache_beard: 0.973
latent_young: 0.769
latent_female_old_bald: 0.008
latent_female_young_bald: 0.007
latent_male_old_bald: 0.120
latent_male_young_bald: 0.021
