In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import re
import pandas as pd
import talib

np.set_printoptions(suppress=True)

import sys, os
sys.path.append("../src")

from helpers import build_distance_matrix
from macro_models import batched_gaussian_process
from priors import diffusion_prior, length_scale_prior
from micro_models import dynamic_batch_diffusion, fast_dm_simulate, diffusion_trial
from networks_2 import DynamicGaussianNetwork
from context import generate_design_matrix
from transformations import unscale_z, scale_z

In [2]:
import tensorflow_probability as tfp
tfd = tfp.distributions
tfpl = tfp.layers

from tensorflow.keras.layers import GRU, Dense, LSTM
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.utils import to_categorical

from tqdm.notebook import tqdm
from functools import partial

In [3]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
print(tf.config.list_physical_devices('GPU'))


# physical_devices = tf.config.list_physical_devices('CPU')
# tf.config.set_visible_devices([], 'GPU')
# os.environ["CUDA_VISIBLE_DEVICES"]="-1" 

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [4]:
T = 3200
N_PARAMS = 6
DIST_MAT = build_distance_matrix(T)
AMPLITUDES = [0.15, 0.15, 0.15, 0.15, 0.1, 0.05]

N_SAMPLES  = 2000
BATCH_SIZE = 16
TEST_SIZE = 10
N_CHUNKS = 4000
EPOCHS = 100

Metal device set to: Apple M2

systemMemory: 24.00 GB
maxCacheSize: 8.00 GB



2022-11-07 10:08:40.064278: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-11-07 10:08:40.064724: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [5]:
MACRO_MEAN = [5.0, 5.0, 5.0, 5.0, 5.0, 5.0]
MACRO_STD = [2.8, 2.8, 2.8, 2.8, 2.8, 2.8]

MICRO_MEAN = [1.3, 1.3, 1.3, 1.3, 1.3, 0.3]
MICRO_STD = [1.0, 1.0, 1.0, 1.0, 0.75, 0.25]

In [6]:
EMPIRIC_COLOR = '#1F1F1F'
NEURAL_COLOR = '#852626'
COMPARISON_COLOR = '#133a76'

In [7]:
def generator_fun(batch_size):
    theta0 = diffusion_prior(batch_size, n_cond=N_PARAMS-2)
    eta = length_scale_prior(batch_size, N_PARAMS)
    theta_t = batched_gaussian_process(theta0, DIST_MAT, eta, amplitudes=AMPLITUDES)
    context = generate_design_matrix(batch_size, T)

    rt = dynamic_batch_diffusion(theta_t, context)
    x = np.concatenate((rt, to_categorical(context[:, :, np.newaxis])), axis=-1)

    eta_z = scale_z(eta, MACRO_MEAN, MACRO_STD)
    theta_t_z = scale_z(theta_t, MICRO_MEAN, MICRO_STD)

    return eta_z.astype(np.float32), theta_t_z.astype(np.float32), x.astype(np.float32)

In [8]:
eta_z, theta_t_z, x = generator_fun(10)
print(eta_z.shape)
print(theta_t_z.shape)
print(x.shape)

(10, 6)
(10, 3200, 6)
(10, 3200, 5)


In [None]:
# f, axarr = plt.subplots(10, 6, figsize=(25, 30))
# time = np.arange(1, theta_t.shape[1]+1)
# for j in range(10):
#     for i in range(6):
#         ax = axarr[j, i]
#         ax.plot(time, theta_t[j, :, i], label='True', color='black', linestyle='dashed')
#         sns.despine(ax=ax)
#         ax.legend()
#         ax.grid(alpha=0.3)
# f.tight_layout()

In [None]:
def presimulate_data():
    for n in range(N_CHUNKS):
        if n > 261:
            eta_z, theta_t_z, x = generator_fun(BATCH_SIZE)
            np.save(f'../data/offline_data_new/data/x_{n}.npy', x)
            np.save(f'../data/offline_data_new/parameters/eta_params_{n}.npy', eta_z)
            np.save(f'../data/offline_data_new/parameters/theta_params_{n}.npy', theta_t_z)

In [None]:
# presimulate_data()

