In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
from utils import A_cluster


In [38]:
class config:
    input = 700
    output = 20
    hid = 128    # RC Neurons
    thr = 0.5
    decay = 0.5
    rst = 0.05
    
    N_hid = hid
    p_in = 0.2        # ratio of inhibitory neurons
    gamma = 1.0       # shape factor of gamma distribution
    binary = False    # binary matrix of reservoir A
    net_type = 'BAC'  # type of reservoir connection topology
                      # 'ER',  # Erdos-Renyi Random Network
                      # 'ERC', # Clusters of Erdos-Renyi Networks
                      # 'BA',  # Barabasi-Albert Network
                      # 'BAC', # Clusters of Barabasi-Albert networks
                      # 'WS',  # Watts Strogatz small world networks
                      # 'WSC', # Clusters of Watts Strogatz small world networks
                      # 'RAN', # random network
                      # 'DTW', # Developmental Time Window for multi-cluster small-world network
    noise = True      # add noise in A
    noise_str = 0.05  # noise strength
    p_ER = 0.2        # connection probability when creating edges, for ER and WS graphs
    m_BA = 3          # number of edges to attach from a new node to existing nodes
    k = 5             # number of clusters in A
    R = 0.2           # distance factor when deciding connections in random network
    scale = False     # rescale matrix A with spectral radius
    
    
    batch = 32
    epoch = 50
    lr = 0.01
    device = torch.device('cuda:0')
    

In [39]:
from spikingjelly.datasets.shd import SpikingHeidelbergDigits

SHD_train = SpikingHeidelbergDigits('D:\Ph.D\Research\SNN-SRT数据\SHD', train=True, data_type='frame', frames_number=20, split_by='number')
SHD_test = SpikingHeidelbergDigits('D:\Ph.D\Research\SNN-SRT数据\SHD', train=False, data_type='frame', frames_number=20, split_by='number')
train_loader = torch.utils.data.DataLoader(dataset=SHD_train, batch_size=config.batch, shuffle=True, drop_last=False, num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=SHD_test, batch_size=config.batch, shuffle=False, drop_last=False, num_workers=0)

The directory [D:\Ph.D\Research\SNN-SRT数据\SHD\extract] for saving extracted files already exists.
SpikingJelly will not check the data integrity of extracted files.
If extracted files are not integrated, please delete [D:\Ph.D\Research\SNN-SRT数据\SHD\extract] manually, then SpikingJelly will re-extract files from [D:\Ph.D\Research\SNN-SRT数据\SHD\download].
The directory [D:\Ph.D\Research\SNN-SRT数据\SHD\frames_number_20_split_by_number] already exists.
The directory [D:\Ph.D\Research\SNN-SRT数据\SHD\extract] for saving extracted files already exists.
SpikingJelly will not check the data integrity of extracted files.
If extracted files are not integrated, please delete [D:\Ph.D\Research\SNN-SRT数据\SHD\extract] manually, then SpikingJelly will re-extract files from [D:\Ph.D\Research\SNN-SRT数据\SHD\download].
The directory [D:\Ph.D\Research\SNN-SRT数据\SHD\frames_number_20_split_by_number] already exists.


