In [None]:
import numpy as np
import pandas as pd
from numba import njit
from scipy.stats import norm, halfnorm, uniform
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import bayesflow as beef
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Bidirectional
from tensorflow.keras.models import Sequential
from keras.utils import to_categorical
from sklearn.metrics import r2_score

import sys
sys.path.append("../src/")
from priors import sample_eta, sample_theta_t
from context import generate_context
from likelihood import sample_softmax_rl
from configurator import configure_input

In [None]:
%load_ext autoreload
%autoreload 2
# Suppress scientific notation for floats
np.set_printoptions(suppress=True)
# Configure rng
RNG = np.random.default_rng()

In [None]:
TRAIN_NETWORK = True

In [None]:
THETA_NAMES = ("Learning rate", "Sensitivity")
THETA_LABELS= (r"$\alpha$", r"$\tau$")

# plotting
FONT_SIZE_1 = 22
FONT_SIZE_2 = 18
FONT_SIZE_3 = 16

import matplotlib
matplotlib.rcParams['font.serif'] = "Palatino"
matplotlib.rcParams['font.family'] = "serif"

In [None]:
eta = sample_eta()
theta = sample_theta_t(eta)
time = np.arange(theta.shape[0])
fig, axarr = plt.subplots(1, 2, figsize=(14, 4))
for i, ax in enumerate(axarr.flat):
    ax.grid(alpha=0.5)
    ax.plot(
        time,
        theta[:, i],
        color='maroon'
    )
    ax.set_title(f'{THETA_NAMES[i]} ({THETA_LABELS[i]})', fontsize=FONT_SIZE_1)
    ax.tick_params(axis='both', which='major', labelsize=FONT_SIZE_3)
    if i == 0:
        ax.set_ylabel("Parameter value", fontsize=FONT_SIZE_2)
    ax.set_xlabel("Time step", fontsize=FONT_SIZE_2)

sns.despine()
fig.tight_layout()

## Generative Model

### Prior

In [None]:
prior = beef.simulation.TwoLevelPrior(
    hyper_prior_fun=sample_eta,
    local_prior_fun=sample_theta_t,
)

### Context

In [None]:
context_gen = beef.simulation.ContextGenerator(
    batchable_context_fun=generate_context,
)

In [None]:
context = context_gen(1)['batchable_context'][0]
eta = sample_eta()
theta = sample_theta_t(eta)

### Likelihood

In [None]:
likelihood = beef.simulation.Simulator(
    simulator_fun=sample_softmax_rl,
    context_generator=context_gen,
)

### Simulator

In [None]:
model = beef.simulation.TwoLevelGenerativeModel(
    prior=prior,
    simulator=likelihood,
    name="non-stationary_rl",
)

## Neural Approximator

In [None]:
approximator_settings = {
    "lstm1_hidden_units": 512,
    "lstm2_hidden_units": 256,
    "lstm3_hidden_units": 128,
    "trainer": {
        "max_to_keep": 1,
        "default_lr": 5e-4,
        "memory": False,
    }
}

In [None]:
summary_network = beef.networks.HierarchicalNetwork(
    [
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm1_hidden_units"], return_sequences=True)),
                Bidirectional(LSTM(approximator_settings["lstm2_hidden_units"], return_sequences=True)),
            ]
        ),
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm3_hidden_units"]))
            ]
        )
    ]
)

In [None]:
local_network = beef.amortizers.AmortizedPosterior(
    beef.networks.InvertibleNetwork(
        num_params=2,
        num_coupling_layers=8,
        coupling_settings={
            "dense_args": dict(kernel_regularizer=None),
            "dropout": False,
            "coupling_design": 'interleaved'
        }
    )
)
global_network = beef.amortizers.AmortizedPosterior(
    beef.networks.InvertibleNetwork(
        num_params=4,
        num_coupling_layers=6,
        coupling_settings={
            "dense_args": dict(kernel_regularizer=None),
            "dropout": False,
            "coupling_design": 'interleaved'
        }
    )
)

In [None]:
amortizer = beef.amortizers.TwoLevelAmortizedPosterior(
    local_amortizer=local_network,
    global_amortizer=global_network,
    summary_net=summary_network
)
trainer = beef.trainers.Trainer(
    amortizer=amortizer,
    generative_model=model,
    configurator=configure_input,
    **approximator_settings.get("trainer"),
    checkpoint_path="../checkpoints/reversal_learning"
)

In [None]:
if TRAIN_NETWORK:
    history = trainer.train_online(
        epochs=200,
        iterations_per_epoch=1000,
        batch_size=32
    )
else:
    history = trainer.loss_history.get_plottable()

In [None]:
loss_plot = beef.diagnostics.plot_losses(history)