In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from dynamic_model.dataset import load_dataset
from dynamic_model.dataset import DynamicsModelDataset
from dynamic_model.model import DynamicsLookAheadModel
from dynamic_model.train import train_model

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(f"Using Device: {device}")
torch.backends.cudnn.benchmark = True

Using Device: cuda:0


## Load Dataset

In [3]:
dataset_input_key = 'merged_input'
dataset_output_key = 'merged_output'
dataset_file_path = 'tmp/ib-out/ib-samples-la.npy'


batch_size = 8000
params = {'batch_size': batch_size,
          'shuffle': True,
          'num_workers': 0,
          'drop_last':True}



train_dataset, val_dataset = load_dataset(file_path=dataset_file_path, input_key=dataset_input_key, output_key= dataset_output_key, dataset_class=DynamicsModelDataset, device=device)
train_loader = DataLoader(train_dataset, **params)
val_loader = DataLoader(val_dataset, **params)

(672000, 39, 9)
(672000, 10, 6)
(288000, 39, 9)
(288000, 10, 6)


## Define model

In [4]:
num_of_features = train_dataset.get_features_size()
seq_len = train_dataset.get_seq_len()
hidden_size = 32
out_size = train_dataset.get_output_feature_size()
look_ahead = train_dataset.get_look_ahead_size()
print(seq_len)
model = DynamicsLookAheadModel(features=num_of_features,hidden_size=hidden_size,out_size=out_size, batch_size=batch_size, seq_len=seq_len, n_layers=1, dropout_p=0.7, look_ahead=look_ahead).to(device=device)

30


In [5]:
train_model(model=model, train_loader=train_loader, test_loader=val_loader, n_epochs=10000, learning_rate=1e-3)

Untrained test
--------
Test loss: 13532.750569661459

Epoch 0
---------
Train loss: 13444.266624813989
Test loss: 13189.661295572916
Epoch time: epoch_time = 5.810s
Epoch 1
---------
Train loss: 13147.892054966518
Test loss: 12914.208468967014
Epoch time: epoch_time = 5.703s
Epoch 2
---------
Train loss: 12875.033238002232
Test loss: 12637.397135416666
Epoch time: epoch_time = 5.646s
Epoch 3
---------
Train loss: 12601.80267624628
Test loss: 12375.182996961805
Epoch time: epoch_time = 5.701s
Epoch 4
---------
Train loss: 12345.951892671132
Test loss: 12125.686469184027
Epoch time: epoch_time = 5.664s
Epoch 5
---------
Train loss: 12100.652704148066
Test loss: 11885.61083984375
Epoch time: epoch_time = 5.679s
Epoch 6
---------
Train loss: 11864.35804966518
Test loss: 11653.949435763889
Epoch time: epoch_time = 5.721s
Epoch 7
---------
Train loss: 11636.098051525298
Test loss: 11430.122992621527
Epoch time: epoch_time = 5.682s
Epoch 8
---------
Train loss: 11415.47705078125
Test loss: 1

KeyboardInterrupt: 