# Predicting Ground Water Levels with Kernel Regression

In [21]:
from __future__ import absolute_import, division, print_function

import os
import json
import pyro
import torch
import logging
import numpy as np
import pandas as pd
import seaborn as sns
import pyro.optim as optim
import pyro.contrib.gp as gp
import matplotlib.pyplot as plt
import pyro.distributions as dist
import matplotlib.animation as animation

from torch.distributions import constraints

from functools import partial
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import Image, Video
from pyro.contrib.autoguide import AutoMultivariateNormal
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, JitTrace_ELBO

pyro.set_rng_seed(0)

In [2]:
%matplotlib inline
logging.basicConfig(format="%(message)s", level=logging.INFO)

# Enable validation checks
pyro.enable_validation(True)
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith("0.4.1")

In [3]:
pyro.set_rng_seed(1)

In [4]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

## Helper Functions

In [5]:
def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    dist = torch.clamp(dist, 0.0, np.inf)
    
    return dist

In [6]:
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [7]:
def visualize_posterior(samples):
    import math
    
    sites = list(samples.keys())
    
    r = int(math.ceil(math.sqrt(len(samples))))
    fig, axs = plt.subplots(nrows=r, ncols=r, figsize=(15, 13))
    fig.suptitle("Marginal Posterior Density", fontsize=16)
    
    
    for i, ax in enumerate(axs.reshape(-1)):
        if i >= len(sites):
            break
        site = sites[i]
        sns.distplot(samples[site], ax=ax)
        ax.set_title(site)
        
    handles, labels = ax.get_legend_handles_labels()

## Loading Data

In [8]:
data = pd.read_csv("data/sample_data.csv", encoding="ISO-8859-1")

data_wells = data[data.type == "well"]
data_farms = data[data.type == "farm"]

XW, YW = [], []
for t in data_wells["timestep"].unique():
    data_ = data_wells[data_wells["timestep"] == t]

    XW.append(data_[["latitude", "longitude"]].values)
    YW.append(data_["observation"].values)
    
XW = XW[0]

XF = data_farms[["latitude", "longitude"]].values
YF = data_farms["observation"].values

In [9]:
plt.clf()
fig = plt.figure(figsize=(10, 10), dpi=100)

plt.ion()

plt.scatter(XF[:, 0], XF[:, 1], marker="s", s=7, color="lightgreen")

scat = plt.scatter(XW[:, 0], XW[:, 1], marker="s", s=20, c=[(0, 0, 0, 1)] * len(XW))
label = plt.text(0, 0, '', fontsize=12)

colors = []
for obs in YW:
#     min_v = min(obs)
#     max_v = max(obs)
#     colors.append([max((x - min_v) / (max_v - min_v), 0.1) for x in obs])
    colors.append([min(1 - abs(x) / 15, 1) for x in obs])
    
colors = np.array(colors)

def update_plot(i, scat):
    scat.set_array(colors[i])
    label.set_text(["Sp", "Su", "Fa", "Wi"][i % 4])
    return scat,

anim = animation.FuncAnimation(fig, update_plot, frames=range(len(XW)), fargs=(scat,), interval=1000)

plt.gray()
plt.close()

<Figure size 432x288 with 0 Axes>

