In [None]:
import numpy as np
import os, sys
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from pathlib import Path
from numba import njit
import seaborn as sns

sys.path.append(os.path.abspath(os.path.join('../src')))
from varying_drift_diffusion import *
from motion_simulation import *
from accumulators import *
from threshold_dynamics import *

# bayesflow
sys.path.append(os.path.abspath(os.path.join('../../BayesFlow')))
from bayesflow.networks import InvariantNetwork, InvertibleNetwork
from bayesflow.amortizers import SingleModelAmortizer
from bayesflow.trainers import ParameterEstimationTrainer
from bayesflow.diagnostics import *

from tensorflow.keras.layers import Dense, GRU, LSTM, Conv1D, MultiHeadAttention, GlobalAveragePooling1D
from tensorflow.keras.models import Sequential
from tensorflow.python.keras.utils.np_utils import to_categorical

In [None]:
a = 2
c = 1

a_lower, a_upper = linear_collapsing_bound(2, 1)
plt.plot(np.arange(1e4), a_lower, label="lower")
plt.plot(np.arange(1e4), a_upper, label="upper")
plt.legend()
sns.despine()

In [None]:
a = 1.5
c = 5
max_iter = 1e4
t = np.arange(max_iter) * 0.001
a_lower = a*(t / (t + c))
a_upper = a - a*(t / (t + c))

plt.plot(np.arange(4000), a_upper[0:4000], label="upper")
plt.plot(np.arange(4000), a_lower[0:4000], label="lower")
plt.legend()
sns.despine()

In [None]:
a = 0.5
tau = 0.1
max_iter = 1e4
t = np.arange(max_iter) * 0.001
a_upper = (a/2) * np.exp(-tau*t) + a/2
a_lower = -(a/2) * np.exp(-tau*t) + a/2

plt.plot(np.arange(1e4), a_upper, label="upper")
plt.plot(np.arange(1e4), a_lower, label="lower")
plt.legend()
sns.despine()

In [None]:
a = 2
tau = 1
a_lower, a_upper = exponential_collapsing_bound(a, tau)

plt.plot(np.arange(1e4), a_upper, label="upper")
plt.plot(np.arange(1e4), a_lower, label="lower")
plt.legend()
sns.despine()

In [None]:
a = 1
a_0 = 0.1

lambd = 1
k = 2

max_iter = 1e4
t = np.arange(max_iter) * 0.001

a_upper = a - (1 - np.exp(-(t/lambd)**k)) * ((a/2) - a_0)
a_lower = (1 - np.exp(-(t/lambd)**k)) * ((a/2) - a_0)


plt.plot(np.arange(10000), a_upper[0:10000], label="upper")
plt.plot(np.arange(10000), a_lower[0:10000], label="lower")
plt.legend()
sns.despine()

In [None]:
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)

In [None]:
%load_ext autoreload
%autoreload 2
np.set_printoptions(suppress=True)

## Constants

In [None]:
# simulation 
N_SIM = 500
N_OBS = 100

# bayesflow
PARAM_NAMES = ["a", "ndt", "bias", "kappa"]
N_PARAMS = len(PARAM_NAMES)
N_EPOCHS = 50
ITER_PER_EPOCH = 1000
BATCH_SIZE = 32
N_SAMPLES = 2000

## Simulator Test

In [None]:
n_obs = 100
a     = 3.0
ndt   = 0.2
bias  = 0.5
kappa = 5
theta = np.array([a, ndt, bias, kappa])

unique_motions = np.array([-0.725, -0.675, -0.625, -0.575, -0.525, 0.525,  0.575,  0.625,  0.675,  0.725], dtype=np.float32)
amplitude = np.repeat(unique_motions, 10)
motion_set, condition = motion_experiment_manual(1, amplitude, 1)

rt, resp = var_dm_simulator(theta, 1, motion_set)


In [None]:
%time
p_, x_ = var_dm_batch_simulator(32, 100)
x_.shape

## Bayes Flow

