In [5]:
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
import pandas as pd
import plotly.express as px

from pyro.infer import MCMC, NUTS, HMC
from pyro.infer.mcmc.util import initialize_model, summary

In [15]:
def generate(num_predictors, num_samples):
    """Generates samples from a simple poisson lognormal model (without noise)"""
    betas = pyro.sample('betas', dist.Normal(torch.zeros(num_predictors+1), 0.5 * torch.ones(num_predictors+1)))
    with pyro.plate('gen_data_plate', num_samples):
        noise = pyro.sample('noise', dist.Normal(0, 0.025))
        pred = pyro.sample('gen_predictors', dist.Uniform(torch.zeros((num_predictors, num_samples)), torch.ones((num_predictors, num_samples)) * 5))
        X = torch.cat((torch.ones(1, num_samples), pred), 0)
        thetas = betas @ X + noise
        accidents = pyro.sample('gen_accidents', dist.Poisson(torch.exp(thetas)))
    return betas, X, accidents

def prelim_model(num_predictors, num_observations, predictors, data):
    """Specifies the prior for a simple poisson lognormal model (without noise)"""
    sigma_squared = 1.0/pyro.sample('tau_squared', dist.Chi2(10))
    betas = pyro.sample('betas', dist.Normal(torch.zeros(num_predictors+1), 10 * torch.ones(num_predictors+1)))
    with pyro.plate('observation_plate', num_observations):
        noise = pyro.sample('noise', dist.Normal(0, 1))
        thetas = betas @ predictors + noise
        accidents = pyro.sample('accidents', dist.Poisson(torch.exp(thetas)), obs=data)
    return accidents
    

In [16]:
num_samples = 1000
num_predictors = 10

In [17]:
betas, X, obs = generate(num_predictors, num_samples)
nuts_kernel = NUTS(prelim_model, jit_compile=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=1000, num_chains=1, mp_context = "spawn")
mcmc.run(num_predictors, num_samples, X, obs)


  result = torch.tensor(0., device=self.device)
Sample: 100%|██████████| 2000/2000 [06:14,  5.34it/s, step size=6.49e-02, acc. prob=0.922]


In [18]:
print(betas)
mcmc.summary(prob=0.95)
print(obs)

tensor([ 0.6999, -0.1162,  0.3481, -0.4646,  0.2658, -0.0985, -0.0074, -0.7256,
         0.4661,  0.0450, -0.1655])

                   mean       std    median      2.5%     97.5%     n_eff     r_hat
  tau_squared      9.99      4.32      9.51      2.30     18.19   1599.79      1.00
     betas[0]      0.01      0.27      0.00     -0.56      0.48    306.27      1.00
     betas[1]     -0.09      0.04     -0.09     -0.15     -0.02    262.11      1.00
     betas[2]      0.40      0.04      0.40      0.32      0.46    328.25      1.00
     betas[3]     -0.50      0.04     -0.50     -0.57     -0.43    311.73      1.00
     betas[4]      0.33      0.03      0.33      0.26      0.39    411.64      1.00
     betas[5]     -0.11      0.04     -0.11     -0.18     -0.04    405.07      1.00
     betas[6]      0.06      0.03      0.06     -0.01      0.12    346.84      1.02
     betas[7]     -0.80      0.04     -0.80     -0.87     -0.72    305.95      1.00
     betas[8]      0.53      0.03      0.53