### Import libraries

In [1]:
import os
import torch
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import data
import rnn
import warnings
warnings.filterwarnings('ignore')

### Set device for analysis

In [2]:
# CPU or GPU device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load training data and split into train and validation sets

In [3]:
# Create DL dataset
ds = data.LabeledDataset(['action', 'alpha_bin', 'beta_bin', 'alpha', 'beta'],
                         path=os.path.join('data', 'synth_trainset_20.csv'))
# Split dataset
train_ds, val_ds = random_split(ds, [0.8,  0.2])

### Train and validate RNN model

In [4]:
# Instantiate RNN model
model = rnn.GRU(input_size=ds.nactions+1,
                hidden_size=32,
                alpha_embedding_size=ds.nbins_alpha,
                beta_embedding_size=ds.nbins_beta,
                output_size=ds.nactions,
                dropout=0.2) 

# Instantiate Data Loaders for traning and validation data
train_loader = DataLoader(train_ds, shuffle=False, batch_size=1)
val_loader = DataLoader(val_ds, shuffle=False, batch_size=1)

# Train RNN model
model, train_loss, val_loss = rnn.training_loop(model, device, train_loader, val_loader,
                                                'synth_trnn', nepochs=10) 

loss BCE action: 0.6666927747428417
loss CE alpha: 0.32022438757121563
loss CE beta: 0.32461406849324703
loss MSE alpha: 0.8068663602753077
loss MSE beta: 2.4148902175948024
Step 1, Train Loss 4.533287808677414, Val Loss 5.289989732205868
loss BCE action: 0.6396625638008118
loss CE alpha: 0.32214922085404396
loss CE beta: 0.32170724496245384
loss MSE alpha: 0.7475513247009076
loss MSE beta: 2.3444723011925817
Step 2, Train Loss 4.375542655510799, Val Loss 5.149811643641442
loss BCE action: 0.5989931523799896
loss CE alpha: 0.3224239610135555
loss CE beta: 0.32248274981975555
loss MSE alpha: 0.6801362964033615
loss MSE beta: 2.2954571917653084
Step 3, Train Loss 4.219493351381971, Val Loss 4.971934332192177
loss BCE action: 0.5271183643490076
loss CE alpha: 0.321669640019536
loss CE beta: 0.3252326771616936
loss MSE alpha: 0.6210815255853959
loss MSE beta: 2.231542559340596
Step 4, Train Loss 4.026644766456229, Val Loss 4.804092554375529
loss BCE action: 0.4895580792799592
loss CE alpha