# 6 Interactive posterior predictives checks

In [5]:
import os
from ipywidgets import interact, FloatSlider

import torch
import pyro
import pyro.distributions as D
import pyro.poutine as poutine
from pyro.infer.resampler import Resampler
import matplotlib.pyplot as plt

In [6]:
def model(T: int=1000, data=None):
    df = pyro.sample('df', D.LogNormal(0, 1))
    pscale = pyro.sample('pscale', D.LogNormal(0, 1))
    mscale = pyro.sample('mscale', D.LogNormal(0, 1))
    
    with pyro.plate('dt', T):
        process_noise = pyro.sample('process_noise', D.StudentT(df, 0, pscale))
    trend = pyro.deterministic('trend', process_noise.cumsum(-1))
    with pyro.plate('t', T):
        return pyro.sample('obs', D.Normal(trend, mscale), obs=data)

In [7]:
def plot_trajectory(df=1.0, pscale=1.0, mscale=1.0):
    pyro.set_rng_seed(12345)
    data = {
        'df': torch.as_tensor(df),
        'pscale': torch.as_tensor(pscale),
        'mscale': torch.as_tensor(mscale)
    }
    trajectory = poutine.condition(model, data)()
    plt.figure(figsize=(8, 4)).patch.set_color('white')
    plt.plot(trajectory)
    plt.xlabel('time')
    plt.ylabel('obs')

In [8]:
interact(
    plot_trajectory,
    df=FloatSlider(value=1.0, min=0.01, max=10.0),
    pscale=FloatSlider(value=0.1, min=0.01, max=1.0),
    mscale=FloatSlider(value=1.0, min=0.01, max=10.0),

)

interactive(children=(FloatSlider(value=1.0, description='df', max=10.0, min=0.01), FloatSlider(value=0.1, des…

<function __main__.plot_trajectory(df=1.0, pscale=1.0, mscale=1.0)>

In [9]:
def model2(T=1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):
    df = pyro.sample('df', D.LogNormal(df0, df1))
    pscale = pyro.sample('pscale', D.LogNormal(p0, df1))
    mscale = pyro.sample('mscale', D.LogNormal(m0, m1))
    
    with pyro.plate('dt', T):
        process_noise = pyro.sample('process_noise', D.StudentT(df, 0, pscale))
    trend = pyro.deterministic('trend', process_noise.cumsum(-1))
    with pyro.plate('t', T):
        return pyro.sample('obs', D.Normal(trend, mscale), obs=data)

In [11]:
def plot_trajectories(**kwargs):
    pyro.set_rng_seed(12345)
    with pyro.plate('trajectories', 20, dim=-2):
        trajectories = model2(**kwargs)
    plt.figure(figsize=(8, 4)).patch.set_color('white')
    plt.plot(trajectories.T)
    plt.xlabel('time')
    plt.ylabel('obs')

In [12]:
interact(
    plot_trajectories,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
)

interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…

<function __main__.plot_trajectories(**kwargs)>

In [13]:
def make_partial_model(df0, df1, p0, p1, m0, m1):
    def partial_model():
        # Sample parameters from the prior.
        pyro.sample("df", D.LogNormal(df0, df1))
        pyro.sample("p_scale", D.LogNormal(p0, p1))  # process noise
        pyro.sample("m_scale", D.LogNormal(m0, m1))  # measurement noise
    return partial_model

In [14]:
partial_guide = make_partial_model(0, 10, 0, 10, 0, 10)
resampler = Resampler(partial_guide, model, num_guide_samples=10000)

In [15]:
def plot_resampled(df0, df1, p0, p1, m0, m1):
    partial_model = make_partial_model(df0, df1, p0, p1, m0, m1)
    samples = resampler.sample(partial_model, num_samples=20)
    trajectories = samples["obs"]
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")

In [16]:
interact(
    plot_resampled,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
)

interactive(children=(FloatSlider(value=0.0, description='df0', max=5.0, min=-5.0), FloatSlider(value=1.0, des…

<function __main__.plot_resampled(df0, df1, p0, p1, m0, m1)>