In [1]:
%reset -s -f

In [None]:
import os
from functools import partial
import logging

import numpy as np
from numpy.random import binomial
from numpy.random import normal

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch 
from torch import nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.nn import PyroSample
from pyro.nn import PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI
from pyro.infer import Trace_ELBO
from pyro.infer import Predictive

In [None]:
%matplotlib inline

In [None]:
assert pyro.__version__.startswith('1.3')
pyro.enable_validation(True)
pyro.set_rng_seed(42)
pyro.enable_validation(True)


### Generate a set of randomly distributed features

In [None]:
y = normal(0.7, 0.2, 1000)

In [None]:
# lets create an artificial dataset from randomly sampled points
data_df = pd.DataFrame({"F1": binomial(1, 0.1, 1000), 
                       "F2": binomial(1, 0.2, 1000),
                        "F3": binomial(1, 0.1, 1000),
                       "F4": binomial(1, 0.3, 1000), 
                       "F5": binomial(1, 0.5, 1000),
                       "F6": binomial(1, 0.2, 1000), 
                       "F7": binomial(1, 0.3, 1000),
                       "F8": binomial(1, 0.5, 1000),
                       "F9": binomial(1, 0.1, 1000),
                       "F10": binomial(1, 0.1, 1000),
                       "EXPLAIN": y})
data_df

In [None]:
# make categorical data plottable
melted_df = pd.melt(data_df, id_vars="EXPLAIN")
melted_df = melted_df[melted_df["value"] != 0][["EXPLAIN", "variable"]]

In [None]:
sns.catplot(x="variable", y="EXPLAIN", kind="boxen", palette="dark",
            data=melted_df)

In [None]:
# make the data PyTorch objects
X = torch.tensor(data_df[["F{}".format(i+1) for i in range(10)]].values, dtype=torch.float)
y = torch.tensor(data_df.EXPLAIN.values, dtype=torch.float)

INPUT_DIM = 10
OUTPUT_DIM = 1
ITERATIONS = 1500

X

In [None]:
class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0.,1.)
                                         .expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.)
                                      .expand([out_features]).to_event(1))
    
    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

In [None]:
# instantiate leaerning objects
bayes_reg_model = BayesianRegression(INPUT_DIM, OUTPUT_DIM)
guide = AutoDiagonalNormal(bayes_reg_model)

adam_opt = pyro.optim.Adam({"lr": 0.03})
svi = SVI(bayes_reg_model, guide, adam_opt, loss=Trace_ELBO())

In [None]:
pyro.clear_param_store()
for i in range(ITERATIONS):
    # make inference step
    loss = svi.step(X, y)
    # observe testing
    if i % 100 == 0:
        print("[iteration {}] loss: {}".format(i, loss / len(X)))

In [None]:
# inspect learned parameters
guide.requires_grad_(False)
print("Learned parameters:")
for name, param in pyro.get_param_store().items():
    print(name, pyro.param(name))
    
print("Guide parameters:")
print(guide.quantiles([0.25, 0.5, 0.75]))

In [None]:
def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v,0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats


In [None]:
predict = Predictive(bayes_reg_model, guide=guide, num_samples=800,
                    return_sites=("linear.weight", "obs", "_RETURN"))
samples = predict(X)
pred_summary = summary(samples)

In [None]:
mu = pred_summary["_RETURN"]
y_hat = pred_summary["obs"]
predictions = pd.DataFrame({
    "F1": data_df.F1,
    "F2": data_df.F2,
    "F3": data_df.F3,
    "F4": data_df.F4,
    "F5": data_df.F5,
    "F6": data_df.F6,
    "F7": data_df.F7,
    "F8": data_df.F8,
    "F9": data_df.F9,
    "F10": data_df.F10,
    "mu_mean": mu["mean"],
    "mu_p95": mu["95%"],
    "mu_p5": mu["5%"],
    "y_mean": y_hat["mean"],
    "y_p95": y_hat["95%"],
    "y_p5": y_hat["5%"],
    "Y": y,
})
predictions

In [None]:
melted_pred = pd.melt(predictions, id_vars=["mu_mean", "y_mean", "Y"], 
        value_vars=["F{}".format(i+1) for i in range(10)])
melted_mu_pred_df = melted_pred[melted_pred["value"] != 0][["mu_mean", "variable"]]
melted_y_pred_df = melted_pred[melted_pred["value"] != 0][["y_mean", "variable"]]

