In [94]:
import sys
sys.path.insert(0, '../')

import math
import pickle
import numpy as np
import torch
import beer

import warnings
warnings.filterwarnings('ignore')


# For plotting.
from bokeh.io import show, output_notebook, push_notebook
from bokeh.plotting import figure, gridplot
from bokeh.models import Arrow, OpenHead, NormalHead, VeeHead
from bokeh.events import Tap, Pan, PanEnd, PanStart
from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application
from bokeh.models import ColumnDataSource, LabelSet
from bokeh.models.widgets import Button, RadioButtonGroup, Dropdown, Slider
from bokeh.layouts import layout
output_notebook()

# Convenience functions for plotting.
import plotting

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [95]:
with open('timit_hmm_e10.mdl', 'rb') as f:
    hmm = pickle.load(f)
phone_mixture_set = hmm.modelset.original_modelset.modelsets[1]
len(phone_mixture_set.modelset)

564

In [96]:
hmm.start_pdf

{'sil': 0,
 'aa': 5,
 'ae': 8,
 'ah': 11,
 'ao': 14,
 'aw': 17,
 'ax': 20,
 'ay': 23,
 'b': 26,
 'ch': 29,
 'cl': 32,
 'd': 35,
 'dh': 38,
 'dx': 41,
 'eh': 44,
 'el': 47,
 'en': 50,
 'epi': 53,
 'er': 56,
 'ey': 59,
 'f': 62,
 'g': 65,
 'hh': 68,
 'ih': 71,
 'ix': 74,
 'iy': 77,
 'jh': 80,
 'k': 83,
 'l': 86,
 'm': 89,
 'n': 92,
 'ng': 95,
 'ow': 98,
 'oy': 101,
 'p': 104,
 'r': 107,
 's': 110,
 'sh': 113,
 't': 116,
 'th': 119,
 'uh': 122,
 'uw': 125,
 'v': 128,
 'vcl': 131,
 'w': 134,
 'y': 137,
 'z': 140,
 'zh': 143}

In [97]:
nphones = 47
nstates = 3
ncomps = 4
dim = 39 

means_precisions = phone_mixture_set.modelset.means_precisions

stats = []
for i in range(0, nphones):
    hmm_stats = []
    for j in range(0, nstates * ncomps):
        idx = i * (nstates * ncomps) + j
        g_stats = means_precisions[idx].posterior.natural_parameters - means_precisions[idx].prior.natural_parameters
        g_stats[-2] = -2 * g_stats[-2]
        hmm_stats.append(g_stats[None, :-1])
    stats.append(torch.cat(hmm_stats, dim=-1))
    
stats_counts = torch.cat(stats, dim=0).double()

In [167]:
llh = PhoneLogLikelihood(nphones, dim)
r_stats_counts = llh._reshape(stats_counts)
stats, counts = llh._extract_stats_counts(r_stats_counts)
stats_counts.shape, r_stats_counts.shape, stats.shape, counts.shape

(torch.Size([47, 948]),
 torch.Size([47, 12, 79]),
 torch.Size([47, 12, 78]),
 torch.Size([47, 12]))

In [165]:
ncomps * nstates * (2 * dim + 1)

948

In [98]:
class PhoneLogLikelihood:
    
    def __init__(self, nphones, dim):
        self.nphones = nphones
        self.dim = dim
        
    def get_counts(self, stats_counts):
        r_stats_counts = self._reshape(stats_counts, has_counts=True)
        _, counts = self._extract_stats_counts(r_stats_counts)
        return counts
        
    def __call__(self, psi, stats_counts, norm=False):
        r_psi = self._reshape(psi, has_counts=False)
        r_stats_counts = self._reshape(stats_counts, has_counts=True)
        stats, counts = self._extract_stats_counts(r_stats_counts)
        np1 = r_psi[:, :, :self.dim].exp()
        np2 = r_psi[:, :, self.dim:] 
        nparams = torch.cat([np1, np2], dim=-1)

        precs = np1
        means = np2 / np1
        log_norm = counts * torch.sum(.5 * np1 * (np2 ** 2) - precs.log(), dim=-1)
        log_norm += .5 * dim * math.log(2 * math.pi)

        retval = (nparams * stats).sum(dim=(1, 2)) - log_norm.sum(dim=-1)
        if norm:
            retval /= counts.sum()
        return retval
        
    def _extract_stats_counts(self, stats_counts):
        counts = stats_counts[:, :, -1]
        stats = stats_counts[:, :, :-1]
        return stats, counts

    def _reshape(self, array, has_counts=True):
        size = self.dim * 2 + 1 if has_counts else self.dim * 2
        return array.reshape(self.nphones, -1, size)
        
    def argmax(self, stats_counts):
        stats_counts = self._reshape(stats_counts, has_counts=True)
        stats, counts = self._extract_stats_counts(stats_counts)
        means = stats[:, :, dim:] / counts[:, :, None]
        variances = -2 * (stats[:, :, :dim] / counts[:, :, None]) - means**2
        precs = 1/variances
        psi = torch.cat([precs.log(), precs * means], dim=-1)
        return psi.reshape(self.nphones, -1)
    
    def hessian(self, psi, stats_counts, mode='diagonal'):
        r_psi = self._reshape(psi, has_counts=False)
        r_stats_counts = self._reshape(stats_counts, has_counts=True)
        stats, counts = self._extract_stats_counts(r_stats_counts)
        psi1, psi2 = r_psi[:, :, :self.dim], r_psi[:, :, self.dim:]
        e_psi1 = psi1.exp()
        T1, T2 = stats[:, :, :self.dim], stats[:, :, self.dim:]
        H_psi1 = T1 * e_psi1 - counts[:, :, None] * (psi2**2) / (2. * e_psi1)
        H_psi2 = - counts[:, :, None] / e_psi1
        H_diag = torch.cat([H_psi1, H_psi2], dim=-1) 
        H_diag = H_diag.reshape(nphones, -1)
        
        if mode == 'full':
            raise NotImplementedError('full Hessian matrix computation')
        elif mode == 'diagonal':
            return H_diag
        elif mode == 'scalar':
            return H_diag.mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

In [99]:
llh = PhoneLogLikelihood(nphones, dim)
amax = llh.argmax(stats_counts)
avg_prec = (1/amax.var(dim=0)).max() / amax.shape[-1]

In [138]:
obs_dim = nstates * ncomps * 2 * dim
s_dim = 100
with torch.no_grad():
    M = torch.zeros(s_dim, obs_dim) 
    global_mean = torch.zeros(obs_dim)

model_hps3 = beer.models.GeneralizedSubspaceModel.create(
    PhoneLogLikelihood(nphones, dim), 
    M, 
    global_mean, 
    1,
    1.,
    prior_strength=1,
    hyper_prior_strength=1e-3,
    hessian_type='diagonal',
    noise_std=0.1
).double()

model_hps2 = beer.models.GeneralizedSubspaceModel.create(
    PhoneLogLikelihood(nphones, dim), 
    M, 
    global_mean, 
    1.,
    1.,
    prior_strength=1,
    hyper_prior_strength=1e-2,
    hessian_type='diagonal',
    noise_std=0.1
).double()

model_hps1 = beer.models.GeneralizedSubspaceModel.create(
    PhoneLogLikelihood(nphones, dim), 
    M, 
    global_mean, 
    1.,
    1.,
    prior_strength=1,
    hyper_prior_strength=1e-1,
    hessian_type='diagonal',
    noise_std=0.1
).double()

model_hps0 = beer.models.GeneralizedSubspaceModel.create(
    PhoneLogLikelihood(nphones, dim), 
    M, 
    global_mean, 
    1.,
    1.,
    prior_strength=1,
    hyper_prior_strength=1,
    hessian_type='diagonal',
    noise_std=0.1
).double()

#optim_hps0 = beer.BayesianModelOptimizer([[model_hps0.mean, model_hps0.subspace, model_hps0.precision]], lrate=1.)
optim_hps0 = beer.BayesianModelOptimizer(model_hps0.mean_field_factorization(), lrate=1.)
optim_hps1 = beer.BayesianModelOptimizer(model_hps1.mean_field_factorization(), lrate=1.)
optim_hps2 = beer.BayesianModelOptimizer(model_hps2.mean_field_factorization(), lrate=1.)
optim_hps3 = beer.BayesianModelOptimizer(model_hps3.mean_field_factorization(), lrate=1.)