In [None]:
class ChunkLoader:
    def __init__(self, path_to_data):
        self.path_to_data = path_to_data
        self.data_list = sorted(os.listdir(os.path.join(path_to_data, 'data')), key=lambda f: int(re.sub('\D', '', f)))
        self.eta_list = sorted(os.listdir(os.path.join(path_to_data, 'parameters/eta')), key=lambda f: int(re.sub('\D', '', f)))
        self.theta_list = sorted(os.listdir(os.path.join(path_to_data, 'parameters/theta')), key=lambda f: int(re.sub('\D', '', f)))
        self.indices = list(range(len(self.data_list)))
        np.random.shuffle(self.indices)
        self.num_batches = len(self.data_list)
        self.current_index = 0
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.current_index < self.num_batches:
            self.current_index += 1
            idx = self.indices[self.current_index -1]
            batch_x = np.load(os.path.join(self.path_to_data, 'data', self.data_list[idx]))
            batch_eta_params = np.load(os.path.join(self.path_to_data, 'parameters/eta', self.eta_list[idx]))
            batch_theta_params = np.load(os.path.join(self.path_to_data, 'parameters/theta', self.theta_list[idx]))
            return batch_eta_params, batch_theta_params, batch_x
        self.indices = list(range(len(self.data_list)))
        np.random.shuffle(self.indices)
        self.current_index = 0
        raise StopIteration

In [None]:
loader = ChunkLoader('../data/offline_data_new')

## Network

In [10]:
network_settings = {
    'embedding_gru_units': 256, #256
    'embedding_lstm_units' : 256, #256
    'dense_pre_args': dict(units=256, activation='selu', kernel_initializer='lecun_normal'), #256
    'dense_micro_args': dict(units=128, activation='selu', kernel_initializer='lecun_normal'), #128
    'dense_macro_args': dict(units=128, activation='selu', kernel_initializer='lecun_normal'), #128
    'macro_lstm_units': 128, #128
    'n_micro_params': 6,
    'n_macro_params': 6
}
network = DynamicGaussianNetwork(network_settings)

In [None]:
steps_per_epoch = 500
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=0.001,
    decay_steps=2500,
    decay_rate=0.8,
    staircase=True
)
optimizer = tf.keras.optimizers.Adam(lr_schedule)

In [None]:
def nll(y_true, y_pred):
    return tf.reduce_mean(-y_pred.log_prob(y_true))

In [None]:
def epoch_trainer(generator, network, optimizer, steps_per_epoch, p_bar):
    losses = []
    for step, batch in enumerate(generator):
        with tf.GradientTape() as tape:
            
            # Simulate from model
            eta_z, theta_t_z, data = batch

            idx = np.random.choice(np.arange(0, 16), size=8, replace=False, p=None)

            # Forward pass
            eta_z_hat, theta_t_z_hat = network(data[idx], eta_z[idx])

            # loss computation
            loss_eta = nll(eta_z[idx], eta_z_hat)
            loss_theta = nll(theta_t_z[idx], theta_t_z_hat)
            total_loss = loss_eta + loss_theta
        
        # One step backprop
        g = tape.gradient(total_loss, network.trainable_variables)
        optimizer.apply_gradients(zip(g, network.trainable_variables))
        losses.append(total_loss.numpy())

        # Update progress bar
        p_bar.set_postfix_str("Ep: {},Step {},Loss.Macro: {:.3f},Loss.Micro: {:.3f},Loss.Avg: {:.3f}"
                              .format(ep, step, loss_eta.numpy(), loss_theta.numpy(), np.mean(losses)))
        p_bar.update(1)
    return losses

In [None]:
losses = []
for ep in range(1, EPOCHS+1):
    with tqdm(total=loader.num_batches, desc=f'Training Epoch {ep}') as p_bar:
        loss_ep = epoch_trainer(loader, network, optimizer, steps_per_epoch, p_bar)
        losses.append(loss_ep)
    network.save_weights('../trained_networks/gp_ddm_3200_joint_offline_factorized_new')

## Validation

In [11]:
network.load_weights('../trained_networks/gp_ddm_3200_joint_offline_factorized_new')

NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files for ../trained_networks/gp_ddm_3200_joint_offline_factorized_new

