# Bayesian Kriging to Predict Ground Water Levels

In [None]:
from __future__ import absolute_import, division, print_function

import os
import json
import pyro
import torch
import logging
import numpy as np
import pandas as pd
import seaborn as sns
import pyro.optim as optim
import pyro.contrib.gp as gp
import matplotlib.pyplot as plt
import pyro.distributions as dist

from torch.distributions import constraints

from functools import partial
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
from mpl_toolkits.mplot3d import Axes3D
from pyro.contrib.autoguide import AutoMultivariateNormal
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, JitTrace_ELBO

pyro.set_rng_seed(0)

In [None]:
%matplotlib inline
logging.basicConfig(format="%(message)s", level=logging.INFO)

# Enable validation checks
pyro.enable_validation(True)
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith("0.4.1")

In [None]:
pyro.set_rng_seed(1)

In [None]:
use_cuda = True

### Helper Functions

In [None]:
def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    dist = torch.clamp(dist, 0.0, np.inf)
    
    return dist

In [None]:
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [None]:
def visualize_data(
    plot_observed_data=False,
    plot_predictions=False,
    n_prior_samples=0,
    model=None,
    kernel=None,
    n_test=500,
):

    ax = plt.figure().add_subplot(111, projection='3d')
    
    if plot_observed_data:
        ax.scatter(XW[:, 0], XW[:, 1], YW, marker="x")
        ax.scatter(XF[:, 0], XF[:, 1], YF, marker="o")
        
    if plot_predictions:
        Xtest_ltd = torch.linspace(0, 5, n_test)
        Xtest_lng = torch.linspace(0, 10, n_test)
        
        Xtest_ltd, Xtest_lng = np.meshgrid(Xtest_ltd, Xtest_lng)
        
        Xtest_ltd = np.expand_dims(Xtest_ltd, -1)
        Xtest_lng = np.expand_dims(Xtest_lng, -1)
        
        Xtest = np.concatenate([Xtest_ltd, Xtest_lng], -1).reshape((-1, 2))
        
        with torch.no_grad():
            if type(model) == gp.models.VariationalSparseGP:
                mean, cov = model(Xtest, full_cov=True)
            else:
                mean, cov = model(Xtest, full_cov=True, noiseless=False)
                
        print(mean, cov)
                
        sd = cov.diag().sqrt()  # standard deviation at each input point x
        plt.plot(Xtest.numpy(), mean.numpy(), "r", lw=2)  # plot the mean
        plt.fill_between(
            Xtest,  # plot the two-sigma uncertainty about the mean
            (mean - 2.0 * sd).numpy(),
            (mean + 2.0 * sd).numpy(),
            color="C0",
            alpha=0.3,
        )
        
    if n_prior_samples > 0:  # plot samples from the GP prior
        Xtest_ltd = torch.linspace(0, 5, n_test)
        Xtest_lng = torch.linspace(0, 10, n_test)
        
        Xtest_ltd, Xtest_lng = np.meshgrid(Xtest_ltd, Xtest_lng)
        
        Xtest_ltd = np.expand_dims(Xtest_ltd, -1)
        Xtest_lng = np.expand_dims(Xtest_lng, -1)
        
        Xtest = np.concatenate([Xtest_ltd, Xtest_lng], -1).reshape((-1, 2))
        Xtest = torch.tensor(Xtest)
        
        noise = (
            model.noise
            if type(model) != gp.models.VariationalSparseGP
            else model.likelihood.variance
        )
        cov = kernel.forward(Xtest) + noise.expand(n_test).diag()
        samples = dist.MultivariateNormal(
            torch.zeros(n_test), covariance_matrix=cov
        ).sample(sample_shape=(n_prior_samples,))
        plt.plot(Xtest.numpy(), samples.numpy().T, lw=2, alpha=0.4)

#     plt.xlim(-0.5, 5.5)

In [None]:
def visualize_posterior(samples):
    import math
    
    sites = list(samples.keys())
    
    r = int(math.ceil(math.sqrt(len(samples))))
    fig, axs = plt.subplots(nrows=r, ncols=r, figsize=(12, 10))
    fig.suptitle("Marginal Posterior Density", fontsize=16)
    
    
    for i, ax in enumerate(axs.reshape(-1)):
        site = sites[i]
        sns.distplot(samples[site], ax=ax)
        ax.set_title(site)
        
    handles, labels = ax.get_legend_handles_labels()