elbos_hps0 = []
elbos_hps1 = []
elbos_hps2 = []
elbos_hps3 = []
nepochs = 1000
epochs = list(range(nepochs))
for i in range(1, nepochs + 1):
    optim_hps0.init_step()
    optim_hps1.init_step()
    optim_hps2.init_step()
    optim_hps3.init_step()
    elbo_hps0 = beer.evidence_lower_bound(model_hps0, stats_counts.double())
    elbo_hps1 = beer.evidence_lower_bound(model_hps1, stats_counts.double())
    elbo_hps2 = beer.evidence_lower_bound(model_hps2, stats_counts.double())
    elbo_hps3 = beer.evidence_lower_bound(model_hps3, stats_counts.double())
    elbo_hps0.backward()
    elbo_hps1.backward()
    elbo_hps2.backward()
    elbo_hps3.backward()
    optim_hps0.step()
    optim_hps1.step()
    optim_hps2.step()
    optim_hps3.step()
    elbos_hps0.append(float(elbo_hps0) / float(model_hps0.llh_func.get_counts(stats_counts).sum()))
    elbos_hps1.append(float(elbo_hps1) / float(model_hps0.llh_func.get_counts(stats_counts).sum()))
    elbos_hps2.append(float(elbo_hps2) / float(model_hps0.llh_func.get_counts(stats_counts).sum()))
    elbos_hps3.append(float(elbo_hps3) / float(model_hps0.llh_func.get_counts(stats_counts).sum()))

In [139]:
source = ColumnDataSource({
    'epoch': epochs,
    'elbo_hps0': elbos_hps0,
    'elbo_hps1': elbos_hps1,
    'elbo_hps2': elbos_hps2,
    'elbo_hps3': elbos_hps3,
})

fig = figure(width=500, height=400)
fig.line(source=source, x='epoch', y='elbo_hps0', color='grey', 
         alpha=.75, line_width=2, legend='γ=1')
fig.line(source=source, x='epoch', y='elbo_hps1', color='blue', 
         alpha=.75, line_width=2, legend='γ=1e-1')
fig.line(source=source, x='epoch', y='elbo_hps2', color='green', 
         alpha=.75, line_width=2, legend='γ=1e-2')
fig.line(source=source, x='epoch', y='elbo_hps3', color='red', 
         alpha=.75, line_width=2, legend='γ=1e-3')
fig.legend.location = 'bottom_right'
fig.legend.label_text_font_size = '16pt'
fig.xaxis.major_label_text_font_size = '16pt'
fig.xaxis.axis_label_text_font_size = '16pt'
fig.yaxis.major_label_text_font_size = '16pt'
fig.yaxis.axis_label_text_font_size = '16pt'
fig.min_border_left = 20
fig.min_border_right = 20
fig.xgrid.visible = False
fig.ygrid.visible = False
fig.xaxis.axis_label = 'epoch'
fig.yaxis.axis_label = 'ELBO'
fig.xaxis.ticker = [0,  500, 1000]
show(fig)

In [153]:
def variability(model):
    scales = np.array(model.scale.expected_value())
    scales.sort()
    inv_scales = 1/scales
    #inv_scales /= inv_scales.sum()
    return inv_scales


fig = figure(width=400, height=400)
fig.line(range(1, 100 + 1), variability(model_hps0), 
         color='grey', alpha=.75, line_width=2, legend='γ=1')
fig.line(range(1, 100 + 1), variability(model_hps1), 
         color='blue', alpha=.75,line_width=2, legend='γ=1e-1')
fig.line(range(1, 100 + 1), variability(model_hps2), 
         color='green', alpha=.75,line_width=2, legend='γ=1e-2')
fig.line(range(1, 100 + 1), variability(model_hps3), 
         color='red', alpha=.75, line_width=2, legend='γ=1e-3')
fig.legend.label_text_font_size = '16pt'
fig.xaxis.major_label_text_font_size = '16pt'
fig.xaxis.axis_label_text_font_size = '16pt'
fig.yaxis.major_label_text_font_size = '16pt'
fig.yaxis.axis_label_text_font_size = '16pt'
fig.min_border_left = 20
fig.min_border_right = 20
fig.xgrid.visible = False
fig.ygrid.visible = False
fig.xaxis.axis_label = 'subspace dimension'
fig.yaxis.axis_label = 'E[|w^2]'
show(fig)

In [152]:
M = model_hps3.subspace.posterior.mean.numpy()
width, height = M.T.shape
fig = figure(width=1000, height=400, x_range=(0, width), y_range=(0, height))
fig.image(image=[M], x=0, y=0, dw=width, dh=height)
show(fig)

In [176]:
C = np.zeros((1100, 1100))
(C.size * C.itemsize) / 1_000_000

9.68

In [87]:
fig = figure(width=400, height=400)#, y_axis_type='log')
fig.line(range(1, len(scales) + 1), inv_scales_ninfp, color='red', line_width=2)
fig.line(range(1, len(scales) + 1), inv_scales_infp, color='blue', line_width=2)
fig.line(range(1, len(scales) + 1), inv_scales_infp_h1e2, color='green', line_width=2)
fig.line(range(1, len(scales) + 1), inv_scales_infp_h1, color='grey', line_width=2)
show(fig)

NameError: name 'scales' is not defined

In [67]:
#inv_scales_ninfp = inv_scales
#inv_scales_infp = inv_scales
#inv_scales_infp_h1e2 = inv_scales
inv_scales_infp_h1 = inv_scales

NameError: name 'inv_scales' is not defined

In [103]:
cache = model_hps3.latent_posteriors(stats_counts)
means = cache['latent_means'].numpy()

In [104]:
phone_id = dict(**hmm.start_pdf)
del phone_id['sil']
phone_id = [(key, val) for key, val in phone_id.items()]
phones = [phone for phone, _ in sorted(phone_id, key=lambda x: x[1])]
phones

vowels_semivowels = [
    phones.index('aa'),
    phones.index('ae'),
    phones.index('ah'),
    phones.index('ao'),
    phones.index('aw'),
    phones.index('ax'),
    phones.index('ay'),
    phones.index('eh'),
    phones.index('er'),
    phones.index('ey'),
    phones.index('ih'),
    phones.index('ix'),
    phones.index('iy'),
    phones.index('ow'),
    phones.index('oy'),
    phones.index('uh'),
    phones.index('uw'),
    phones.index('el'),
    phones.index('l'),
    phones.index('r'),
    phones.index('w'),
    phones.index('y'),
]

nasals_flaps = [
    phones.index('en'),
    phones.index('m'),
    phones.index('n'),
    phones.index('ng'),
    phones.index('dx')
]

strong_fricatives = [
    phones.index('s'),
    phones.index('z'),
    phones.index('sh'),
    phones.index('zh'),
    phones.index('ch'),
    phones.index('jh')
]

weak_fricatives = [
    phones.index('v'),
    phones.index('f'),
    phones.index('dh'),
    phones.index('th'),
    phones.index('hh')
]

stops = [
    phones.index('b'),
    phones.index('d'),
    phones.index('g'),
    phones.index('p'),
    phones.index('t'),
    phones.index('k')
]

closures = [
    phones.index('vcl'),
    phones.index('cl'),
    phones.index('epi'),
]

In [105]:
W = model_hps3.subspace.expected_value().numpy()
np.linalg.norm(W, axis=1).argsort()[::-1]

array([31, 42, 58, 87, 41, 44, 11, 10,  1, 93, 30, 12, 48, 52, 39, 74, 91,
       62, 56, 53, 71, 68, 65, 66, 67, 69, 70, 73, 72, 76, 77, 75, 13, 92,
       54, 14, 55, 26, 27, 57, 94, 96, 78, 95, 79,  0, 98, 28, 59, 15, 97,
       16, 99, 80, 18, 89, 81, 17, 40, 32, 20, 84, 35, 34, 64, 29, 61, 86,
        5, 19, 22,  4, 83, 88, 21, 82,  3, 38,  2, 36, 33,  7, 85, 37, 90,
       60, 25,  8,  9, 23,  6, 63, 24, 43, 45, 46, 51, 50, 47, 49])