sns.catplot(x="variable", y="mu_mean", kind="violin", bw=.2, palette="Set2", data=melted_mu_pred_df)
sns.catplot(x="variable", y="y_mean", kind="violin", bw=.2, palette="Set2", data=melted_y_pred_df)

In [None]:
# plot uncertainty
weights = samples["linear.weight"]
weights = weights.reshape(weights.shape[0], 10)
gamma_1 = weights[:, 1]
gamma_12 = weights[:,1] + weights[:, 2]
gamma_123 = weights[:,1] + weights[:, 2] + weights[:, 3]

fig = plt.figure(figsize=(10,6))
sns.distplot(gamma_1, kde_kws={"label": "Feat 1"})
sns.distplot(gamma_12, kde_kws={"label": "Intersection 1,2"})
sns.distplot(gamma_123, kde_kws={"label": "Intersection 1,2,3"})
fig.suptitle("Density of Fit")

## Let's make the inference more explicit

In [None]:
# write model without PyroModule usage
def model(F1, F2, F3, EXPLAIN):
    """
    Model to take into account F1-3 and their interaction with explicit priors.
    """
    assert len(F1) == len(F2) == len(F3) == len(EXPLAIN)
    f = pyro.sample("f", dist.Normal(0., 10.))
    b_f1 = pyro.sample("bF1", dist.Normal(0., 1.))
    b_f2 = pyro.sample("bF2", dist.Normal(0., 1.))
    b_f3 = pyro.sample("bF3", dist.Normal(0., 1.))
    b_f12 = pyro.sample("bF12", dist.Normal(0., 1.))
    b_f13 = pyro.sample("bF13", dist.Normal(0., 1.))
    b_f23 = pyro.sample("bF23", dist.Normal(0., 1.))
    b_f123 = pyro.sample("bF123", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    mean = f + b_f1*F1 + b_f2*F2 + b_f3*F3 + b_f12*F1*F2 + b_f13*F1*F3 + b_f23*F2*F3 + b_f123*F1*F2*F3
    # step over independent events
    with pyro.plate("data", len(F1)):
        pyro.sample("obs", dist.Normal(mean, sigma), obs=EXPLAIN)
        
def guide(F1, F2, F3, EXPLAIN):
    f_loc = pyro.param("f_loc", torch.tensor(0.))
    f_scale = pyro.param("f_scale", torch.tensor(1.), 
                        constraint=constraints.positive)
    sigma_loc = pyro.param("sigma_loc", torch.tensor(1.), 
                        constraint=constraints.positive)
    weights_loc = pyro.param("weights_loc", torch.randn(7))
    weights_scale = pyro.param("weights_scale", torch.ones(7),
                              constraint=constraints.positive)
    
    f = pyro.sample("f", dist.Normal(f_loc, f_scale))
    b_f1 = pyro.sample("bF1", dist.Normal(weights_loc[0], weights_scale[0]))
    b_f2 = pyro.sample("bF2", dist.Normal(weights_loc[1], weights_scale[1]))
    b_f3 = pyro.sample("bF3", dist.Normal(weights_loc[2], weights_scale[2]))
    b_f12 = pyro.sample("bF12", dist.Normal(weights_loc[3], weights_scale[3]))
    b_f13 = pyro.sample("bF13", dist.Normal(weights_loc[4], weights_scale[4]))
    b_f23 = pyro.sample("bF23", dist.Normal(weights_loc[5], weights_scale[5]))
    b_f123 = pyro.sample("bF123", dist.Normal(weights_loc[6], weights_scale[6]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
    mean = f + b_f1*F1 + b_f2*F2 + b_f3*F3 + b_f12*F1*F2 + b_f13*F1*F3 + b_f23*F2*F3 + b_f123*F1*F2*F3
    
    

In [None]:
def summary(samples):
    """
    Utility function to summarise the fit
    """
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [None]:
svi = SVI(model, guide, optim.Adam({'lr': 0.05}), loss=Trace_ELBO())

f1, f2, f3 = X[:,1], X[:,2], X[:,3]

pyro.clear_param_store()
ITERATION = 500

for i in range(ITERATIONS):
    elbo = svi.step(f1, f2, f3, y)
    if i % 100 == 0:
        print("[STEP {}] Elbo loss: {}".format(i, elbo))