#### This file implements the classic coupled sticks instantiation of the HDP-HMM (Teh's construction)

In [1]:
import pyro
import torch
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS

import sys

In [2]:
sys.path.append('code')

from generative_process import HMM

In [3]:
def model(T, ncat=10, gamma_init=1, alpha_init=1, alpha_0_init=1, state_prior = None):

    # gamma1 only for learning, always set to 1 in the HDP literature
    gamma1 = pyro.param("gamma1", torch.ones((ncat)), constraint=dist.constraints.positive)

    # gamma2 corresponds to gamma in the HDP literature
    gamma2 = pyro.param("gamma2", torch.ones((ncat))*gamma_init, constraint=dist.constraints.positive)

    # alpha are the parameters from which the atoms phi are drawn
    alpha = pyro.param("alpha", torch.ones((ncat, ncat))*alpha_init, constraint=dist.constraints.positive)

    # alpha_0 parameterizes pi
    alpha_0_1 = pyro.param("alpha_0_1", torch.ones((ncat))*alpha_0_init, constraint=dist.constraints.positive)
    alpha_0_2 = pyro.param("alpha_0_2", torch.ones((ncat))*alpha_0_init, constraint=dist.constraints.positive)

    # set prior over states to either input value
    if state_prior is not None:
        prior_states = pyro.param("prior_states", torch.tensor(state_prior), constraint=dist.constraints.interval(0,1))
    # or make uniform
    else:
        prior_states = pyro.param("prior_states", torch.ones((ncat))/ncat, constraint=dist.constraints.interval(0,1))


    # draw atoms phi
    phi = pyro.sample("phi", dist.Dirichlet(alpha))

    # draw beta_primes
    beta_prime = pyro.sample("beta_prime", dist.Beta(gamma1, gamma2))

    # construct betas
    beta = torch.tensor([beta_prime[k]*torch.prod(1-beta_prime[0:k]) for k in range(ncat)])

    # construct pi_primes
    exponent1 = alpha_0_1 * beta
    # print("beta")
    # print(beta)
    # note to self: check whether sum is l<k or l>k
    sum_betas = torch.tensor([0.] + [(beta[0:k]).sum() for k in range(1,ncat)])
    exponent2 = alpha_0_2 * torch.tensor([1 - (beta[0:k]).sum() for k in range(ncat)])
    # print("sum beta")
    # print(torch.tensor([(beta[0:k]).sum() for k in range(ncat)]))
    # print("1-sum beta")
    # print(torch.tensor([1 - (beta[0:k]).sum() for k in range(ncat)]))

    pi_prime = []
    pi = []
    for j in range(ncat):
        pi_prime.append(pyro.sample("pi_prime_"+str(j), dist.Beta(exponent1, exponent2)))

        pi.append(torch.tensor([pi_prime[-1][k]*torch.prod(1-pi_prime[-1][0:k-1]) for k in range(ncat)]))


    # start the HMM
    # init its variables and probs
    prob_states = []
    states = []
    prob_observations = []
    observations = []

    # first state!
    prob_states.append(prior_states)
    # print("prob states 0")
    # print(prob_states[-1])
    states.append(pyro.sample("state_0", dist.Categorical(probs=prob_states[-1])))
    # print("state 0")
    # print(states[-1])

    # first observation
    # print("phi")
    # print(phi[:,states[-1]])
    prob_observations.append(phi[:,states[-1]])
    # print("prob observations 0")
    # print(prob_observations[-1])
    observations.append(pyro.sample("observation_0", dist.Categorical(probs=prob_observations[-1])))

    for t in range(1,T):

        # transition state
        prob_states.append(pi[states[-1]])
        # print("trans matrix")
        # print(pi)
        # print("trans vector")
        # print(pi[states[-1]])
        # print("prob states "+str(t))
        # print(prob_states[-1])
        # generate state
        states.append(pyro.sample("state_"+str(t), dist.Categorical(probs=prob_states[-1])))
        print("state "+str(t))
        print(states[-1])

        # observation prob
        prob_observations.append(phi[:,states[-1]])
        observations.append(pyro.sample("observation_"+str(t), dist.Categorical(probs=prob_observations[-1])))

        return observations


In [4]:
# make data!

ns = 5
no = ns

# observations are either the state itself or its neighbors with lower probabilities

phi = torch.tensor([[0.70, 0.15, 0.00, 0.00, 0.15],
                    [0.15, 0.70, 0.15, 0.00, 0.00],
                    [0.00, 0.15, 0.70, 0.15, 0.00],
                    [0.00, 0.00, 0.15, 0.70, 0.15],
                    [0.15, 0.00, 0.00, 0.15, 0.70]])

pi = torch.tensor( [[0.40, 0.60, 0.00, 0.00, 0.00],
                    [0.00, 0.40, 0.60, 0.00, 0.00],
                    [0.00, 0.00, 0.40, 0.60, 0.00],
                    [0.00, 0.00, 0.00, 0.40, 0.60],
                    [0.60, 0.00, 0.00, 0.00, 0.40]])

prior_states = torch.tensor([0.10, 0.00, 0.00, 0.00, 0.00])

In [5]:
flow = HMM()

flow.set_parameters(pi, phi, prior_states)

T = 10

flow.simulate_timeseries(T)

In [6]:
data_dict = {}
for t in range(T):
    data_dict["observation_"+str(t)] = torch.tensor(flow.observations)[t]

def conditioned_model(T):
    print("i am condition")
    return poutine.condition(model, data=data_dict)(T)

In [7]:
num_samples = 100
warmup_steps = 20
num_chains = 2

# see https://pyro.ai/examples/mcmc.html
nuts_kernel = NUTS(conditioned_model)
mcmc = MCMC(
    nuts_kernel,
    num_samples=num_samples,
    warmup_steps=warmup_steps,
    num_chains=num_chains,
)
mcmc.run(model, T, data_dict)
mcmc.summary(prob=0.5)

Warmup [1]:   0%|          | 0/120 [00:00, ?it/s]

Warmup [2]:   0%|          | 0/120 [00:00, ?it/s]

KeyboardInterrupt: 