detach_reset 参数的作用是决定在重置神经元状态时是否切断计算图的连接。这对反向传播非常重要。
当 detach_reset=True 时，每次神经元发放脉冲并进行重置时，会从计算图中断开。这意味着重置操作不会影响到之前的计算，这样的设计避免了梯度在每个脉冲发放时被传播回去，从而使得模型更稳定并收敛更快。

In [52]:
import datetime
import os
import time
import torch
from torch.utils.data import DataLoader

import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter
import sys
from torch.cuda import amp
import smodels
import argparse
from spikingjelly.clock_driven import functional
from spikingjelly.datasets import cifar10_dvs
import math
from tqdm.notebook import tqdm  # tqdm 进度条显示,这个更好看

_seed_ = 2020
import random
random.seed(2020)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

import numpy as np
np.random.seed(_seed_)

In [2]:
def split_to_train_test_set(train_ratio: float, origin_dataset: torch.utils.data.Dataset, num_classes: int, random_split: bool = False):
    '''
    :param train_ratio: split the ratio of the origin dataset as the train set
    :type train_ratio: float
    :param origin_dataset: the origin dataset
    :type origin_dataset: torch.utils.data.Dataset
    :param num_classes: total classes number, e.g., ``10`` for the MNIST dataset
    :type num_classes: int
    :param random_split: If ``False``, the front ratio of samples in each classes will
            be included in train set, while the reset will be included in test set.
            If ``True``, this function will split samples in each classes randomly. The randomness is controlled by
            ``numpy.randon.seed``
            使用这个函数就意味着均匀的划分了类别。 下面说的都是针对一个类别
            这儿的意思是 如果为 False, 就按照划分比例顺序索引数据，如果为 True，就完全随机的索引数据，只确保比例正确。
    :type random_split: int
    :return: a tuple ``(train_set, test_set)``
    :rtype: tuple
    '''
    label_idx = []
    for i in range(num_classes):
        label_idx.append([])
    # 经过这个for循环，就能将每个类别归到一个数组里面，且获取到他们在数组中的位置。
    for i, item in enumerate(origin_dataset):
        y = item[1]
        if isinstance(y, np.ndarray) or isinstance(y, torch.Tensor):
            y = y.item()
        label_idx[y].append(i)
    # 如果为True，就随机打乱顺序
    train_idx = []
    test_idx = []
    if random_split:
        for i in range(num_classes):
            np.random.shuffle(label_idx[i])
    #
    for i in range(num_classes):
        pos = math.ceil(label_idx[i].__len__() * train_ratio)
        train_idx.extend(label_idx[i][0: pos])
        test_idx.extend(label_idx[i][pos: label_idx[i].__len__()])

    return torch.utils.data.Subset(origin_dataset, train_idx), torch.utils.data.Subset(origin_dataset, test_idx)