In [107]:
l_cov = cache['latent_cov'].numpy()
l_means = cache['latent_means'].numpy()

In [108]:
phones = []
for key in hmm.start_pdf:
    if key != 'sil':
        phones.append(key)
        
idx0 = 31
idx1 = 42


fig = figure(width=800, height=800)

# vowels/semi-vowels
source_v_sv = ColumnDataSource(data={
    'x': means[vowels_semivowels, idx0],
    'y': means[vowels_semivowels, idx1],
    'names': np.array(phones)[vowels_semivowels]
})
for i in vowels_semivowels:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='red')
fig.cross(source=source_v_sv, x='x', y='y', color='red', legend='vowels/semi-vowels')
labels = LabelSet(source=source_v_sv, x='x', y='y', text='names',
                  x_offset=5, y_offset=5,  render_mode='canvas', text_color='red')
fig.add_layout(labels)

# nasals/flaps
source_nsl_flps = ColumnDataSource(data={
    'x': means[nasals_flaps, idx0],
    'y': means[nasals_flaps, idx1],
    'names': np.array(phones)[nasals_flaps]
})
for i in nasals_flaps:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='green')
fig.cross(source=source_nsl_flps, x='x', y='y', color='green', legend='nasal/flaps')
labels = LabelSet(x='x', y='y', text='names',
                  x_offset=5, y_offset=5, source=source_nsl_flps, render_mode='canvas', text_color='green')
fig.add_layout(labels)

# strong fricatives
source_s_frct = ColumnDataSource(data={
    'x': means[strong_fricatives, idx0],
    'y': means[strong_fricatives, idx1],
    'names': np.array(phones)[strong_fricatives]
})
for i in strong_fricatives:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='blue')
fig.cross(source=source_s_frct, x='x', y='y', color='blue', legend='strong fricatives')
labels = LabelSet(source=source_s_frct, x='x', y='y', text='names',
                  x_offset=5, y_offset=5, render_mode='canvas', text_color='blue')
fig.add_layout(labels)

# weak fricatives
source_w_frct = ColumnDataSource(data={
    'x': means[weak_fricatives, idx0],
    'y': means[weak_fricatives, idx1],
    'names': np.array(phones)[weak_fricatives]
})
for i in weak_fricatives:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='orange')
fig.cross(source=source_w_frct, x='x', y='y', color='orange', legend='weak fricatives')
labels = LabelSet(source=source_w_frct, x='x', y='y', text='names',
                  x_offset=5, y_offset=5, render_mode='canvas', text_color='orange')
fig.add_layout(labels)

# stops
source_stop = ColumnDataSource(data={
    'x': l_means[stops, idx0],
    'y': l_means[stops, idx1],
    'names': np.array(phones)[stops]
})
for i in stops:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='black')
fig.cross(source=source_stop, x='x', y='y', color='black', legend='stops')
labels = LabelSet(source=source_stop, x='x', y='y', text='names',
                  x_offset=5, y_offset=5, render_mode='canvas', text_color='black')
fig.add_layout(labels)

# closures
source_cls = ColumnDataSource(data={
    'x': means[closures, idx0],
    'y': means[closures, idx1],
    'names': np.array(phones)[closures]
})
for i in closures:
    mean = np.array([l_means[i, idx0], l_means[i, idx1]])
    cov = np.zeros((2, 2))
    cov [0, 0] = l_cov[idx0, idx0]
    cov [1, 0] = l_cov[idx0, idx1]
    cov [0, 1] = l_cov[idx0, idx1]
    cov [1, 1] = l_cov[idx1, idx1]
    plotting.plot_normal(fig, mean, cov, alpha=.1, n_std_dev=1, color='grey')
fig.cross(source=source_cls, x='x', y='y', color='grey', legend='closures')
labels = LabelSet(source=source_cls, x='x', y='y', text='names',
                  x_offset=5, y_offset=5, render_mode='canvas', text_color='grey')
fig.add_layout(labels)
fig.xgrid.visible = False
fig.ygrid.visible = False
fig.legend.location = 'bottom_right'
fig.legend.label_text_font_size = '16pt'
fig.xaxis.major_label_text_font_size = '16pt'
fig.xaxis.axis_label_text_font_size = '16pt'
fig.yaxis.major_label_text_font_size = '16pt'
fig.yaxis.axis_label_text_font_size = '16pt'
fig.min_border_left = 20
fig.min_border_right = 20


show(fig)

In [107]:
phones = []
for key in hmm.start_pdf:
    if key != 'sil':
        phones.append(key)
phones

['aa',
 'ae',
 'ah',
 'ao',
 'aw',
 'ax',
 'ay',
 'b',
 'ch',
 'cl',
 'd',
 'dh',
 'dx',
 'eh',
 'el',
 'en',
 'epi',
 'er',
 'ey',
 'f',
 'g',
 'hh',
 'ih',
 'ix',
 'iy',
 'jh',
 'k',
 'l',
 'm',
 'n',
 'ng',
 'ow',
 'oy',
 'p',
 'r',
 's',
 'sh',
 't',
 'th',
 'uh',
 'uw',
 'v',
 'vcl',
 'w',
 'y',
 'z',
 'zh']

In [32]:
class NormalLogLikekilhood:
    
    def __call__(self, psi_params, stats, counts):
        nparams = torch.zeros_like(psi_params)
        nparams[:, 0] = psi_params[:, 0]
        nparams[:, -1] = -torch.exp(psi_params[:, -1])
        return (nparams * stats).sum(dim=-1) - counts * self.log_norm(nparams)
    
    def log_norm(self, nparams):
        precs = 2 * torch.exp(nparams[:, -1])
        means = nparams[:, 0] / precs
        return .5 * (means ** 2) * precs - .5 * torch.log(precs) + math.log(2 * math.pi)
    
    def argmax(self, stats, counts):
        mT = stats / counts[:, None]
        means = mT[:, 0]
        precs = 1/(mT[:, -1] - means**2)
        return torch.cat([
            (precs * means)[:, None], 
            torch.log(.5 * precs)[:, None]], 
        dim=-1)

    def hessian(self, psi_params, stats, counts, mode='diagonal'):
        H = torch.zeros(len(psi_params), 2 * 2, dtype=psi_params.dtype, 
                        device=psi_params.device)
        
        e_psi1 = torch.exp(- psi_params[:, -1])
        H[:, 0] = -.5 * counts * e_psi1
        H[:, 1] = .5 * counts * psi_params[:, 0] * e_psi1
        H[:, 2] = .5 * counts * psi_params[:, 0] * e_psi1 
        H[:, 3] = -stats[:, 1] / e_psi1 -.25 * counts * ((psi_params[:, 0])**2) * e_psi1
        hessians = H.reshape(len(psi_params), 2, 2) 
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

In [33]:
class CategoricalTrueModel:
    
    def __init__(self, noise_samples=15):
        self.scale = 1
        self.weights = np.array([[0, 1]])
        self.mean = np.zeros((1, 2))
        self.noise = np.random.randn(noise_samples, 1) 
        self.noise_samples = noise_samples
    
    @property
    def subspace(self):
        weights = np.array(self.weights)
        #weights = weights / np.sqrt(np.sum(weights ** 2))
        h = np.linspace(-20, 20, 1000)
        return h[:, None] * weights[0] + self.mean
    
    @property
    def samples(self):
        return self.noise @ (self.scale * self.weights) + self.mean
    
    def link(self, x):
        return np.exp(x) / (1 +  np.exp(x).sum(axis=1)[:, None])
    
    def resample(self):
        self.noise = np.random.randn(self.noise_samples, 1) 
        
    def generate_data(self, N):
        pi = self.link(self.samples)
        K = self.samples.shape[0]
        pi = np.c_[pi, 1 - pi.sum(axis=-1)]
        data = [np.random.choice(pi.shape[-1], size=N, p=pi[i])
                for i in range(pi.shape[0])]
        data = torch.LongTensor(data)
        
        # Compute the accumulated sufficient statistics of the 
        # samples for each model.
        T = torch.cat([
            beer.utils.onehot(data[i], max_label=3, 
                              dtype=data.dtype, device=data.device)[:, :-1].sum(dim=0)[None]
            for i in range(data.shape[0])
        ]).float()
        counts = torch.FloatTensor([N] * len(T))

        T = torch.cat([T, counts[:, None]], dim=-1)
        return T
    
    
