In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from dynamic_model.dataset import DynamicsModelDataset
from dynamic_model.model import DynamicsModel
from dynamic_model.train import train_model




In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(f"Using Device: {device}")
torch.backends.cudnn.benchmark = True

Using Device: cuda:0


## Load Dataset

In [3]:

def load_dataset(file_path, input_key, output_key, test_size= 0.3, device=device):

    dataset = np.load(file_path, allow_pickle=True)
    x = dataset[()][input_key]
    y = dataset[()][output_key]
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size)
    train_dataset = DynamicsModelDataset(x_train,y_train, device)
    validation_dataset = DynamicsModelDataset(x_test,y_test, device)
    return train_dataset, validation_dataset


dataset_input_key = 'merged_input'
dataset_output_key = 'merged_output'
dataset_file_path = 'tmp/ib-out/ib-samples.npy'




In [4]:

batch_size = 4000
params = {'batch_size': batch_size,
          'shuffle': True,
          'num_workers': 0,
          'drop_last':True}



train_dataset, val_dataset = load_dataset(file_path=dataset_file_path, input_key=dataset_input_key, output_key= dataset_output_key)
train_loader = DataLoader(train_dataset, **params)
val_loader = DataLoader(val_dataset, **params)

In [5]:
num_of_features = train_dataset.get_features_size()
seq_len = train_dataset.get_seq_len()
hidden_size = 32
out_size = train_dataset.get_output_feature_size()

model = DynamicsModel(features=num_of_features,hidden_size=hidden_size,out_size=out_size, batch_size=batch_size, seq_len=seq_len).to(device=device)

In [None]:
train_model(model=model, train_loader=train_loader, test_loader=val_loader, n_epochs=200)

Untrained test
--------


  return F.mse_loss(input, target, reduction=self.reduction)


Test loss: 13199.896267361111

Epoch 0
---------
Train loss: 12872.911797337278
Test loss: 12489.693372938367
Epoch time: epoch_time = 6.753s
Epoch 1
---------
Train loss: 12219.658757858728
Test loss: 11907.002549913195
Epoch time: epoch_time = 6.641s
Epoch 2
---------
Train loss: 11671.171580297707
Test loss: 11382.214626736111
Epoch time: epoch_time = 6.819s
Epoch 3
---------
Train loss: 11156.507298215607
Test loss: 10873.40296766493
Epoch time: epoch_time = 6.570s
Epoch 4
---------
Train loss: 10680.396386140903
Test loss: 10413.439371744791
Epoch time: epoch_time = 6.652s
Epoch 5
---------
Train loss: 10253.113778198964
Test loss: 10014.389797634549
Epoch time: epoch_time = 6.615s
Epoch 6
---------
Train loss: 9860.352151904586
Test loss: 9628.198092990451
Epoch time: epoch_time = 6.712s
Epoch 7
---------
Train loss: 9503.916616586539
Test loss: 9285.29387749566
Epoch time: epoch_time = 6.550s
Epoch 8
---------
Train loss: 9186.511279585799
Test loss: 8976.188917371961
Epoch time