In [None]:
import os
import torch
import time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm import tqdm
from scipy.io import savemat
from torch.cuda import amp
from spikingjelly.activation_based import neuron, functional, layer, surrogate

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [None]:
Begin_epoch = 0
Max_epoch = 128
Learning_rate = 1e-4
Weight_decay = 0
Momentum = 0.9
Top_k = 5
AMP = True

Dataset_path = '/home/mrc/Datasets/CIFAR10/'
Batch_size = 128
Workers = 8
Targetnum = 10
Timestep = 4

T_train = None
Test_every_iteration = None
Name_suffix = '_BF_step2'
Savemodel_path = './savemodels/'
Record_path = './recorddata/'
if not os.path.exists(Savemodel_path):
    os.mkdir(Savemodel_path)
if not os.path.exists(Record_path):
    os.mkdir(Record_path)

In [None]:
_seed_ = 2023
torch.manual_seed(_seed_)
np.random.seed(_seed_)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
transform_train = transforms.Compose([
    transforms.Pad(4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
 
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

Train_data = datasets.CIFAR10(root=Dataset_path, train=True, download=True, transform=transform_train)
Test_data = datasets.CIFAR10(root=Dataset_path, train=False, download=True, transform=transform_test)

train_data_loader = torch.utils.data.DataLoader(
    dataset=Train_data,
    batch_size=Batch_size,
    shuffle=True,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=True
)

test_data_loader = torch.utils.data.DataLoader(
    dataset=Test_data,
    batch_size=Batch_size,
    shuffle=False,
    num_workers=Workers, 
    pin_memory=True,
    drop_last=False
)

In [None]:
def IFNode(v_threshold=1.0):
    return neuron.IFNode(surrogate_function=surrogate.ATan(), v_threshold=v_threshold, v_reset=0.0, detach_reset=True)

class MyNet(nn.Module):
    def __init__(self, T=4, num_classes=10):
        super().__init__()
        self.T = T
        self.numclasses = num_classes

        self.conv1 = layer.Conv2d(3, 16, kernel_size=3, padding=1, stride=1, bias=False)    # 32 * 32
        self.sn1 = IFNode()                                             

        self.conv2 = layer.Conv2d(16, 32, kernel_size=3, padding=1, stride=2, bias=False)  # 16 * 16
        self.sn2 = IFNode()                                            
        
        self.conv3 = layer.Conv2d(32, 64, kernel_size=3, padding=1, stride=2, bias=False) # 8 * 8
        self.sn3 = IFNode()
        
        self.conv4 = layer.Conv2d(64, 64, kernel_size=3, padding=1, stride=1, bias=False) # 8 * 8
        self.sn4 = IFNode()
        
        self.conv5 = layer.Conv2d(64, 128, kernel_size=3, padding=1, stride=2, bias=False) # 4 * 4
        self.sn5 = IFNode()

        self.conv6 = layer.Conv2d(128, 128, kernel_size=3, padding=1, stride=1, bias=False) # 4 * 4
        self.sn6 = IFNode()

        self.conv7 = layer.Conv2d(128, 256, kernel_size=3, padding=1, stride=1, bias=False) # 4 * 4
        self.sn7 = IFNode()

        self.conv8 = layer.Conv2d(256, 256, kernel_size=3, padding=1, stride=1, bias=False) # 4 * 4
        self.sn8 = IFNode()

        self.linear1 = layer.Linear(4*4*256, 256, bias=False)
        self.sn9 = IFNode()

        self.linear2 = layer.Linear(256, self.numclasses, bias=False)
    
    def forward(self, x):

        x = x.repeat(self.T, 1, 1, 1, 1)

        x = self.sn1(self.conv1(x))
        x = self.sn2(self.conv2(x))
        x = self.sn3(self.conv3(x))
        x = self.sn4(self.conv4(x))
        x = self.sn5(self.conv5(x))
        x = self.sn6(self.conv6(x))
        x = self.sn7(self.conv7(x))
        x = self.sn8(self.conv8(x))
        
        if self.linear1.step_mode == 's':
            x = torch.flatten(x, 1)
        elif self.linear1.step_mode == 'm':
            x = torch.flatten(x, 2)
        x = self.sn9(self.linear1(x))
        x = self.linear2(x)
        
        return x.mean(0)

net = MyNet(num_classes = Targetnum, T = Timestep)

In [None]:
net = nn.DataParallel(net).cuda()
functional.set_step_mode(net, step_mode='m')
functional.set_backend(net, 'cupy', neuron.IFNode)

max_test_acc = 0.
if Begin_epoch!=0:
    net.load_state_dict(torch.load(Savemodel_path + f'epoch{Begin_epoch-1}{Name_suffix}.h5'))
    max_test_acc = np.load(Savemodel_path + f'max_acc{Name_suffix}.npy')
    max_test_acc = max_test_acc.item()

scaler = amp.GradScaler() if AMP else None
Test_top1 = []
Test_topk = []
Test_lossall = []
Epoch_list = []
Iteration_list = []

In [None]:
criterion_test = nn.CrossEntropyLoss()
# criterion_test = nn.MSELoss()

# optimizer = torch.optim.Adam(
#     params = [{'params' : net.parameters(), 'initial_lr': Learning_rate}],
#     lr = Learning_rate, weight_decay=Weight_decay)
optimizer = torch.optim.SGD(
    params = [{'params' : net.parameters(), 'initial_lr': Learning_rate}],
    lr = Learning_rate, momentum=Momentum, weight_decay=Weight_decay)

# lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/Max_epoch), last_epoch=Begin_epoch-1)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Max_epoch, eta_min=0, last_epoch=Begin_epoch-1)

In [None]:
net.load_state_dict(torch.load(Savemodel_path + f'max_acc{Name_suffix}.h5'))
# net.load_state_dict(torch.load(Savemodel_path + f'net_with_BN_Free.h5'))
functional.reset_net(net)

In [None]:
Confusion_Matrix = torch.zeros((Targetnum, Targetnum))
net.eval()
with torch.no_grad():
    for img, label in tqdm(test_data_loader):
        img = img.cuda()
        label = label.cuda()
        out_fr = net(img)
        guess = out_fr.argmax(1)
        for j in range(len(label)):
            Confusion_Matrix[label[j],guess[j]] += 1
        functional.reset_net(net)
acc = Confusion_Matrix.diag()
acc = acc.sum()/Confusion_Matrix.sum()
Confusion_Matrix,acc

In [None]:
from utils import *
from Test_fr import *

In [None]:
param = []
param.append([[32, 32], 1, 1])
param.append([[32, 32], 2, 1])
param.append([[16, 16], 2, 1])
param.append([[8, 8], 1, 1])
param.append([[8, 8], 2, 1])
param.append([[4, 4], 1, 1])
param.append([[4, 4], 1, 1])
param.append([[4, 4], 1, 1])
param.append([])
param.append([])

all_hardware_test(net, param, Test_data, test_data_loader, neuron_type=neuron.IFNode, n_clusters=15, bit=8)