class CategoricalLogLikekilhood:
    
    def __call__(self, nparams, stats, counts):
        return (nparams * stats).sum(dim=-1) - counts * self.log_norm(nparams)
    
    def log_norm(self, nparams):
        return torch.log(1 + nparams.exp().sum(dim=-1))
    
    def argmax(self, stats, counts):
        pi = stats / counts[:, None]
        pi_K = pi.sum(dim=1)
        eps = torch.tensor(1e-6, dtype=stats.dtype,
                            device=stats.device)
        remainder = torch.max(eps, 1 - pi_K)
        return torch.log(eps + pi) - torch.log(remainder)[:, None]
    
    def hessian(self, nparams, stats, counts, mode='diagonal'):
        #mean = self.mean(nparams)
        mean = torch.exp(nparams) / (1 +  torch.exp(nparams).sum(dim=1)[:, None])
        I = torch.eye(nparams.shape[-1], dtype=nparams.dtype,
                      device=nparams.device)
        hessians = (1e-6 + I)[None] * mean[:, :, None]
        hessians -= (mean[:, :, None] * mean[:, None, :])
        hessians *= -counts[:, None, None]
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')

In [4]:
class Inference:
    
    def __init__(self, true_model, llh, data_size=100):
        self.true_model = true_model
        self.data_size = data_size
        self.model = None
        self.model_scalar = self.create_model(llh, 0)
        self.model_diagonal = self.create_model(llh, 1)
        self.model_full = self.create_model(llh, 2)
        self.data = self.generate_data()
        
    def create_model(self, llh, hessian_type=0):
        if hessian_type == 0:
            htype = 'scalar'
        elif hessian_type == 1:
            htype = 'diagonal'
        else:
            htype = 'full'
        dim = 2
        s_dim = 1
        with torch.no_grad():
            M = torch.zeros(s_dim, dim) 
            global_mean = torch.zeros(dim)
        model = beer.models.GeneralizedSubspaceModel.create(
            llh, 
            M, 
            global_mean, 
            hessian_type=htype,
            noise_std=0.1
        )
        return model
        
    @property
    def ml_points(self):
        return self.model_full.llh_func.argmax(self.data[:, :-1], self.data[:, -1]).numpy()
    
    @property
    def subspace_scalar(self):
        weights = self.model_scalar.subspace.expected_value().numpy()
        weights = weights / np.sqrt(np.sum(weights ** 2))
        mean = self.model_scalar.mean.expected_value().numpy()
        h = np.linspace(-100, 100, 1000)
        return h[:, None] * weights[0] + mean
    
    @property
    def subspace_diagonal(self):
        weights = self.model_diagonal.subspace.expected_value().numpy()
        weights = weights / np.sqrt(np.sum(weights ** 2))
        mean = self.model_diagonal.mean.expected_value().numpy()
        h = np.linspace(-100, 100, 1000)
        return h[:, None] * weights[0] + mean
    
    @property
    def subspace_full(self):
        weights = self.model_full.subspace.expected_value().numpy()
        weights = weights / np.sqrt(np.sum(weights ** 2))
        mean = self.model_full.mean.expected_value().numpy()
        h = np.linspace(-100, 100, 1000)
        return h[:, None] * weights[0] + mean

    def generate_data(self):
        true_model = self.true_model
        N = self.data_size
        return self.true_model.generate_data(N)

    def resample(self):
        self.data = self.generate_data()
        
    def run(self, callback=None):
        optim = beer.BayesianModelOptimizer([self.model_scalar.mean_field_factorization()[0] + \
                                             self.model_diagonal.mean_field_factorization()[0] +
                                             self.model_full.mean_field_factorization()[0]], lrate=1.)
        for i in range(50):
            optim.init_step()
            elbo_scalar = beer.evidence_lower_bound(self.model_scalar, self.data)
            elbo_scalar.backward()
            elbo_diagonal = beer.evidence_lower_bound(self.model_diagonal, self.data)
            elbo_diagonal.backward()
            elbo_full = beer.evidence_lower_bound(self.model_full, self.data)
            elbo_full.backward()
            optim.step()
            if callback is not None:
                norm = float(self.data[:, -1].sum())
                callback(float(elbo_scalar) / norm, float(elbo_diagonal) / norm, float(elbo_full) / norm)

## Example of subspace

