# Dataset Fusion

## ImageNet and ES-ImageNet
Output.shape:   
Frame = [2, 224, 224], 
Fig = [3, 224, 224]

Targetnum = 1000

In [None]:
import torch
from es_imagenet import ImageNet_Fusion
from torchvision import transforms

Data_dvs_path = '/home/mrc/Datasets/DVS_ImageNet/'
Data_fig_path = '/home/mrc/Datasets/ImageNet/'
Batch_size = 128
Workers = 12
Targetnum = 1000
Timestep = 8


transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
    # Lighting(0.1),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([                      
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

Data_dvs_path = Data_dvs_path + '/extract/ES-imagenet-0.18/'
Train_data = ImageNet_Fusion(train=True, fig=True, frame=True, data_dvs_path=Data_dvs_path, data_fig_path=Data_fig_path, transform=transform_train)
Test_data = ImageNet_Fusion(train=False, fig=True, frame=True, data_dvs_path=Data_dvs_path, data_fig_path=Data_fig_path, transform=transform_test)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

## Fusion Image Display

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from spikingjelly.datasets import play_frame

data = Test_data
random_integer = np.random.randint(0, len(data))

# imgtype = 10
# while data[random_integer][2] != imgtype:
#     random_integer = np.random.randint(0, len(data))

img_frame = data[random_integer][0]
img_fig = data[random_integer][1].permute(1, 2, 0).numpy()
# plt.imsave('test.png', img_fig)

play_frame(img_frame, save_gif_to='test.gif')
print(f'len(data) = {len(data)}')
print(f'Display number = {random_integer}\n')
print(f'Frame data:\n')
print(f'img_frame.shape = {img_frame.shape}, img_frame.type = {data[random_integer][2]}')
print(f'img_frame.max() = {img_frame.max()}, img_frame.min() = {img_frame.min()}\n')

for i in range(img_frame.shape[1]):
    print(f'img_frame[{i}].mean() = {img_frame[:,i].mean()}, img[{i}].var() = {img_frame[:,i].var()}')

plt.imshow(img_fig)
plt.plot()
print(f'\nFig data:\n')
print(f'img_fig.shape = {img_fig.shape}, img_fig.type = {data[random_integer][2]}')
print(f'img_fig.max() = {img_fig.max()}, img_fig.min() = {img_fig.min()}\n')

for i in range(img_fig.shape[2]):
    print(f'img_fig[{i}].mean() = {img_fig[:,:,i].mean()}, img_fig[{i}].var() = {img_fig[:,:,i].var()}')

## Speed Test

In [None]:
import time
from tqdm import tqdm

data_loader = train_data_loader

start_time = time.time()
for i, (img_frame, img_fig, label) in enumerate(tqdm(data_loader)):
    continue
print(f'img_frame.shape = {img_frame.shape}')
print(f'img_fig.shape = {img_fig.shape}')
print(f'Time used: {time.time() - start_time:.5f} s')