In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

import inspect
import pandas as pd
import autograd.numpy as np
import matplotlib.pyplot as plt

from utils.games import WetChicken2D
from utils.models import BNN_LV, BayesianModel, SamplerModel
from utils.training import HMC

# Import helpers for building Weights & Biases callbacks:
from utils.training import build_wb_callback_postpred, build_wb_callback_plotfunc


In [None]:
# Get data from wet chicken:
env = WetChicken2D(L=5, W=3, max_steps=20, seed=207)
no_action_policy = lambda state: (0,0)  # For any state, play the "do nothing" action.
random_policy = None  # If the policy is None, the simulator chooses an action at random.
env.run(episodes=100, progress=100, policy=no_action_policy)


In [None]:
transition_dataset = env.extract_transition_dataset()
transition_dataset


In [None]:
X_train = transition_dataset[['start_x','start_y','action_x','action_y']].to_numpy()
Y_train = transition_dataset[['result_x','result_y']].to_numpy()
X_test = np.linspace(-6,6, 100)


In [None]:
# Define Bayesian Neural Network with Latent Variable (BNN_LV):
L = 1
N, M = X_train.shape  # Input shape.
_, K = Y_train.shape  # Output shape.
gamma = 1.0  # Standard deviation of noise for each latent input.
sigma = 1.0  # Standard evation of noise on each model output.

# Newtork architecture:
architecture = {
    'input_n' : M,  # 4 inputs.
    'output_n' : K,  # 2 outputs.
    'hidden_layers' : [20,20],
    'biases' : [1,1,1],
    'activations' : ['relu', 'relu', 'linear'],
    'gamma' : [gamma]*L,
    'sigma' : [sigma]*K,
    'seed' : 207,
}

# Initialize network:
bnn_lv = BNN_LV(architecture=architecture)

# Get number of weights in network:
D = bnn_lv.D

# # Train network to get MLE estimate as starting point for sampler:
# bnn_lv.fit(X_train, Y_train, step_size=0.01, max_iteration=5000, check_point=500, regularization_coef=None)


In [None]:
# Define Bayesian model (with a posterior on W and Z):
bayesian_model = BayesianModel(
    X = X_train,
    Y = Y_train,
    nn = bnn_lv,
    prior_weights_mean = 0,
    prior_weights_stdev = 5.0,
    prior_latents_mean = 0,
    prior_latents_stdev = gamma,
    likelihood_stdev = 0.25,
    output_noise_stdev = sigma,
    label = 'Wet Chicken',
)
# Wrap the model so that it takes a single input (`samples`) that stores both W and Z:
sampler_model = SamplerModel(bayesian_model)

sampler_model.display()
sampler_model.describe()
sampler_model.info()


In [None]:
# Create the posterior :
log_posterior = sampler_model.log_posterior

# Get the MLE starting weights from the fitted network:
mle_weights = bnn_lv.get_weights()

# Concatenate starting values for W and Z into a single init vector:
W_init = mle_weights.reshape(1,-1)
Z_init = np.zeros((N,1))
position_init = sampler_model.stack(W_init, Z_init)

# # Build a callback that produces a scatter plot using W&B built-in functions:
# wb_callback_postpred = build_wb_callback_postpred(sampler_model, x_data=X_test, interval=200)

# Define W&B settings:
wb_settings = {
    'entity' : 'gpestre',
    'project' : 'am207',
    'group' : 'chicken_hmc',
    'name' : 'chicken_hmc_v1',
    'notes' : 'HMC on toy dataset with hsc noise.',
    'progress' : 10,
    'base_path' : '../data/',
    'filename' : 'temp_hmc_state.json',
    'archive' : {  # Manually archive info about network and priors.
        'architecture' : architecture,
        'N' : N,
        'M' : M,
        'K' : K,
        'L' : L,
        'D' : D,
        'gamma' : gamma,
        'sigma' : sigma,
        'position_init' : position_init,
    },
    #'callback' : [wb_callback_postpred],
}

# Sample from HMC:
hmc = HMC(
    log_target_func=log_posterior, position_init=position_init,
    total_samples=100, burn_in=0.5, thinning_factor=1, 
    leapfrog_steps=10, step_size=1e-5, mass=1.0, random_seed=207,
    progress=5, wb_settings=wb_settings,
)
hmc.sample()