In [5]:
class Animation:
    
    def __init__(self, model, inference, param_plot):
        self.model = model
        self.inference = inference
        self.mean_source = ColumnDataSource()
        self.subspace_source = ColumnDataSource()
        self.infer_subspace_source = ColumnDataSource()
        self.samples_source = ColumnDataSource()
        self.data_source = ColumnDataSource()
        self.elbo_source = ColumnDataSource()
        self.elbo_scalar = []
        self.elbo_diag = []
        self.elbo_full = []
        self.epochs = []
        self._current_weights = self.model.weights
        self.param_plot = param_plot
        self.update()
        self.update_inference()
        
    def update_inference(self):
        infer_subspace_full = self.inference.subspace_full
        infer_subspace_diagonal = self.inference.subspace_diagonal
        infer_subspace_scalar = self.inference.subspace_scalar
        infer_subspace_full_mean = self.model.link(infer_subspace_full)
        infer_subspace_diagonal_mean = self.model.link(infer_subspace_diagonal)
        infer_subspace_scalar_mean = self.model.link(infer_subspace_scalar)
        data = self.inference.ml_points
        data_mean = self.model.link(data)
        self.infer_subspace_source.data = {
            'subspace_full_η1': infer_subspace_full[:, 0],
            'subspace_full_η2': infer_subspace_full[:, 1],
            'subspace_full_π1': infer_subspace_full_mean[:, 0],
            'subspace_full_π2': infer_subspace_full_mean[:, 1],
            'subspace_diagonal_η1': infer_subspace_diagonal[:, 0],
            'subspace_diagonal_η2': infer_subspace_diagonal[:, 1],
            'subspace_diagonal_π1': infer_subspace_diagonal_mean[:, 0],
            'subspace_diagonal_π2': infer_subspace_diagonal_mean[:, 1],
            'subspace_scalar_η1': infer_subspace_scalar[:, 0],
            'subspace_scalar_η2': infer_subspace_scalar[:, 1],
            'subspace_scalar_π1': infer_subspace_scalar_mean[:, 0],
            'subspace_scalar_π2': infer_subspace_scalar_mean[:, 1],
            
        }
        self.data_source.data = {
            'samples_η1': data[:, 0],
            'samples_η2': data[:, 1],
            'samples_π1': data_mean[:, 0],
            'samples_π2': data_mean[:, 1],
        }
        self.elbo_source.data = {
            'epochs': self.epochs,
            'elbo_scalar': self.elbo_scalar,
            'elbo_diag': self.elbo_diag,
            'elbo_full': self.elbo_full,
        }
        
    def update(self):
        mean = self.model.link(self.model.mean)
        subspace = self.model.link(self.model.subspace)
        samples = self.model.link(self.model.samples)
        self.mean_source.data = {
            'mean_η1': self.model.mean[:, 0],
            'mean_η2': self.model.mean[:, 1],
            'mean_π1': mean[:, 0],
            'mean_π2': mean[:, 1]
        }
        self.subspace_source.data = {
            'subspace_η1': self.model.subspace[:, 0],
            'subspace_η2': self.model.subspace[:, 1],
            'subspace_π1': subspace[:, 0],
            'subspace_π2': subspace[:, 1],
        }
        self.samples_source.data = {
            'samples_η1': self.model.samples[:, 0],
            'samples_η2': self.model.samples[:, 1],
            'samples_π1': samples[:, 0],
            'samples_π2': samples[:, 1],
        }
        
    def on_tap(self, event):
        self.model.mean = np.array([[event.x, event.y]])
        self.update()
        
    def on_pan(self, event):
        dx = 0.01
        angle = -event.delta_x * dx * np.pi
        rotation = np.array([[np.cos(angle), -np.sin(angle)], 
                             [np.sin(angle), np.cos(angle)]])
        self.model.weights = self._current_weights @ rotation 
        self.update()
    
    def on_end_pan(self, event):
        self._current_weights = self.model.weights
        
    def resample(self):
        self.model.resample()
        self.update()
        
    def change_sample_size(self, attr, old, new):
        self.model.noise_samples = new
        self.model.resample()
        self.update()
        
    def train(self):
        def callback(elbo_scalar, elbo_diag, elbo_full):
            self.elbo_scalar.append(elbo_scalar)
            self.elbo_diag.append(elbo_diag)
            self.elbo_full.append(elbo_full)
            self.epochs.append(len(self.epochs))
            self.update_inference()
        self.inference.run(callback=callback)
        self.update_inference()
    
    def resample_data(self):
        self.inference.resample()
        self.update_inference()
        
    def change_data_size(self, attr, old, new):
        self.inference.data_size = new
        self.inference.resample()
        self.update_inference()
        
    def reinit(self):
        self.elbo_scalar = []
        self.elbo_diag = []
        self.elbo_full = []
        self.epochs = []
        llh = self.inference.model_scalar.llh_func
        self.inference.model_scalar = self.inference.create_model(llh, 0)
        self.inference.model_diagonal = self.inference.create_model(llh, 1)
        self.inference.model_full = self.inference.create_model(llh, 2)
        self.update_inference()
    
    def spread(self, attr, old, new): 
        self.model.scale = new
        self.update()
        
    def __call__(self, doc):
        x_plot_range=(-10, 10)
        y_plot_range=(-1, 1)
        
        
        ##########################################################################
        # True Model plots
        
        # Natural parameters plot.
        fig1 = figure(width=400, height=400, x_range=x_plot_range, 
                      y_range=y_plot_range)
        fig1.xaxis.axis_label = 'ψ1'
        fig1.xaxis.major_label_text_font_size = '12pt'
        fig1.xaxis.axis_label_text_font_size = '15pt'
        fig1.yaxis.axis_label = 'ψ2'
        fig1.yaxis.major_label_text_font_size = '12pt'
        fig1.yaxis.axis_label_text_font_size = '15pt'
        fig1.toolbar.active_drag = None
        fig1.toolbar_location = None
        fig1.line(source=self.subspace_source, x='subspace_η1', y='subspace_η2')
        fig1.cross(source=self.mean_source, x='mean_η1', y='mean_η2', size=10, 
                   line_width=2, color='red')
        fig1.circle(source=self.samples_source, x='samples_η1', y='samples_η2',
                    alpha=.3, color='blue')

        # Expected parameters plot.
        fig2 = self.param_plot()
        fig2.min_border_top = 20
        fig2.xaxis.axis_label_text_font_size = '15pt'
        fig2.xaxis.major_label_text_font_size = '12pt'
        fig2.yaxis.axis_label_text_font_size = '15pt'
        fig2.yaxis.major_label_text_font_size = '12pt'
        fig2.circle(source=self.samples_source, x='samples_π1', y='samples_π2', alpha=.3, 
                    color='blue')
        fig2.line(source=self.subspace_source, x='subspace_π1', y='subspace_π2')
        fig2.cross(source=self.mean_source, x='mean_π1', y='mean_π2', size=10, 
                   line_width=2, color='red')

        fig1.on_event(Tap, self.on_tap)
        fig1.on_event(Pan, self.on_pan)
        fig1.on_event(PanEnd, self.on_end_pan)
        
        # Interface.
        resample_button = Button(label='Resample', button_type='success')
        resample_button.on_click(self.resample)
        menu = [('Full', 'full'), ('Diagonal', 'diagonal'), ('Scalar', 'scalar')]
        hessian_select = Dropdown(label="Hessian matrix", button_type="success", menu=menu)
        sample_size_slider = Slider(start=10, end=100, value=20, step=1, title="number of samples")
        sample_size_slider.on_change('value', self.change_sample_size)
        
        ##########################################################################
        # Inference plots.
        
        # Natural parameters plot.
        fig3 = figure(width=400, height=400, x_range=x_plot_range, 
                      y_range=y_plot_range)
        fig3.xaxis.axis_label = 'ψ1'
        fig3.xaxis.major_label_text_font_size = '12pt'
        fig3.xaxis.axis_label_text_font_size = '15pt'
        fig3.yaxis.axis_label = 'ψ2'
        fig3.yaxis.major_label_text_font_size = '12pt'
        fig3.yaxis.axis_label_text_font_size = '15pt'
        fig3.toolbar.active_drag = None
        fig3.toolbar_location = None
        fig3.line(source=self.subspace_source, x='subspace_η1', y='subspace_η2', color='black', line_width=4, alpha=.3)
        fig3.line(source=self.infer_subspace_source, x='subspace_full_η1', y='subspace_full_η2', line_width=2, color='red')
        fig3.line(source=self.infer_subspace_source, x='subspace_diagonal_η1', y='subspace_diagonal_η2', line_width=2, color='blue')
        fig3.line(source=self.infer_subspace_source, x='subspace_scalar_η1', y='subspace_scalar_η2', line_width=2, color='green')
        fig3.circle(source=self.data_source, x='samples_η1', y='samples_η2',
                      alpha=.3, color='black')

        # Expected parameters plot.
        fig4 = self.param_plot()
        fig4.min_border_top = 20
        fig4.xaxis.axis_label_text_font_size = '15pt'
        fig4.xaxis.major_label_text_font_size = '12pt'
        fig4.yaxis.axis_label_text_font_size = '15pt'
        fig4.yaxis.major_label_text_font_size = '12pt'
        fig4.line(source=self.subspace_source, x='subspace_π1', y='subspace_π2', color='black', line_width=4, alpha=.3)
        fig4.circle(source=self.data_source, x='samples_π1', y='samples_π2', alpha=.3, 
                    color='black')
        fig4.line(source=self.infer_subspace_source, x='subspace_full_π1', y='subspace_full_π2', color='red', line_width=2, alpha=.5, legend='H_full')
        fig4.line(source=self.infer_subspace_source, x='subspace_diagonal_π1', y='subspace_diagonal_π2', color='blue', line_width=2, alpha=.5, legend='H_diag')
        fig4.line(source=self.infer_subspace_source, x='subspace_scalar_π1', y='subspace_scalar_π2', color='green', line_width=2, alpha=.5, legend='H_scalar')
        fig4.legend.location = 'top_right'
        #fig4.cross(source=self.mean_source, x='mean_π1', y='mean_π2', size=10, 
        #           line_width=2, color='red')
        
        fig5 = figure(width=400, height=400)
        fig5.xaxis.axis_label = 'epoch'
        fig5.xaxis.major_label_text_font_size = '12pt'
        fig5.xaxis.axis_label_text_font_size = '15pt'
        fig5.yaxis.axis_label = 'ELBO'
        fig5.yaxis.major_label_text_font_size = '12pt'
        fig5.yaxis.axis_label_text_font_size = '15pt'
        #fig5.toolbar.active_drag = None
        #fig5.toolbar_location = None
        fig5.line(source=self.elbo_source, x='epochs', y='elbo_scalar', color='green', line_width=2, alpha=.3)
        fig5.line(source=self.elbo_source, x='epochs', y='elbo_diag', color='blue', line_width=2, alpha=.3)
        fig5.line(source=self.elbo_source, x='epochs', y='elbo_full', color='red', line_width=2, alpha=.3)
        
        # Inference model interface
        resample_data_button = Button(label='Resample data', button_type='success')
        resample_data_button.on_click(self.resample_data)
        data_size_slider = Slider(start=10, end=1000, value=100, step=1, title='number of samples per model (i.e. "words per documents")')
        data_size_slider.on_change('value', self.change_data_size)
        spread_slider = Slider(start=1, end=10, value=1, step=.5, title="subspace scaling")
        spread_slider.on_change('value', self.spread)
        hessian_select = Dropdown(label="Hessian matrix", button_type="success", menu=menu)
        sample_size_slider = Slider(start=10, end=100, value=20, step=1, title='number of models (i.e. "number of documents")')
        sample_size_slider.on_change('value', self.change_sample_size)
        train_button = Button(label='Train', button_type='success')
        train_button.on_click(self.train)
        reinit_button = Button(label='Initialize model', button_type='warning')
        reinit_button.on_click(self.reinit)
        
        l = layout([
            [[sample_size_slider, data_size_slider, resample_button, spread_slider], 
            #[[fig1, fig2]], [resample_data_button, train_button]],
            [[fig1, fig2], [fig3, fig4, fig5]], [resample_data_button, train_button, reinit_button]],
        ])
        
        doc.add_root(l)