In [None]:
eta_z_test, theta_t_z_test, x_test = generator_fun(TEST_SIZE)

In [None]:
eta_z_pred, theta_t_z_pred = network.sample_n(x_test, 100)

In [None]:
eta_test = unscale_z(eta_z_test, MACRO_MEAN, MACRO_STD)
theta_t_test = unscale_z(theta_t_z_test, MICRO_MEAN, MICRO_STD)

eta_pred = unscale_z(eta_z_pred, MACRO_MEAN, MACRO_STD)
theta_t_pred = unscale_z(theta_t_z_pred, MICRO_MEAN, MICRO_STD)

In [None]:
theta_t_pred_mean = theta_t_pred.numpy().mean(axis=0)
theta_t_pred_std = theta_t_pred.numpy().std(axis=0)

In [None]:
f, axarr = plt.subplots(TEST_SIZE, 6, figsize=(25, 30))
std_mul = 1
time = np.arange(1, theta_t_test.shape[1]+1)
for j in range(TEST_SIZE):
    for i in range(6):
        ax = axarr[j, i]
        ax.plot(time, theta_t_test[j, :, i], label='True', color='black', linestyle='dashed')
        ax.plot(time, theta_t_pred_mean[j, :, i], label='Pred', lw=2, color='#8c6eb5')
        ax.fill_between(time, 
                        theta_t_pred_mean[j, :, i] + std_mul * theta_t_pred_std[j, :, i], 
                        theta_t_pred_mean[j, :, i] - std_mul * theta_t_pred_std[j, :, i], color='#8c6eb5', alpha=0.3)
        sns.despine(ax=ax)
        ax.legend()
        ax.grid(alpha=0.3)
f.tight_layout()

##  Fit on empirical data (Subject 10)

In [None]:
network.load_weights('../trained_networks/gp_ddm_3200_joint_offline_factorized_new')

In [None]:
# read data
data = pd.read_csv('../data/data_lexical_decision.csv', sep=',', header=0)

person_data = data[data.id == 10]

# negative rts for error responses
person_data.rt.loc[person_data.acc == 0] = -person_data.rt.loc[person_data.acc == 0]

# iterate over subjects
x_nn = np.zeros((1, T, 5))

rt = np.array([person_data.rt])[:, :, np.newaxis]
stim_type = np.array([person_data.stim_type])[:, :, np.newaxis] - 1 
context = to_categorical(stim_type)
x_nn = tf.concat((rt, context), axis=-1)

x_nn.shape

In [None]:
post_eta_z, post_theta_t_z = network.sample_n(x_nn, N_SAMPLES)

In [None]:
post_eta = unscale_z(post_eta_z, MACRO_MEAN, MACRO_STD)
post_theta_t = unscale_z(post_theta_t_z, MICRO_MEAN, MICRO_STD)

In [None]:
# read fast-dm parameter estimates
fast_dm_params = pd.read_csv('../data/parameters_full_ddm_error_coding_cs.lst', encoding='iso-8859-1', header=0, delim_whitespace=True)
fast_dm_params['dataset'] = fast_dm_params['dataset'].str.extract('(\d+)').astype(int)
fast_dm_params = fast_dm_params[['dataset', 'v_1', 'v_2', 'v_3', 'v_4', 'a', 't0', 'sv', 'st0']]
fast_dm_params = fast_dm_params.sort_values('dataset')
fast_dm_params = fast_dm_params.reset_index(drop=True)
fast_dm_params = fast_dm_params.to_numpy()[:, 1:]
fast_dm_params.shape

In [None]:
# predict data with fast_dm for all subjects
context = person_data.context.to_numpy() - 1
pred_rt_fast_dm = fast_dm_simulate(fast_dm_params[9], context)
pred_rt_fast_dm.shape