In [None]:
class CustomSummary(tf.keras.Model):
    
    def __init__(self, meta_inv, n_out=10):
        super(CustomSummary, self).__init__()
        self.inv = InvariantNetwork(meta_inv)
        self.out = Dense(n_out)
        
    def call(self, x):
        return self.out(self.inv(x)) 

In [None]:
sum_meta = {
    'n_dense_s1': 2,
    'n_dense_s2': 2,
    'n_dense_s3': 2,
    'n_equiv':    2,
    'dense_s1_args': {'activation': 'relu', 'units': 32},
    'dense_s2_args': {'activation': 'relu', 'units': 32},
    'dense_s3_args': {'activation': 'relu', 'units': 32},
}

# invertable inference network
inf_meta = {
    'n_coupling_layers': 4,
    's_args': {
        'units': [128, 128],
        'activation': 'elu',
        'initializer': 'glorot_uniform',
    },
    't_args': {
        'units': [128, 128],
        'activation': 'elu',
        'initializer': 'glorot_uniform',
    },
    'alpha': 1.9,
    'permute': True,
    'use_act_norm': True,
    'n_params': N_PARAMS
}

inference_net = InvertibleNetwork(inf_meta)
summary_net = CustomSummary(sum_meta)
amortizer = SingleModelAmortizer(inference_net, summary_net)

In [None]:
# Learning-rate decay
learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
    0.0005, 1000, 0.99, staircase=True
)

trainer = ParameterEstimationTrainer(
    network=amortizer, 
    generative_model=var_dm_batch_simulator,
    learning_rate=learning_rate,
    checkpoint_path='../src/selected_checkpoints/time_var_dm2',
    clip_value=3,
    max_to_keep=5
)

In [None]:
# Learning-rate decay
# trainer.optimizer = tf.keras.optimizers.Adam(0.00007)

In [None]:
# # %%time
# # # online training
# losses = trainer.train_online(5, ITER_PER_EPOCH, BATCH_SIZE, n_obs=N_OBS)

## Parameter Recovery

In [None]:
# Simulate and amortized inference
p_, x_ = var_dm_batch_simulator(n_sim=N_SIM,n_obs=N_OBS)

In [None]:
samples = amortizer.sample(x_, n_samples=N_SAMPLES)
param_means = samples.mean(axis=1)

In [None]:
# Recovery plot
f = true_vs_estimated(theta_true=p_, theta_est=param_means,
                  param_names=PARAM_NAMES, dpi=300, figsize=(20,6),font_size=16)

## Simulation Based Calibration

In [None]:
# Simulate
n_sbc = 5000
n_post_samples_sbc = 250
params, sim_data = var_dm_batch_simulator(n_sbc, N_OBS)

In [None]:
# Amortized inference
param_samples = np.concatenate([amortizer.sample(x, n_post_samples_sbc)
                                for x in tf.split(sim_data, 10, axis=0)], axis=0)

In [None]:
# Rank-plot
f = plot_sbc(param_samples, params, param_names=PARAM_NAMES, figsize=(24, 8), bins=23)

## Bayesian Eye Chart

In [None]:
# Simulation
true_params, sim_data = var_dm_batch_simulator(N_SIM, N_OBS)

# Amortized inference
param_samples = np.concatenate([amortizer.sample(x, N_SAMPLES)
                                for x in tf.split(sim_data, 10, axis=0)], axis=0)

In [None]:
### Posterior z-score
# Compute posterior means and stds
post_means = param_samples.mean(1)
post_stds = param_samples.std(1)
post_vars = param_samples.var(1)

# Compute posterior z score
post_z_score = (post_means - true_params) / post_stds

### Posterior contraction, i.e., 1 - post_var / prior_var
prior_a = (0.5, 0.1, 0.2, 0.0) # lower bound of uniform prior
prior_b = (3.0, 0.5, 0.8, 10.0) # upper bound of uniform prior

# Compute prior vars analytically
prior_vars = np.array([(b-a)**2/12 for a,b in zip(prior_a, prior_b)])
# prior_vars = np.concatenate((prior_vars[0:2], np.array([0.0025]), prior_vars[2:]))
post_cont = 1 - post_vars / prior_vars

