In [1]:
import autograd.numpy as np
import autograd.scipy as sp
from autograd.scipy import special
from autograd import grad, hessian, hessian_vector_product
import matplotlib.pyplot as plt
from copy import deepcopy

from scipy import optimize

from valez_finite_VI_lib import initialize_parameters, generate_data, compute_elbo
from generic_optimization_lib import unpack_params, pack_params

In [2]:
def check_approx_eq(x, y, tol=1e-12):
    return np.max(np.abs(x - y)) < tol

In [3]:
np.random.seed(12321) # this is a seed where VI works well

alpha = 10 # IBP parameter

Num_samples = 2000 # sample size
D = 2 # dimension
# so X will be a N\times D matrix

sigma_A = 100

sigma_eps = .1 # variance of noise

K_inf = 3 # take to be large for a good approximation to the IBP

Pi, Z, mu, A, X = generate_data(Num_samples, D, K_inf, sigma_A, sigma_eps, alpha)

K_approx = deepcopy(K_inf) # variational truncation

tau, nu, phi_mu, phi_var = initialize_parameters(Num_samples, D, K_approx)
nu_init = np.round(nu * (nu >= 0.9) + nu * (nu <= 0.1)) + nu * (nu >= 0.1) * (nu <= 0.9)
params = pack_params(deepcopy(tau), deepcopy(phi_mu), deepcopy(phi_var), deepcopy(nu))
params_init = deepcopy(params)


AttributeError: module 'autograd.scipy.special' has no attribute 'logit'

In [4]:
class DataSet(object):
    def __init__(self, X, K_approx, alpha, sigma_eps, sigma_A):
        self.X = X
        self.K_approx = K_approx
        self.alpha = alpha
        self.data_shape = {'D': X.shape[1], 'N': X.shape[0] , 'K':K_approx}
        self.sigmas = {'eps': sigma_eps, 'A': sigma_A}
        #self.nu = np.empty((X.shape[0], K_approx))
        
    def wrapped_kl(self, params, verbose=False):
        tau, phi_mu, phi_var, nu = \
            unpack_params(params, self.data_shape['K'], self.data_shape['D'], self.data_shape['N'])
        elbo = compute_elbo(tau, nu, phi_mu, phi_var, self.X, self.sigmas, self.alpha)[0]
        if verbose:
            print -1 * elbo
        return -1 * elbo
        


In [6]:
data_set = DataSet(X, K_approx, alpha, sigma_eps, sigma_A)
data_set.wrapped_kl(params)

4858169.677175656

In [8]:
get_kl_grad = grad(data_set.wrapped_kl)
get_kl_hvp = hessian_vector_product(data_set.wrapped_kl)

kl_grad = get_kl_grad(params)
kl_hvp = get_kl_hvp(params, kl_grad)

print kl_grad
print kl_hvp

[-1708.20647253   173.71298395 -2738.82010207 ...,    35.92961572
    60.1122894      3.58939   ]
[ -4534516.47123467   1739212.97942215 -10179841.1094394  ...,
  12136800.26632632  13880040.31309656   3663558.89338671]


In [9]:
vb_opt = optimize.minimize(
    lambda params: data_set.wrapped_kl(params, verbose=True),
    params_init, method='trust-ncg', jac=get_kl_grad, hessp=get_kl_hvp,
    tol=1e-6, options={'maxiter': 100, 'disp': True, 'gtol': 1e-6 })



4858169.67718
4507580.4666
3866133.91997
2821738.85538
1683994.95868
1572440.16592
1385585.86464
1058080.17789
713115.448692
673867.254049
652979.423718
72339145292.3
525871.998052
31082012003.7
482562.78367
714152.736789
458236.23201
429368.983794
119967645.478
408306.904637
386204.196115
169542542.653
367182.946403
343507.209254
859416108.077
332680.279592
324780.309084
307485.813777
283213.501556
282515.817251
286629.181175
275846.041605
266747.367467
261015.709604
253622.764708
1499278.5144
245355.467608
237294.314446
231402.208132
223678.672573
220216.055715
221615.341947
217173.519537
215506.618202
211280.097351
207516.384078
206734.667234
205353.683161
204039.906242
198806.266592
208424.378821
196616.541605
195293.590273
192710.009355
191000.225662
186758.43645
185592.614838
181473.428782
179311.92317
176507.956275
171955.664224
168771.582119
164748.648141
163546.096531
160510.48432
156485.361121
153460.53815
150639.49267
153115.803271
150547.302024
150160.572957
149916.470282
1

In [18]:
tau, phi_mu, phi_var, nu = unpack_params(vb_opt.x, D=D, K_approx=K_approx, Num_samples=Num_samples)
print phi_mu.transpose()
print A

[[ 7.27434354 -2.006652  ]
 [ 8.3221905  -1.20240182]
 [ 8.55562752 -0.59554079]]
[[ 13.17491681  -7.83796064]
 [  0.1543092    0.97045299]
 [ 10.94654306   3.54649255]]
