In [None]:
import torch
import numpy as np
from tqdm import tqdm
from PIL import Image
from torchvision import datasets

In [None]:
dataset_path = '/ssd/Datasets/I2E_ImageNet/'
batch_size = 128
workers = 16

In [None]:
class I2E_NpzFolder(datasets.DatasetFolder):
    def __init__(self, root, loader=None, extensions=['npz'], transform=None, target_transform=None, is_valid_file=None):
        super(I2E_NpzFolder, self).__init__(root, loader, extensions, transform, target_transform, is_valid_file)

    def __getitem__(self, index):

        path, target = self.samples[index]
        sample = torch.from_numpy(np.load(path)['arr_0']).float()
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

In [None]:
train_dataset = I2E_NpzFolder(root=dataset_path + 'train')
val_dataset = I2E_NpzFolder(root=dataset_path + 'val')
print(f'len(train_dataset): {len(train_dataset)}, len(val_dataset): {len(val_dataset)}')

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)

In [None]:
item = 0
all_spike = []
for i, (images, target) in enumerate(tqdm(train_loader)):
    item += images.shape[0]
    all_spike.append(images.mean(dim=[2, 3, 4]))
all_spike = torch.cat(all_spike, 0)

print(all_spike.shape, item)
print(all_spike.mean(0), all_spike.std(0))
print(all_spike.mean(), all_spike.std())

In [None]:
def imgpand(img):                       # [T, 2, H, W] -> [T, 3, H, W]
    img = np.array(img)
    assert len(img.shape) == 4
    zoers = np.zeros((img.shape[0], img.shape[2], img.shape[3]))
    img = np.concatenate([img, np.expand_dims(zoers, axis=1)], axis=1)
    return img

In [None]:
item = np.random.randint(0, len(train_dataset))
img = imgpand(np.array(train_dataset[item][0]))

images=[Image.fromarray(np.array(img[i]*255).transpose(1, 2, 0).astype(np.uint8)) for i in range(8)]
images[0].save('test.gif', save_all=True, append_images=images[1:], duration=200, loop=0)
print(img.mean((0, 2, 3)))