In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import v2
from spikingjelly.activation_based import functional

In [None]:
use_baseline_datasets = False                               # True: use baseline I2E datasets; False: use real-time generated I2E datasets
type = ['baseline', 'augmentation', 'transfer'][0]          # 'baseline': baseline-I; 'augmentation': baseline-II; 'transfer': transfer-I

In [None]:
ori_dataset_path = '/ssd/Datasets/CIFAR10/'
dataset_path_i2e = '/ssd/Datasets/I2E-CIFAR10/'
batch_size = 128
workers = 8

In [None]:
class I2E_NpzFolder(datasets.DatasetFolder):
    def __init__(self, root, loader=None, extensions=['npz'], transform=None, target_transform=None, is_valid_file=None):
        super(I2E_NpzFolder, self).__init__(root, loader, extensions, transform, target_transform, is_valid_file)

    def __getitem__(self, index):
        path, target = self.samples[index]
        target = int(path.split('/')[-2])
        sample = torch.from_numpy(np.load(path)['arr_0']).float()
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target


transform_test = v2.Compose([
    v2.PILToTensor(),
    v2.Resize((128, 128)),
    v2.ToDtype(torch.float32, scale=True),
])


if use_baseline_datasets:
    from Network.MSNet import *
    val_dataset = I2E_NpzFolder(root=dataset_path_i2e + 'val') 
else:
    from Network.MSNet_I2E import *
    val_dataset = datasets.CIFAR10(root=ori_dataset_path, train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)

In [None]:
resume = f'./models/cifar10_{type}.pth.tar'
model = resnet18(num_classes=10, ratio=0.07, shuffle=[4, 5, 6, 7, 0, 1, 2, 3])
functional.set_step_mode(model, step_mode='m')
model = torch.nn.DataParallel(model).cuda()

print(f"=> loading checkpoint '{resume}'")
checkpoint = torch.load(resume)
best_acc1 = checkpoint['best_acc1']
model.load_state_dict(checkpoint['state_dict'])
print(f"=> loaded checkpoint '{resume}' (epoch {checkpoint['epoch']}) best_acc1: {best_acc1:.3f}")

In [None]:
Confusion_Matrix = torch.zeros((10, 10))
model.eval()
with torch.no_grad():
    for img, label in tqdm(val_loader):
        img = img.cuda()
        label = label.cuda()
        out_fr = model(img)
        guess = out_fr.argmax(1)
        for j in range(len(label)):
            Confusion_Matrix[label[j],guess[j]] += 1
        functional.reset_net(model)
acc = Confusion_Matrix.diag()
acc = acc.sum()/Confusion_Matrix.sum()
print(f'Confusion_Matrix = \n{Confusion_Matrix}')
print(f'acc = {acc}')