# Predicting Ground Water Levels with Kernel Regression

In [1]:
from __future__ import absolute_import, division, print_function

import os
import json
import pyro
import torch
import pickle
import logging
import numpy as np
import pandas as pd
import seaborn as sns
import pyro.optim as optim
import pyro.contrib.gp as gp
import matplotlib.pyplot as plt
import pyro.distributions as dist
import matplotlib.animation as animation

from torch.distributions import constraints

from functools import partial
from pyro.infer.mcmc import NUTS
from pyro.infer.mcmc.api import MCMC
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import Image, Video
from pyro.contrib.autoguide import AutoMultivariateNormal
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, JitTrace_ELBO

pyro.set_rng_seed(0)

In [2]:
%matplotlib inline
logging.basicConfig(format="%(message)s", level=logging.INFO)

# Enable validation checks
pyro.enable_validation(True)
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith("0.4.1")

In [3]:
pyro.set_rng_seed(1)

In [4]:
if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

## Helper Functions

In [5]:
def pairwise_distances(x, y=None):
    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
    
    dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
    dist = torch.clamp(dist, 0.0, np.inf)
    
    return dist

In [6]:
def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

In [7]:
def visualize_posterior(samples):
    import math
    
    sites = list(samples.keys())
    
    r = int(math.ceil(math.sqrt(len(samples))))
    fig, axs = plt.subplots(nrows=r, ncols=r, figsize=(15, 13))
    fig.suptitle("Marginal Posterior Density", fontsize=16)
    
    
    for i, ax in enumerate(axs.reshape(-1)):
        if i >= len(sites):
            break
        site = sites[i]
        sns.distplot(samples[site], ax=ax)
        ax.set_title(site)
        
    handles, labels = ax.get_legend_handles_labels()

## Defining the Model

### Generative Model
---
**Farm Factor**
\begin{align*}
    \ln(\delta) \sim \mathcal{N}(1.0, 0.5)
\end{align*}

**Distance Factors**
\begin{align*}
    \ln(\theta_w) \sim \mathcal{N}(0.0, 0.5) \\
    \ln(\theta_f) \sim \mathcal{N}(0.0, 0.5)
\end{align*}

**Variance**
\begin{align*}
    \sigma^2 \sim \text{Gam}(1.0, 1.0)
\end{align*}

**Seasonal Factors**
For season $s \in \mathcal{S}$
\begin{align*}
    \gamma_s \sim \mathcal{N}(0.0, 1.0)
\end{align*}

**Base Water Levels**

The base water levels are modeled as a simple AR(1) process. The details of this are as follows

\begin{align*}
    \mu_0 \sim \mathcal{N}(\gamma_{s_0}, 1.0) \\
\end{align*}
For $t = 1 \dots T$, we specify
\begin{align*}
    \mu_{t} \sim \mathcal{N}(\mu_{t - 1} + \gamma_{s_t}, 1.0)
\end{align*}

**Likelihood**

For $t = 0 \dots T$, we specify
\begin{align*}
    \mathbf{y}_t \sim \mathcal{N}(\mu_t - \delta \cdot K(X_{t,w}, X_{t,f})\ /\ \theta_f, 1.0)
\end{align*}

---

<img src="includes/hmm-model.png" alt="drawing" width="600"/>

In [8]:
def model(XW, YW, YF, WF_distances, gp=False):
    assert not torch._C._get_tracing_state()

    delta = pyro.sample("delta", dist.LogNormal(1.0, 0.5))

    if gp:
        theta_w = pyro.sample("theta_w", dist.LogNormal(0.0, 0.5))    
    else:
        sigma = pyro.sample("sigma", dist.Gamma(1.0, 1.0))
    
    theta_f = pyro.sample("theta_f", dist.LogNormal(0.0, 0.5))
    
    n_seasons = 3
    sf = pyro.sample("sf", dist.Normal(torch.zeros(n_seasons), 1.0))

    data_plate = pyro.plate("data", len(YW[0]))
        
    mu = 0
    for t in pyro.markov(range(len(YW))):
        if gp:
            sigma = torch.exp(-pairwise_distances(XW[t], XW[t]) / theta_w)
                
        mu = pyro.sample(
            "mu_{}".format(t), dist.Normal(mu + sf[t % n_seasons], 1.0)
        )
        
        mean = mu - delta * (YF[t] * torch.exp(-WF_distances[t] / theta_f)).sum(1)
        
        if gp:
            pyro.sample(
                "obs_{}".format(t), dist.MultivariateNormal(mean, sigma), obs=YW[t]
            )
        else:
            with pyro.plate("data_{}".format(t), len(YW[t])):
