In [None]:
import os, sys
# for accessing src, stan, etc.
sys.path.append(os.path.abspath(os.path.join("../..")))


import tensorflow as tf
import bayesflow as bf
import numpy as np
import matplotlib.pyplot as plt


from tensorflow.keras.layers import Dense, Lambda
from tensorflow.keras.models import Sequential

from bayesflow.trainers import Trainer
from bayesflow.amortizers import AmortizedPosterior
from bayesflow.networks import InvertibleNetwork
from bayesflow.summary_networks import DeepSet, HierarchicalNetwork, SetTransformer

from src.networks import AmortizedMixture, AmortizedPosteriorMixture
from src.models.MixtureNormal import model, modelFixedContext, configurator, constrain_parameters, unconstrain_parameters, constrained_parameter_names, generate_fixed_dataset
from amortizer import amortizer
from cmdstanpy import CmdStanModel
from logging import getLogger

stan_logger = getLogger("cmdstanpy")
stan_logger.disabled = True

In [None]:
stan_model = CmdStanModel(stan_file="../../stan/mixture-normal.stan")

In [None]:
true_params_unconstrained = np.array([[0, -0.5, -2.0, 0.45, 0.05]])
true_params = constrain_parameters(true_params_unconstrained)

df = modelFixedContext(n_obs=200, n_rep=5).simulator(true_params_unconstrained)
df['prior_draws'] = true_params_unconstrained
df['sim_non_batchable_context'] = np.array([100, 5])

# reorder data points based on their means
ind =df['sim_data'][:,:,:-1]
ind = np.argsort(np.mean(ind, axis=-1)[0])
df['sim_data'] = df['sim_data'][:,ind]

df = configurator(df)

observables = df['posterior_inputs']['summary_conditions'][0,...,0]
means = np.array(np.mean(observables, axis=-1))
latents = np.array(df['mixture_inputs']['latents'][0,0])

stan_df = {
    "n_obs": 200,
    "n_cls": 3,
    "n_rep": 5,
    "y": observables,
    "mu_prior": [-1.5, 0, 1.5],
    "mixture_prior": [2, 2, 2]
}

In [None]:
bf_unconstrained_posterior, bf_class_membership = amortizer.sample(df, n_samples=4000)
bf_unconstrained_posterior = bf_unconstrained_posterior[0]
bf_class_membership = bf_class_membership[0]
bf_posterior = constrain_parameters(bf_unconstrained_posterior)

In [None]:
stan_fit = stan_model.sample(stan_df, show_progress=False)
print(stan_fit.diagnose())

In [None]:
stan_posterior = stan_fit.draws_pd(vars = ["p", "mu"])
stan_posterior = np.array(stan_posterior)

In [None]:
# if desired, test how the mixture network performs with the parameter samples from Stan
# df['mixture_inputs']['parameters'] = np.expand_dims(unconstrain_parameters(stan_posterior), axis=0).astype(np.float32)
# bf_class_membership = amortizer.amortized_mixture.sample(df['mixture_inputs'])
# bf_class_membership = bf_class_membership[0]

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=3)
plt.rcParams['figure.figsize'] = [8, 6]

axs = axs.flatten()

bins = [
    np.linspace(np.min(bf_posterior[:,i])-0.1, np.max(bf_posterior[:,i])+0.1, 36) for i in range(6)
]
for i, par in enumerate(constrained_parameter_names):
    axs[i].hist(bf_posterior[:,i], bins=bins[i], alpha = 0.5, density=True, label="BayesFlow")
    axs[i].hist(stan_posterior[:,i], bins=bins[i], alpha = 0.5, density=True, label="Stan")
    axs[i].scatter(true_params[0,i], 0, color="red", label="Truth")
    axs[i].set_title(par)

axs[0].legend()
fig.tight_layout()

In [None]:
stan_class_membership = stan_fit.stan_variables()["class_membership"]

losses = {
    "stan": amortizer.amortized_mixture.loss(np.tile(np.expand_dims(latents, 0), (4000, 1, 1)), stan_class_membership),
    "bf": amortizer.amortized_mixture.loss(np.tile(np.expand_dims(latents, 0), (4000, 1, 1)), bf_class_membership)
}

In [None]:
fig, axs = plt.subplots(3, 2)
plt.rcParams['figure.figsize'] = [8, 10]

axs[0,0].hist(means, bins=30, density=True)

axs[0,0].set_title("Observations")
axs[1,0].set_title("BayesFlow")
axs[2,0].set_title("Stan")

for cls in range(3):
    axs[1,0].plot(means, np.median(bf_class_membership[...,cls], axis=0))
    axs[1,0].fill_between(
        means,
        np.quantile(bf_class_membership[...,cls], q=0.025, axis=0),
        np.quantile(bf_class_membership[...,cls], q=0.975, axis=0),
        alpha=0.5
    )

    axs[2,0].plot(means, np.median(stan_class_membership[...,cls], axis=0),label="P(z={})".format(cls+1))
    axs[2,0].fill_between(
        means,
        np.quantile(stan_class_membership[...,cls], q=0.025, axis=0),
        np.quantile(stan_class_membership[...,cls], q=0.975, axis=0),
        alpha=0.5
    )


for cls in range(3):
    axs[cls,1].set_title("P(z={})".format(cls+1))
    axs[cls,1].plot(means, np.median(bf_class_membership[...,cls], axis=0), label="BayesFlow")
    axs[cls,1].fill_between(
        means,
        np.quantile(bf_class_membership[...,cls], q=0.025, axis=0),
        np.quantile(bf_class_membership[...,cls], q=0.975, axis=0),
        alpha=0.5,
        color="blue"
    )
    
    axs[cls,1].plot(means, np.median(stan_class_membership[...,cls], axis=0), label="Stan")
    axs[cls,1].fill_between(
        means,
        np.quantile(stan_class_membership[...,cls], q=0.025, axis=0),
        np.quantile(stan_class_membership[...,cls], q=0.975, axis=0),
        alpha=0.5,
        color="orange"
    )
axs[-1,0].set_xlabel(r"$\bar{x}$")
axs[-1,1].set_xlabel(r"$\bar{x}$")


axs[-1,0].legend()
axs[-1,1].legend()
fig.tight_layout()