# Plotting time
f, axarr = plt.subplots(1, 4, figsize=(16, 4))
for i, (p, ax) in enumerate(zip(PARAM_NAMES, axarr.flat)):


    ax.scatter(post_cont[:, i], post_z_score[:, i], color='#8f2727', alpha=0.7)
    ax.set_title(p, fontsize=20)
    sns.despine(ax=ax)
    ax.set_xlim([-0.1, 1.05])
    ax.set_ylim([-3.5, 3.5])
    ax.grid(color='black', alpha=0.1)
    ax.set_xlabel('Posterior contraction', fontsize=14)
    if i == 0 or i == 3:
        ax.set_ylabel('Posterior z-score', fontsize=14)
f.tight_layout()

## Posterior Retrodictive Checks

### Empirical Data Preparation

In [None]:
# read data
directory = str(Path().absolute())
path = str(Path(directory).parents[0]) + '/data/single_sub_data.csv'
data = np.loadtxt(open(path, 'rb'), delimiter=",", skiprows=1)

# subset data
data_subset = data[(data[:, 1] == 1) & (data[:, 2] == 1)]

In [None]:
# get one hot encoded amplitude
amplitude = data_subset[:, 4]
condition = get_hot_encoded_amplitude(amplitude)

# prepare data for amortized inference
final_data = np.hstack((np.expand_dims(data_subset[:, 6], axis=1),
                        np.expand_dims(data_subset[:, 5], axis=1), condition))

final_data = np.expand_dims(final_data, axis=0)

final_data.shape

### Amortized Inference

In [None]:
samples = amortizer.sample(final_data, n_samples=N_SAMPLES)
sns.pairplot(pd.DataFrame(samples, columns=PARAM_NAMES))

### Plotting

In [None]:
emp_data = data_subset[:, -3:]
pred_data = var_dm_pp_check(emp_data, samples)

In [None]:
tmp = pred_data[0, :]
quantiles = [0.1, 0.3, 0.5, 0.7, 0.9]
np.quantile(tmp[:, 0], quantiles)

In [None]:
unique_amplitude  = np.round(np.sort(np.unique(amplitude)), 3)
n_sim             = samples.shape[0]
rt_quantiles      = [0.1, 0.3, 0.5, 0.7, 0.9]
pred_rt_quantiles = np.empty((n_sim, 10, len(rt_quantiles)))

for sim in range(n_sim):
        # iterate over amplitudes
        for i in range(len(unique_amplitude)):
            tmp_data = pred_data[sim, (pred_data[sim, :, 2] == unique_amplitude[i]), :]
            pred_rt_quantiles[sim, i] = np.quantile(tmp_data[:, 0], rt_quantiles)

In [None]:
pred_rt_quantiles = np.quantile(pred_rt_quantiles, [0.025, 0.5, 0.975], axis=0)
pred_rt_quantiles.shape

In [None]:
emp_rt_quantiles = np.empty((10, len(rt_quantiles)))
for i in range(len(unique_amplitude)):
    tmp_data = emp_data[(np.round(emp_data[:, 0], 3) == unique_amplitude[i]), 2]
    emp_rt_quantiles[i] = np.quantile(tmp_data, rt_quantiles)


In [None]:
emp_rt_quantiles.shape

In [None]:
for i in range(len(rt_quantiles)):
    plt.plot(range(len(unique_amplitude)), pred_rt_quantiles[1, :, i], label="Predicted Mean", linestyle='dashed')
    plt.fill_between(range(len(unique_amplitude)), pred_rt_quantiles[0, :, i], pred_rt_quantiles[2, :, i],
                     alpha=0.2, label="Predictive Uncertainty")
    plt.plot(range(len(unique_amplitude)), emp_rt_quantiles[:, i], label="Empirical Quantiles",
             linestyle="solid")
plt.xticks(range(len(unique_amplitude)), unique_amplitude, rotation=45)
