Permalink
Fetching contributors…
Cannot retrieve contributors at this time
106 lines (91 sloc) 4.66 KB
from __future__ import absolute_import
from __future__ import print_function
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd.scipy.stats.norm as norm
from autograd.misc.optimizers import adam, sgd
# same BBSVI function!
from black_box_svi import black_box_variational_inference
if __name__ == '__main__':
# Specify an inference problem by its unnormalized log-density.
# it's difficult to see the benefit in low dimensions
# model parameters are a mean and a log_sigma
np.random.seed(42)
obs_dim = 20
Y = np.random.randn(obs_dim, obs_dim).dot(np.random.randn(obs_dim))
def log_density(x, t):
mu, log_sigma = x[:, :obs_dim], x[:, obs_dim:]
sigma_density = np.sum(norm.logpdf(log_sigma, 0, 1.35), axis=1)
mu_density = np.sum(norm.logpdf(Y, mu, np.exp(log_sigma)), axis=1)
return sigma_density + mu_density
# Build variational objective.
D = obs_dim * 2 # dimension of our posterior
objective, gradient, unpack_params = \
black_box_variational_inference(log_density, D, num_samples=2000)
# Define the natural gradient
# The natural gradient of the ELBO is the gradient of the elbo,
# preconditioned by the inverse Fisher Information Matrix. The Fisher,
# in the case of a diagonal gaussian, is a diagonal matrix that is a
# simple function of the variance. Intuitively, statistical distance
# created by perturbing the mean of an independent Gaussian is
# determined by how wide the distribution is along that dimension ---
# the wider the distribution, the less sensitive statistical distances is
# to perturbations of the mean; the narrower the distribution, the more
# the statistical distance changes when you perturb the mean (imagine
# an extremely narrow Gaussian --- basically a spike. The KL between
# this Gaussian and a Gaussian $\epsilon$ away in location can be big ---
# moving the Gaussian could significantly reduce overlap in support
# which corresponds to a greater statistical distance).
#
# When we want to move in directions of steepest ascent, we multiply by
# the inverse fisher --- that way we make quicker progress when the
# variance is wide, and we scale down our step size when the variance
# is small (which leads to more robust/less chaotic ascent).
def fisher_diag(lam):
mu, log_sigma = unpack_params(lam)
return np.concatenate([np.exp(-2.*log_sigma),
np.ones(len(log_sigma))*2])
# simple! basically free!
natural_gradient = lambda lam, i: (1./fisher_diag(lam)) * gradient(lam, i)
# function for keeping track of callback ELBO values (for plotting below)
def optimize_and_lls(optfun):
num_iters = 200
elbos = []
def callback(params, t, g):
elbo_val = -objective(params, t)
elbos.append(elbo_val)
if t % 50 == 0:
print("Iteration {} lower bound {}".format(t, elbo_val))
init_mean = -1 * np.ones(D)
init_log_std = -5 * np.ones(D)
init_var_params = np.concatenate([init_mean, init_log_std])
variational_params = optfun(num_iters, init_var_params, callback)
return np.array(elbos)
# let's optimize this with a few different step sizes
elbo_lists = []
step_sizes = [.1, .25, .5]
for step_size in step_sizes:
# optimize with standard gradient + adam
optfun = lambda n, init, cb: adam(gradient, init, step_size=step_size,
num_iters=n, callback=cb)
standard_lls = optimize_and_lls(optfun)
# optimize with natural gradient + sgd, no momentum
optnat = lambda n, init, cb: sgd(natural_gradient, init, step_size=step_size,
num_iters=n, callback=cb, mass=.001)
natural_lls = optimize_and_lls(optnat)
elbo_lists.append((standard_lls, natural_lls))
# visually compare the ELBO
plt.figure(figsize=(12,8))
colors = ['b', 'k', 'g']
for col, ss, (stand_lls, nat_lls) in zip(colors, step_sizes, elbo_lists):
plt.plot(np.arange(len(stand_lls)), stand_lls,
'--', label="standard (adam, step-size = %2.2f)"%ss, alpha=.5, c=col)
plt.plot(np.arange(len(nat_lls)), nat_lls, '-',
label="natural (sgd, step-size = %2.2f)"%ss, c=col)
llrange = natural_lls.max() - natural_lls.min()
plt.ylim((natural_lls.max() - llrange*.1, natural_lls.max() + 10))
plt.xlabel("optimization iteration")
plt.ylabel("ELBO")
plt.legend(loc='lower right')
plt.title("%d dimensional posterior"%D)
plt.show()