#             with data_plate:
                pyro.sample(
                    "obs_{}".format(t), dist.Normal(mean, sigma), obs=YW[t]
                )

In [30]:
def predict(XW, XF, YF, samples, gp=False):
        
    sigma = samples["sigma"]
    delta = samples["delta"]
    
    if gp:
        theta_w = samples["theta_w"]
        
    theta_f = samples["theta_f"]
    
    mu = list(zip(*[samples["mu_{}".format(i)] for i in range(len(YF))]))
    mu = np.array(mu)
    
    samples = []
    for t in range(len(YF)):
        YF_ = YF[t].cpu().numpy()
        
        if gp:
            pdx = pairwise_distances(XW[t]).cpu().numpy()
        pdf = pairwise_distances(XW[t], XF[t]).cpu().numpy()
    
        samples_ = []
        for i in range(len(delta)):
            if gp:
                sg = np.exp(-pdx / theta_w[i])
            else:
                sg = sigma[i]
                
            mean = mu[i, t] - delta[i] * (YF_ * np.exp(-pdf / theta_f[i])).sum(1)
            samples_.append(np.random.normal(mean, sg))
            
        samples_ = np.array(samples_)
        samples.append(samples_)
        
    return samples

In [99]:
vals = list(predict(XW_r, XF_r, YF_r, samples))

In [100]:
vals = [x.mean(0) for x in vals]

In [103]:
vals[7]

array([-34.04215076, -32.16421496, -40.46766094, -33.21190748,
       -30.72401199, -34.06916408, -34.60869695, -35.92405091,
       -37.24654335, -37.10247457, -35.52356699, -35.22256515,
       -33.73864122, -32.91149726, -35.33260415, -30.35373844,
       -33.57927149, -35.2426191 , -37.01480677, -34.21382468,
       -35.21576545, -34.21334367, -35.9983931 , -36.81179993,
       -34.49353128, -30.97596572, -31.5107213 , -30.92422419,
       -36.18588612])

In [104]:
YW_r[7]

tensor([[ -48.0700],
        [  -9.1000],
        [  -6.4500],
        [ -17.6000],
        [ -19.2000],
        [ -25.1500],
        [ -27.2400],
        [ -30.1400],
        [ -19.7000],
        [ -62.9000],
        [ -33.8000],
        [ -54.3400],
        [  -9.4800],
        [ -59.3500],
        [ -17.7900],
        [-116.0000],
        [ -42.3800],
        [ -56.2300],
        [ -11.9600],
        [ -10.1000],
        [ -23.6200],
        [ -58.3600],
        [ -35.6300],
        [ -10.8000],
        [ -20.6500],
        [ -13.6000],
        [ -31.1300],
        [ -45.1000],
        [ -55.9500]], dtype=torch.float64)

In [90]:
XW_r[0][13]

tensor([26.2417, 71.5125], dtype=torch.float64)

In [112]:
XW_r[9][-6]

tensor([26.4917, 71.4875], dtype=torch.float64)

In [109]:
list(zip(vals[9], YW_r[9][:, 0].cpu().numpy()))

[(-37.85402207593267, -47.5),
 (-33.796590227095365, -7.48),
 (-35.82105790583092, -8.46),
 (-28.69468870735836, -16.86),
 (-36.180533210067715, -18.66),
 (-34.54529134760516, -27.1),
 (-32.11504513310349, -38.2),
 (-34.814679529821426, -27.59),
 (-32.13170505063718, -31.48),
 (-33.18659539250355, -36.6),
 (-30.36204796858018, -23.21),
 (-35.240600699269706, -33.8),
 (-30.60654545548461, -96.66),
 (-33.99185626099613, -54.6),
 (-33.06156273940035, -9.3),
 (-32.545093419191005, -60.96),
 (-36.7138313979355, -17.94),
 (-34.24093436774197, -11.6),
 (-30.27672985189944, -8.0),
 (-35.70648081151071, -23.0),
 (-27.727470908460806, -68.5),
 (-35.45589645161391, -21.4),
 (-42.757901745158954, -10.4),
 (-33.707175380736125, -20.5),
 (-31.330760264190914, -13.9),
 (-32.72187281194161, -32.9),
 (-27.10393087564684, -43.2),
 (-30.06646489546819, -57.2)]

