In [12]:
import numpy as np
import torch
from torch.utils.data import Dataset

In [53]:
class PackedTrajectoryDataset(Dataset):
    def __init__(self, packed_trajectories, offsets, indices):
        self.packed_trajectories = torch.Tensor(packed_trajectories)
        self.len = len(indices)
        self.n_variables = packed_trajectories.shape[0]
        indices = np.array(indices)
        self.start_offsets = offsets[indices]
        self.end_offsets = offsets[indices+1]
        self.traj_lengths = self.end_offsets - self.start_offsets

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        if isinstance(idx, int):
            idx = [idx]
        n_samples = len(idx)
        traj_lengths = self.traj_lengths[idx]
        max_length = int(traj_lengths.max().item())

        result = torch.zeros(n_samples, self.n_variables, max_length, dtype=self.packed_trajectories.dtype)
        for i in range(n_samples):
            off_idx = idx[i]
            result[i,:,:traj_lengths[i]] = self.packed_trajectories[:,self.start_offsets[off_idx]:self.end_offsets[off_idx]]
        return result

In [4]:
data = np.load('data_packed.npz')
print(data.files)

['data_names', 'label_names', 'lattice_names', 'offsets', 'labels', 'traj_data', 'config_indices']


In [6]:
traj = data['traj_data']
offsets = data['offsets']

In [8]:
data_names = data['data_names']
print(data_names)

['wait_times' 'jump_lengths' 'times' 'sq_disp' 'distinct_sites' 'sites']


In [9]:
traj_input = traj[[2,3,4]]

In [108]:
def check_correctness(n_indices):
    indices = np.random.choice(offsets.shape[0]-1, size=(n_indices,), replace=False)
    gaps = [int(offsets[i+1]-offsets[i]) for i in indices]
    dataset = PackedTrajectoryDataset(traj_input, offsets, indices)
    max_err = 0
    for seq_idx, offset_idx in enumerate(indices):
        d = dataset[seq_idx].numpy().squeeze()
        ref = traj_input[:,offsets[offset_idx]:offsets[offset_idx+1]]
        curr_max_err = abs(d - ref).max()
        if curr_max_err > max_err:
            max_err = curr_max_err
    print(f"maximum discrepancy: {max_err:.1e}")

In [118]:
check_correctness(10000)

maximum discrepancy: 1.0e+00