In [None]:
def pr_check(emp_data, post_theta_t, n_sim, sma_period=5):
    # get experimental context
    context = emp_data.stim_type.values - 1
    # get empirical response times
    emp_rt = np.abs(emp_data.rt.values)
    sma_emp_rt = talib.SMA(emp_rt, timeperiod=sma_period)
    
    # sample from posterior
    idx = np.arange(0, N_SAMPLES-1, N_SAMPLES/n_sim, dtype=np.int32)
    theta = post_theta_t[idx]

    n_obs = emp_rt.shape[0]
    pred_rt = np.zeros((n_sim, n_obs))
    sma_pred_rt = np.zeros((n_sim, n_obs))
    # iterate over number of simulations
    for sim in range(n_sim):
        # Iterate over number of trials
        rt = np.zeros(n_obs)
        for t in range(n_obs):
            # Run diffusion process
            rt[t] = diffusion_trial(theta[sim, t, context[t]], theta[sim, t, 4], theta[sim, t, 5])
        pred_rt[sim] = np.abs(rt)
        sma_pred_rt[sim] = talib.SMA(np.abs(rt), timeperiod=sma_period)

    return pred_rt, sma_pred_rt, emp_rt, sma_emp_rt

In [None]:
# predict data with neural for all subjects
pred_rt_neural = np.zeros((N_SUBS, N_SIM, N_OBS))
sma_pred_rt_neural = np.zeros((N_SUBS, N_SIM, N_OBS))

pred_rt_quantiles = np.zeros((N_SUBS, 2, N_OBS))
pred_rt_medians = np.zeros((N_SUBS, N_OBS))

emp_rt =  np.zeros((N_SUBS, N_OBS))
sma_emp_rt =  np.zeros((N_SUBS, N_OBS))

for sub in range(N_SUBS):
    # predict RTs
    person_data = data[data.id == sub+1]
    pred_rt_neural[sub], sma_pred_rt_neural[sub], emp_rt[sub], sma_emp_rt[sub] = pr_check(person_data, post_theta_t[:, sub, :, :], N_SIM)
    # compute RT quantiles
    pred_rt_quantiles[sub] = np.quantile(sma_pred_rt_neural[sub], [0.025, 0.975], axis=0)
    pred_rt_medians[sub] = np.median(sma_pred_rt_neural[sub], axis=0)
    print("Sub nr. {} is predicted".format(sub+1))

In [None]:
horizon=800
emp_data_horizon = x_nn[:, :N_OBS-horizon, :]

# inference on restircted data
post_eta_z_horizon = np.zeros((N_SAMPLES, N_SUBS, N_OBS-horizon, N_PARAMS))
post_theta_t_z_horizon = np.zeros((N_SAMPLES, N_SUBS, N_OBS-horizon, N_PARAMS))
for i in range(len(ids)):
    post_eta_z_horizon[:, i:i+1, :, :], post_theta_t_z_horizon[:, i:i+1, :, :] = network.sample_n(emp_data_horizon[i:i+1], N_SAMPLES)
    print("Sub nr. {} is fitted".format(i+1))

In [None]:
post_eta_last = unscale_z(post_eta_z_horizon[:, :, -1, :], MACRO_MEAN, MACRO_STD)
post_theta_last = unscale_z(post_theta_t_z_horizon[:, :, -1, :], MICRO_MEANS,  MICRO_STDS)

In [None]:
idx = np.arange(0, N_SAMPLES-1, N_SAMPLES/N_SIM, dtype=np.int32)
post_eta_last_select = post_eta_last[idx]
post_theta_last_select = post_theta_last[idx]
post_eta_last_select.shape

In [None]:
# generate dynamic parameters and simulate RTs
pred_rt_horizon = np.zeros((N_SIM, N_SUBS, horizon, 1))
sma_pred_rt_horizon = np.zeros((N_SIM, N_SUBS, horizon, 1))

context = x_nn[:, :, 1:].argmax(axis=2)[:, T-horizon:]

for sub in range(N_SUBS):
    for i in range(N_SIM):
        pred_theta_t = random_walk(post_theta_last_select[i, sub:sub+1], post_eta_last_select[i, sub:sub+1], horizon)
        pred_rt_horizon[i, sub:sub+1] = np.abs(dynamic_batch_diffusion(pred_theta_t, context[sub:sub+1]).astype(np.float32))
        sma_pred_rt_horizon[i, sub, :, 0] = talib.SMA(pred_rt_horizon[i, sub, :, 0], timeperiod=5)

    print("Sub nr. {} is predicted".format(sub+1))

