In [1]:
import torch
import pyro

In [2]:
def link_function_1(theta, X):
    K = theta
    return K

def link_function_2(K):
    k = K
    return k

In [3]:
def model(data):

    theta_dim = 4

    # priors
    theta = pyro.sample('theta', pyro.distributions.Normal(torch.zeros(theta_dim), torch.ones(theta_dim)))

    # likelihood
    with pyro.plate('data', data.shape[0]):
        K = link_function_1(theta, data)
        k = link_function_2(K)
        pyro.sample('obs', pyro.distributions.Normal(k, 1), obs=data)

    return theta

In [4]:
data = torch.tensor([0., 100., 0., 100.])
n_samples = 1000

# Importance Sampling

In [5]:
pyro.clear_param_store()
pyro.set_rng_seed(0)
# inference with importance sampling
importance = pyro.infer.Importance(model, guide=None, num_samples=n_samples)

print("doing importance sampling...")
emp_marginal = pyro.infer.EmpiricalMarginal(importance.run(data))

posterior_mean = emp_marginal.mean
posterior_std_dev = emp_marginal.variance.sqrt()

# report results
print(posterior_mean)
print(posterior_std_dev)
print("done.")

doing importance sampling...
tensor([0.4249, 3.0748, 0.1199, 1.4926])
tensor([4.3226e-16, 1.9483e-16, 9.8360e-17, 1.6342e-17])
done.


# MCMC

In [6]:
pyro.clear_param_store()
pyro.set_rng_seed(0)

nuts_kernel = pyro.infer.NUTS(model, jit_compile=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=n_samples)
mcmc.run(data)
samples = mcmc.get_samples()
samples_mean = samples['theta'].mean(axis=0)
samples_std_dev = samples['theta'].std(axis=0)

print(samples_mean)
print(samples_std_dev)

  result = torch.tensor(0.0, device=self.device)
Sample: 100%|██████████| 2000/2000 [00:05, 371.64it/s, step size=6.95e-01, acc. prob=0.933]

tensor([ 4.4310e-02,  5.0033e+01, -1.5608e-02,  5.0019e+01])
tensor([0.6693, 0.6948, 0.7145, 0.6786])





# HMC

In [7]:
pyro.clear_param_store()
pyro.set_rng_seed(0)

hmc_kernel = pyro.infer.HMC(model, jit_compile=True, step_size=0.0855, num_steps=4)
mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=n_samples, warmup_steps=100)
mcmc.run(data)
samples = mcmc.get_samples()

samples_mean = samples['theta'].mean(axis=0)
samples_std_dev = samples['theta'].std(axis=0)

print(samples_mean)
print(samples_std_dev)

  result = torch.tensor(0.0, device=self.device)
Sample: 100%|██████████| 1100/1100 [00:00, 1395.59it/s, step size=8.04e-01, acc. prob=0.941]

tensor([ 1.7838e-02,  5.0043e+01, -4.4580e-02,  4.9956e+01])
tensor([0.7182, 0.6969, 0.6691, 0.7018])





# SMC

In [25]:
# Create model and guide for Sequential Monte Carlo
pyro.clear_param_store()
pyro.set_rng_seed(0)


class Model(object):
    def __init__(self):
        self.theta_dim = 4

    def init(self, state, initial):
        print(initial)
        self.t = 0
        state['theta'] = pyro.sample('theta_init', pyro.distributions.Normal(initial, torch.ones(len(initial))))

    def step(self, state, y=None):
        self.t += 1
        K = link_function_1(state['theta'], y)
        k = link_function_2(K)
        state['theta'] = pyro.sample("theta_{}".format(self.t), pyro.distributions.Normal(k, 1), obs=y)

        return state['theta']

class Guide(object):
    def __init__(self, model):
        self.model = model
        self.theta_dim = 4

    def init(self, state, initial):
        print(initial)
        self.t = 0
        pyro.sample('theta_init', pyro.distributions.Normal(initial, torch.ones(len(initial))))

    def step(self, state, y=None):
        self.t += 1
        K = link_function_1(state['theta'], y)
        k = link_function_2(K)
        pyro.sample("theta_{}".format(self.t), pyro.distributions.Normal(k, 1))

In [35]:
pyro.set_rng_seed(0)

model = Model()
guide = Guide(model)

smc = pyro.infer.SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)

data = torch.tensor([[0.]*100, [30.]*100, [0.]*100, [10.]*100])

smc.init(initial=torch.tensor([0.]*100))

for y in data[1:, :]:
    print(y)
    smc.step(y)

print("At final time step:")
theta = smc.get_empirical()["theta"]
print("mean: {}".format(theta.mean))
print("std: {}".format(theta.variance ** 0.5))


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0.])
tensor([30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30., 30.,
        30., 30., 30., 30., 30., 30., 30