In [None]:
import time
import tonic
import tonic.transforms as transforms
import numpy as np
from torch.utils.data.sampler import Sampler
from pathlib import Path

### parameters

In [None]:
dataset_name = 'POKERDVS' # name of dataset: POKERDVS -- NMNIST -- NCARS -- DVSGesture
download_dataset = True # downloads the datasets before parsing
first_saccade_only = False # specific for N-MNIST (3 saccades 100ms each)
subsample = 100 # take a sample of the dataset
spatial_histograms = True
K = 10

### load dataset

In [None]:
start_time = time.time()
transform = transforms.Compose([transforms.ToRatecodedFrame(frame_time=5000, merge_polarities=True)])

if dataset_name == 'NCARS': # 304 x 240
    train_set = tonic.datasets.NCARS(save_to='./data', train=True, download=download_dataset, transform=transform)
    test_set = tonic.datasets.NCARS(save_to='./data', train=False, download=download_dataset, transform=transform)
if dataset_name == 'POKERDVS': # 35 x 35
    train_set = tonic.datasets.POKERDVS(save_to='./data', train=True, download=download_dataset, transform=transform)
    test_set = tonic.datasets.POKERDVS(save_to='./data', train=False, download=download_dataset, transform=transform)
elif dataset_name == "DVSGesture": # 128 x 128
    train_set = tonic.datasets.DVSGesture(save_to='./data', train=True, download=download_dataset, transform=transform)
    test_set = tonic.datasets.DVSGesture(save_to='./data', train=False, download=download_dataset, transform=transform)
elif dataset_name == 'NMNIST': # 34 x 34
    train_set = tonic.datasets.NMNIST(save_to='./data/nmnist', train=True, download=download_dataset, first_saccade_only=first_saccade_only, transform=transform)
    test_set = tonic.datasets.NMNIST(save_to='./data/nmnist', train=False, download=download_dataset, first_saccade_only=first_saccade_only, transform=transform)
    
x_index = train_set.ordering.find('x')
y_index = train_set.ordering.find('y')
t_index = train_set.ordering.find('t')

In [None]:
# take a subset
train_index = np.arange(len(train_set))
np.random.shuffle(train_index)

test_index = np.arange(len(test_set))
np.random.shuffle(test_index)

if subsample > 0 and subsample < 100:
    print("Taking %s%% of the dataset" % subsample)
    
    # calculate number of samples we want to take
    train_samples = np.ceil((subsample * len(train_set)) / 100).astype(int)
    test_samples = np.ceil((subsample * len(test_set)) / 100).astype(int)
    
    # choosing indices of the subset
    train_index = train_index[:train_samples]
    test_index = test_index[:test_samples]

In [None]:
# custom sampler for torch dataloader
class custom_sampler(Sampler):
    """Samples elements from a given list of indices.
    
    Arguments:
        indices (list): a list of indices
    """

    def __init__(self, indices):
        self.num_samples = len(indices)
        self.indices = indices
     
    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return self.num_samples

In [None]:
from torch.utils.data import DataLoader
trainloader = DataLoader(train_set, sampler=custom_sampler(train_index), shuffle=False)
testloader = DataLoader(test_set, sampler=custom_sampler(test_index), shuffle=False)