## Disables abls warnings

In [None]:
import absl.logging
absl.logging.set_verbosity('error')

## Google Colab Stuff

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
sys.path.append("/content/drive/MyDrive/Thesis")

## Training
> ! This will take a very long time if not parallized in multible notebooks

In [None]:
import jax
from models.disrnn_utils import create_disrnn_train_state, disrnn_value_trainstep
from models.gru_utils import create_gru_train_state, gru_value_trainstep
from models.rnn_utils import train_model
from torch.utils.data import DataLoader
from custom_datasets import custom_collate, from_disk
from itertools import product

In [None]:
SEQ_LENGTH = 1_000
OUT_DIM = 1
IN_DIM = 2

LEARNING_RATE = 0.001
BATCH_SIZE = 10
NUM_EPOCHS = 5
STOP_TRAINING = 3

VALUERNN_HIDDEN_SIZES = [50, 20, 10, 5, 4, 3, 2]
DISRNN_HIDDEN_SIZES = [10]
DISRNN_KL_LOSSES = [0.01, 0.0075, 0.005]

OMISSON_PROBABILITIES = [0, 0.1]

# Value RNNs
for om_prob, hidden_size in product(OMISSON_PROBABILITIES, VALUERNN_HIDDEN_SIZES):
    om_prob_str = str(om_prob).replace('.', '')
    dataset = from_disk("MyStarkweather", f"data/belief_{om_prob_str}")
    train_dataloader = DataLoader(dataset,
                                  batch_size=10,
                                  drop_last=True,
                                  collate_fn=custom_collate)

    print("Training on Hidden Size: ", hidden_size)

    master_key = jax.random.PRNGKey(0)
    state = create_gru_train_state(master_key,
                                learning_rate=LEARNING_RATE,
                                hidden_size=hidden_size,
                                batch_size=BATCH_SIZE,
                                seq_length=SEQ_LENGTH,
                                out_dim=OUT_DIM,
                                in_dim=IN_DIM)

    save_path = f"data/models/belief_models/vrnn_{str(LEARNING_RATE).replace('.', '')}_{hidden_size}_{om_prob_str}"
    trained_model_state, training_metrics = train_model(state,
                                                        train_dataloader,
                                                        train_step_fun=gru_value_trainstep,
                                                        num_epochs=NUM_EPOCHS,
                                                        stop_training=STOP_TRAINING,
                                                        print_every_other=1,
                                                        save_path=save_path)
    del master_key

# DisRNNs
for om_prob, hidden_size, kl_loss in product(OMISSON_PROBABILITIES, DISRNN_HIDDEN_SIZES, DISRNN_KL_LOSSES):
    print(f"Training om_prob: {om_prob}, hidden_size: {hidden_size}, kl_loss: {kl_loss}")
    om_prob_str = str(om_prob).replace('.', '')
    dataset = from_disk("MyStarkweather", f"data/belief_{om_prob_str}")
    train_dataloader = DataLoader(dataset,
                                  batch_size=10,
                                  drop_last=True,
                                  collate_fn=custom_collate)

    master_key = jax.random.PRNGKey(0)
    state = create_disrnn_train_state(master_rng_key=master_key,
                                      learning_rate=LEARNING_RATE,
                                      hidden_size=hidden_size,
                                      batch_size=BATCH_SIZE,
                                      seq_length=SEQ_LENGTH,
                                      in_dim=IN_DIM,
                                      out_dim=OUT_DIM,
                                      update_mlp_shape=[5, 5, 5],
                                      choice_mlp_shape=[2, 2],
                                      kl_loss_factor=kl_loss,
                                      )

    save_path = f"data/models/belief_models/vrnn_{str(LEARNING_RATE).replace('.', '')}_{hidden_size}_{om_prob_str}_{str(kl_loss).replace('.','')}"
    trained_model_state, training_metrics = train_model(state,
                                                        train_dataloader,
                                                        train_step_fun=disrnn_value_trainstep,
                                                        num_epochs=NUM_EPOCHS,
                                                        stop_training=STOP_TRAINING,
                                                        print_every_other=1,
                                                        save_path=save_path)
    del master_key

## Analysis

In [None]:
from belief_utils import decode_available_models
from belief_analyze import multi_value_analyze, calc_mse, plot_mses
model_confs = decode_available_models()
trial_dict = multi_value_analyze(0.0, model_confs)
mses = calc_mse(trial_dict)
plot_mses(mses, model_confs)

In [None]:
from belief_analyze import plot_compare_rpes
from belief_utils import extract_last_vals, calc_rpe_groups

rpegroup_dict = {}
for key, value in trial_dict.items():
    rpegroup_dict[key] = calc_rpe_groups(value)
    
plot_compare_rpes(extract_last_vals(rpegroup_dict))