In [38]:
import pyro
import pyro.distributions as dist
import torch
import pandas as pd
import plotly.express as px

from pyro.infer import MCMC, NUTS, HMC
from pyro.infer.mcmc.util import initialize_model, summary

In [90]:
def generate(num_predictors, num_samples):
    """Generates samples from a simple poisson lognormal model (without noise)"""
    #TODO: Add noise
    betas = pyro.sample('betas', dist.Normal(torch.zeros(num_predictors+1), 0.25 * torch.ones(num_predictors+1)))
    
    with pyro.plate('gen_data_plate', num_samples):
        with pyro.plate('gen_predictors_plate', num_predictors) as ind:
            pred = pyro.sample('gen_predictors', dist.Uniform(0, 5))
            X = torch.cat((torch.ones(1, num_samples), pred), 0)
            thetas = betas @ X
        accidents = pyro.sample('gen_accidents', dist.Poisson(torch.exp(thetas)))
    return betas, X, accidents

def prelim_model(num_predictors, num_observations, predictors, data):
    """Specifies the prior for a simple poisson lognormal model (without noise)"""
    #TODO: Add noise
    betas = pyro.sample('betas', dist.Normal(torch.zeros(num_predictors+1), 10 * torch.ones(num_predictors+1)))
    thetas = betas @ predictors
    with pyro.plate('observation_plate', num_observations):
        accidents = pyro.sample('accidents', dist.Poisson(torch.exp(thetas)), obs=data)
    return accidents
    

In [None]:
num_samples = 1000
num_predictors = 10

In [91]:
betas, X, obs = generate(num_predictors, num_observations)
nuts_kernel = NUTS(prelim_model, jit_compile=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=1000, num_chains=1, mp_context = "spawn")
posterior = mcmc.run(num_predictors, num_observations, X, obs)


  result = torch.tensor(0., device=self.device)
Sample: 100%|██████████| 2000/2000 [02:07, 15.64it/s, step size=1.22e-01, acc. prob=0.935]


In [92]:
mcmc.summary(prob=0.95)
print(betas)
print(obs)


                mean       std    median      2.5%     97.5%     n_eff     r_hat
  betas[0]      0.31      0.13      0.31      0.07      0.58    612.22      1.00
  betas[1]     -0.34      0.02     -0.34     -0.38     -0.31   1283.14      1.00
  betas[2]     -0.28      0.02     -0.28     -0.31     -0.25   1156.22      1.00
  betas[3]     -0.12      0.01     -0.12     -0.15     -0.09   1301.30      1.00
  betas[4]      0.38      0.02      0.38      0.35      0.41   1212.24      1.00
  betas[5]     -0.03      0.01     -0.03     -0.06     -0.00   1458.41      1.00
  betas[6]     -0.12      0.01     -0.12     -0.14     -0.08   1318.94      1.00
  betas[7]     -0.23      0.02     -0.23     -0.25     -0.20   1564.98      1.00
  betas[8]      0.27      0.02      0.27      0.24      0.30   1159.60      1.00
  betas[9]     -0.13      0.02     -0.13     -0.16     -0.10   1232.97      1.00
 betas[10]      0.49      0.02      0.49      0.46      0.53   1632.94      1.00

Number of divergences: 0
t