In [10]:
anim.save("includes/data-animation.mp4", fps=1)

Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
MovieWriter.run: running command: ['ffmpeg', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', '1000x1000', '-pix_fmt', 'rgba', '-r', '1', '-loglevel', 'error', '-i', 'pipe:', '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-y', 'includes/data-animation.mp4']


In [11]:
Video("includes/data-animation.mp4")

In [20]:
XW = torch.tensor(XW)
YW = torch.tensor(YW)

XF = torch.tensor(XF)
YF = torch.tensor(YF)

  """Entry point for launching an IPython kernel.
  
  after removing the cwd from sys.path.
  """


## Defining the Model

### Generative Model
---
**Farm Factor**
\begin{align*}
    \ln(\delta) \sim \mathcal{N}(1.0, 0.5)
\end{align*}

**Distance Factors**
\begin{align*}
    \ln(\theta_w) \sim \mathcal{N}(0.0, 0.5) \\
    \ln(\theta_f) \sim \mathcal{N}(0.0, 0.5)
\end{align*}

**Variance**
\begin{align*}
    \sigma^2 \sim \text{Gam}(1.0, 1.0)
\end{align*}

**Seasonal Factors**
For season $s \in \mathcal{S}$
\begin{align*}
    \gamma_s \sim \mathcal{N}(0.0, 1.0)
\end{align*}

**Base Water Levels**

The base water levels are modeled as a simple AR(1) process. The details of this are as follows

\begin{align*}
    \mu_0 \sim \mathcal{N}(\gamma_{s_0}, 1.0) \\
\end{align*}
For $t = 1 \dots T$, we specify
\begin{align*}
    \mu_{t} \sim \mathcal{N}(\mu_{t - 1} + \gamma_{s_t}, 1.0)
\end{align*}

**Likelihood**

For $t = 0 \dots T$, we specify
\begin{align*}
    \mathbf{y}_t \sim \mathcal{N}(\mu_t - \delta \cdot K(X_{t,w}, X_{t,f})\ /\ \theta_f, 1.0)
\end{align*}

---

<img src="includes/hmm-model.png" alt="drawing" width="600"/>

In [13]:
def model(XW, YW, gp=False):
    assert not torch._C._get_tracing_state()

    delta = pyro.sample("delta", dist.LogNormal(1.0, 0.5))

    if gp:
        theta_w = pyro.sample("theta_w", dist.LogNormal(0.0, 0.5))
        sigma = torch.exp(-pairwise_distances(XW, XW) / theta_w)
        
    else:
        sigma = pyro.sample("sigma", dist.Gamma(1.0, 1.0))
    
    theta_f = pyro.sample("theta_f", dist.LogNormal(0.0, 0.5))
    

    farm_factor = delta * (YF * torch.exp(-pairwise_distances(XW, XF) / theta_f)).sum(1)

    sf = pyro.sample("sf", dist.Normal(torch.zeros(4), 1.0))

    data_plate = pyro.plate("data", len(YW[0]))

    mu = 0
    for i in pyro.plate("sequence", len(YW)):
        mu = pyro.sample(
            "mu_{}".format(i), dist.Normal(mu + sf[i % 4], 1.0)
        )
        
        mean = mu - farm_factor
        with data_plate:
            if gp:
                pyro.sample(
                    "obs_{}".format(i), dist.MultivariateNormal(mean, sigma), obs=YW[i]
                )

In [14]:
def predict(XW, samples, gp=False):
    if gp:
        pdx = pairwise_distances(XW).cpu().numpy()
    pdf = pairwise_distances(XW, XF).cpu().numpy()
    
    YF_ = YF.cpu().numpy()
    
    sigma = samples["sigma"]
    delta = samples["delta"]
    
    if gp:
        theta_w = samples["theta_w"]
        
    theta_f = samples["theta_f"]
    
    mu = list(zip(*[samples["mu_{}".format(i)] for i in range(len(YW))]))
    mu = np.array(mu)
    
    for i in range(len(delta)):
        samples = []
        for t in range(len(YW)):
            if gp:
                sg = np.exp(-pdx / theta_w[i])
            else:
                sg = sigma[i]
                
            mean = mu[i, t] - delta[i] * (YF_ * np.exp(-pdf / theta_f[i])).sum(1)
            samples.append(np.random.multivariate_normal(mean, sg))
            
        samples = np.array(samples)
            
        yield samples

### Inference

In [15]:
use_gp = False

In [16]:
samples_file = "data/" + ("gp-samples" if use_gp else "kr-samples") + ".json"

In [17]:
try:
    with open(samples_file, "r") as f:
        samples = {k: np.array(v) for k, v in json.load(f).items()}
    
except:
    nuts_kernel = NUTS(partial(model, gp=use_gp))

    mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=400)
    mcmc_run = mcmc.run(XW, YW)

    samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [18]:
samples_ = {k: v.tolist() for k, v in samples.items()}
with open("data/kr-samples.json", "w") as f:
    json.dump(samples_, f)

In [19]:
for site, values in summary(samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: delta
       mean       std        5%      25%       50%       75%       95%
0  0.692065  0.002864  0.687882  0.68975  0.691992  0.693799  0.696731 

Site: theta_f
       mean       std        5%       25%       50%       75%       95%
0  0.598756  0.003512  0.593951  0.596369  0.598588  0.600489  0.605166 

Site: sigma
       mean       std        5%      25%       50%       75%       95%
0  0.491833  0.014983  0.469867  0.47993  0.492256  0.500444  0.518304 

Site: sf
       mean       std        5%       25%       50%       75%       95%
0 -1.349452  0.297369 -1.853333 -1.539048 -1.346878 -1.126683 -0.915593
1 -1.744000  0.284042 -2.154995 -1.951336 -1.748563 -1.583667 -1.215060
2  2.565167  0.434421  1.815618  2.313591  2.557347  2.854837  3.225456
3 -0.016686  0.283367 -0.450356 -0.234661 -0.037776  0.184439  0.441216 

Site: mu_0
       mean       std        5%       25%       50%       75%       95%
0 -5.219291  0.126555 -5.423017 -5.312674 -5.208899 -5.134357 -5.017868 