In [83]:
y.shape

torch.Size([29, 1])

In [58]:
vals = vals.transpose(1, 0, 2)

AttributeError: 'list' object has no attribute 'transpose'

In [59]:
for i in range(len(vals)):
    print(vals.mean(0)[i], y[i])

AttributeError: 'list' object has no attribute 'mean'

In [None]:
vals.mean(0)

## Loading Data

In [None]:
data = pd.read_csv("data/sample-data/data.csv", encoding="ISO-8859-1")

data_wells = data[data.type == "well"]
data_farms = data[data.type == "farm"]

XW, YW = [], []
for t in data_wells["timestep"].unique():
    data_ = data_wells[data_wells["timestep"] == t]

    XW.append(data_[["latitude", "longitude"]].values)
    YW.append(data_["observation"].values)
    
XW = XW[0]

XF = data_farms[["latitude", "longitude"]].values
YF = data_farms["observation"].values

In [None]:
plt.clf()
fig = plt.figure(figsize=(10, 10), dpi=100)

plt.ion()

plt.scatter(XF[:, 0], XF[:, 1], marker="s", s=7, color="lightgreen")

scat = plt.scatter(XW[:, 0], XW[:, 1], marker="s", s=20, c=[(0, 0, 0, 1)] * len(XW))
label = plt.text(0, 0, '', fontsize=12)

colors = []
for obs in YW:
#     min_v = min(obs)
#     max_v = max(obs)
#     colors.append([max((x - min_v) / (max_v - min_v), 0.1) for x in obs])
    colors.append([min(1 - abs(x) / 15, 1) for x in obs])
    
colors = np.array(colors)

def update_plot(i, scat):
    scat.set_array(colors[i])
    label.set_text(["Sp", "Su", "Fa", "Wi"][i % 4])
    return scat,

anim = animation.FuncAnimation(fig, update_plot, frames=range(len(XW)), fargs=(scat,), interval=1000)

plt.gray()
plt.close()

In [None]:
anim.save("includes/sample-data-animation.mp4", fps=1)

In [None]:
Video("includes/sample-data-animation.mp4")

In [None]:
XW = torch.tensor(XW)
YW = torch.tensor(YW)

XF = torch.tensor(XF)
YF = torch.tensor(YF)

In [None]:
timesteps = len(YW)

XW = XW.repeat(timesteps, 1, 1)

YF = YF.repeat(timesteps, 1, 1)
XF = XF.repeat(timesteps, 1, 1)

## Loading Real Data

In [91]:
with open("data/dataset.pkl", "rb") as f:
    XF_r = [np.array(x) for x in pickle.load(f)]
    YF_r = [np.array(x) for x in pickle.load(f)]
                        
    XW_r = [np.array(x) for x in pickle.load(f)]
    YW_r = [np.array(x) for x in pickle.load(f)]

In [92]:
plt.clf()
fig = plt.figure(figsize=(10, 10), dpi=100)

plt.ion()

scat_f = plt.scatter(XF_r[0][:, 0], XF_r[0][:, 1], marker="s", s=7, color="lightgreen")

scat_w = plt.scatter(XW_r[0][:, 0], XW_r[0][:, 1], marker="s", s=20, c=[(0, 0, 0, 1)] * len(XW_r[0]))
label = plt.text(0, 0, '', fontsize=12)

# colors = []
# for obs in YW:
# #     min_v = min(obs)
# #     max_v = max(obs)
# #     colors.append([max((x - min_v) / (max_v - min_v), 0.1) for x in obs])
#     colors.append([min(1 - abs(x) / 15, 1) for x in obs])

# colors = np.array(colors)

def update_plot(i, scat_w, scat_f):
    scat_w.set_offsets(XW_r[i])
    scat_w.set_array(np.array([min(1 - abs(x[0]) / 50, 1) for x in YW_r[i]]))
    
    scat_f.set_offsets(XF_r[i])
    return scat_w, scat_f