In [6]:
def parameter_plot_categorical():
    fig = figure(width=400, height=400, x_range=(0, 1), y_range=(0, 1))
    fig.line([1, 0], [0, 1], line_width=2, color='black')
    fig.xaxis.axis_label = 'π1'
    fig.yaxis.axis_label = 'π2'
    fig.toolbar_location = None
    fig.toolbar.active_drag = None
    return fig

def parameter_plot_normal():
    fig = figure(width=400, height=400, x_range=(-5, 5), 
                y_range=(0, 20))
    fig.xaxis.axis_label = 'μ'
    fig.yaxis.axis_label = 'λ'
    fig.toolbar_location = None
    fig.toolbar.active_drag = None
    return fig
    


true_model = NormalTrueModel(noise_samples=20)
inference = Inference(true_model, NormalLogLikekilhood())
handler = FunctionHandler(Animation(true_model, inference, parameter_plot_normal))

#true_model = CategoricalTrueModel(noise_samples=20)
#inference = Inference(true_model, CategoricalLogLikekilhood())
#handler = FunctionHandler(Animation(true_model, inference, parameter_plot_categorical))

app = Application(handler)
show(app, notebook_url="localhost:8888")

In [54]:
inference.run()

RuntimeError: The size of tensor a (2) must match the size of tensor b (20) at non-singleton dimension 1

In [52]:
T = inference.data
model = inference.model_full
params = model.llh_func.argmax(T[:, :-1], T[:, -1])
hessians = model.llh_func.hessian(params, T[:, :-1], T[:, -1], mode='full')
idxs = tuple(range(2))
hessians[:, idxs, idxs].min(dim=-1)[0]

tensor([ -1548.0605,   -208.7843, -13125.5518,  -1907.1984,  -1951.8234,
          -494.6999,  -4803.7285,   -430.4739,  -2989.7507, -18189.8555,
         -1167.2666,  -7906.4854,   -243.4854,  -1887.6621,  -2705.0090,
         -3577.3328,   -183.0521,   -199.0589, -28110.0430,  -6489.3906])

In [87]:
model = inference.model_scalar
weights = model.subspace.expected_value()
mean = model.mean.expected_value()
idx = 18
h_k = model.latent_posteriors(T)[idx].mean
h_k @ weights + mean, inference.ml_points[idx]

NameError: name 'T' is not defined

In [260]:
w = weights.numpy()
w = w / np.sqrt(np.sum(w ** 2))
np.linalg.norm(w)

1.0

In [10]:
model = inference.model_full
model.mean.posterior._log_norm()

tensor(62011.0977)

## Generate synthetic data

In [86]:
# Draw 100 values for each model.
N = 100
#true_model = CategoricalModel()
pi = true_model.link(true_model.samples)
K = true_model.samples.shape[0]
pi = np.c_[pi, 1 - pi.sum(axis=-1)]
data = [np.random.choice(pi.shape[-1], size=N, p=pi[i])
        for i in range(pi.shape[0])]
data = torch.LongTensor(data)

# Compute the accumulated sufficient statistics of the 
# samples for each model.
T = torch.cat([
    beer.utils.onehot(data[i], max_label=pi.shape[-1], 
                      dtype=data.dtype, device=data.device)[:, :-1].sum(dim=0)[None]
    for i in range(data.shape[0])
]).float()
counts = torch.FloatTensor([N] * len(T))

T = torch.cat([T, counts[:, None]], dim=-1)
print('# models:', K), T.shape

ValueError: probabilities are not non-negative

# Likelihood function

In [109]:
class CategoricalLogLikekilhood:
    
    def __call__(self, nparams, stats, counts):
        return (nparams * stats).sum(dim=-1) - counts * self.log_norm(nparams)
    
    def log_norm(self, nparams):
        return torch.log(1 + nparams.exp().sum(dim=-1))
    
    def argmax(self, stats, counts):
        pi = stats / counts[:, None]
        pi_K = pi.sum(dim=1)
        eps = torch.tensor(1e-6, dtype=stats.dtype,
                            device=stats.device)
        remainder = torch.max(eps, 1 - pi_K)
        return torch.log(eps + pi) - torch.log(remainder)[:, None]
    
    def mean(self, nparams):
        norm = (1 +  np.exp(nparams).sum(dim=-1))
        return nparams.exp() / norm[:, None]
    
    def hessian(self, nparams, counts, mode='diagonal'):
        mean = self.mean(nparams)
        I = torch.eye(nparams.shape[-1], dtype=nparams.dtype,
                      device=nparams.device)
        hessians = (1e-6 + I)[None] * mean[:, :, None]
        hessians -= (mean[:, :, None] * mean[:, None, :])
        hessians *= -counts[:, None, None]
        if mode == 'full':
            return hessians
        elif mode == 'diagonal':
            idxs = tuple(range(hessians.shape[-1]))
            return hessians[:, idxs, idxs]
        elif mode == 'scalar':
            dim = hessians.shape[-1]
            idxs = tuple(range(dim))
            return hessians[:, idxs, idxs].mean(dim=-1)
        else:
            raise ValueError(f'Unknown Hessian approximation mode: "{mode}"')
            
l = CategoricalLogLikekilhood()
opt = l.argmax(T[:, :-1], T[:, -1])
l.hessian(opt, T[:, -1])

tensor([[-0.0001, -5.6402],
        [-2.9101, -4.7502],
        [-0.0001, -5.6402],
        [-0.0001, -1.9602],
        [-0.0001, -4.7502],
        [-0.0001, -3.8402],
        [-0.0001, -1.9602],
        [-4.7501, -7.3602],
        [-0.0001, -3.8402],
        [-0.9901, -1.9602],
        [-0.0001, -3.8402],
        [-1.9601, -3.8402],
        [-0.0001, -1.9602],
        [-0.0001, -4.7502],
        [-0.0001, -5.6402],
        [-0.0001, -1.9602],
        [-0.0001, -3.8402],
        [-0.0001, -2.9102],
        [-0.9901, -5.6402],
        [-0.0001, -6.5102]])

## Prior/posterior distributions for the inference

In [115]:
dim = 2
s_dim = 3
M = torch.zeros(s_dim, dim) + math.sqrt(2)
#M = torch.from_numpy(true_model.weights).float()
global_mean = torch.zeros(dim)

model = beer.models.GeneralizedSubspaceModel.create(CategoricalLogLikekilhood(), M, global_mean, hessian_type='scalar')
model = model.double()

optim = beer.BayesianModelOptimizer([[model.subspace], [model.mean]], lrate=1.)

for i in range(200):
    optim.init_step()
    elbo = beer.evidence_lower_bound(model, T.double())
    print(f'{float(elbo) / float(T[:, -1].mean()):.3f}')
    elbo.backward()
    optim.step()

