# Neuromorphic Datasets

In [None]:
!pip install spikingjelly

## N-MNIST
Output.shape = [2, 34, 34]

Targetnum = 10

In [None]:
import torch
from spikingjelly.datasets.n_mnist import NMNIST

Dataset_path = '/home/mrc/Datasets/N_MNIST/'
Batch_size = 128
Workers = 4
Targetnum = 10
Timestep = 4


Train_data = NMNIST(root=Dataset_path, train=True, data_type='frame', frames_number=Timestep, split_by='number')
Test_data = NMNIST(root=Dataset_path, train=False, data_type='frame', frames_number=Timestep, split_by='number')

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## DVS-Gesture
Output.shape = [2, 128, 128]

Targetnum = 11

In [None]:
import torch
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture

Dataset_path = '/home/mrc/Datasets/DVS_Gesture/'
Batch_size = 16
Workers = 4
Targetnum = 11
Timestep = 16


Train_data = DVS128Gesture(root=Dataset_path, train=True, data_type='frame', frames_number=Timestep, split_by='number')
Test_data = DVS128Gesture(root=Dataset_path, train=False, data_type='frame', frames_number=Timestep, split_by='number')

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## DVS_CIFAR-10
Output.shape = [2, 128, 128]

Targetnum = 10

In [None]:
import torch
import os
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets import split_to_train_test_set

Dataset_path = '/home/mrc/Datasets/DVS_CIFAR10/'
Batch_size = 128
Workers = 4
Targetnum = 10
Timestep = 4


Dts_cache = Dataset_path + 'dts_cache/'
train_set_pth = os.path.join(Dts_cache, f'train_set_{Timestep}.pt')
test_set_pth = os.path.join(Dts_cache, f'test_set_{Timestep}.pt')

if os.path.exists(train_set_pth) and os.path.exists(test_set_pth):
    Train_data = torch.load(train_set_pth)
    Test_data = torch.load(test_set_pth)
else:
    origin_set = CIFAR10DVS(root=Dataset_path, data_type='frame', frames_number=Timestep, split_by='number')

    Train_data, Test_data = split_to_train_test_set(0.9, origin_set, 10)
    if not os.path.exists(Dts_cache):
        os.makedirs(Dts_cache)
    torch.save(Train_data, train_set_pth)
    torch.save(Test_data, test_set_pth)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## N-Caltech101
Output.shape = [2, 180, 240]

Targetnum = 101

In [None]:
import os
import torch
from spikingjelly.datasets import split_to_train_test_set
from spikingjelly.datasets.n_caltech101 import NCaltech101

Dataset_path = '/home/mrc/Datasets/N_Caltech101/'
Batch_size = 128
Workers = 4
Targetnum = 101
Timestep = 16


Dts_cache = Dataset_path + 'dts_cache/'
train_set_pth = os.path.join(Dts_cache, f'train_set_{Timestep}.pt')
test_set_pth = os.path.join(Dts_cache, f'test_set_{Timestep}.pt')

if os.path.exists(train_set_pth) and os.path.exists(test_set_pth):
    Train_data = torch.load(train_set_pth)
    Test_data = torch.load(test_set_pth)
else:
    origin_set = NCaltech101(root=Dataset_path, data_type='frame', frames_number=Timestep, split_by='number')

    Train_data, Test_data = split_to_train_test_set(0.9, origin_set, 101)
    if not os.path.exists(Dts_cache):
        os.makedirs(Dts_cache)
    torch.save(Train_data, train_set_pth)
    torch.save(Test_data, test_set_pth)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## ES-ImageNet_Spikingjelly
Output.shape = [2, 256, 256]

Targetnum = 1000

In [None]:
import torch
from spikingjelly.datasets.es_imagenet import ESImageNet

Dataset_path = '/ssd/Datasets/DVS_ImageNet/'
Batch_size = 128
Workers = 12
Targetnum = 1000
Timestep = 8

class SpikeCrop(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.len = 16               # (256-224) // 2

    def forward(self, img):
        return img[:,:,self.len:-self.len,self.len:-self.len]

transform_train = SpikeCrop()
transform_test = SpikeCrop()

Train_data = ESImageNet(root=Dataset_path, train=True, data_type='frame', frames_number=Timestep, split_by='number', transform=transform_train)
Test_data = ESImageNet(root=Dataset_path, train=False, data_type='frame', frames_number=Timestep, split_by='number', transform=transform_test)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## ES-ImageNet_Official
Output.shape = [2, 224, 224]

Targetnum = 1000

In [None]:
import torch
from es_imagenet import ES_Imagenet_Official

Dataset_path = '/ssd/Datasets/DVS_ImageNet/'
Batch_size = 128
Workers = 12
Targetnum = 1000
Timestep = 8


Dataset_path = Dataset_path + 'extract/ES-imagenet-0.18/'
Train_data = ES_Imagenet_Official(mode='train', data_set_path=Dataset_path)
Test_data = ES_Imagenet_Official(mode='test', data_set_path=Dataset_path)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## Show DVS-Image

In [None]:
import numpy as np
from spikingjelly.datasets import play_frame

data = Train_data
random_integer = np.random.randint(0, len(data))
# random_integer = 845295

# imgtype = 10
# while data[random_integer][1] != imgtype:
#     random_integer = np.random.randint(0, len(data))

img = data[random_integer][0]
play_frame(img, save_gif_to='test.gif')
print(f'len(data) = {len(data)}')
print(f'Display number = {random_integer}')
print(f'img.shape = {img.shape}, img.type = {data[random_integer][1]}')
print(f'img.max() = {img.max()}, img.min() = {img.min()}\n')

for i in range(img.shape[1]):
    print(f'img[{i}].mean() = {img[:,i].mean()}, img[{i}].var() = {img[:,i].var()}')

## Speed Test

In [None]:
import time
from tqdm import tqdm

data_loader = train_data_loader

start_time = time.time()
for i, (img, label) in enumerate(tqdm(data_loader)):
    continue
print(f'img.shape = {img.shape}')
print(f'Time used: {time.time() - start_time:.5f} s')