In [None]:
%load_ext autoreload
%autoreload 2


import os,sys
#sys.path.append(os.path.join(os.path.dirname("src_py/models/pivae.py"), '../'))
#print(sys.path)
sys.path.append("../src_py")


In [None]:
from tqdm import tqdm, trange
import cmdstanpy
import pandas as pd
import pickle
from models.phi import PHI
from models.vae import VAE
import torch
import numpy as np
import matplotlib.pyplot as plt 
import matplotlib as mpl

In [None]:
# cmdstanpy.install_cmdstan()

In [None]:
plt.style.use('seaborn-paper')
mpl.rc('axes.spines', right=False, top=False)
mpl.rc('axes', labelsize=20)
mpl.rc('xtick', labelsize=16, top=False)
mpl.rc('xtick.minor', visible=False)
mpl.rc('ytick', labelsize=16, right=False)
mpl.rc('ytick.minor', visible=False)
mpl.rc('savefig', bbox='tight', format='pdf')
mpl.rc('figure', figsize=(10, 10))
mpl.rc('legend',fontsize=16)

In [None]:
# creating phi model
phi = PHI(2, alpha=1.0, n_centers=1000, hidden_dim1=100, hidden_dim2=100, out_dims=50)
with open('2d_gp_phi.pkl', 'rb') as f:
    weights_phi = pickle.load(f)
phi.load_state_dict(weights_phi)
phi.eval()

In [None]:
# creating vae model
vae = VAE(input_dim=50, hidden_dim1=64, hidden_dim2=32, latent_dim=20)
with open('2d_gp_vae.pkl', 'rb') as f:
    weights_vae = pickle.load(f)
vae.decoder.load_state_dict(weights_vae)
vae.decoder.eval()

In [None]:
# data for inference
df_inf = pd.read_csv('2d_gp_inf_data.csv')
x_inf = df_inf[['x1','x2']].to_numpy()
y_inf = df_inf[['y']].to_numpy()
size = (900,1)
proba_0 = 0.8                 # resulting array will have 80% of zeros
idx = np.random.choice([0, 1], size=size, p=[proba_0, 1-proba_0])
ll_idx = np.where(idx)

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(y_inf, color='blue', alpha=0.5)
ax.scatter(ll_idx[0],y_inf[ll_idx], marker='+', color='red', alpha=0.5, s=100)
ax.set_xlabel('$x$')
ax.set_ylabel('$y=f(x)$')

In [None]:
stan_data = {'p': 20, 
                 'p1': 64,
                 'p2': 32,
                 'n': 900,
                 'W1': weights_vae['linear1.weight'].T.numpy(),
                 'B1': weights_vae['linear1.bias'].T.numpy(),
                 'W2': weights_vae['linear2.weight'].T.numpy(),
                 'B2': weights_vae['linear2.bias'].T.numpy(),
                 'W3': weights_vae['out.weight'].T.numpy(),
                 'B3': weights_vae['out.bias'].T.numpy(),
                 'beta_dim' : 50,
                 'phi_x' : phi(torch.tensor(x_inf).float()).detach().numpy(),
                 'y': y_inf.reshape(900,),
                 'll_len' : ll_idx[0].shape[0],
                 'll_idxs' : ll_idx[0]+1}

In [None]:
sm = cmdstanpy.CmdStanModel(stan_file='pivae_2d.stan')

In [None]:
fit = sm.sample(data=stan_data, iter_sampling=2000, iter_warmup=500, chains=4)

In [None]:
out = fit.stan_variables()

df = pd.DataFrame(out['y2'])

In [None]:
# mpl.rc('figure', figsize=(40, 10))
# fig = plt.figure()
# ax = fig.add_subplot(111)
# ax.plot(y_inf, color='black', label='True')
# ax.scatter(ll_idx[0], y_inf[ll_idx], s=46,label = 'Observations')
# ax.fill_between(range(0,900), df.quantile(0.025).to_numpy(), df.quantile(0.975).to_numpy(),
#                     facecolor="blue",
#                     color='blue', 
#                     alpha=0.2, label = '95% Credible Interval')
# ax.plot(df.median().to_numpy()*2, color='red', alpha=0.7, label = 'Posterior mean')