In [None]:
import numpy as np
import pandas as pd
import bayesflow as bf
import seaborn as sns
import matplotlib.pyplot as plt
from numba import njit
from sklearn.metrics import r2_score

import sys
sys.path.append("../../src/")
from helpers import truncnorm_better, random_num_obs
from simulator import *

## Constants

In [16]:
TRAIN_NETWORK = True

In [2]:
PARAM_NAMES = [r"$v$", r"$a_0$", r"$\lambda$", r"$\tau$"]

## Generative Model

In [3]:
prior = bf.simulation.Prior(
    prior_fun=sample_cddm_prior,
    param_names=PARAM_NAMES
)

In [4]:
context = bf.simulation.ContextGenerator(
    non_batchable_context_fun=random_num_obs
    )

In [5]:
likelihood = bf.simulation.Simulator(
    simulator_fun=sample_cddm_experiment,
    context_generator=context
)

In [None]:
model = bf.simulation.GenerativeModel(
    prior=prior,
    simulator=likelihood,
    name="hyperbolic_cddm_no_constraint",
)

### Prior Push Forward Checks

In [7]:
example_sim = model(batch_size=10)

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axarr.flat):
    sns.histplot(example_sim["sim_data"][i], color="maroon", alpha=0.75, ax=ax)
    sns.despine(ax=ax)
    ax.set_ylabel("")
    ax.set_yticks([])
    if i > 4:
        ax.set_xlabel("Simulated RTs (seconds)")
f.tight_layout()

### Configurator

In [9]:
PRIOR_MEANS = [2.4, 2.4, 0.4, 0.6]
PRIOR_STDS = [1.8, 1.2, 0.2, 0.4]

In [10]:
def configure_input(forward_dict):
    data = forward_dict["sim_data"]
    
    vec_num_obs = forward_dict["sim_non_batchable_context"] * np.ones((data.shape[0], 1))

    params = forward_dict["prior_draws"].astype(np.float32)

    out_dict = dict(
        parameters=(params - PRIOR_MEANS) / PRIOR_STDS,
        direct_conditions=np.sqrt(vec_num_obs).astype(np.float32),
        summary_conditions=data[:, :, None].astype(np.float32),
    )
    return out_dict

## Neural Approximator

In [11]:
summary_net = bf.networks.SetTransformer(input_dim=1, summary_dim=16)

In [12]:
inference_net = bf.networks.InvertibleNetwork(
    num_params=len(prior.param_names),
    coupling_settings={"dense_args": dict(kernel_regularizer=None), "dropout": False},
)

In [13]:
amortizer = bf.amortizers.AmortizedPosterior(inference_net, summary_net)

In [None]:
trainer = bf.trainers.Trainer(
    generative_model=model,
    amortizer=amortizer,
    configurator=configure_input,
    checkpoint_path=f"../../checkpoints/{model.name}"
)

## Train Model

In [None]:
if TRAIN_NETWORK:
    history = trainer.train_online(
        epochs=50,
        iterations_per_epoch=1000,
        batch_size=32
    )
    f = bf.diagnostics.plot_losses(trainer.loss_history.get_plottable())
else:
    f = bf.diagnostics.plot_losses(trainer.loss_history.get_plottable())

## Validation