In [None]:
# Add the path of the beer source code ot the PYTHONPATH.
from collections import defaultdict
import random
import sys
sys.path.insert(0, '../')
from beer import __init__

import copy

import beer
import numpy as np
import torch

# For plotting.
from bokeh.io import show, output_notebook, export_png
from bokeh.plotting import figure, gridplot
from bokeh.models import LinearAxis, Range1d
output_notebook()

# Convenience functions for plotting.
import plotting

### Set GPU

In [None]:
torch.cuda.set_device("cuda:3")
torch.cuda.current_device()
# device = torch.device("cuda:5")
# xxx.to(device)

### Synthetic Data

In [None]:
import synthetic_data
data, states = synthetic_data.generate_sequential_data()

### Construct Graph
This graph describe the transformation of hidden state.

In [None]:
graph = beer.graph.Graph()

# Initial and final state are non-emitting.
s0 = graph.add_state()
s4 = graph.add_state()
graph.start_state = s0
graph.end_state = s4

s1 = graph.add_state(pdf_id=0)
s2 = graph.add_state(pdf_id=1)
s3 = graph.add_state(pdf_id=2)
graph.add_arc(s0, s1) # default weight=1
graph.add_arc(s1, s1)
graph.add_arc(s1, s2)
graph.add_arc(s2, s2)
graph.add_arc(s2, s3)
graph.add_arc(s3, s3)
graph.add_arc(s3, s1)
graph.add_arc(s1, s4)
graph.add_arc(s2, s4)
graph.add_arc(s3, s4)

graph.normalize()
graph

In [None]:
cgraph = graph.compile()
cgraph.final_log_probs

### Pretrain HMM 

In [None]:
# We use the global mean/cov. matrix of the data to initialize the mixture.
data_mean = torch.from_numpy(data.mean(axis=0)).float()
data_var = torch.from_numpy(np.cov(data.T)).float()

trans_mat = np.array([[.5, .5, 0], [0, .5, .5], [.5, 0, .5]])
init_states = torch.LongTensor([0])
final_states = torch.LongTensor([2])
transitions = torch.from_numpy(trans_mat).float()

# HMM (full cov).
modelset = beer.NormalSet.create(data_mean, data_var, size=len(transitions),
                                prior_strength=1., noise_std=0, 
                                cov_type='full')
hmm_full = beer.HMM.create(cgraph, modelset).double().cuda()

hmm_full

In [None]:
modelset.bayesian_parameters

In [None]:
modelset.conjugate_bayesian_parameters

In [None]:
epochs = 30
lrate = 1.
X = torch.from_numpy(data).cuda()

optim = beer.VBConjugateOptimizer(hmm_full.mean_field_factorization(), lrate)

elbos = []

for epoch in range(epochs):
    optim.init_step()
    elbo = beer.evidence_lower_bound(hmm_full, X, datasize=len(X), viterbi=False)
    elbo.backward()
    elbos.append(float(elbo) / len(X)) 
    optim.step()

In [None]:
fig = figure()
fig.line(range(len(elbos)), elbos)
show(fig)

In [None]:
fig = figure(width=250, height=250)
fig.circle(data[:, 0], data[:, 1], alpha=.1)
# plotting.plot_hmm(fig, hmm_full, alpha=.3, colors=['blue', 'red', 'green'])
show(fig)

### Train VAE

In [None]:
encoder = beer.nnet.ResidualFeedForwardNet(dim_in=2, nblocks=2, block_width=2)
decoder = beer.nnet.ResidualFeedForwardNet(dim_in=2, nblocks=2, block_width=2)
vae = beer.VAE(hmm_full, encoder, decoder).double().cuda()

In [None]:
# small number of epochs for testing
epochs = 1000
update_prior_after_epoch = 50
prior_lrate = 1.
cjg_optim = beer.VBConjugateOptimizer(vae.mean_field_factorization(), lrate=0)
std_optim = torch.optim.Adam(vae.parameters(), lr=1e-3)
optim = beer.VBOptimizer(cjg_optim, std_optim)

In [None]:
elbos = []
for e in range(epochs):
    optim.init_step()
    elbo = beer.evidence_lower_bound(vae, X, nsamples=5, kl_weight=0.8)
    elbo.backward()
    optim.step()
    if e >= update_prior_after_epoch:
        cjg_optim.lrate = prior_lrate
    if e % 50 == 0:
        elbos.append(float(elbo) / len(X))
    if e % 50 == 0:
        post = vae.posteriors(X)
        fig = figure(title='', width=250, height=250)
        m = post.params.mean.data.clone().cpu().numpy()
        fig.circle(m[:, 0], m[:, 1], alpha=.1)
        plotting.plot_hmm(fig, vae.prior, alpha=.3, colors=['blue', 'red', 'green'])
        show(fig)

In [None]:
fig = figure()
fig.line(range(len(elbos[:])), elbos[:])
show(fig)
elbos[-1]

In [None]:
fig = figure(title='', width=250, height=250)
post = vae.posteriors(X)
m = post.sample(1).data.clone().cpu().numpy().reshape(-1, 2)
# m = post.params.mean.data.clone().cpu().numpy()
fig.circle(m[:, 0], m[:, 1], alpha=.1)
plotting.plot_hmm(fig, vae.prior.cpu(), alpha=.3, colors=['blue', 'red', 'green'])
show(fig)

In [None]:
fig = figure(title='', width=250, height=250)
post = vae.posteriors(X)
m = post.params.mean.data.clone().cpu().numpy()
fig.circle(m[:, 0], m[:, 1], alpha=.1)
show(fig)

In [None]:
fig = figure(title='', width=250, height=250)
post = vae.posteriors(X)
m = post.sample(1).data.clone().cpu().numpy().reshape(-1, 2)
fig.circle(m[:, 0], m[:, 1], alpha=.1)
show(fig)

In [None]:
fig = figure(title='', width=250, height=250)
post = vae.posteriors(X)
h = post.sample(1).reshape(-1, 2)
pdf = vae.pdfs(h)
m = pdf.params.mean.data.clone().cpu().numpy()
fig.circle(m[:, 0], m[:, 1], alpha=.1)
show(fig)

In [None]:
fig = figure(title='', width=250, height=250)
post = vae.posteriors(X)
h = post.sample(1).reshape(-1, 2)
pdf = vae.pdfs(h)
m = pdf.sample(1).data.clone().cpu().numpy().reshape(-1, 2)
fig.circle(m[:, 0], m[:, 1], alpha=.1)
show(fig)

In [None]:
state_seq = vae.prior.decode(post.params.mean.clone().cpu())
torch.eq((state_seq)%3, torch.from_numpy(states))

### Save Model

In [None]:
torch.save(vae.state_dict(), 'hmm-vae-cuda.pkl')

### Load Model

In [None]:
vae.load_state_dict(torch.load('hmm-vae-cuda.pkl'))