In [40]:
class ActFun(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.gt(0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp = abs(input - 0) < 0.5 # lens
        return grad_input * temp.float()

act_fun = ActFun.apply

In [41]:
def mem_update(input, mem, spk, thr, decay, rst):
    mem = rst * spk + mem * decay * (1-spk) + input
    spike = act_fun(mem - thr)
    return mem, spike

class RC(nn.Module):
    def __init__(self) -> None:
        super(RC, self).__init__()
        input = config.input
        hid = config.hid
        out = config.output
        self.fc_in = nn.Linear(input, hid)
        self.A = nn.Parameter(torch.tensor(A_cluster(config)), requires_grad=False) # 邻接矩阵
        self.fc_out = nn.Linear(hid, out)
        
    def forward(self, input):
        batch, time_step, in_dim = input.shape
        hid_mem = torch.empty(batch, config.hid).uniform_(0, 0.1).to('cuda')
        hid_spk = sum_spk = torch.zeros(batch, config.hid).to('cuda')
        for t in range(time_step):
            x = self.fc_in(input[:,t,:])
            x = x @ self.A
            hid_mem, hid_spk = mem_update(x, hid_mem, hid_spk, config.thr, config.decay, config.rst)
            sum_spk += hid_spk
        sum_spk /= time_step
        out = self.fc_out(sum_spk)
        return out


In [42]:
def train(model, optimizer, criterion, num_epochs, train_loader, test_loader, device):
    train_accs = []
    best_acc = 75
    for epoch in range(num_epochs):
        now = time.time()
        for i, (samples, labels) in enumerate(train_loader):
            samples = samples.requires_grad_().to(device)
            labels = labels.long().to(device)
            optimizer.zero_grad()
            outputs = model(samples.to(device))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        tr_acc = test(model, train_loader)
        ts_acc = test(model, test_loader)
        if ts_acc > best_acc and tr_acc > 75:
            best_acc = ts_acc
        train_accs.append(tr_acc)
        res_str = 'epoch: ' + str(epoch) \
                    + ' Loss: ' + str(loss.item()) \
                    + '. Tr Acc: ' + str(tr_acc)   \
                    + '. Ts Acc: ' + str(ts_acc)   \
                    + '. Time:' + str(time.time()-now)
        print(res_str)
    return train_accs

def test(model, dataloader):
    correct, total = 0, 0
    for images, labels in dataloader:
        outputs = model(images.to(config.device))
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted.cpu() == labels.long().cpu()).sum()
    accuracy = 100. * correct.numpy() / total
    return accuracy

In [43]:
model = RC().to('cuda')
model.fc_in.requires_grad_ = False
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.lr)
acc = train(model, optimizer, criterion, config.epoch, train_loader, test_loader, 'cuda')
accuracy = test(model, test_loader)

epoch: 0 Loss: 1.9954992532730103. Tr Acc: 44.89946051986268. Ts Acc: 41.386925795053. Time:49.33854937553406
epoch: 1 Loss: 1.3443876504898071. Tr Acc: 49.26434526728789. Ts Acc: 46.775618374558306. Time:49.394436836242676
epoch: 2 Loss: 1.545304298400879. Tr Acc: 55.11280039234919. Ts Acc: 42.40282685512368. Time:47.19324040412903
epoch: 3 Loss: 0.8963672518730164. Tr Acc: 57.93281020107896. Ts Acc: 51.63427561837456. Time:48.78626227378845
epoch: 4 Loss: 1.4052256345748901. Tr Acc: 66.25796959293771. Ts Acc: 56.22791519434629. Time:50.378888845443726
epoch: 5 Loss: 1.1851226091384888. Tr Acc: 67.0426679744973. Ts Acc: 59.80565371024735. Time:46.271888732910156
epoch: 6 Loss: 0.8978437185287476. Tr Acc: 63.48700343305542. Ts Acc: 57.72968197879859. Time:49.56532597541809
epoch: 7 Loss: 1.0156340599060059. Tr Acc: 63.92839627268269. Ts Acc: 54.15194346289753. Time:47.26534295082092
epoch: 8 Loss: 0.8282561302185059. Tr Acc: 69.47032859244727. Ts Acc: 56.97879858657244. Time:53.0813541

In [37]:
from thop import profile

model = RC().to('cuda')
params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print("Parameters: {:.2f}K".format(params / 1e3))

input = torch.randn(1, 20, 700).cuda() #.to(config.device)
flops, params = profile(model.cuda(), inputs=(input,))
print("FLOPS: {:.2f}M".format(flops / 1e6))
print("Parameters: {:.2f}K".format(params / 1e3))

Parameters: 92.31K
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
FLOPS: 1.79M
Parameters: 92.31K
