In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
sys.path.append('../')

import torch
from tqdm import tqdm
from Network.MSNet import *
from spikingjelly.activation_based import functional
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets import split_to_train_test_set

In [None]:
type = ['baseline', 'transferI', 'transferII'][0]             # 'baseline': baseline; 'transferI': transfer-I; 'transferII': transfer-II

In [None]:
dataset_path = '/ssd/Datasets/DVS_CIFAR10/'
batch_size = 16
workers = 8
timestep = 8

In [None]:
Dts_cache = dataset_path + 'dts_cache/'
train_set_pth = os.path.join(Dts_cache, f'train_set_{timestep}.pt')
test_set_pth = os.path.join(Dts_cache, f'test_set_{timestep}.pt')

if os.path.exists(train_set_pth) and os.path.exists(test_set_pth):
    train_dataset = torch.load(train_set_pth, weights_only=False)
    val_dataset = torch.load(test_set_pth, weights_only=False)
else:
    origin_set = CIFAR10DVS(root=dataset_path, data_type='frame', frames_number=timestep, split_by='number')

    train_dataset, val_dataset = split_to_train_test_set(0.9, origin_set, 10)
    if not os.path.exists(Dts_cache):
        os.makedirs(Dts_cache)
    torch.save(train_dataset, train_set_pth)
    torch.save(val_dataset, test_set_pth)

train_data_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers, 
    pin_memory=True,
    drop_last=True
)

val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=workers,
    pin_memory=True,
    drop_last=False
)

In [None]:
resume = f'./models/dvscifar10_{type}.pth.tar'
model = resnet18(num_classes=10, ratio=0.07, shuffle=[4, 5, 6, 7, 0, 1, 2, 3])
functional.set_step_mode(model, step_mode='m')
model = torch.nn.DataParallel(model).cuda()

print(f"=> loading checkpoint '{resume}'")
checkpoint = torch.load(resume)
best_acc1 = checkpoint['best_acc1']
model.load_state_dict(checkpoint['state_dict'])
print(f"=> loaded checkpoint '{resume}' (epoch {checkpoint['epoch']}) best_acc1: {best_acc1:.3f}")

In [None]:
Confusion_Matrix = torch.zeros((10, 10))
model.eval()
with torch.no_grad():
    for img, label in tqdm(val_loader):
        img = img.cuda()
        label = label.cuda()
        out_fr = model(img)
        guess = out_fr.argmax(1)
        for j in range(len(label)):
            Confusion_Matrix[label[j],guess[j]] += 1
        functional.reset_net(model)
acc = Confusion_Matrix.diag()
acc = acc.sum()/Confusion_Matrix.sum()
print(f'Confusion_Matrix = \n{Confusion_Matrix}')
print(f'acc = {acc}')