In [1]:
import numpy as np
import matplotlib.pyplot as plt

import sys

if '..' not in sys.path:
    sys.path.append('..')

from data import ThreeBalls2DFreeFallDataset

import torch
from torch.utils.data import DataLoader

import reservoirpy as rpy
from reservoirpy.nodes import Input, Reservoir, Ridge, ReLU, ESN
from reservoirpy.observables import mse

from itertools import product

rpy.verbosity(0)
rpy.set_seed(42)

In [2]:
def calculate_test_loss(model, dataloader, dataset, visualize_first_10_trajectories=True):
    test_loss = (np.sum([mse(model.run(X.squeeze().numpy()), y.squeeze().numpy())*len(y) for X, y in dataloader]) / len(dataset)) ** 0.5
    print('Test loss: ', test_loss)

    if visualize_first_10_trajectories:
        _, axs1 = plt.subplots(2, 5, figsize=(20,10))
        _, axs2 = plt.subplots(2, 5, figsize=(20,10))

        plt.title('Y-coordinates over time')
        for X, y in dataloader:
            pred = model.run(X.squeeze().numpy())

            for count in range(10):
                gr = np.insert(y.squeeze().numpy()[count].reshape(1,-1)[0], 0, X.squeeze().numpy()[count][-3:-1])
                pr = np.insert(pred.squeeze()[count].reshape(1,-1)[0], 0, X.squeeze().numpy()[count][-3:-1])

                axs1[int(count >= 5)][count % 5].plot(gr[1::2], label=f'Ground truth')
                axs1[int(count >= 5)][count % 5].plot(pr[1::2], label='Predicted')
                axs1[int(count >= 5)][count % 5].legend()

                axs2[int(count >= 5)][count % 5].plot(gr[0::2], gr[1::2], label=f'Ground truth')
                axs2[int(count >= 5)][count % 5].plot(pr[0::2], pr[1::2], label='Predicted')
                axs2[int(count >= 5)][count % 5].legend()

            break

        plt.show()
    return test_loss

In [None]:
batch_size = 32

train_data, test_data = ThreeBalls2DFreeFallDataset.train_test_split('../data/raw/three-balls-2d-free-fall', test_frac=0.2)

train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)