anim = animation.FuncAnimation(fig, update_plot, frames=range(len(XW_r)), fargs=(scat_w, scat_f), interval=1000)

plt.gray()
plt.close()

<Figure size 432x288 with 0 Axes>

In [93]:
anim.save("includes/data-animation.mp4", fps=1)

Animation.save using <class 'matplotlib.animation.FFMpegWriter'>
MovieWriter.run: running command: ['ffmpeg', '-f', 'rawvideo', '-vcodec', 'rawvideo', '-s', '1000x1000', '-pix_fmt', 'rgba', '-r', '1', '-loglevel', 'error', '-i', 'pipe:', '-vcodec', 'h264', '-pix_fmt', 'yuv420p', '-y', 'includes/data-animation.mp4']


In [108]:
Video("includes/data-animation.mp4")

In [None]:
26.4917, 71.4875]

In [94]:
XF_r = [torch.tensor(x) for x in XF_r]
YF_r = [torch.tensor(x[0]) for x in YF_r]

XW_r = [torch.tensor(x) for x in XW_r]
YW_r = [torch.tensor(x) for x in YW_r]

## Inference

In [15]:
use_gp = False

In [None]:
samples_file = "data/sample-data/" + ("gp-samples" if use_gp else "kr-samples") + ".json"

In [53]:
WF_distances = [pairwise_distances(XW_r[i], XF_r[i]) for i in range(len(YW_r))]

In [54]:
nuts_kernel = NUTS(partial(model, WF_distances=WF_distances, gp=use_gp), max_plate_nesting=1)

mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=400)
mcmc_run = mcmc.run(XW_r, YW_r, YF_r)

samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

sample: 100%|██████████| 500/500 [01:19<00:00,  6.27it/s, step size=3.15e-01, acc. prob=0.923]


In [18]:
samples_ = {k: v.tolist() for k, v in samples.items()}
with open("data/real-data/kr-samples.json", "w") as f:
    json.dump(samples_, f)

In [None]:
WF_distances = [pairwise_distances(XW[i], XF[i]) for i in range(len(YW))]

In [None]:
nuts_kernel = NUTS(partial(model, WF_distances=WF_distances, gp=use_gp))

mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=400)
mcmc_run = mcmc.run(XW, YW, YF)

samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [None]:
samples_ = {k: v.tolist() for k, v in samples.items()}
with open("data/kr-samples2.json", "w") as f:
    json.dump(samples_, f)

In [96]:
with open("data/real-data/kr-samples.json", "r") as f:
        samples = {k: np.array(v) for k, v in json.load(f).items()}

In [None]:
try:
    with open(samples_file, "r") as f:
        samples = {k: np.array(v) for k, v in json.load(f).items()}
    
except:
    nuts_kernel = NUTS(partial(model, gp=use_gp))

    mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=400)
    mcmc_run = mcmc.run(XW, YW)

    samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

In [98]:
for site, values in summary(samples).items():
    print("Site: {}".format(site))
    print(values, "\n")

Site: delta
       mean       std        5%       25%       50%       75%       95%
0  0.000219  0.000072  0.000126  0.000165  0.000201  0.000247  0.000356 

Site: sigma
        mean       std         5%        25%        50%        75%        95%
0  23.180259  0.131626  22.958813  23.097648  23.183979  23.267464  23.403579 

Site: theta_f
       mean       std        5%       25%       50%       75%       95%
0  0.000265  0.000086  0.000132  0.000203  0.000262  0.000314  0.000401 

Site: sf
       mean       std        5%       25%       50%       75%       95%
0 -3.138915  0.499943 -4.043172 -3.460290 -3.067657 -2.799999 -2.419570
1 -1.147099  0.480017 -2.015271 -1.437748 -1.138173 -0.858489 -0.361831
2  1.178620  0.531882  0.373820  0.843479  1.159863  1.544441  1.960773 

Site: mu_0
        mean       std        5%        25%        50%        75%        95%
0 -21.714631  0.541971 -22.63309 -22.088746 -21.655123 -21.364281 -20.847652 

Site: mu_1
        mean       std         5%  

In [None]:
for site, values in summary(samples).items():
    print("Site: {}".format(site))
    print(values, "\n")