In [48]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
from hydra.utils import instantiate
from hydra import initialize, compose
import hydra

import wandb

from data.dataManager import DataManager
from model.modelCreator import ModelCreator
from omegaconf import OmegaConf
from scripts.run import setup_model, load_model_instance

from utils.plots import vae_plots
from utils.rbm_plots import plot_rbm_histogram, plot_rbm_params, plot_forward_output_v2

from scripts.run import set_device

In [None]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path="config")
config=compose(config_name="config.yaml")
wandb.init(tags = [config.data.dataset_name], project=config.wandb.project, entity=config.wandb.entity, config=OmegaConf.to_container(config, resolve=True), mode='disabled')

In [None]:
new_model = False
if new_model:
    self = setup_model(config)
    # self.model = self.model.double()  # sets all model parameters to float64
else:
    self = load_model_instance(config)
    # self.model = self.model.double()


In [None]:
self.evaluate_vae(self.data_mgr.val_loader,0)

In [None]:
self.prior_samples.shape
self.post_samples.shape

In [None]:
def sigmoid_C_k(weights_ax, weights_bx, weights_cx,
                 pa_state, pb_state, pc_state, bias_x) -> torch.Tensor:
    """partition_state()

    :param weights_a (torch.Tensor) : (n_nodes_a, n_nodes_x)
    :param weights_b (torch.Tensor) : (n_nodes_b, n_nodes_x)
    :param weights_c (torch.Tensor) : (n_nodes_c, n_nodes_x)
    :param pa_state (torch.Tensor) : (batch_size, n_nodes_a)
    :param pb_state (torch.Tensor) : (batch_size, n_nodes_b)
    :param pc_state (torch.Tensor) : (batch_size, n_nodes_c)
    :param bias_x (torch.Tensor) : (n_nodes_x)
    """
    p_activations = (torch.matmul(pa_state, weights_ax) +
                        torch.matmul(pb_state, weights_bx) +
                        torch.matmul(pc_state, weights_cx) + bias_x)
    return torch.sigmoid(p_activations).detach()

In [None]:
p0 = self.post_samples[:self._config.data.batch_size_tr,:302].to(self._device)
p1 = self.post_logits[:self._config.data.batch_size_tr,:302].to(self._device)
p2 = self.post_logits[:self._config.data.batch_size_tr,302:604].to(self._device)
p3 = self.post_logits[:self._config.data.batch_size_tr,604:].to(self._device)

W01 = self.model.prior.weight_dict['01']
W02 = self.model.prior.weight_dict['02']
W03 = self.model.prior.weight_dict['03']

W12 = self.model.prior.weight_dict['12']
W13 = self.model.prior.weight_dict['13']
W23 = self.model.prior.weight_dict['23']
# precompute the needed transposes only once
W12_T = W12.T
W13_T = W13.T
W23_T = W23.T

b1 = self.model.prior.bias_dict['1']
b2 = self.model.prior.bias_dict['2']
b3 = self.model.prior.bias_dict['3']

p3_ans = sigmoid_C_k(W03,   W13,   W23,   p0, p1, p2, b3)
p2_ans = sigmoid_C_k(W02,   W12, W23_T,   p0, p1, p3, b2)
p1_ans = sigmoid_C_k(W01, W12_T, W13_T,   p0, p2, p3, b1)

In [None]:
plt.hist((p1 - p1_ans).pow(2).mean(dim=1).cpu(), label='p1')
plt.hist((p2 - p2_ans).pow(2).mean(dim=1).cpu(), label='p2', alpha=0.7)
plt.hist((p3 - p3_ans).pow(2).mean(dim=1).cpu(), label='p3', alpha=0.5)
plt.xlabel("Mean Squared Error")
plt.ylabel("Frequency")
plt.show()