In [151]:
import sys
sys.path.insert(0, '../')

import math
import os
import numpy as np
from scipy.special import gammaln
import torch
import beer

import warnings
#warnings.filterwarnings('ignore')


# For plotting.
from bokeh.io import show, output_notebook, export_png
from bokeh.plotting import figure, gridplot
from bokeh.models import Arrow, OpenHead, NormalHead, VeeHead
from bokeh.events import Tap, Pan, PanEnd, PanStart
from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application
from bokeh.models import ColumnDataSource
from bokeh.models.widgets import Button, RadioButtonGroup, Dropdown, Slider
from bokeh.layouts import layout
from bokeh.palettes import viridis


output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Model

\begin{align}
p(h_k) &= \mathcal{N}\big(h_k | 0, I\big)  \\
p(\psi_k | W, \mu, h_k, \tau) &= \mathcal{N}\big(\psi_k | W^T h_k + \mu, \frac{1}{\tau} I \big) \\
p(\psi_k | W, \mu, \tau) &= \mathcal{N}\big(\psi_k | \mu, W^T W + \frac{1}{\tau} I \big) \\
\pi_{k,i} &= \frac{e^{\psi_{k,i}}}{1 + \sum_{j=1}^{D-1} e^{\psi_{k,j}}}
\end{align}


In [98]:
x_range = y_range = (-5, 5)
nsamples = 100
mu = np.array([1, -.5])
weights = np.array([[1., 2.]])
W = weights
tau = 10
C = W.T @ W + (1./tau) * np.identity(2)
L = np.linalg.cholesky(C)
psis = mu + np.random.randn(nsamples, 2) @ L.T
pi = np.exp(psis) / (1 + np.exp(psis).sum(axis=1)[:, None])

fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
fig1.circle(psis[:, 0], psis[:, 1], alpha=.5)

fig2 = figure(width=400, height=400, x_range=(0, 1), y_range=(0, 1))
fig2.line([1, 0], [0, 1], line_width=2, color='black')
fig2.circle(pi[:, 0], pi[:, 1], alpha=.5)

show(gridplot([[fig1, fig2]]))

In [99]:
def generate_data(pi, nsamples=100):
    pi = np.c_[pi, 1 - pi.sum(axis=-1)]
    data = [np.random.choice(pi.shape[-1], size=nsamples, p=pi[i])
            for i in range(pi.shape[0])]
    data = torch.LongTensor(data)
        
    # Compute the accumulated sufficient statistics of the 
    # samples for each model.
    T = torch.cat([
        beer.utils.onehot(data[i], max_label=3, 
                          dtype=data.dtype, device=data.device)[:, :-1].sum(dim=0)[None]
        for i in range(data.shape[0])
    ]).float()
    counts = torch.FloatTensor([nsamples] * len(T))

    T = torch.cat([T, counts[:, None]], dim=-1)
    return T