In [None]:
pred_rt_horizon_medians = np.median(sma_pred_rt_horizon, axis=0)
pred_rt_horizon_quantiles = np.quantile(sma_pred_rt_horizon, [0.025, 0.975], axis=0)

In [None]:
def reorderLegend2(ax=None,order=None,unique=False):
    if ax is None: ax=plt.gca()
    handles, labels = ax.get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0])) # sort both labels and handles by labels
    if order is not None: # Sort according to a given list (not necessarily complete)
        keys=dict(zip(order,range(len(order))))
        labels, handles = zip(*sorted(zip(labels, handles), key=lambda t,keys=keys: keys.get(t[0],np.inf)))
    if unique:  labels, handles= zip(*unique_everseen(zip(labels,handles), key = labels)) # Keep only the first of each handle
    ax.legend(handles, labels,
              fontsize=16, loc='upper right')
    return(handles, labels)

In [None]:
# initialize figure
horizon = 800
for sub in range(N_SUBS):
    f, ax = plt.subplots(1, 2, figsize=(18, 8),
                        gridspec_kw={'width_ratios': [6, 1]})
    axrr = ax.flat
    # plot empiric and predicted response times series
    time = np.arange(N_OBS) 

    axrr[0].plot(time, sma_emp_rt[sub], color=EMPIRIC_COLOR, lw=1.4, alpha=0.8, label='SMA5: Empiric')
    axrr[0].plot(time[:N_OBS-horizon], pred_rt_medians[sub, :N_OBS-horizon], color=NEURAL_COLOR, lw=1.4, label='SMA5: Post. re-simulation median', alpha=0.8)
    axrr[0].plot(time[N_OBS-horizon:], pred_rt_horizon_medians[sub], color="#b35032", lw=1.4, label='SMA5: Multi-horizon predictive median', alpha=0.6)
    axrr[0].fill_between(time[N_OBS-horizon:], pred_rt_horizon_quantiles[0, sub, :, 0], pred_rt_horizon_quantiles[1, sub, :, 0], color="#b35032", linewidth=0, alpha=0.4, label='Multi-horizon predictive 95% CI')
    axrr[0].fill_between(time[:N_OBS-horizon], pred_rt_quantiles[sub, 0, :N_OBS-horizon], pred_rt_quantiles[sub, 1, :N_OBS-horizon], color=NEURAL_COLOR, linewidth=0, alpha=0.5, label='Post. re-simulation 95% CI')
    
    for idx in np.argwhere(person_data.session.diff().values == 1):
        if idx == 800:
            axrr[0].axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.7)
        else:
            axrr[0].axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.7)
    for idx in np.argwhere(person_data.block.diff().values == 1):
        if idx == 100:
            axrr[0].axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)
        else:
            axrr[0].axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)
    sns.despine(ax=axrr[0])
    axrr[0].set_ylabel('RT(s)', fontsize=18, rotation=0, labelpad=40)
    axrr[0].set_xlabel('\nTrial', fontsize=18)
    axrr[0].tick_params(axis='both', which='major', labelsize=16)

    f.legend(fontsize=16, loc='center', 
            bbox_to_anchor=(0.5, -0.05), ncol=3)

    axrr[0].grid(False)
    axrr[0].set_xticks([1, 800, 1600, 2400, 3200])

    # plot empiric and predicted response time dist
    plt.setp(ax, ylim=(0, 1.5))
    sns.histplot(y=np.abs(emp_rt[sub]), fill=EMPIRIC_COLOR, color=EMPIRIC_COLOR, alpha=0.8, label="Empiric", ax=axrr[1], stat="density", bins=250, linewidth=0)#062759
    sns.kdeplot(y=np.abs(pred_rt_fast_dm[sub]), fill= COMPARISON_COLOR, color=COMPARISON_COLOR, alpha=0.3, label="Fast-dm", ax=axrr[1], linewidth=3.5)#598f70
    sns.kdeplot(y=pred_rt_neural[sub].flatten(), fill=NEURAL_COLOR, color=NEURAL_COLOR, alpha=0.3, label="Dynamic DDM", ax=axrr[1], linewidth=3.5)

    axrr[1].legend(fontsize=16)
    axrr[1].set_xlabel('', fontsize=18)
    axrr[1].tick_params(axis='both', which='major', labelsize=16)
    axrr[1].set_yticklabels('')
    axrr[1].set_xticklabels('')
    axrr[1].xaxis.set_ticks([])
    axrr[1].yaxis.set_ticks([])
    axrr[1].get_xaxis().set_visible(False)
    for line in axrr[1].get_lines():
        line.set_alpha(1)
    sns.despine(ax=axrr[1], bottom=True)

    axrr[0].annotate('Re-simulation',
                xy=(0.38, 1), xytext=(0, 20),
                xycoords=('axes fraction', 'figure fraction'),
                textcoords='offset points',
                size=20, ha='center', va='top', weight="bold")

    axrr[0].annotate('Prediction',
                xy=(0.84, 1), xytext=(0, 20),
                xycoords=('axes fraction', 'figure fraction'),
                textcoords='offset points',
                size=20, ha='center', va='top', weight="bold")

    axrr[0].tick_params(length=8)

    plt.subplots_adjust(wspace = 0.05)
    f.tight_layout()
    f.savefig("../plots/rt_time_series_sub_{}.pdf".format(sub+1), dpi=300, bbox_inches='tight')

