In [None]:
import numpy as np
import pandas as pd
from numba import njit
from scipy.stats import norm, halfnorm, uniform
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import bayesflow as beef
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Bidirectional
from tensorflow.keras.models import Sequential
from keras.utils import to_categorical
from sklearn.metrics import r2_score

import sys
sys.path.append("../src/")
from priors import sample_eta, sample_random_walk
from model import sample_softmax_rl
from context import generate_context
from configurator import configure_input
from helpers import softmax

In [None]:
%load_ext autoreload
%autoreload 2
# Suppress scientific notation for floats
np.set_printoptions(suppress=True)
# Configure rng
RNG = np.random.default_rng()

In [None]:
# physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
# print(tf.config.list_physical_devices('GPU'))

## Constants

In [None]:
TRAIN_NETWORK = True

In [None]:
# parameters
THETA_NAMES = ("Learning rate", "Learning rate", "Sensitivity")
THETA_LABELS= (r"$\alpha_{\text{selected}}$", r"$\alpha_{\text{unselected}}$", r"$\tau$")
ETA_NAMES = ("Transition scale", "Transition scale", "Transition scale")
ETA_LABELS= (r"$\sigma_{\alpha}$", r"$\sigma_{\alpha}$", r"$\sigma_{\tau}$")

# plotting
FONT_SIZE_1 = 22
FONT_SIZE_2 = 18
FONT_SIZE_3 = 16

import matplotlib
matplotlib.rcParams['font.serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "serif"

# analysis
NUM_VALIDATION_SIM = 400
NUM_SAMPLES = 500
NUM_RESIM = 100

## Data Preparation

In [None]:
data = pd.read_csv("../data/empiric_data.csv")
NUM_SUBJECTS = len(np.unique(data.participant))

## Exemplar Trajectory

In [None]:
eta = sample_eta()
theta = sample_random_walk(eta)
time = np.arange(theta.shape[0])
fig, axarr = plt.subplots(1, 3, figsize=(14, 4))
for i, ax in enumerate(axarr.flat):
    ax.grid(alpha=0.5)
    ax.plot(
        time,
        theta[:, i],
        color='maroon'
    )
    ax.set_title(f'{THETA_NAMES[i]} ({THETA_LABELS[i]})', fontsize=FONT_SIZE_1)
    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
    if i == 0:
        ax.set_ylabel("Parameter value", fontsize=FONT_SIZE_2)
    ax.set_xlabel("Time step", fontsize=FONT_SIZE_2)

sns.despine()
fig.tight_layout()

## Generative Model

### Prior

In [None]:
prior = beef.simulation.TwoLevelPrior(
    hyper_prior_fun=sample_eta,
    local_prior_fun=sample_random_walk,
)

In [None]:
GLOBAL_PRIOR_MEAN = np.array([0.02, 0.02, 0.8])
GLOBAL_PRIOR_STD = np.array([0.01, 0.01, 0.6])
LOCAL_PRIOR_MEAN = np.array([0.4, 0.4, 5.7])
LOCAL_PRIOR_STD = np.array([0.25, 0.25, 7.6])

### Context

In [87]:
context = beef.simulation.ContextGenerator(
    batchable_context_fun=generate_context,
)

### Likelihood

In [88]:
likelihood = beef.simulation.Simulator(
    simulator_fun=sample_softmax_rl,
    context_generator=context,
)

### Simulator

In [89]:
model = beef.simulation.TwoLevelGenerativeModel(
    prior=prior,
    simulator=likelihood,
    name="non-stationary_softmax_rl",
)

INFO:root:Performing 2 pilot runs with the non-stationary_softmax_rl model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 200, 3)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 200)
INFO:root:Shape of hyper_prior_draws batch after 2 pilot simulations: (batch_size = 2, 3)
INFO:root:Shape of local_prior_draws batch after 2 pilot simulations: (batch_size = 2, 200, 3)
INFO:root:No shared_prior_draws provided.
INFO:root:Could not determine shape of simulation batchable context. Type appears to be non-array: <class 'list'>,                                    so make sure your input configurator takes care of that!
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional prior non-batchable context provided.


## Neural Approximator

In [90]:
approximator_settings = {
    "lstm1_hidden_units": 512,
    "lstm2_hidden_units": 256,
    "lstm3_hidden_units": 128,
    "trainer": {
        "max_to_keep": 1,
        "default_lr": 5e-4,
        "memory": False,
    }
}

In [91]:
summary_network = beef.networks.HierarchicalNetwork(
    [
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm1_hidden_units"], return_sequences=True)),
                Bidirectional(LSTM(approximator_settings["lstm2_hidden_units"], return_sequences=True)),
            ]
        ),
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm3_hidden_units"]))
            ]
        )
    ]
)

In [92]:
local_network = beef.amortizers.AmortizedPosterior(
    beef.networks.InvertibleNetwork(
        num_params=3,
        num_coupling_layers=8,
        coupling_settings={
            "dense_args": dict(kernel_regularizer=None),
            "dropout": False,
            "coupling_design": 'interleaved'
        }
    )
)
global_network = beef.amortizers.AmortizedPosterior(
    beef.networks.InvertibleNetwork(
        num_params=3,
        num_coupling_layers=6,
        coupling_settings={
            "dense_args": dict(kernel_regularizer=None),
            "dropout": False,
            "coupling_design": 'interleaved'
        }
    )
)

In [93]:
amortizer = beef.amortizers.TwoLevelAmortizedPosterior(
    local_amortizer=local_network,
    global_amortizer=global_network,
    summary_net=summary_network
)
trainer = beef.trainers.Trainer(
    amortizer=amortizer,
    generative_model=model,
    configurator=configure_input,
    **approximator_settings.get("trainer"),
    checkpoint_path="../checkpoints/three_alt_full_feed"
)

INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.


## Trainer

In [None]:
if TRAIN_NETWORK:
    history = trainer.train_online(
        epochs=200,
        iterations_per_epoch=1000,
        batch_size=32
    )
else:
    history = trainer.loss_history.get_plottable()

In [None]:
loss_plot = beef.diagnostics.plot_losses(history)

## Validation

In [94]:
x = model(1)
_ = configure_input(x)
_['summary_conditions'].shape

(1, 200, 6)