In [1]:
import torch

from dem_sim.datasets import SampleDataset, StepDataset
from dem_sim.utils import periodic_difference

data_dir = '/Users/aronjansen/Documents/grainsData/raw/'
filename = 'simState_path_sampling_5000_graphs_reformatted.hdf5'
sample_dataset = SampleDataset(data_dir + filename)
step_dataset = StepDataset(sample_dataset)

In [2]:
pos_diffs = [
    periodic_difference(sample_dataset[sample].positions[:-1], sample_dataset[sample].positions[1:], 
                        torch.unsqueeze(sample_dataset[sample].domain[1:], dim=1))
    for sample in range(len(sample_dataset))]
vel_diffs = [
    sample_dataset[sample].velocities[:-1][:3] - sample_dataset[sample].velocities[1:][:3] 
    for sample in range(len(sample_dataset))]
vel_ang_diffs = [
    sample_dataset[sample].velocities[:-1][3:] - sample_dataset[sample].velocities[1:][3:]
    for sample in range(len(sample_dataset))]

In [3]:
def get_rmse(diffs):
    # here mean is taken over the nodes, features (i.e. components) and timesteps.
    # the first two we want, the latter we do at the end, so as not to make steps in longer samples count less
    tot_sq_diffs = sum(diff.shape[0] * (diff**2).mean() for diff in diffs)
    mse = tot_sq_diffs / len(step_dataset)
    return mse.sqrt()

pos_rmse = get_rmse(pos_diffs)
vel_rmse = get_rmse(vel_diffs)
vel_ang_rmse = get_rmse(vel_ang_diffs)

In [4]:
pos_rmse, vel_rmse, vel_ang_rmse

(tensor(0.0016), tensor(0.0017), tensor(0.0176))