In [None]:
import torch
import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist
import arviz
import os
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.contrib.examples.util  # patches torchvision
from pyro.contrib.examples.util import MNIST
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def f(x):
    return torch.sin(20 * x) + 2 * torch.cos(14 * x) - 2 * torch.sin(6 * x)

In [None]:
X = torch.tensor([-1, -1/2, 0, 1/2, 1])
y = f(X)
y

In [None]:
kernel = gp.kernels.RBF(input_dim=1)
kernel.variance = pyro.nn.PyroSample(dist.LogNormal(-1.0, 1.0))
kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(0.0, 2.0))
gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(10**(-4)))

In [None]:
def f_model(x):
    gpr.kernel.variance = pyro.nn.PyroSample(dist.LogNormal(-1.0, 1.0))
    gpr.kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(0.0, 2.0))
    loc, var = gpr(x, noiseless = False, full_cov = False)
    pyro.sample("loc", dist.Delta(loc))
    pyro.sample("var", dist.Delta(var))
    pyro.sample("f", dist.Normal(loc, var))

In [None]:
C = 4
W = 100

In [None]:
nuts_kernel=pyro.infer.NUTS(gpr.model, jit_compile=True)
mcmc=pyro.infer.MCMC(nuts_kernel, num_samples=500, num_chains=C, warmup_steps=W)
mcmc.run()

In [None]:
data = arviz.from_pyro(mcmc)
# Specify we want 95% credible interval (hdi=high density interval)
summary = arviz.summary(data, hdi_prob=0.95)
print(summary)
arviz.plot_posterior(data, hdi_prob=0.95)

In [None]:
posterior_samples=mcmc.get_samples()

In [None]:
x = torch.tensor([-1/4])

In [None]:
posterior_predictive = pyro.infer.Predictive(f_model, posterior_samples = posterior_samples)(x)


In [None]:
posterior_predictive['loc']

In [None]:
def algo1(X, y, XNew, T, C = 4, W = 100):
    kernel = gp.kernels.RBF(input_dim=1)
    kernel.variance = pyro.nn.PyroSample(dist.LogNormal(-1.0, 1.0))
    kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(0.0, 2.0))
    X_k, y_k = X, y
    #gpr = gp.models.GPRegression(X, y, kernel, noise=torch.tensor(10**(-4)))
    #nuts_kernel = pyro.infer.NUTS(gpr.model, jit_compile=True)
    #mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=1,
    #                    num_chains=C, warmup_steps=W)
    #mcmc.run()
    fs_list = []
    mean_list = []
    var_list = []
    min_list = []
    for k in range(T):
        gpr = gp.models.GPRegression(
            X_k, y_k, kernel, noise=torch.tensor(10**(-4)))
        nuts_kernel = pyro.infer.NUTS(gpr.model, jit_compile=True)
        mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=1,
                        num_chains=C, warmup_steps=W)
        mcmc.run()
        posterior_samples=mcmc.get_samples()
        posterior_predictive = pyro.infer.Predictive(
            f_model, posterior_samples=posterior_samples)(XNew)
        fs = posterior_predictive['f'].mean(dim = 0)
        fs_list.append(fs)
        mean_list.append(posterior_predictive['loc'].mean(dim = 0))
        var_list.append(posterior_predictive['var'].mean(dim = 0))
        idx = torch.argmin(fs)
        X_min = XNew[idx]
        y_min = f(X_min)
        min_list.append([X_min, y_min])
        X_k = torch.concat((X_k, X_min.reshape(1)))
        y_k = torch.concat((y_k, y_min.reshape(1)))
        pyro.clear_param_store()
        #gpr.X = torch.concat((gpr.X, XNew[idx]))
        #gpr.y = torch.concat((gpr.y, torch.tensor([min])))
    return fs_list, mean_list, var_list, min_list

In [None]:
XNew = torch.linspace(-1, 1, steps = 200)
T = 30

In [None]:
%%capture
fs_list, mean_list, var_list, min_list = algo1(X, y, XNew, T)

In [None]:
plt.plot(np.array(min_list)[:,1])

In [None]:

def vae_model(x):
    #with pyro.plate("data", x.shape[0]):
    w1=pyro.sample("mw1", dist.Normal(0, 1).expand([2, 100]).to_event(2))
    b1=pyro.sample("mb1", dist.Normal(0, 1).expand([100]).to_event(1))
    w2=pyro.sample("mw2", dist.Normal(0, 1).expand([100, 400]).to_event(2))
    b2=pyro.sample("mb2", dist.Normal(0, 1).expand([400]).to_event(1))
    w3=pyro.sample("mw3", dist.Normal(0, 1).expand([400, 784]).to_event(2))
    b3=pyro.sample("mb3", dist.Normal(0, 1).expand([784]).to_event(1))
    print(w1.shape)
    print(b1.shape)
    print(w2.shape)
    with pyro.plate("data", x.shape[0]):
        z = pyro.sample("latent", dist.Normal(0, 1).expand([2]).to_event(1))
        h1 = torch.relu((z@ w1) +b1)
        h2 = torch.relu((h1 @ w2) + b2)
        h3 = torch.sigmoid((h2 @ w3) + b3)
        pyro.sample("obs", dist.Bernoulli(h3).to_event(1), obs = x.reshape(-1, 784).type(torch.int))

In [None]:
def vae_guide(x):
    w1=pyro.sample("gw1", dist.Normal(0, 1).expand([784, 400]).to_event(2))
    b1=pyro.sample("gb1", dist.Normal(0, 1).expand([400]).to_event(1))
    w2=pyro.sample("gw2", dist.Normal(0, 1).expand([400, 100]).to_event(2))
    b2=pyro.sample("gb2", dist.Normal(0, 1).expand([100]).to_event(1))
    w3=pyro.sample("gw3", dist.Normal(0, 1).expand([100, 2]).to_event(2))
    b3=pyro.sample("gb3", dist.Normal(0, 1).expand([2]).to_event(1))
    w4=pyro.sample("gw4", dist.Normal(0, 1).expand([100, 2]).to_event(2))
    b4=pyro.sample("gb4", dist.Normal(0, 1).expand([2]).to_event(1))
    x = x.reshape(-1, 784)
    h1 = torch.relu((x @ w1) +b1)
    h2 = torch.relu((h1 @ w2) + b2)
    z_loc = torch.relu((h2 @ w3) + b3)
    z_scale = torch.exp(torch.relu(h2 @ w4 + b4))
    with pyro.plate("data", x.shape[0]):
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))