In [37]:
parser = argparse.ArgumentParser(description='Classify DVS128 Gesture')
parser.add_argument('-T', default=20, type=int, help='simulating time-steps')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-b', default=16, type=int, help='batch size')
parser.add_argument('-epochs', default=64, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-data_dir', type=str, default='/raid/wfang/datasets/CIFAR10DVS')
parser.add_argument('-out_dir', type=str,default = r"../save_models",help='root dir for saving logs and checkpoint')

# parser.add_argument('-resume', default=r'E:\mycode\jupyter\0.SNN\test_git2\sew_resnet\save_models',
#                     type=str, help='resume from the checkpoint path')

parser.add_argument('-resume',type=str, help='resume from the checkpoint path')

parser.add_argument('-amp', action='store_true', help='automatic mixed precision training')

parser.add_argument('-opt', type=str, help='use which optimizer. SDG or Adam', default='SGD')
parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR')
parser.add_argument('-step_size', default=32, type=float, help='step_size for StepLR')
parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
parser.add_argument('-T_max', default=64, type=int, help='T_max for CosineAnnealingLR')
parser.add_argument('-model', type=str)
parser.add_argument('-cnf', type=str)
parser.add_argument('-T_train', default=None, type=int)
parser.add_argument('-dts_cache', type=str, default='./dts_cache')
args = parser.parse_args([])
print(args)
args.model = "SEWResNet"
args.cnf = 'ADD'
print(args)

Namespace(T=20, device='cuda:0', b=16, epochs=64, j=4, data_dir='/raid/wfang/datasets/CIFAR10DVS', out_dir='../save_models', resume=None, amp=False, opt='SGD', lr=0.1, momentum=0.9, lr_scheduler='CosALR', step_size=32, gamma=0.1, T_max=64, model=None, cnf=None, T_train=None, dts_cache='./dts_cache')
Namespace(T=20, device='cuda:0', b=16, epochs=64, j=4, data_dir='/raid/wfang/datasets/CIFAR10DVS', out_dir='../save_models', resume=None, amp=False, opt='SGD', lr=0.1, momentum=0.9, lr_scheduler='CosALR', step_size=32, gamma=0.1, T_max=64, model='SEWResNet', cnf='ADD', T_train=None, dts_cache='./dts_cache')


In [16]:
net = smodels.__dict__[args.model](args.cnf)
print(net)
net.to(args.device)

ResNetN(
  (conv): Sequential(
    (0): Sequential(
      (0): SeqToANNContainer(
        (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): LIFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
    )
    (1): SEWBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): SeqToANNContainer(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): LIFNode(
            v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
            (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
          )
        )
        (1): Sequential(
          

ResNetN(
  (conv): Sequential(
    (0): Sequential(
      (0): SeqToANNContainer(
        (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): LIFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
        (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
      )
    )
    (1): SEWBlock(
      (conv): Sequential(
        (0): Sequential(
          (0): SeqToANNContainer(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): LIFNode(
            v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=torch, tau=2.0
            (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
          )
        )
        (1): Sequential(
          

In [19]:
optimizer = None
if args.opt == 'SGD':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
elif args.opt == 'Adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise NotImplementedError(args.opt)

lr_scheduler = None
if args.lr_scheduler == 'StepLR':
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
elif args.lr_scheduler == 'CosALR':
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max)
else:
    raise NotImplementedError(args.lr_scheduler)

In [21]:
# train_set_pth = os.path.join(args.dts_cache, f'train_set_{args.T}.pt')
# test_set_pth = os.path.join(args.dts_cache, f'test_set_{args.T}.pt')
# if os.path.exists(train_set_pth) and os.path.exists(test_set_pth):
#     train_set = torch.load(train_set_pth)
#     test_set = torch.load(test_set_pth)
# else:
#     origin_set = cifar10_dvs.CIFAR10DVS(root=args.data_dir, data_type='frame', frames_number=args.T, split_by='number')

#     train_set, test_set = split_to_train_test_set(0.9, origin_set, 10)
#     if not os.path.exists(args.dts_cache):
#         os.makedirs(args.dts_cache)
#     torch.save(train_set, train_set_pth)
#     torch.save(test_set, test_set_pth)

# train_data_loader = DataLoader(
#     dataset=train_set,
#     batch_size=args.b,
#     shuffle=True,
#     num_workers=args.j,
#     drop_last=True,
#     pin_memory=True)

# test_data_loader = DataLoader(
#     dataset=test_set,
#     batch_size=args.b,
#     shuffle=False,
#     num_workers=args.j,
#     drop_last=False,
#     pin_memory=True)

FileNotFoundError: [WinError 3] 系统找不到指定的路径。: '/raid/wfang/datasets/CIFAR10DVS\\download'

In [24]:
# 数据集获取  CIFAR10DVS  10个类别
def get_cifar10dvs(batch_size = 16,num_workers = 4,T = 20,
                   train_path = r'E:\mycode\jupyter\0.SNN\data\CIFAR10DVS\split\t20\train0.8.pth', 
                   test_path = r'E:\mycode\jupyter\0.SNN\data\CIFAR10DVS\split\t20\test0.2.pth'):
    train_set = torch.load(train_path, weights_only=False)
    test_set = torch.load(test_path, weights_only=False)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True,num_workers=num_workers,pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, drop_last=False,num_workers=num_workers,pin_memory=True)
    return train_loader,test_loader

train_loader,test_loader = get_cifar10dvs(args.b, args.j)

scaler = None
if args.amp:
    scaler = amp.GradScaler()

start_epoch = 0
max_test_acc = 0

In [30]:
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    start_epoch = checkpoint['epoch'] + 1
    max_test_acc = checkpoint['max_test_acc']

In [42]:
out_dir = os.path.join(args.out_dir, f'{args.model}_{args.cnf}_T_{args.T}_T_train_{args.T_train}_{args.opt}_lr_{args.lr}_')
out_dir

'../save_models\\SEWResNet_ADD_T_20_T_train_None_SGD_lr_0.1_'

In [43]:
if args.lr_scheduler == 'CosALR':
    out_dir += f'CosALR_{args.T_max}'
elif args.lr_scheduler == 'StepLR':
    out_dir += f'StepLR_{args.step_size}_{args.gamma}'
else:
    raise NotImplementedError(args.lr_scheduler)

if args.amp:
    out_dir += '_amp'

In [44]:
out_dir

'../save_models\\SEWResNet_ADD_T_20_T_train_None_SGD_lr_0.1_CosALR_64'

In [45]:
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out_dir}.')
else:
    print(out_dir)
    assert args.resume is not None

Mkdir ../save_models\SEWResNet_ADD_T_20_T_train_None_SGD_lr_0.1_CosALR_64.


In [46]:
pt_dir = out_dir + '_pt'
if not os.path.exists(pt_dir):
    os.makedirs(pt_dir)
    print(f'Mkdir {pt_dir}.')
pt_dir

Mkdir ../save_models\SEWResNet_ADD_T_20_T_train_None_SGD_lr_0.1_CosALR_64_pt.


'../save_models\\SEWResNet_ADD_T_20_T_train_None_SGD_lr_0.1_CosALR_64_pt'

In [47]:
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))

writer = SummaryWriter(os.path.join(out_dir, 'logs'), purge_step=start_epoch)

In [53]:
for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    datas = tqdm(iter(train_loader),file=sys.stdout)
    for frame, label in datas:
        optimizer.zero_grad()
        frame = frame.float().to(args.device)

        if args.T_train:
            sec_list = np.random.choice(frame.shape[1], args.T_train, replace=False)
            sec_list.sort()
            frame = frame[:, sec_list]

        label = label.to(args.device)
        if args.amp:
            with amp.autocast():
                out_fr = net(frame)
                loss = F.cross_entropy(out_fr, label)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = net(frame)
            loss = F.cross_entropy(out_fr, label)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() * label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)
    train_loss /= train_samples
    train_acc /= train_samples

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('train_acc', train_acc, epoch)
    lr_scheduler.step()

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for frame, label in test_loader:
            frame = frame.float().to(args.device)
            label = label.to(args.device)
            out_fr = net(frame)
            loss = F.cross_entropy(out_fr, label)

            test_samples += label.numel()
            test_loss += loss.item() * label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)

    test_loss /= test_samples
    test_acc /= test_samples
    writer.add_scalar('test_loss', test_loss, epoch)
    writer.add_scalar('test_acc', test_acc, epoch)

    save_max = False
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        save_max = True

    checkpoint = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'lr_scheduler': lr_scheduler.state_dict(),
        'epoch': epoch,
        'max_test_acc': max_test_acc
    }

    if save_max:
        torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_max.pth'))

    torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_latest.pth'))
    for item in sys.argv:
        print(item, end=' ')
    print('')
    print(args)
    print(out_dir)
    total_time = time.time() - start_time
    print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={total_time}, escape_time={(datetime.datetime.now()+datetime.timedelta(seconds=total_time * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}')

  0%|          | 0/500 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x00000278FEE6CA40>
Traceback (most recent call last):
  File "D:\software\anaconda\envs\my_torch\Lib\site-packages\torch\utils\data\dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "D:\software\anaconda\envs\my_torch\Lib\site-packages\torch\utils\data\dataloader.py", line 1562, in _shutdown_workers
    if self._persistent_workers or self._workers_status[worker_id]:
                                   ^^^^^^^^^^^^^^^^^^^^
AttributeError: '_MultiProcessingDataLoaderIter' object has no attribute '_workers_status'


KeyboardInterrupt: 