## Imports

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

from utils.data_utils import (
    Mode,
    get_features,
    get_labels,
    get_mixed_trials_data,
    normalize_train_labels,
    normalize_val_labels,
)

# Fixed the randomness
tf.keras.utils.set_random_seed(33)
tf.config.experimental.enable_op_determinism()

## Read data

In [None]:
N_TRAIN_AGENT = 30000
N_VAL_AGENT = 3000
NUM_TRIAL = 500

train_file = "train_file.csv"
val_file = "val_file.csv"

train_data = pd.read_csv(train_file)
val_data = pd.read_csv(val_file)

## Process data

In [None]:
all_train_features = get_features(train_data, N_TRAIN_AGENT, NUM_TRIAL, mode=mode)
all_val_features = get_features(val_data, N_VAL_AGENT, NUM_TRIAL, mode=mode)

# Padding trials if necessary
target_trial = 500
all_trials = [target_trial]
train_features = get_mixed_trials_data(all_train_features, all_trials, mode=mode)
val_features = get_mixed_trials_data(all_val_features, all_trials, mode=mode)

# Process labels
train_name_to_labels = get_labels(train_data, mode)
normalized_train_labels, name_to_scaler = normalize_train_labels(train_name_to_labels)

val_name_to_labels = get_labels(val_data, mode)
normalized_val_labels = normalize_val_labels(val_name_to_labels, name_to_scaler)

print(train_features.shape, len(val_name_to_labels))
output_dim = len(val_name_to_labels)

## Model Training

### GridSearch

In [None]:
from utils.hypetune import bayesian_search
from model import get_gru_model

# (TODO) move to a config file
param_grid = {
    "input_x": train_features.shape[1],
    "input_y": train_features.shape[2],
    "output_dim": output_dim,
    "units": 64 + hp.randint("units", 128),
    "learning_rate": 3e-4,
    "dropout": hp.uniform("dropout", 0.15, 0.25),
    "dropout1": hp.uniform("dropout1", 0.01, 0.1),
    "dropout2": hp.uniform("dropout2", 0.01, 0.05),
    "epochs": 25,
    "batch_size": 256,
}

best_model, best_params = bayesian_search(
    get_gru_model,
    param_grid,
    train_features,
    normalized_train_labels,
    val_features,
    normalized_val_labels,
)

print(f"Found best parameters {best_params}")

### Training Model

In [None]:
from keras.callbacks import EarlyStopping

callbacks = [EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True)]

history = best_model.fit(
    train_features,
    normalized_train_labels,
    epochs=200,
    batch_size=batch_size,
    callbacks=callbacks,
    validation_data=(val_features, normalized_val_labels),
    verbose=2,
)

## Model Evaluation

### Prepare test data

In [None]:
test_data = pd.read_csv("test_file.csv")
all_trials_features = get_features(test_data, num_agents, NUM_TRIAL, mode=mode)
test_name_to_labels = get_labels(test_data, mode)
normalized_test_labels = normalize_val_labels(test_name_to_labels, name_to_scaler)

print(all_test_features.shape, normalized_test_labels.shape)

### Predict parameters

In [None]:
all_prediction = best_model.predict(all_test_features)

### Plot parameter recovery

In [None]:
from utils.data_utils import get_recovered_parameters
from utils.plotting import plot_recovery

all_test_param = get_recovered_parameters(
    name_to_scaler, test_name_to_labels, prediction
)
plot_recovery(all_test_param, "alpha")
plot_recovery(all_test_param, "beta")