### Import libraries

In [1]:
import os
import numpy as np
import pandas as pd 
import qlearner as ql
import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch.nn as nn
import data
import rnn
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import seaborn as sns

### Set theme for plots

In [3]:
# Set seaborn style
rc_defaults = {'figure.titlesize': 24, 'axes.labelsize': 20,
               'xtick.labelsize': 20, 'ytick.labelsize': 20,
               'lines.linewidth': 3}
sns.set_theme(style='ticks', rc=rc_defaults)

### Simulate data

In [2]:
# Number of agents
train_nagents = 2000
test_nagents = 100

# Number of trials per agent
ntrials = 1000

# Number of trials per block
block_size = 100

# Initialize list to store training data
train_data = []

for agent in range(0, train_nagents):
    # Set parameters
    alpha_gen = ql.ParamGeneratorStat(name='alpha', func=np.random.uniform, low=0, high=1, ntrials=ntrials,
                                         max_switch=np.random.choice([2, 3, 4]), prob_to_switch=0.005)
    beta_gen = ql.ParamGeneratorNonStat(name='beta', func=np.random.uniform, low=0, high=10, ntrials=ntrials,
                                        max_switch=np.random.choice([2, 3, 4]), prob_to_switch=0.005)

    # Simulate data for agent
    qlearner = ql.QLearner4Armed(agent, alpha_gen=alpha_gen, beta_gen=beta_gen, ntrials=ntrials, block_size=block_size).simulate()
    
    # Append data to list
    train_data.append(qlearner.format_df(columns=['agent', 'trial', 'block', 'action', 'reward', 'alpha', 'beta', 'alpha_bin', 'beta_bin']))

# Concatenate training data into one dataframe
df_train_data = pd.concat(train_data).reset_index().drop(columns=['index'])

# Save data
fname = os.path.join('data', f'synth_train_4armed.csv')
df_train_data.to_csv(fname, index=False)

# Initialize list to store test data
test_data = []

for agent in range(0, 100):
    # Set parameters
    alpha_gen = ql.ParamGeneratorStat(name='alpha', func=np.random.uniform, low=0, high=1, ntrials=ntrials,
                                         max_switch=np.random.choice([2, 3, 4]), prob_to_switch=0.005)
    beta_gen = ql.ParamGeneratorNonStat(name='beta', func=np.random.uniform, low=0, high=10, ntrials=ntrials,
                                        max_switch=np.random.choice([2, 3, 4]), prob_to_switch=0.005)

    # Simulate data for agent
    qlearner = ql.QLearner4Armed(agent, alpha_gen=alpha_gen, beta_gen=beta_gen,
                                 ntrials=ntrials, block_size=block_size).simulate()

    # Append results to list
    test_data.append(qlearner.format_df(columns=['agent', 'trial', 'action', 'reward', 'alpha', 'beta']))

# Concatenate test data into one dataframe
df_test_data = pd.concat(test_data).reset_index().drop(columns=['index'])

# Save data
fname = os.path.join('data', 'synth_test_4armed.csv')
df_test_data.to_csv(fname, index=False)

# Inspect test data
for param in ['alpha', 'beta']:
    sns.relplot(data=df_test_data[df_test_data['agent'].isin(np.arange(10))],
            kind='line', x='trial', y=param, col='agent')

### Train model

In [None]:
# Create DL dataset
ds = data.LabeledDataset(['action', 'alpha_bin', 'beta_bin', 'alpha', 'beta'],
                         path=os.path.join('data', 'synth_train_4armed.csv'))
# Split dataset
train_ds, val_ds = random_split(ds, [0.8,  0.2])

# CPU or GPU device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate RNN model
model = rnn.GRU(input_size=ds.nactions+1,
                hidden_size=32,
                alpha_embedding_size=ds.nbins_alpha,
                beta_embedding_size=ds.nbins_beta,
                output_size=ds.nactions,
                dropout=0.2) 

# Instantiate Data Loaders for traning and validation data
train_loader = DataLoader(train_ds, shuffle=False, batch_size=1)
val_loader = DataLoader(val_ds, shuffle=False, batch_size=1)

# Train RNN model
model, train_loss, val_loss = rnn.training_loop(model, device, train_loader, val_loader,
                                                'synth_trnn_4armed', nepochs=5) 

### Evaluate model

In [None]:
# Create DL dataset
ds = data.LabeledDataset(['action', 'alpha', 'beta'], df=df_test_data)

# Instantiate Data Loader for test data
test_loader = DataLoader(ds, shuffle=False, batch_size=1)

# Instantiate RNN model
model = rnn.GRU(input_size=ds.nactions+1,
                hidden_size=32,
                alpha_embedding_size=5,
                beta_embedding_size=5,
                output_size=ds.nactions,
                dropout=0) 

# Load checkpoint for model inference
cp = torch.load(os.path.join('checkpoint', 'synth_trnn_4armed_train.pth'),
                  map_location='cpu')
print(f"Number of epochs in training: {cp['epoch']}")

# Assign model state
model.load_state_dict(cp['model_state'])

# Change to evaluation mode
model.eval();

# Initialize loss dictionary
loss = {'action': {'name': 'BCE', 'values': np.zeros(test_nagents)},
        'alpha': {'name': 'MSE alpha', 'values': np.zeros(test_nagents)},
        'beta': {'name': 'MSE beta', 'values': np.zeros(test_nagents)}}

# Initialize model predictions
y_alpha_all = np.zeros((test_nagents, ntrials))
y_beta_all = np.zeros((test_nagents, ntrials))

# Evaluate model on test data
for i, (X, y_true) in enumerate(test_loader):
    # Forward pass
    y_action, _, _, y_alpha, y_beta, _, _ = model(X)
    
    # Compute loss
    # NB: tolerance (1e-10) and diviseion by 10  is based on original implementation
    loss['action']['values'][i] = nn.BCELoss()(y_action[:,:,0]+1e-10, y_true[0][:,:,0]).item()
    loss['alpha']['values'][i] = nn.MSELoss()(y_alpha, y_true[1]).item()
    loss['beta']['values'][i] = nn.MSELoss()(y_beta/10, y_true[2]/10).item()

    # Save predictions for plotting
    y_alpha_all[i, :] = y_alpha.cpu().detach().numpy().flatten()
    y_beta_all[i, :] = y_beta.cpu().detach().numpy().flatten()

# Print losses
for l in loss:
    print(f"tRNN {loss[l]['name']} loss: {loss[l]['values'].mean():.5f} +/- {loss[l]['values'].std():.5f}")

### Plot parameter recovery

In [None]:
# Collect targets and predictions in a dataframe
df_recovery = pd.concat([df_test_data[['agent', 'trial', 'alpha', 'beta']].assign(data='Target'),
                         df_test_data[['agent', 'trial']].assign(data='t-RNN',
                                                                 alpha=y_alpha_all.flatten(),
                                                                 beta=y_beta_all.flatten())])

# Plot parameter recovery for alpha
g = sns.relplot(data=df_recovery[df_recovery['agent'].isin(np.arange(3))],
                kind='line', x='trial', y='alpha', col='agent', hue='data',
                height=1.5, aspect=2.2)
g.set(xlim=(0, ntrials), ylim=(-0.2, 1.4));

# Plot parameter recovery for beta
g = sns.relplot(data=df_recovery[df_recovery['agent'].isin(np.arange(3))],
                kind='line', x='trial', y='beta', col='agent', hue='data',
                height=1.5, aspect=2.2)
g.set(xlim=(0, ntrials), ylim=(-0.1, 10.1));