In [None]:
import numpy as np
import pandas as pd
from numba import njit
from scipy.stats import norm, halfnorm, uniform
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from functools import partial

# Get rid of annoying tf warning
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import bayesflow as bf
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Bidirectional
from tensorflow.keras.models import Sequential
from keras.utils import to_categorical
from sklearn.metrics import r2_score

import sys
sys.path.append("../src/")
from model import generative_model
from configurator import configure_input

In [None]:
%load_ext autoreload
%autoreload 2
# Suppress scientific notation for floats
np.set_printoptions(suppress=True)
# Configure rng
RNG = np.random.default_rng()

In [None]:
# physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
# print(tf.config.list_physical_devices('GPU'))

## Constants

In [None]:
TRAIN_NETWORK = True

In [None]:
THETA_NAMES = ("Learning rate", "Memory contribution")
THETA_LABELS= (r"$\alpha$", r"$p$")
ETA_NAMES = ("Transition scale", "Transition scale")
ETA_LABELS= (r"$\sigma_{\alpha}$", r"$\sigma_{p}$")
KAPPA_NAMES = ("Memory decay", "Memory capacity")
KAPPA_LABELS= (r"$\phi$", r"$c$")

THETA_PRIOR_MEAN = np.array([0.5, 0.5])
THETA_PRIOR_STD = np.array([0.3, 0.3])
ETA_PRIOR_MEAN = np.round(halfnorm(0, 0.05).mean(), decimals=2)
ETA_PRIOR_STD = np.round(halfnorm(0, 0.05).std(), decimals=2)
KAPPA_PRIOR_MEAN = np.array([0.5, 4.7])
KAPPA_PRIOR_STD = np.array([0.3, 1])


## Parameters

$\alpha \rightarrow$ Learning rate [0, 1] dynamic

$\tau \rightarrow$ Inverse temperature [0, ] fix it at 10

$\phi \rightarrow$ Memory decay [0, 1] static

$w \rightarrow$ Memory contribution = $p*min(1, \frac{C}{n_S})$

$p \rightarrow$ Initial memory weighting [0, 1] dynamic

$C \rightarrow$ Memory capacity [0, ] is usually 5-9 static

$n_S \rightarrow$ Set size in current block

$\gamma \rightarrow$ Perseveration [0, 1] we don't need this at all because alpha is already dynamic

## Context

Column 1: Stimulus [0, 5]

Column 2: Correct response [0, 2]

Column 3: Block id [1, 14]

Columns 4: Set size [3, 6]

## Generative Model

In [None]:
%%time
forward_dict = generative_model(32)
_ = configure_input(forward_dict)

## Neural Approximator

In [None]:
approximator_settings = {
    "lstm1_hidden_units": 512,
    "lstm2_hidden_units": 256,
    "lstm3_hidden_units": 128,
    "trainer": {
        "max_to_keep": 1,
        "default_lr": 5e-4,
        "memory": False,
    },
    "local_amortizer_settings": {
        "num_coupling_layers": 8,
        "coupling_design": 'interleaved'
    },
    "global_amortizer_settings": {
        "num_coupling_layers": 6,
        "coupling_design": 'interleaved'
    },
}

In [None]:
summary_network = bf.networks.HierarchicalNetwork(
    [
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm1_hidden_units"], return_sequences=True)),
                Bidirectional(LSTM(approximator_settings["lstm2_hidden_units"], return_sequences=True)),
            ]
        ),
        Sequential(
            [
                Bidirectional(LSTM(approximator_settings["lstm3_hidden_units"]))
            ]
        )
    ]
)

In [None]:
local_network = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(
        num_params=2,
        **approximator_settings.get("local_amortizer_settings")
    )
)
global_network = bf.amortizers.AmortizedPosterior(
    bf.networks.InvertibleNetwork(
        num_params=2+2,
        **approximator_settings.get("global_amortizer_settings")
    )
)

In [None]:
amortizer = bf.amortizers.TwoLevelAmortizedPosterior(
    local_amortizer=local_network,
    global_amortizer=global_network,
    summary_net=summary_network
)
trainer = bf.trainers.Trainer(
    amortizer=amortizer,
    generative_model=generative_model,
    configurator=configure_input,
    **approximator_settings.get("trainer"),
    checkpoint_path="../checkpoints/ns_rlwm"
)

## Training

In [None]:
%%time
if TRAIN_NETWORK:
    history = trainer.train_online(
        epochs=75, 
        iterations_per_epoch=1000, 
        batch_size=16
    )
else:
    history = trainer.loss_history.get_plottable()

In [None]:
f = bf.diagnostics.plot_losses(history)