### Import libraries

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import os
import data
import rnn
import warnings
warnings.filterwarnings('ignore')

### Set device for analysis 

In [2]:
# CPU or GPU device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load test data and trained model

In [3]:
# Load test data
ds = data.LabeledDataset(['action', 'alpha', 'beta'],
                         path=os.path.join('data', 'synth_test.csv'))

# Instantiate Data Loader for test data
test_loader = DataLoader(ds, shuffle=False, batch_size=1)

# Instantiate RNN model
model = rnn.GRU(input_size=ds.nactions+1,
                hidden_size=32,
                alpha_embedding_size=5,
                beta_embedding_size=5,
                output_size=ds.nactions,
                dropout=0) 

# Load checkpoint for model inference
cp = torch.load(os.path.join('checkpoint', 'synth_trnn_train.pth'),
                  map_location='cpu')

# Assign model state
model.load_state_dict(cp['model_state'])

# Change to evaluation mode
model.eval()

print(f"Number of epochs in training: {cp['epoch']}")

Number of epochs in training: 10


### Evaluate model on test data

In [4]:
# Initialize loss dictionary
loss = {'action': {'name': 'BCE', 'values': np.zeros(len(test_loader))},
        'alpha': {'name': 'MSE alpha', 'values': np.zeros(len(test_loader))},
        'beta': {'name': 'MSE beta', 'values': np.zeros(len(test_loader))}}

# Evaluate model on test data
for i, (X, y_true) in enumerate(test_loader):
    # Forward pass
    y_action, _, _, y_alpha, y_beta, _, _ = model(X)
    
    # Compute loss
    loss['action']['values'][i] = nn.BCELoss()(y_action[:,:,0]+1e-10, y_true[0][:,:,0]).item()
    loss['alpha']['values'][i] = nn.BCELoss()(y_alpha, y_true[1]).item()
    loss['beta']['values'][i] = nn.MSELoss()(y_beta/10, y_true[2]/10).item()

# Print losses
for l in loss:
    print(f"tRNN {loss[l]['name']} loss: {loss[l]['values'].mean():.5f} +/- {loss[l]['values'].std():.5f}")

tRNN BCE loss: 0.41755 +/- 0.16004
tRNN MSE alpha loss: 0.70647 +/- 0.04917
tRNN MSE beta loss: 0.18445 +/- 0.21062