In [None]:
def plot_dynamic_posteriors(dynamic_posterior, fast_dm_params, par_labels, par_names, 
                            ground_truths=None):
    """
    Inspects the dynamic posterior given a single data set. Assumes six dynamic paramters.
    """
        
    means = dynamic_posterior.mean(axis=0)
    # quantiles = np.quantile(dynamic_posterior, [0.025, 0.975], axis=0)
    stds = dynamic_posterior.std(axis=0)
    
    post_max = np.array(means).max(axis=0)
    post_min = np.array(means).min(axis=0)
    upper_y_ax = post_max + [1, 1, 1, 1, 0.2, 0.05]
    lower_y_ax = post_min - [1, 1, 1, 1, 0.2, 0.05]

    time = np.arange(x_nn.shape[1])
    f, axarr = plt.subplots(2, 3, figsize=(18, 8))
    for i, ax in enumerate(axarr.flat):
        ci_upper = means[:, i] + stds[:, i]
        ci_lower = means[:, i] - stds[:, i]
        ax.plot(time, means[:, i], color=NEURAL_COLOR, label='Post. mean')
        ax.fill_between(time, ci_upper, ci_lower, color=NEURAL_COLOR, alpha=0.6, linewidth=0, label='Post. std. deviation')

        if ground_truths is not None:
            ax.plot(time, ground_truths[:, i], color='black', linestyle='dashed', label='True Dynamic', lw=2)
        sns.despine(ax=ax)

        # ax.set_xlabel('Trial', fontsize=18)
        # ax.set_ylabel('Parameter value ({})'.format(par_names[i]), fontsize=18)
        if i == 0:
            ax.set_xlabel('Trial', fontsize=18)
            ax.set_ylabel("Parameter value", fontsize=18)

        ax.set_title(par_labels[i] + ' ({})'.format(par_names[i]), fontsize=20)
        ax.set_xticks([1, 800, 1600, 2400, 3200])
        ax.tick_params(axis='both', which='major', labelsize=16)

        ax.set_ylim(lower_y_ax[i], upper_y_ax[i])

        ax.grid(False)

        # vertical bars
        for idx in np.arange(799, 2400, 800):
            if idx == 799:
                ax.axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.5)
            else:
                ax.axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.5)
        for idx in np.arange(99, 3100, 100):
            if idx == 99:
                ax.axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)
            else:
                ax.axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)


        # horizontal fast-dm params
        if i <= 3:
            ax.plot(time, np.repeat(fast_dm_params[i], x_nn.shape[1]), color=COMPARISON_COLOR, alpha=1, label='Fast-dm estimate', lw=2.5)
            ax.fill_between(time, fast_dm_params[i] - fast_dm_params[6], fast_dm_params[i] + fast_dm_params[6], color=COMPARISON_COLOR, alpha=0.3, linewidth=0, label='Fast-dm inter-trial variability')
        elif i == 4:
            ax.plot(time, np.repeat(fast_dm_params[i], x_nn.shape[1]), color=COMPARISON_COLOR, alpha=1, label='Fast-dm estimate', lw=2.5)
        else:
            ax.plot(time, np.repeat(fast_dm_params[i], x_nn.shape[1]), color=COMPARISON_COLOR, alpha=1, label='Fast-dm estimate', lw=2.5)
            ax.fill_between(time, fast_dm_params[i] - fast_dm_params[7]/2, fast_dm_params[i] + fast_dm_params[7]/2, color=COMPARISON_COLOR, alpha=0.3, linewidth=0, label='Fast-dm inter-trial variability')


        f.subplots_adjust(hspace=0.5)
        if i == 0:
            f.legend(fontsize=16, loc='center', 
                     bbox_to_anchor=(0.5, -0.05), ncol=4)

    f.tight_layout()
    f.savefig("../plots/param_dynamic_sub_{}.pdf".format(sub+1), dpi=300, bbox_inches="tight")

