In [None]:
# imports
# python core imports
from venv import create
from typing import Callable, Dict, Optional, Tuple
from abc import abstractmethod
import struct
import os
from concurrent.futures import ThreadPoolExecutor
import time
import math
from re import M
import shutil
import argparse
import logging
import sys

# python external modules
import scipy.io
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import pandas as pd
from posixpath import split

# torch imports
import torch
import torch.utils.data
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda import amp
from torchvision.datasets import DatasetFolder, utils
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

# spikingjelly imports
from spikingjelly.clock_driven import functional, surrogate, layer
import spikingjelly.event_driven.neuron as neuron
import spikingjelly.event_driven.encoding as encoding
from datasets.__init__ import *

In [None]:
# config variables
np_savez = np.savez_compressed
_seed_ = 2020
torch.manual_seed(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

# logging setup for production only
# logging.basicConfig(filename=f'logs/runlog.log', level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(message)s')
# logger = logging.getLogger()
# sys.stderr.write = logger.error
# sys.stdout.write = logger.info

In [None]:
# model
class VotingLayer(nn.Module):
    def __init__(self, voter_num: int):
        super().__init__()
        self.voting = nn.AvgPool1d(voter_num, voter_num)
    
    def forward(self, x: torch.Tensor):
        return self.voting(x.unsqueeze(1)).squeeze(1)


# In[7]:


class PythonNet(nn.Module):
    def __init__(self, channels: int):
        super().__init__()
        conv = []
        conv.extend(PythonNet.conv3x3(2, channels))
        conv.append(nn.MaxPool2d(2, 2))
        
        for i in range(4):
            conv.extend(PythonNet.conv3x3(channels, channels))
            conv.append(nn.MaxPool2d(2, 2))
        
        self.conv = nn.Sequential(*conv)
        self.fc = nn.Sequential(
            nn.Flatten(),
            layer.Dropout(0.5),
            nn.Linear(channels * 4 * 4, channels * 2 * 2, bias=False),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True),
            layer.Dropout(0.5),
            nn.Linear(channels * 2 * 2, 20, bias=False),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True)
        )
        
        self.vote = VotingLayer(10)
        
    def forward(self, x: torch.Tensor):
        x = x.permute(1, 0, 2, 3, 4)
        print(x.shape)
        torch.reshape(x, (x.shape[0], x.shape[1], x.shape[2], 128, 128))
        out_spikes = self.vote(self.fc(self.conv(x[0])))
        for t in range(1, x.shape[0]):
            out_spikes += self.vote(self.fc(self.conv(x[t])))
        return out_spikes / x.shape[0]
    
    @staticmethod
    def conv3x3(in_channels: int, out_channels):
        return [
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True)
        ]

In [None]:
# argument parsing
snnparser = argparse.ArgumentParser(description='Classify DVS128 Gesture')
snnparser.add_argument('-T', default=16, type=int, help='simulating time-steps')
snnparser.add_argument('-device', default='cuda:0', help='device')
snnparser.add_argument('-b', default=1, type=int, help='batch size')
snnparser.add_argument('-epochs', default=64, type=int, metavar='N', help='number of total epochs to run')
snnparser.add_argument('-j', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
snnparser.add_argument('-channels', default=128, type=int, help='channels of Conv2d in SNN')
snnparser.add_argument('-data_dir', type=str, default='./', help='root dir')
snnparser.add_argument('-out_dir', type=str, default='./output', help='root dir for saving logs and checkpoint')

snnparser.add_argument('-lr', default=0.001, type=float, help='learning rate')
snnparser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
snnparser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR')
snnparser.add_argument('-step_size', default=32, type=float, help='step_size for StepLR')
snnparser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
snnparser.add_argument('-T_max', default=32, type=int, help='T_max for CosineAnnealingLR')

args = snnparser.parse_args("")

In [None]:
net = PythonNet(channels=args.channels)
net.to(args.device)

optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.T_max)

train_set = FYPDataset(args.data_dir, train=True, data_type='frame', split_by='number', frames_number=args.T)
test_set = FYPDataset(args.data_dir, train=False, data_type='frame', split_by='number', frames_number=args.T)

train_data_loader = DataLoader(
    dataset=train_set,
    batch_size=args.b,
    shuffle=True,
    num_workers=args.j,
    drop_last=True,
    pin_memory=True)

test_data_loader = DataLoader(
    dataset=test_set,
    batch_size=args.b,
    shuffle=False,
    num_workers=args.j,
    drop_last=False,
    pin_memory=True)

scaler = amp.GradScaler()

start_epoch = 0
max_test_acc = 0

out_dir = os.path.join(args.out_dir, f'T_{args.T}_b_{args.b}_c_{args.channels}_SGD_lr_{args.lr}_')
out_dir += f'CosALR_{args.T_max}'
out_dir += '_amp'


if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print(f'Mkdir {out_dir}.')

with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))

writer = SummaryWriter(os.path.join(out_dir, 'dvsg_logs'), purge_step=start_epoch)

In [None]:
if __name__ == '__main__':
    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        net.train()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        for frame, label in train_data_loader:
            optimizer.zero_grad()
            frame = frame.float().to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 11).float()
            # if args.amp:
            with amp.autocast():
                out_fr = net(frame)
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # else:
            #     out_fr = net(frame)
            #     loss = F.mse_loss(out_fr, label_onehot)
            #     loss.backward()
            #     optimizer.step()

            train_samples += label.numel()
            train_loss += loss.item() * label.numel()
            train_acc += (out_fr.argmax(1) == label).float().sum().item()

            functional.reset_net(net)
        train_loss /= train_samples
        train_acc /= train_samples

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        lr_scheduler.step()

        net.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0
        with torch.no_grad():
            for frame, label in test_data_loader:
                frame = frame.float().to(args.device)
                label = label.to(args.device)
                label_onehot = F.one_hot(label, 11).float()
                out_fr = net(frame)
                loss = F.mse_loss(out_fr, label_onehot)

                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (out_fr.argmax(1) == label).float().sum().item()
                functional.reset_net(net)

        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True

        checkpoint = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lr_scheduler': lr_scheduler.state_dict(),
            'epoch': epoch,
            'max_test_acc': max_test_acc
        }

        if save_max:
            torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))

        print(args)
        print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, total_time={time.time() - start_time}')

