In [1]:
import numpy as np
import bayesflow as bf
import seaborn as sns
import matplotlib.pyplot as plt

import sys
sys.path.append("../../src/")
from helpers import CollapsingDDM, CollapsingCDM, NeuralApproximator
from configurations import model_configs

  from tqdm.autonotebook import tqdm


## Hyperbolic No Constraints

In [2]:
model = CollapsingDDM(model_configs['hyperbolic_ddm_no_contraint'])
approximator = NeuralApproximator(model)

INFO:root:Performing 2 pilot runs with the hyperbolic_ddm_no_contraint model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 4)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 500, 1)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.


In [3]:
x = model.generate(32)
x

{'prior_non_batchable_context': None,
 'prior_batchable_context': None,
 'prior_draws': array([[0.51280495, 2.87141887, 0.55983656, 0.3261128 ],
        [1.92104049, 3.62998897, 0.57294957, 0.79213613],
        [0.00432108, 2.21203561, 0.17212376, 0.50059106],
        [0.73075403, 2.97793513, 1.5788366 , 0.23272776],
        [2.70410705, 3.07394869, 0.24913623, 0.6753023 ],
        [1.33933693, 2.40991492, 0.64573787, 0.76050904],
        [0.29221349, 3.87913817, 0.45228386, 0.58908516],
        [2.05755392, 3.8454628 , 1.55184831, 0.53480157],
        [1.63110555, 2.32653641, 0.46367394, 0.19715717],
        [1.37630669, 1.93854151, 0.71584598, 0.15453883],
        [1.76173513, 2.8357104 , 1.65538962, 0.7904198 ],
        [1.56689361, 1.5412705 , 0.21823664, 0.66981107],
        [1.28665239, 1.63729011, 0.43186277, 0.88515835],
        [0.45994901, 3.90166041, 1.90866062, 0.36888064],
        [2.07681422, 3.41660605, 0.7223208 , 0.21301526],
        [2.97958982, 1.76135192, 0.63737016

### Prior Push Forward Checks

In [None]:
%%time
example_sim = model.generate(batch_size=32)

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axarr.flat):
    sns.histplot(example_sim["sim_data"][i, : , 0], color="maroon", alpha=0.75, ax=ax)
    sns.despine(ax=ax)
    ax.set_ylabel("")
    ax.set_yticks([])
    if i > 4:
        ax.set_xlabel("Simulated RTs (seconds)")
f.tight_layout()

### Train Model

In [None]:
history = approximator.run()

### Validation

In [15]:
val_sim = model.generate(1000)
true_params = val_sim["prior_draws"]

In [16]:
val_data = model.configure(val_sim)

In [17]:
post_samples = approximator.amortizer.sample(val_data, n_samples=2000)

In [18]:
pred_params = post_samples * model.prior_stds + model.prior_means

In [None]:
f = bf.diagnostics.plot_recovery(
    pred_params, true_params,
    param_names=model.prior.param_names,
    uncertainty_agg=None
)

## Ndt Constraint

In [None]:
model = CollapsingDDM(model_configs['hyperbolic_ndt_contraint'])
approximator = NeuralApproximator(model)

In [None]:
history = approximator.run()

In [5]:
val_sim = model.generate(1000)
true_params = val_sim["prior_draws"]

In [6]:
val_data = model.configure(val_sim)

In [7]:
post_samples = approximator.amortizer.sample(val_data, n_samples=2000)

In [8]:
pred_params = post_samples * model.prior_stds + model.prior_means

In [None]:
f = bf.diagnostics.plot_recovery(
    pred_params, true_params,
    param_names=model.prior.param_names,
    uncertainty_agg=None
)

## Exponential no constraint

In [None]:
model = CollapsingDDM(model_configs['exponential_no_contraint'])
approximator = NeuralApproximator(model)

### Prior Push Forward Check

In [None]:
%%time
example_sim = model.generate(batch_size=32)

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axarr.flat):
    sns.histplot(example_sim["sim_data"][i, : , 0], color="maroon", alpha=0.75, ax=ax)
    sns.despine(ax=ax)
    ax.set_ylabel("")
    ax.set_yticks([])
    if i > 4:
        ax.set_xlabel("Simulated RTs (seconds)")
f.tight_layout()

### Train Model

In [None]:
history = approximator.run(75)

### Validation

In [7]:
val_sim = model.generate(1000)
true_params = val_sim["prior_draws"]

In [8]:
val_data = model.configure(val_sim)

In [9]:
post_samples = approximator.amortizer.sample(val_data, n_samples=2000)

In [10]:
pred_params = post_samples * model.prior_stds + model.prior_means

In [None]:
f = bf.diagnostics.plot_recovery(
    pred_params, true_params,
    param_names=model.prior.param_names,
    uncertainty_agg=None
)

## Exponential no constraint

In [None]:
model = CollapsingDDM(model_configs['exponential_no_contraint'])
approximator = NeuralApproximator(model)

### Prior Push Forward Check

In [None]:
%%time
example_sim = model.generate(batch_size=32)

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axarr.flat):
    sns.histplot(example_sim["sim_data"][i, : , 0], color="maroon", alpha=0.75, ax=ax)
    sns.despine(ax=ax)
    ax.set_ylabel("")
    ax.set_yticks([])
    if i > 4:
        ax.set_xlabel("Simulated RTs (seconds)")
f.tight_layout()

### Train Model

In [None]:
history = approximator.run(75)

In [None]:
f = bf.diagnostics.plot_losses(history)

### Validation

In [15]:
val_sim = model.generate(1000)
true_params = val_sim["prior_draws"]

In [16]:
val_data = model.configure(val_sim)

In [17]:
post_samples = approximator.amortizer.sample(val_data, n_samples=2000)

In [18]:
pred_params = post_samples * model.prior_stds + model.prior_means

In [None]:
f = bf.diagnostics.plot_recovery(
    pred_params, true_params,
    param_names=model.prior.param_names,
    uncertainty_agg=None
)

## Exponential ndt constraint

In [None]:
model = CollapsingDDM(model_configs['exponential_ndt_contraint'])
approximator = NeuralApproximator(model)

In [None]:
%%time
example_sim = model.generate(batch_size=32)

In [None]:
f, axarr = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axarr.flat):
    sns.histplot(example_sim["sim_data"][i, : , 0], color="maroon", alpha=0.75, ax=ax)
    sns.despine(ax=ax)
    ax.set_ylabel("")
    ax.set_yticks([])
    if i > 4:
        ax.set_xlabel("Simulated RTs (seconds)")
f.tight_layout()

In [None]:
history = approximator.run(100)

In [None]:
f = bf.diagnostics.plot_losses(history)

### Validation

In [None]:
val_sim = model.generate(1000)
true_params = val_sim["prior_draws"]

In [None]:
val_data = model.configure(val_sim)

In [None]:
post_samples = approximator.amortizer.sample(val_data, n_samples=2000)

In [None]:
pred_params = post_samples * model.prior_stds + model.prior_means

In [None]:
f = bf.diagnostics.plot_recovery(
    pred_params, true_params,
    param_names=model.prior.param_names,
    uncertainty_agg=None
)