In [None]:
for sub in range(N_SUBS):
    plot_dynamic_posteriors(post_theta_t[:, sub, :, :], fast_dm_params[sub], PARAM_LABELS, PARAM_NAMES)
    print("Sub {} is finished".format(sub+1))

In [None]:
# compute means and stds for neural and fast-dm parameters
neural_means = post_theta_t.mean(axis=0).mean(axis=0)
neural_stds = post_theta_t.mean(axis=0).std(axis=0)

fast_dm_means = fast_dm_params.mean(axis=0)
fast_dm_sd = fast_dm_params.std(axis=0)

post_max = np.array(neural_means).max(axis=0).max()
upper_y_ax = post_max + 1

sigma_factors = [1]
alphas = [0.6]

time = np.arange(N_OBS)
f, axarr = plt.subplots(2, 3, figsize=(18, 8))
for i, ax in enumerate(axarr.flat):
    ax.plot(time, neural_means[:, i], color=NEURAL_COLOR, label='Average post. mean')
    for sigma_factor, alpha in zip(sigma_factors, alphas):
        ci_upper = neural_means[:, i] + sigma_factor * neural_stds[:, i]
        ci_lower = neural_means[:, i] - sigma_factor * neural_stds[:, i]
        ax.fill_between(time, ci_upper, ci_lower, color=NEURAL_COLOR, alpha=alpha, linewidth=0, label='Std. deviation post. mean')
    sns.despine(ax=ax)

    if i == 0:
        ax.set_xlabel('Trial', fontsize=18)
        ax.set_ylabel("Parameter value", fontsize=18)

    ax.set_title(PARAM_LABELS[i] + ' ({})'.format(PARAM_NAMES[i]), fontsize=20)
    ax.set_xticks([1, 800, 1600, 2400, 3200])
    ax.tick_params(axis='both', which='major', labelsize=16)
    # if i < 4:
    #     ax.set_ylim(0, upper_y_ax)
    # else:
    #     ax.set_ylim(0)
    ax.grid(False)

    # vertical bars
    for idx in np.arange(799, 2400, 800):
        if idx == 799:
            ax.axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.5)
        else:
            ax.axvline(idx, color='black', linestyle='solid', lw=1.5, alpha=0.5)
    for idx in np.arange(99, 3100, 100):
        if idx == 99:
            ax.axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)
        else:
            ax.axvline(idx, color='black', linestyle='dotted', lw=1.5, alpha=0.4)

    # horizontal fast-dm params
    ax.plot(time, np.repeat(fast_dm_means[i], x_nn.shape[1]), color=COMPARISON_COLOR, alpha=1, label='Average Fast-dm estimate', lw=2.5)
    ax.fill_between(time, fast_dm_means[i] - fast_dm_sd[i], fast_dm_means[i] + fast_dm_sd[i], color=COMPARISON_COLOR, alpha=0.3, linewidth=0, label='Std. deviation Fast-dm estimate')

    f.subplots_adjust(hspace=0.5)
    if i == 0:
        f.legend(fontsize=16, loc='center', 
                    bbox_to_anchor=(0.5, -0.05),fancybox=False, shadow=False, ncol=4)

f.tight_layout()