-20.228
-6.445
-5.043
-4.897
-4.821
-4.769
-4.735
-4.713
-4.698
-4.689
-4.681
-4.677
-4.672
-4.670
-4.667
-4.666
-4.663
-4.662
-4.660
-4.659
-4.657
-4.656
-4.654
-4.653
-4.651
-4.651
-4.649
-4.648
-4.646
-4.645
-4.644
-4.643
-4.641
-4.641
-4.639
-4.638
-4.637
-4.636
-4.635
-4.634
-4.633
-4.632
-4.631
-4.630
-4.629
-4.628
-4.627
-4.626
-4.625
-4.625
-4.623
-4.623
-4.622
-4.621
-4.620
-4.620
-4.619
-4.618
-4.617
-4.617
-4.616
-4.615
-4.614
-4.614
-4.613
-4.613
-4.612
-4.612
-4.611
-4.610
-4.609
-4.609
-4.608
-4.608
-4.607
-4.607
-4.606
-4.606
-4.605
-4.605
-4.604
-4.604
-4.603
-4.603
-4.602
-4.602
-4.602
-4.601
-4.601
-4.600
-4.600
-4.600
-4.599
-4.599
-4.598
-4.598
-4.598
-4.597
-4.597
-4.597
-4.596
-4.596
-4.596
-4.596
-4.595
-4.595
-4.594
-4.594
-4.594
-4.594
-4.593
-4.593
-4.593
-4.593
-4.592
-4.592
-4.592
-4.592
-4.591
-4.591
-4.591
-4.591
-4.591
-4.590
-4.590
-4.590
-4.590
-4.590
-4.589
-4.589
-4.589
-4.589
-4.589
-4.589
-4.588
-4.588
-4.588
-4.588
-4.588
-4.588
-4.587
-4.587
-4.58

In [11]:
model.subspace.posterior.moments(), moments_matrixnormal(model.subspace.posterior)

((tensor([[-0.0535,  1.2347]], dtype=torch.float64),
  tensor([[1.5361]], dtype=torch.float64)),
 (tensor([[-0.0535,  1.2347]], dtype=torch.float64),
  tensor([[1.5361]], dtype=torch.float64)))

In [10]:
l_posts = model.optimal_latent_posteriors(T.double())
opt = model.llh_func.argmax(T[:, :-1].double(), T[:, -1].double())
hessians = model.llh_func.hessian(opt, T[:, -1].double())
W = model.subspace.expected_value()
m = model.mean.expected_value()
opt @ W.t(), (opt - m)[0], hessians[0] @ (opt-m)[0]

fig = figure(width=400, height=400, x_range=(-5, 5), y_range=(-5, 5))
#fig.circle(opt[:, 0].numpy(), opt[:, 1].numpy())
for i in range(10):
    if i == 8 or True:
        fig.circle(opt[i:i+1, 0].numpy(), opt[i:i+1, 1].numpy())
        y = (W.t() @ l_posts[i].mean + m ).numpy()
        print(l_posts[i].cov)
        fig.circle(y[0], y[1], color='green')
show(fig)

tensor([[0.0721]], dtype=torch.float64)
tensor([[0.0492]], dtype=torch.float64)
tensor([[0.0460]], dtype=torch.float64)
tensor([[0.0378]], dtype=torch.float64)
tensor([[0.0381]], dtype=torch.float64)
tensor([[0.0390]], dtype=torch.float64)
tensor([[0.0609]], dtype=torch.float64)
tensor([[0.0385]], dtype=torch.float64)
tensor([[0.0397]], dtype=torch.float64)
tensor([[0.0417]], dtype=torch.float64)


In [18]:
W = model.subspace.expected_value()
l_posts = model.optimal_latent_posteriors(T.double())
l_posts