In [100]:
class CategoricalLogLikekilhood:
    
    def __call__(self, nparams, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        lnorm = torch.log(1 + nparams.exp().sum(dim=-1))
        return (nparams * stats).sum(dim=-1) - counts * lnorm
    
    def argmax(self, stats_counts):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        pi = stats / counts[:, None]
        pi_K = pi.sum(dim=1)
        eps = torch.tensor(1e-6, dtype=stats.dtype,
                            device=stats.device)
        remainder = torch.max(eps, 1 - pi_K)
        return torch.log(eps + pi) - torch.log(remainder)[:, None]
    
    def hessian(self, nparams, stats_counts, mode='diagonal'):
        stats = stats_counts[:, :-1]
        counts = stats_counts[:, -1]
        mean = torch.exp(nparams) / (1 +  torch.exp(nparams).sum(dim=1)[:, None])
        I = torch.eye(nparams.shape[-1], dtype=nparams.dtype,
                      device=nparams.device)
        hessians = (1e-6 + I)[None] * mean[:, :, None]
        hessians -= (mean[:, :, None] * mean[:, None, :])
        hessians *= -counts[:, None, None]
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

# Inference

In [153]:
dim = 2
s_dim = 1

with torch.no_grad():
    M = torch.zeros(s_dim, dim).float()
    global_mean = torch.zeros(dim).float()
    prec = torch.tensor(1.).float()
    scale = torch.ones(s_dim).float() 
    
model = beer.models.GeneralizedSubspaceModel.create(
    CategoricalLogLikekilhood(), 
    M, 
    global_mean, 
    prec, 
    scale,
    hessian_type='full',
    noise_std=0.1,
    prior_strength=1
).double()

data = generate_data(pi, nsamples=1000).double()

In [154]:
epochs = 100
params = model.mean_field_factorization()
params = [[model.mean], [model.precision], [model.subspace]]
optim = beer.BayesianModelOptimizer(params, lrate=1.)

epochs = list(range(1, epochs + 1))
elbos = []
for epoch in epochs:
    optim.init_step()
    elbo = beer.evidence_lower_bound(model, data, max_iter=200, conv_threshold=1e-5)
    elbo.backward()
    elbos.append(float(elbo) / float(torch.sum(data[:, -1])))
    optim.step()
    
fig = figure(width=400, height=400)
fig.line(epochs, elbos)
show(fig)

In [155]:
mean = -np.array([0, 0])
cov = 1 * np.identity(2)

cov[1, 1] = 5
cov[0, 0] = .1

def det_jacobian(x, y):
    delta = x** 2 + y**2
    det_jac = delta / (np.sqrt(delta) * delta)
    return det_jac

def plot_normal_polar(fig_mag, fig_angle, mean, cov, r_range, a_range, res=1000, r_legend='', a_legend='', **kwargs):
    prec = np.linalg.inv(cov)
    r = np.linspace(r_range[0], r_range[1], res)
    alpha = np.linspace(a_range[0], a_range[1], res)
    R, A = np.meshgrid(r, alpha)
    X, Y = R * np.cos(A), R * np.sin(A)
    Z = - .5 * (prec[0, 0] * (X - mean[0])**2 + prec[1, 1] * (Y - mean[1])**2) 
    Z += prec[0, 1] * (X - mean[0]) * (Y - mean[1])
    Z -= .5 * np.log(np.linalg.det(cov))
    Z = np.exp(Z) * det_jacobian(X, Y)
    
    dr = r_range[1] - r_range[0] 
    r_vals = Z.sum(axis=0) 
    norm = r_vals.sum() * (dr / Z.shape[1])
    r_vals /= norm
    fig_mag.line(r, r_vals, legend=r_legend, **kwargs)
    
    dalpha = a_range[1] - a_range[0] 
    alpha_vals = Z.sum(axis=1) 
    norm = alpha_vals.sum() * (dalpha / Z.shape[0])
    alpha_vals /= norm
    fig_angle.line(range(len(alpha_vals)), alpha_vals, legend=a_legend, **kwargs)

In [156]:
p_M, p_U = model.subspace.prior.to_std_parameters()
p_U_diag = p_U.diag().numpy()
p_M, p_U = p_M.numpy(), p_U.numpy()
M, U = model.subspace.posterior.to_std_parameters()
U_diag = U.diag().numpy()
M, U = M.numpy(), U.numpy()
p_mean, p_cov = model.mean.prior.to_std_parameters()
p_cov = p_cov * torch.eye(2, dtype=p_cov.dtype)
mean, cov = model.mean.posterior.to_std_parameters()
cov = cov * torch.eye(2, dtype=cov.dtype)

fig_mean = figure(width=350, height=350, x_range=(-2, 2), y_range=(-2, 2))
fig_mean.xgrid.visible = False
fig_mean.ygrid.visible = False
fig_mean.xaxis.ticker = [-1,  0, 1]
fig_mean.yaxis.ticker = [-1,  0, 1]
fig_mean.xaxis.major_label_text_font_size = '16pt'
fig_mean.xaxis.axis_label_text_font_size = '16pt'
fig_mean.yaxis.major_label_text_font_size = '16pt'
fig_mean.yaxis.axis_label_text_font_size = '16pt'
fig_mean.min_border_left = 20
fig_mean.min_border_right = 20
fig_mean.xaxis.axis_label = 'μ'


plotting.plot_normal(fig_mean, p_mean.numpy(), p_cov.numpy(), n_std_dev=2, fill_alpha=.15, 
                     line_alpha=0., color='blue', legend='p(μ)')
plotting.plot_normal(fig_mean, mean.numpy(), cov.numpy(), n_std_dev=2, fill_alpha=.2, 
                     line_alpha=0., color='red', legend='q(μ)')
fig_mean.cross(mu[0], mu[1], color='green', size=10, line_width=2, legend='μ*')

fig_mean.legend.label_text_font_size = '16pt'
fig_mean.legend.location = 'top_left'


r_range = (0.01, 3)
res = 1000
fig_mag = figure(width=350, height=350, x_range=(0, r_range[1]), y_range=(0, 10))
fig_mag.xaxis.ticker = [0, 1, 2, 3]
fig_mag.yaxis.visible = False
fig_mag.xgrid.visible = False
fig_mag.ygrid.visible = False
fig_mag.xaxis.major_label_text_font_size = '16pt'
fig_mag.xaxis.axis_label_text_font_size = '16pt'
#fig_mag.yaxis.major_label_text_font_size = '16pt'
fig_mag.yaxis.axis_label_text_font_size = '16pt'
fig_mag.min_border_left = 20
fig_mag.min_border_right = 20
fig_mag.xaxis.axis_label = 'r'

fig_angle = figure(width=350, height=350, x_range=(0, res), y_range=(0, 5))
fig_angle.yaxis.visible = False
fig_angle.xaxis.ticker = [0, res//4, res//2, (3 * res) // 4, res]
fig_angle.xaxis.major_label_overrides = {0: '-π', res//4:'-π/2', res//2: '0', 
                                         (3 * res) // 4:'π/2', res: 'π'}
fig_angle.xgrid.visible = False
fig_angle.ygrid.visible = False
fig_angle.xaxis.major_label_text_font_size = '16pt'
fig_angle.xaxis.axis_label_text_font_size = '16pt'
fig_angle.min_border_left = 20
fig_angle.min_border_right = 20
fig_angle.xaxis.axis_label = 'α'

angle = int(res * np.arctan2(weights[0, 1], weights[0, 0]) / (2 * np.pi))
neg_angle = int(res * (np.arctan2(weights[0, 1], weights[0, 0]) + np.pi) / (2 * np.pi))
x = np.zeros(100) + angle
neg_x = np.zeros(100) + neg_angle
y = np.linspace(0, 10, 100)
fig_angle.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
              legend='±α*')
fig_angle.line(neg_x, y, line_width=2, alpha=.75, color='green', line_dash='dashed')

norm = np.linalg.norm(weights[0])
x = np.zeros(100) + norm
y = np.linspace(0, 20, 100)
fig_mag.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
             legend='r*')

cov = np.eye(2) * p_U_diag[0]
plot_normal_polar(fig_mag, fig_angle, p_M[0], cov, r_range=r_range, 
                  a_range=(0, 2 * np.pi), res=res, color='blue', alpha=.75, 
                  line_width=2, a_legend='p(α)', r_legend='p(r)')

cov = np.eye(2) * U_diag[0]
plot_normal_polar(fig_mag, fig_angle, M[0], cov, r_range=(0.01, 10), 
                  a_range=(0, 2 * np.pi), res=res, color='red', alpha=.75, 
                  line_width=2, a_legend='q(α)', r_legend='q(r)')

fig_angle.legend.location = 'top_left'
fig_angle.legend.label_text_font_size = '16pt'
fig_mag.legend.location = 'top_left'
fig_mag.legend.label_text_font_size = '16pt'

p_shape, p_rate = model.precision.prior.to_std_parameters()
p_shape, p_rate = p_shape.numpy(), p_rate.numpy()
shape, rate = model.precision.posterior.to_std_parameters()
shape, rate = shape.numpy(), rate.numpy()

t_range = (0.01, 12)
t = np.linspace(t_range[0], t_range[1], 10000)
log_p_tau = p_shape * np.log(p_rate) + (p_shape -1) * np.log(t)  -p_rate * t
log_p_tau -= gammaln(p_shape)
p_tau = np.exp(log_p_tau)
log_q_tau = shape * np.log(rate) + (shape -1) * np.log(t)  -rate * t
log_q_tau -= gammaln(shape)
q_tau = np.exp(log_q_tau)

fig_tau = figure(width=350, height=350, x_range=(0, t_range[1]), y_range=(0, 2.5))
fig_tau.yaxis.visible = False
fig_tau.xaxis.ticker = [0, 5, 10]
fig_tau.xgrid.visible = False
fig_tau.ygrid.visible = False
fig_tau.xaxis.major_label_text_font_size = '16pt'
fig_tau.xaxis.axis_label_text_font_size = '16pt'
fig_tau.min_border_left = 20
fig_tau.min_border_right = 20
fig_tau.xaxis.axis_label = 'τ'

x = np.zeros(100) + tau
y = np.linspace(0, 2.5, 100)
fig_tau.line(x, y, line_width=2, alpha=.75, color='green', line_dash='dashed',
              legend='τ*')

fig_tau.line(t, p_tau, line_width=2, alpha=.75, color='blue', legend='p(τ)')
fig_tau.line(t, q_tau, line_width=2, alpha=.75, color='red', legend='q(τ)')
fig_tau.legend.location = 'top_left' 
fig_tau.legend.label_text_font_size = '16pt'

fig = gridplot([[fig_mean, fig_tau], [fig_mag, fig_angle]], toolbar_location=None)
#outpath = '/Users/lucasondel/Desktop/gsm_M100_N1000.png'
#export_png(fig, outpath)
show(fig)

In [96]:
cache = model.latent_posteriors(data=data, conv_threshold=1e-8)
means, covs = cache['params_mean'], cache['params_cov']

max_psis = model.llh_func.argmax(data).numpy()

fig1 = figure(width=400, height=400, x_range=x_range, y_range=y_range)
#fig1.cross(means.numpy()[:, 0], means.numpy()[:, 1], color='red')
for m, c in zip(means.numpy(), covs.numpy()):
    plotting.plot_normal(fig1, m, c, n_std_dev=1, fill_alpha=.1, color='red')
fig1.circle(psis[:, 0], psis[:, 1], alpha=.5, color='blue')
fig1.cross(max_psis[:, 0], max_psis[:, 1], color='green', size=7)
show(fig1)