### Loading Data

In [None]:
data = pd.read_csv("data/sample_data.csv", encoding="ISO-8859-1")

data_wells = data[data.type == "well"]
data_farms = data[data.type == "farm"]

XW = torch.FloatTensor(data_wells[["latitude", "longitude"]].values)
YW = torch.FloatTensor(data_wells["observation"].values)

XF = torch.FloatTensor(data_farms[["latitude", "longitude"]].values)
YF = torch.FloatTensor(data_farms["observation"].values)

In [None]:
# plot(plot_observed_data=True)
# plt.show()

In [None]:
if use_cuda:
    XW = XW.cuda()
    YW = YW.cuda()

    XF = XF.cuda()
    YF = YF.cuda()

# Bayesian Kernel Regression

## Defining the Model

In [None]:
def model_kr(XW, YW, use_cuda=False):
    mu_delta = torch.Tensor([1.0]).type_as(XW)
    mu_theta = torch.Tensor([0.0]).type_as(XW)

    delta = pyro.sample("delta", dist.LogNormal(mu_delta, 0.5))
    theta = pyro.sample("theta", dist.LogNormal(mu_theta, 0.5))

    sigma_mu = torch.Tensor([2.0]).type_as(XW)
    mu = pyro.sample("mu", dist.Normal(0, sigma_mu))

    a_sigma = torch.FloatTensor([1.0]).type_as(XW)
    b_sigma = torch.FloatTensor([1.0]).type_as(XW)
    sigma = pyro.sample("sigma", dist.Gamma(a_sigma, b_sigma))
    mean = (
        mu
        - (YF * torch.exp(-pairwise_distances(XW, XF) / theta)).sum(1) * delta
    )

    with pyro.plate("data", len(YW)):
        y = pyro.sample("obs", dist.Normal(mean, sigma), obs=YW)

    return y

In [None]:
def guide_kr(XW, YW):
    mu_delta = pyro.param("mu_delta", torch.ones(1).type_as(XW), constraint=constraints.positive)
    mu_theta = pyro.param("mu_theta", torch.zeros(1).type_as(XW), constraint=constraints.positive)
    
    sg_delta = pyro.param("sg_delta", torch.ones(1).type_as(XW), constraint=constraints.positive)
    sg_theta = pyro.param("sg_theta", torch.ones(1).type_as(XW), constraint=constraints.positive)
    
    delta = pyro.sample("delta", dist.Normal(mu_delta, sg_delta))
    theta = pyro.sample("theta", dist.Normal(mu_theta, sg_theta))
    
#     mu_sigma = pyro.param("mu_sigma", torch.tensor(1.0).type_as(XW), constraint=constraints.positive)
#     sigma = pyro.sample("sigma", dist.Gamma(mu_sigma, torch.tensor(1.0).type_as(XW)))
    
    mu_mu = pyro.param("mu_mu", torch.randn(1).type_as(XW))
    sg_mu = pyro.param("sg_mu", torch.ones(1).type_as(XW), constraint=constraints.positive)
    mu = pyro.sample("mu", dist.Normal(mu_mu, sg_mu))
    
    mean = mu - delta * (YF * torch.exp(-pairwise_distances(XW, XF) / theta)).sum(1)

In [None]:
def predict_kr(XW, posterior_samples):
    ps = posterior_samples
    samples = zip(ps["delta"], ps["theta"], ps["mu"], ps["sigma"])
    
    for delta, theta, mu, sigma in samples:
        mean = mu - (YF * torch.exp(-pairwise_distances(XW, XF) / theta)).sum(1) * delta
        yield dist.Normal(mean, sigma).sample()

## Inference

#### HMC

In [None]:
nuts_kernel = NUTS(partial(model_kr, use_cuda=use_cuda))

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=2000)
mcmc_run = mcmc.run(XW, YW)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

In [None]:
visualize_posterior(hmc_samples)

#### SVI

In [None]:
svi = SVI(
    partial(model_kr, use_cuda=use_cuda), guide_kr, optim.Adam({"lr": 0.005}), loss=JitTrace_ELBO(), num_samples=1000
)