In [None]:
# trying using event parameter
# neural network parameters
class Net(nn.Module):
    def __init__(self, m, T):
        super().__init__()
        self.tempotron = neuron.Tempotron(480*640*m, 10, T)
    
    def forward(self, x: torch.Tensor):
        return self.tempotron(x, 'v_max')

In [None]:
parser = argparse.ArgumentParser(description='spikingjelly Tempotron MNIST Training')

parser.add_argument('--device', default='cuda:0', help='运行的设备，例如“cpu”或“cuda:0”\n Device, e.g., "cpu" or "cuda:0"')

parser.add_argument('--dataset-dir', default='./', help='保存MNIST数据集的位置，例如“./”\n Root directory for saving MNIST dataset, e.g., "./"')
parser.add_argument('--log-dir', default='./logs', help='保存tensorboard日志文件的位置，例如“./”\n Root directory for saving tensorboard logs, e.g., "./"')
parser.add_argument('--model-output-dir', default='./', help='模型保存路径，例如“./”\n Model directory for saving, e.g., "./"')

parser.add_argument('-b', '--batch-size', default=64, type=int, help='Batch 大小，例如“64”\n Batch size, e.g., "64"')
parser.add_argument('-T', '--timesteps', default=100, type=int, dest='T', help='仿真时长，例如“100”\n Simulating timesteps, e.g., "100"')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='学习率，例如“1e-3”\n Learning rate, e.g., "1e-3": ', dest='lr')
# parser.add_argument('--tau', default=2.0, type=float, help='LIF神经元的时间常数tau，例如“100.0”\n Membrane time constant, tau, for LIF neurons, e.g., "100.0"')
parser.add_argument('-N', '--epoch', default=100, type=int, help='训练epoch，例如“100”\n Training epoch, e.g., "100"')
parser.add_argument('-m', default=16, type=int, help='使用高斯调谐曲线编码每个像素点使用的神经元数量，例如“16”\n input neuron number for encoding a piexl in GaussianTuning encoder, e.g., "16"')

args = parser.parse_args("")
device = args.device

data_dir = args.dataset_dir
log_dir = args.log_dir
model_output_dir = args.model_output_dir

batch_size = args.batch_size
T = args.T
learning_rate = args.lr
train_epoch = args.epoch
m = args.m

encoder = encoding.GaussianTuning(n=1, m=m, x_min=torch.zeros(size=[1]).to(device), x_max=torch.ones(size=[1]).to(device))

writer = SummaryWriter(log_dir)

train_set = FYPDataset(data_dir, train=True, data_type='frame', split_by='number', frames_number=T)
test_set = FYPDataset(data_dir, train=False, data_type='frame', split_by='number', frames_number=T)

train_data_loader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

test_data_loader = DataLoader(
    dataset= train_set,
    batch_size = batch_size,
    shuffle=False,
    drop_last=False
)

net = Net(m, T).to(device)

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)

train_times = 0
max_test_accuracy = 0

In [None]:
if __name__ == '__main__':
    for epoch in range(train_epoch):
        print("Epoch {}:".format(epoch))
        print("Training...")
        net.train()
        train_correct_sum = 0
        train_sum = 0
        for img, label in train_data_loader:
            img = img.view(img.shape[0], -1).unsqueeze(1)  # [batch_size, 1, 784]
            in_spikes = encoder.encode(img.to(device), T)  # [batch_size, 1, 784, m]
            in_spikes = in_spikes.view(in_spikes.shape[0], -1)  # [batch_size, 784*m]

            optimizer.zero_grad()

            v_max = net(in_spikes)
            loss = neuron.Tempotron.mse_loss(v_max, net.tempotron.v_threshold, label.to(device), 10)
            loss.backward()
            optimizer.step()

            train_correct_sum += (v_max.argmax(dim=1) == label.to(device)).float().sum().item()
            train_sum += label.numel()

            train_batch_acc = (v_max.argmax(dim=1) == label.to(device)).float().mean().item()
            writer.add_scalar('train_batch_acc', train_batch_acc, train_times)

            train_times += 1
        # train_accuracy = train_correct_sum / train_sum

        print("Testing...")
        net.eval()
        with torch.no_grad():
            correct_num = 0
            img_num = 0
            for img, label in test_data_loader:
                img = img.view(img.shape[0], -1).unsqueeze(1)  # [batch_size, 1, 784]

                in_spikes = encoder.encode(img.to(device), T)  # [batch_size, 1, 784, m]
                in_spikes = in_spikes.view(in_spikes.shape[0], -1)  # [batch_size, 784*m]
                v_max = net(in_spikes)
                correct_num += (v_max.argmax(dim=1) == label.to(device)).float().sum().item()
                img_num += img.shape[0]
            test_accuracy = correct_num / img_num
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print("Epoch {}: train_acc = {}, test_acc={}, max_test_acc={}, train_times={}".format(epoch, train_accuracy, test_accuracy, max_test_accuracy, train_times))
        print()
    
    # 保存模型
    torch.save(net, model_output_dir + "/eventdata.ckpt")
    # 读取模型
    # net = torch.load(model_output_dir + "/tempotron_snn_mnist.ckpt")