[NormalFullCovariancePrior(mean=tensor([-0.2671], dtype=torch.float64), cov=tensor([[0.0307]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([0.8820], dtype=torch.float64), cov=tensor([[0.0350]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([-1.2305], dtype=torch.float64), cov=tensor([[0.0519]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([-1.2958], dtype=torch.float64), cov=tensor([[0.0587]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([0.4197], dtype=torch.float64), cov=tensor([[0.0299]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([0.1172], dtype=torch.float64), cov=tensor([[0.0292]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([0.1530], dtype=torch.float64), cov=tensor([[0.0293]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([0.8916], dtype=torch.float64), cov=tensor([[0.0351]], dtype=torch.float64)),
 NormalFullCovariancePrior(mean=tensor([1.1981], dtype=torch.

In [208]:
A[0] @ x[0]
A[1] @ x[1]

tensor([-0.1655,  0.0510])

In [151]:
H = torch.cat([pdf.mean[None, :] for pdf in l_posts], dim=0)
HH = torch.cat([pdf.cov[None, :, :] for pdf in l_posts], dim=0)
HH.shape, H.shape

(torch.Size([15, 3, 3]), torch.Size([15, 3]))

In [163]:
stats  = T[:, :-1]
counts = T[:, -1]

opt = model.llh_func.argmax(stats, counts)
hessians = model.llh_func.hessian(opt)

In [164]:
U = model.subspace.posterior.cov

WHW = trace_array(hessians)[:, None, None] * U + W @ hessians @ W.t()
m_opt = (nparams - m[None])
Hm = (hessians * m_opt[:, None, :]).sum(dim=-1)
counts = stats.sum(dim=-1)
return torch.cat([
    (W @  Hm.t()).reshape(-1, len(hessians)).t(),
    -.5 * N * WHW.reshape(len(hessians), -1)
], dim=-1)

NameError: name 'nparams' is not defined

In [45]:
llh, acc_stats = model.expected_log_likelihood(T)
print(llh.sum() - 0 * latent_kl_div(model.latent_prior, model.latent_posteriors))

for i in range(10):
    prior_nparams = model.latent_prior.natural_parameters
    for i, latent_posterior in enumerate(model.latent_posteriors):
        grad = prior_nparams + acc_stats[i] - latent_posterior.natural_parameters
        latent_posterior.natural_parameters = latent_posterior.natural_parameters + 1 * grad

    llh, acc_stats = model.expected_log_likelihood(T)
    print(llh.sum() - 0 *latent_kl_div(model.latent_prior, model.latent_posteriors))

AttributeError: 'SubspaceCategorical' object has no attribute 'latent_posteriors'

In [332]:
nparams = inverse_link(T / N)
log_likelihood(nparams, T, N)

tensor([-102.1675,  -86.9528,  -98.6099,  -97.6429,  -99.7344,  -96.2011,
         -99.7344, -100.8999,  -97.3253,  -97.1200,  -99.5243,  -96.5711,
         -96.8510,  -99.8975,  -91.6465])

In [666]:
l_moments = [moments_normal(pdf) 
             for pdf in model.latent_posteriors]
H = torch.cat([l_m[0].reshape(-1, 1) 
               for l_m in l_moments], dim=-1)
HH = torch.cat([moments[1].reshape(-1, 1) 
                for moments in l_moments], dim=-1)
m, mm = moments_normal(model.mean_posterior)
W, WW = moments_matrixnormal(model.subspace_posterior)

nparams = (W.t() @ H + m[:, None]).t()
hessians = hessian(nparams)
W, WW = moments_matrixnormal(model.subspace_posterior)
U = (WW - W @ W.t()) / W.shape[-1]

def latent_posterior_accumulate(m, W, WW, hessians):
    U = (WW - W @ W.t()) / W.shape[-1]
    WHW = trace_array(hessians)[:, None, None] * U + W @ hessians @ W.t()
    Hm = hessians @ m
    
    return torch.cat([
        (W @ (T - Hm).t()).reshape(-1, len(hessians)).t(),
        -.5 * WHW.reshape(len(hessians), -1)
    ], dim=-1)

latent_posterior_accumulate(m, W, WW, hessians).shape

tensor([[0.3179, 0.3827],
        [0.3187, 0.3802],
        [0.3143, 0.3931],
        [0.3164, 0.3870],
        [0.3168, 0.3858],
        [0.3172, 0.3846],
        [0.3169, 0.3854],
        [0.3160, 0.3882],
        [0.3167, 0.3861],
        [0.3164, 0.3869],
        [0.3164, 0.3872],
        [0.3157, 0.3893],
        [0.3156, 0.3893],
        [0.3145, 0.3926],
        [0.3170, 0.3854]])


torch.Size([15, 2])

In [1188]:
W = np.array([[[1, 2], [2, 1]], [[2, 3], [3, 2]]])
x = np.array([[1, 2], [3, 4]])
W.shape, x.shape
(W * x[:, None, :]).sum(axis=-1), W[0] @ x[0], W[1] @ x[1]

(array([[ 5,  4],
        [18, 17]]), array([5, 4]), array([18, 17]))

## Likelihood function

In [77]:
class CategoricalLogLikelihood:
    
    def __init__(self, dim):
        self.dim = dim
        
    @staticmethod
    def sufficient_statistics(data):
        beer.utils.onehot
        
    def log_norm(self, natural_params):
        return torch.log(1 + torch.exp(natural_params).sum())
        
    def __call__(natural_params, s_stats):
        return natural_params @ s_stats - log_norm(natural_params)
    
llh = CategoricalLogLikelihood(data.shape[-1])
llh.log_norm()

TypeError: log_norm() missing 1 required positional argument: 'natural_params'

In [39]:
data.sum(dim=1)

tensor([0.9900, 1.0000, 1.0100, 1.0100, 1.0100, 1.0000, 1.0300, 1.0700, 0.9900,
        1.0100, 0.9900, 1.0000, 1.0200, 1.0000, 0.8700])

In [None]:
samples = np.random.multivariate_normal([2, 3], [[2, .75], [.75, 1]], 100)
X = torch.from_numpy(samples).float()


epochs = 100
lrate = 1.

optim = beer.BayesianModelCoordinateAscentOptimizer(model.mean_field_groups)
elbos = []
for epoch in range(epochs):
    optim.zero_grad()
    elbo = beer.evidence_lower_bound(model, X, datasize=len(X))
    elbo.natural_backward()
    elbos.append(float(elbo) / len(X))
    optim.step()

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(range(len(elbos)), elbos)
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
mean, precision = model.mean_precision.expected_value()
mean = mean.numpy()
cov = precision.inverse().numpy()
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
model.mean_precision.posterior.strength 

In [None]:
model.mean_precision.posterior.strength = 1.

In [None]:
mean, precision = model.mean_precision.expected_value()
mean = mean.numpy()
cov = precision.inverse().numpy()
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
model.mean_precision.posterior.strength 

In [None]:
class Normal(beer.BayesianModel):
    
    def __init__(self, prior_mean_precision, posterior_mean_precision):
        super().__init__()
        self.mean_precision = beer.BayesianParameter(prior_mean_precision, 
                                                     posterior_mean_precision)
        
    @property
    def dim(self):
        mean, _ = self.mean_precision.expected_value()
        return len(mean)
    
    def sufficient_statistics(self, data):
        return torch.cat([
            -.5 * data**2,
            data,
            -.5 * torch.ones(data.size(0), 1, dtype=data.dtype, device=data.device),
            .5 * torch.ones(data.size(0), 1, dtype=data.dtype, device=data.device),
        ], dim=-1)
    
    def mean_field_factorization(self):
        return [[self.mean_precision]]
        
    def expected_log_likelihood(self, stats):            
        nparams = self.mean_precision.expected_natural_parameters()
        return stats @ nparams  -.5 * self.dim * math.log(2 * math.pi)
    
    def accumulate(self, stats):
        return {self.mean_precision: stats.sum(dim=0)}

In [None]:
prior_mean_precision = beer.NormalGammaPrior(
    torch.zeros(2), 
    torch.tensor(1.), 
    torch.tensor(1.), 
    torch.ones(2)
)
posterior_mean_precision = beer.NormalGammaPrior(
    torch.zeros(2), 
    torch.tensor(1.), 
    torch.tensor(1.), 
    torch.ones(2)
)
model = Normal(prior_mean_precision, posterior_mean_precision)

In [None]:
samples = np.random.multivariate_normal([2, 3], [[2, .75], [.75, 1]], 100)
X = torch.from_numpy(samples).float()


epochs = 100
lrate = 1.

optim = beer.BayesianModelCoordinateAscentOptimizer(model.mean_field_groups)
elbos = []
for epoch in range(epochs):
    optim.zero_grad()
    elbo = beer.evidence_lower_bound(model, X, datasize=len(X))
    elbo.natural_backward()
    elbos.append(float(elbo) / len(X))
    optim.step()

fig = figure(title='ELBO', width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(range(len(elbos)), elbos)
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
model.mean_precision.posterior.to_std_parameters()

In [None]:
mean, precision = model.mean_precision.expected_value()
print(precision)
mean = mean.numpy()
cov = (1 / precision).diag().numpy()
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
model.mean_precision.posterior.strength

In [None]:
model.mean_precision.posterior.strength = .5

In [None]:
mean, precision = model.mean_precision.expected_value()
print(precision)
mean = mean.numpy()
cov = (1 / precision).diag().numpy()
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
model.mean_precision.posterior

In [None]:
class Normal(beer.BayesianModel):
    
    def __init__(self, prior_mean_precision, posterior_mean_precision):
        super().__init__()
        self.mean_precision = beer.BayesianParameter(prior_mean_precision, 
                                                     posterior_mean_precision)
        
    @property
    def dim(self):
        mean, _ = self.mean_precision.expected_value()
        return len(mean)
    
    def sufficient_statistics(self, data):
        dim = data.shape[1]
        return torch.cat([
            -.5 * torch.sum(data**2, dim=-1).reshape(-1, 1),
            data,
            -.5 * torch.ones(data.size(0), 1, dtype=data.dtype, device=data.device),
            .5 * dim * torch.ones(data.size(0), 1, dtype=data.dtype, device=data.device),
        ], dim=-1)
    
    def mean_field_factorization(self):
        return [[self.mean_precision]]
        
    def expected_log_likelihood(self, stats):  
        nparams = self.mean_precision.expected_natural_parameters()
        return stats @ nparams  -.5 * self.dim * math.log(2 * math.pi)
    
    def accumulate(self, stats):
        return {self.mean_precision: stats.sum(dim=0)}

In [None]:
prior_mean_precision = beer.IsotropicNormalGammaPrior(
    torch.zeros(2), 
    torch.tensor(1.), 
    torch.tensor(1.), 
    torch.tensor(1.)
)
posterior_mean_precision = beer.IsotropicNormalGammaPrior(
    torch.zeros(2), 
    torch.tensor(1.), 
    torch.tensor(1.), 
    torch.tensor(1.)
)
model = Normal(prior_mean_precision, posterior_mean_precision)

In [None]:
samples = np.random.multivariate_normal([2, 3], [[2., 0], [0, 2.]], 50)
X = torch.from_numpy(samples).float()
epochs = 10
lrate = 1.

optim = beer.BayesianModelCoordinateAscentOptimizer(model.mean_field_groups)
elbos = []
for epoch in range(epochs):
    optim.zero_grad()
    elbo = beer.evidence_lower_bound(model, X, datasize=len(X))
    elbo.natural_backward()
    elbos.append(float(elbo) / len(X))
    optim.step()

fig = figure(title='ELBO',width=400, height=400, x_axis_label='step',
              y_axis_label='ln p(X)')
fig.line(range(len(elbos)), elbos)
fig.legend.location = 'bottom_right'

show(fig)

In [None]:
mean, scale, shape, rate = model.mean_precision.posterior.to_std_parameters()
mean, scale, shape, rate

In [None]:
torch.log(torch.tensor(480.3992)), torch.digamma(torch.tensor(51.)) 

In [None]:
torch.sum(mean * mean)

In [None]:
model.mean_precision.posterior.expected_sufficient_statistics()

In [None]:
51 / 517

In [None]:
mean, precision = model.mean_precision.expected_value()
print(precision)
mean = mean.numpy()
cov = ((1 / precision) * torch.eye(2)).numpy() 
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
model.mean_precision.posterior.strength

In [None]:
model.mean_precision.posterior.strength = 1.

In [None]:
mean, precision = model.mean_precision.expected_value()
print(precision)
mean = mean.numpy()
cov = ((1 / precision) * torch.eye(2)).numpy()
data = X.numpy()

fig = figure(width=400, height=400,
    x_range=(mean[0] - 5, mean[0] + 5),
    y_range=(mean[1] - 5, mean[1] + 5)
)
fig.circle(data[:, 0], data[:, 1])
plotting.plot_normal(fig, mean, cov, alpha=.5, )

show(fig)

In [None]:
torch.ones(1).min()