pyro.clear_param_store()
num_iters = 10000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(XW, YW)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

## Prediction

In [None]:
results = torch.stack(list(predict_kr(XW, hmc_samples)))

In [None]:
ax = plt.figure(figsize=(15, 15)).add_subplot(111, projection='3d')

ax.scatter(XW[:, 0], XW[:, 1], YW, marker="x")
ax.scatter(XF[:, 0], XF[:, 1], YF, marker="^")

# for result in results[:100]:
#     ax.scatter(XW[:, 0], XW[:, 1], result, marker=".", color="red")
    
results_mean = results.mean(dim=0).numpy()
results_std = results.std(dim=0).numpy()

ax.scatter(XW[:, 0], XW[:, 1], results_mean, marker="o", color="red", s=15)

ax.scatter(XW[:, 0], XW[:, 1], results_mean - results_std, marker="_", color="green", s=15)
# for i, point in enumerate(XW.numpy().tolist()):
#     ax.plot(point + [results_mean[i] + results_std[i]], point + [results_mean[i] - results_std[i]])
ax.scatter(XW[:, 0], XW[:, 1], results_mean + results_std, marker="_", color="green", s=15)
    
plt.savefig("predictions.png", dpi=240)
plt.show()

In [None]:
results.mean(0)

In [None]:
YW

# Gaussian Processes

## Defining the Model

In [None]:
def model_gp(XW, YW):
    mu_delta = torch.Tensor([1.0]).type_as(XW)
    
    delta = pyro.sample("delta", dist.LogNormal(mu_delta, 0.5))
    
    mu_theta_f = torch.Tensor([0.0]).type_as(XW)
    mu_theta_w = torch.Tensor([0.0]).type_as(XW)
    
    theta_f = pyro.sample("theta_f", dist.LogNormal(mu_theta_f, 0.5))
    theta_w = pyro.sample("theta_w", dist.LogNormal(mu_theta_w, 0.5))
    
    sigma_mu = torch.Tensor([1.0]).type_as(XW)
    mu = pyro.sample("mu", dist.Normal(0, sigma_mu))
    
    sigma = torch.exp(-pairwise_distances(XW, XW) / theta_w)
    mean = mu - delta * (YF * torch.exp(-pairwise_distances(XW, XF) / theta_f)).sum(1)
    
    with pyro.plate("data", len(YW)):
        y = pyro.sample("obs", dist.MultivariateNormal(mean, sigma), obs=YW)
    
    return y

In [None]:
def guide_gp(XW, YW):
    mu_delta = pyro.param("mu_delta", torch.Tensor([1.0]).type_as(XW), constraint=constraints.positive)
    sg_delta = pyro.param("sg_delta", torch.Tensor([1.0]).type_as(XW), constraint=constraints.positive)
    
    delta = -pyro.sample("delta", dist.LogNormal(mu_delta, sg_delta))
    
    mu_theta_f = pyro.param("mu_theta_f", torch.Tensor([0.0]).type_as(XW), constraint=constraints.positive)
    mu_theta_w = pyro.param("mu_theta_w", torch.Tensor([0.0]).type_as(XW), constraint=constraints.positive)
    
    sg_theta_f = pyro.param("sg_theta_f", torch.Tensor([1.0]).type_as(XW), constraint=constraints.positive)
    sg_theta_w = pyro.param("sg_theta_w", torch.Tensor([1.0]).type_as(XW), constraint=constraints.positive)
    
    theta_f = pyro.sample("theta_f", dist.LogNormal(mu_theta_f, sg_theta_f))
    theta_w = pyro.sample("theta_w", dist.LogNormal(mu_theta_w, sg_theta_w))
    
    mu_mu = pyro.param("mu_mu", torch.randn(1).type_as(XW))
    sg_mu = pyro.param("sg_mu", torch.Tensor([1.0]).type_as(XW), constraint=constraints.positive)
    
    mu = pyro.sample("mu", dist.Normal(mu_mu, sg_mu))

In [None]:
def predict_gp(XW, posterior_samples):
    ps = posterior_samples
    samples = zip(ps["delta"], ps["theta_f"], ps["theta_w"], ps["mu"])
    
    pdx = pairwise_distances(XW).cpu().numpy()
    pdf = pairwise_distances(XW, XF).cpu().numpy()
    
    YF_ = YF.cpu().numpy()
    
    for delta, theta_f, theta_w, mu in samples:
        sigma = np.exp(-pdx / theta_w)
        mean = mu - delta * (YF_ * np.exp(-pdf / theta_f)).sum(1)
        
        yield np.random.multivariate_normal(mean, sigma)

## Inference

#### HMC

In [None]:
try:
    with open("data/gp-samples.json", "r") as f:
        hmc_samples = {k: np.array(v) for k, v in json.load(f).items()}
        
except:
    nuts_kernel = NUTS(model_gp)

    mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=4000)
    mcmc_run = mcmc.run(XW, YW)

    hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}c

In [None]:
hmc_samples_ = {k: v.tolist() for k, v in hmc_samples.items()}
with open("data/gp-samples.json", "w") as f:
    json.dump(hmc_samples_, f)

In [None]:
for site, values in summary(hmc_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

In [None]:
visualize_posterior(hmc_samples)

In [None]:
results = np.array(list(predict_gp(XW, hmc_samples)))

In [None]:
ax = plt.figure(figsize=(15, 15)).add_subplot(111, projection='3d')

ax.scatter(XW_[:, 0], XW_[:, 1], YW_, marker="x")
ax.scatter(XF_[:, 0], XF_[:, 1], YF_, marker="^")

# for result in results[:100]:
#     ax.scatter(XW[:, 0], XW[:, 1], result, marker=".", color="red")

results_mean = results.mean(axis=0)
results_std = results.std(axis=0)

ax.scatter(XW_[:, 0], XW_[:, 1], results_mean, marker="o", color="red", s=15)

ax.scatter(XW_[:, 0], XW_[:, 1], results_mean - results_std, marker="_", color="green", s=15)
# for i, point in enumerate(XW.numpy().tolist()):
#     ax.plot(point + [results_mean[i] + results_std[i]], point + [results_mean[i] - results_std[i]])
ax.scatter(XW_[:, 0], XW_[:, 1], results_mean + results_std, marker="_", color="green", s=15)
    
plt.savefig("predictions.png", dpi=240)
plt.show()

#### SVI

In [None]:
svi = SVI(
    model_gp, guide_gp, optim.Adagrad({"lr": 0.01}), loss=JitTrace_ELBO(), num_samples=1000
)

pyro.clear_param_store()
num_iters = 10000 if not smoke_test else 2
for i in range(num_iters):
    elbo = svi.step(XW, YW)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

In [None]:
svi_diagnorm_posterior = svi.run(XW, YW)

In [None]:
sites = ["delta", "theta_f", "theta_w", "mu"]

svi_samples = {
    site: EmpiricalMarginal(svi_diagnorm_posterior, sites=site)
    .enumerate_support()
    .detach()
    .cpu()
    .numpy()
    for site in sites
}

for site, values in summary(svi_samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

In [None]:
results = np.array(list(predict_gp(XW, svi_samples)))

In [None]:
XW_ = XW.cpu().numpy()
YW_ = YW.cpu().numpy()

XF_ = XF.cpu().numpy()
YF_ = YF.cpu().numpy()

In [None]:
ax = plt.figure(figsize=(15, 15)).add_subplot(111, projection='3d')

ax.scatter(XW_[:, 0], XW_[:, 1], YW_, marker="x")
ax.scatter(XF_[:, 0], XF_[:, 1], YF_, marker="^")

# for result in results[:100]:
#     ax.scatter(XW[:, 0], XW[:, 1], result, marker=".", color="red")

results_mean = results.mean(axis=0)
results_std = results.std(axis=0)

ax.scatter(XW_[:, 0], XW_[:, 1], results_mean, marker="o", color="red", s=15)

ax.scatter(XW_[:, 0], XW_[:, 1], results_mean - results_std, marker="_", color="green", s=15)
# for i, point in enumerate(XW.numpy().tolist()):
#     ax.plot(point + [results_mean[i] + results_std[i]], point + [results_mean[i] - results_std[i]])
ax.scatter(XW_[:, 0], XW_[:, 1], results_mean + results_std, marker="_", color="green", s=15)
    
plt.savefig("predictions.png", dpi=240)
plt.show()