<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/Han2020.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Residual Membrane Potential Neuron for Enabling Deep High-accuracy and Low-Latency Spiking Neural Network

# Import and Load Data

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.functional as F
import torch.utils as utils
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

from tqdm import tqdm
import numpy as np

In [None]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.4 nirtorch-1.0 snntorch-0.9.1


In [None]:
import snntorch as snn
import snntorch.utils as sutils
import snntorch.functional as sF

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])

train = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 69974681.39it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# VGG

In [None]:
cnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, 1, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, 1, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(64, 128, 3, 1, 1, bias=False),
    nn.ReLU(),
    nn.Conv2d(128, 128, 3, 1, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Dropout(p=0.4),
    nn.Linear(2048, 2048, bias=False),
    nn.ReLU(),
    nn.Dropout(p=0.4),
    nn.Linear(2048, 2048, bias=False),
    nn.ReLU(),
    nn.Dropout(p=0.4),
    nn.Linear(2048, 10, bias=False)
)
#decay=2e-4, mile=[2,4,6,7], epochs=8, [83,612, 78.42]
#p = 0.5, decay=2e-4, mile=[2,4,6,7,8,9], epochs=10, [76.814, 77.18]
#vgg7, p=0.4, decay=1e-4, mile=[2,4,6,8,9], epochs=10, [81.656, 79.33]
#32P-64P-128+128P-2048-2048-10,p=0.4, decay=1e-4, exp(0.8), epochs=10, [82.364, 78.55]

In [None]:
activation = {}
def named_hook(name):
    def hook(model, input, output):
        activation[name] = output.detach().cpu().numpy()
    return hook

In [None]:
optimizer = optim.Adam(cnn.parameters(), weight_decay=1e-4)
scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.8)
criteria = nn.CrossEntropyLoss()

In [None]:
epochs = 10
device = torch.device('cuda')
cnn.to(device)

for epoch in range(epochs):
    tr_loss = 0
    tr_correct = 0
    cnn.train()
    for data, target in tqdm(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = cnn(data)
        loss = criteria(out, target)
        loss.backward()
        optimizer.step()

        tr_loss += loss.item()
        _, pred = out.max(1)
        tr_correct += (pred == target).sum().item()
    scheduler.step()

    #register activation
    if epoch == epochs - 1:
        for idx, layer in enumerate(cnn):
            if isinstance(layer, nn.ReLU):
                layer.register_forward_hook(named_hook(f'{idx}ReLU'))
        cnn[19].register_forward_hook(named_hook('ReLU19'))
        with torch.no_grad():
            for data, target in train_loader:
                data, target = data.to(device), target.to(device)
                out = cnn(data)

    ts_loss = 0
    ts_correct = 0
    cnn.eval()
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            data, target = data.to(device), target.to(device)
            out = cnn(data)

            ts_loss += criteria(out, target).item()
            _, pred = out.max(1)
            ts_correct += (pred == target).sum().item()

    print(f'Epoch : {epoch}')
    print(f'train : {tr_loss/50000}, {tr_correct/500}%\ttest : {ts_loss/10000}, {ts_correct/100}%')

100%|██████████| 782/782 [00:22<00:00, 34.42it/s]
100%|██████████| 157/157 [00:03<00:00, 44.17it/s]


Epoch : 0
train : 0.025018361573219298, 40.898%	test : 0.020043831557035447, 53.91%


100%|██████████| 782/782 [00:21<00:00, 37.12it/s]
100%|██████████| 157/157 [00:04<00:00, 31.62it/s]


Epoch : 1
train : 0.018751052986383437, 57.668%	test : 0.016441240674257278, 63.49%


100%|██████████| 782/782 [00:24<00:00, 32.04it/s]
100%|██████████| 157/157 [00:02<00:00, 53.59it/s]


Epoch : 2
train : 0.015386248117685318, 65.664%	test : 0.01471907422542572, 67.75%


100%|██████████| 782/782 [00:21<00:00, 35.68it/s]
100%|██████████| 157/157 [00:02<00:00, 56.41it/s]


Epoch : 3
train : 0.013279960027933121, 70.458%	test : 0.012736618757247924, 71.66%


100%|██████████| 782/782 [00:21<00:00, 37.18it/s]
100%|██████████| 157/157 [00:02<00:00, 53.81it/s]


Epoch : 4
train : 0.011801668327450752, 73.772%	test : 0.012071572357416153, 73.0%


100%|██████████| 782/782 [00:21<00:00, 37.06it/s]
100%|██████████| 157/157 [00:03<00:00, 49.76it/s]


Epoch : 5
train : 0.010523487982153893, 76.362%	test : 0.010956252211332321, 76.26%


100%|██████████| 782/782 [00:20<00:00, 38.49it/s]
100%|██████████| 157/157 [00:04<00:00, 36.53it/s]


Epoch : 6
train : 0.00971093220293522, 78.6%	test : 0.010674732801318169, 76.71%


100%|██████████| 782/782 [00:21<00:00, 35.69it/s]
100%|██████████| 157/157 [00:03<00:00, 49.09it/s]


Epoch : 7
train : 0.008977935937643051, 79.974%	test : 0.010577488261461258, 77.1%


100%|██████████| 782/782 [00:20<00:00, 38.59it/s]
100%|██████████| 157/157 [00:03<00:00, 43.87it/s]


Epoch : 8
train : 0.008303053728044033, 81.568%	test : 0.010182826688885689, 78.03%


100%|██████████| 782/782 [00:20<00:00, 38.23it/s]
100%|██████████| 157/157 [00:03<00:00, 42.86it/s]

Epoch : 9
train : 0.007867997389733792, 82.364%	test : 0.00987781553864479, 78.55%





In [None]:
device = torch.device('cpu')
cnn.to(device)
torch.save(cnn.state_dict(), 'hen2020.pth')
np.save('activation_hen.npy', activation)

#  SNN and Conversion

In [None]:
scnn_seq = nn.Sequential(
    nn.Conv2d(3, 32, 3, 1, 1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, 1, 1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Conv2d(64, 128, 3, 1, 1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.Conv2d(128, 128, 3, 1, 1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(2048, 2048, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.Linear(2048, 2048, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.Linear(2048, 10, bias=False),
    snn.Leaky(1)
)

In [None]:
class SCNN(nn.Module):
    def __init__(self):
        super(SCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1, bias=False)
        self.lif1 = snn.Leaky(1, init_hidden=True)
        self.conv2 = nn.Conv2d(32, 32, 3, 1, 1, bias=False)
        self.lif2 = snn.Leaky(1, init_hidden=True)
        self.pool1 = nn.AvgPool2d(2)
        self.conv3 = nn.Conv2d(32, 64, 3, 1, 1, bias=False)
        self.lif3 = snn.Leaky(1, init_hidden=True)
        self.conv4 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
        self.lif4 = snn.Leaky(1, init_hidden=True)
        self.pool2 = nn.AvgPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(4096, 1024, bias=False)
        self.lif5 = snn.Leaky(1, init_hidden=True)
        self.fc2 = nn.Linear(1024, 256, bias=False)
        self.lif6 = snn.Leaky(1, init_hidden=True)
        self.fc3 = nn.Linear(256, 10, bias=False)
        self.lif7 = snn.Leaky(1)

    def forward(self, x):
        x = self.lif1(self.conv1(x))
        x = self.pool1(self.lif2(self.conv2(x)))
        x = self.lif3(self.conv3(x))
        x = self.pool2(self.lif4(self.conv4(x)))
        x = self.flatten(x)
        x = self.lif5(self.fc1(x))
        x = self.lif6(self.fc2(x))
        spk, mem = self.lif7(self.fc3(x))
        return spk, mem

scnn_module = SCNN()

In [None]:
#seq 2 module
cnn_sd = torch.load('hen2020.pth', weights_only=True)

scnn_sd = {}
scnn_sd['conv1.weight'] = cnn_sd['0.weight']
scnn_sd['conv2.weight'] = cnn_sd['2.weight']
scnn_sd['conv3.weight'] = cnn_sd['5.weight']
scnn_sd['conv4.weight'] = cnn_sd['7.weight']
scnn_sd['fc1.weight'] = cnn_sd['12.weight']
scnn_sd['fc2.weight'] = cnn_sd['15.weight']
scnn_sd['fc3.weight'] = cnn_sd['18.weight']

scnn_module.load_state_dict(scnn_sd, strict=False)

_IncompatibleKeys(missing_keys=['lif1.threshold', 'lif1.graded_spikes_factor', 'lif1.reset_mechanism_val', 'lif1.beta', 'lif2.threshold', 'lif2.graded_spikes_factor', 'lif2.reset_mechanism_val', 'lif2.beta', 'lif3.threshold', 'lif3.graded_spikes_factor', 'lif3.reset_mechanism_val', 'lif3.beta', 'lif4.threshold', 'lif4.graded_spikes_factor', 'lif4.reset_mechanism_val', 'lif4.beta', 'lif5.threshold', 'lif5.graded_spikes_factor', 'lif5.reset_mechanism_val', 'lif5.beta', 'lif6.threshold', 'lif6.graded_spikes_factor', 'lif6.reset_mechanism_val', 'lif6.beta', 'lif7.threshold', 'lif7.graded_spikes_factor', 'lif7.reset_mechanism_val', 'lif7.beta'], unexpected_keys=[])

In [None]:
#seq 2 seq
cnn_sd = torch.load('hen2020.pth', weights_only=True)

scnn_sd = {}
scnn_sd = {k: v for k, v in cnn_sd.items() if k in scnn_seq.state_dict()}
scnn_sd['12.weight'] = cnn_sd['13.weight']
scnn_sd['14.weight'] = cnn_sd['16.weight']
scnn_sd['16.weight'] = cnn_sd['19.weight']

scnn_seq.load_state_dict(scnn_sd, strict=False)

_IncompatibleKeys(missing_keys=['1.threshold', '1.graded_spikes_factor', '1.reset_mechanism_val', '1.beta', '4.threshold', '4.graded_spikes_factor', '4.reset_mechanism_val', '4.beta', '7.threshold', '7.graded_spikes_factor', '7.reset_mechanism_val', '7.beta', '9.threshold', '9.graded_spikes_factor', '9.reset_mechanism_val', '9.beta', '13.threshold', '13.graded_spikes_factor', '13.reset_mechanism_val', '13.beta', '15.threshold', '15.graded_spikes_factor', '15.reset_mechanism_val', '15.beta', '17.threshold', '17.graded_spikes_factor', '17.reset_mechanism_val', '17.beta'], unexpected_keys=[])

In [None]:
torch.save(scnn_seq.state_dict(), 'hen2020_snn.pth')

# Threshold Balancing By ReLU

In [None]:
activation_obj = np.load('activation_hen.npy', allow_pickle=True)
activation = activation_obj.item()

thresholds = []
for k, v in activation.items():
    m = np.max(v)
    p = np.percentile(v, q=99.9)
    thresholds.append(p)
    print(k, m, p)

1ReLU 2.015399 1.3645718150138977
3ReLU 3.5909607 1.7925383200645553
6ReLU 2.3062818 0.9410146411657363
8ReLU 1.7592487 0.9739432749748534
13ReLU 3.0144727 2.0458919336795867
16ReLU 6.80631 5.817738733291929
linear 11.839191 11.79015914821625


In [None]:
#Threshold balancing by ReLU
#seq to module
scnn_module.lif1.threshold = torch.tensor(thresholds[0])
scnn_module.lif2.threshold = torch.tensor(thresholds[1])
scnn_module.lif3.threshold = torch.tensor(thresholds[2])
scnn_module.lif4.threshold = torch.tensor(thresholds[3])
#scnn_module.lif5.threshold = torch.tensor(thresholds[4])
#scnn_module.lif6.threshold = torch.tensor(thresholds[5])
#scnn_module.lif7.threshold = torch.tensor(thresholds[6])

In [None]:
#Threshold balancing by ReLU
#seq to seq
for idx, layer in enumerate(scnn_module):
    if isinstance(layer, snn.Leaky):
        layer.threshold = torch.tensor(thresholds[idx])
    #if idx == 6: break

# Threshold Balancing by IF

In [None]:
scnn_sd = torch.load('hen2020_snn.pth', weights_only=True)
scnn_seq.load_state_dict(scnn_sd, strict=False)

<All keys matched successfully>

In [None]:
TB_loader = DataLoader(train, batch_size=256, shuffle=True)
data, target = next(iter(TB_loader))

In [None]:
steps = 8
with torch.no_grad():
    sutils.reset(scnn_seq)
    #first Leaky
    out_list = []
    for step in range(steps):
        out = scnn_seq[0](data)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[1].threshold = torch.tensor(p999)
    #second Leaky
    out_list[]
    for step in out_stack:
        out = scnn_seq[1, 2, 3](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[4].threshold = torch.tensor(p999)
    #third Leaky
    out_list[]
    for step in out_stack:
        out = scnn_seq[4, 5, 6](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[7].threshold = torch.tensor(p999)
    #forth Leaky
    out_list[]
    for step in out_stack:
        out = scnn_seq[7, 8](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[9].threshold = torch.tensor(p999)
    #fifth Leaky
    out_list[]
    for step in out_stack:
        out = scnn_seq[9, 10, 11, 12](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[13].threshold = torch.tensor(p999)
    #sixth Leaky
    out_list[]
    for step in out_stack:
        out = scnn_seq[13, 14](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[15].threshold = torch.tensor(p999)
    #seventh
    out_list[]
    for step in out_stack:
        out = scnn_seq[15, 16](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    p999 = np.percentile(out_stack.numpy(), q=99.9)
    scnn_seq[17].threshold = torch.tensor(p999)
    #out
    out_list = []
    for step in out_stack:
        out = scnn_seq[17](step)
        out_list.append(out)
    out_stack = torch.stack(out_list)
    print(out_stack.shape)

# Result

In [None]:
criteria = sF.ce_count_loss()

In [None]:
epochs = 1
timesteps = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

scnn_module.to(device)
scnn_module.eval()
with torch.no_grad():
    loss = 0
    correct = 0
    for data, target in tqdm(test_loader):
        data, target = data.to(device), target.to(device)
        sutils.reset(scnn_module)

        spk_step = []
        for step in range(timesteps):
            spk, mem = scnn_module(data)
            spk_step.append(spk)
        spk_batch = torch.stack(spk_step)

        loss += criteria(spk_batch, target).item()
        _, pred = spk_batch.sum(0).max(1)
        correct += (pred == target).sum().item()

        del spk_batch, spk_step, data, target

    print(f'\n{loss/10000},\t{correct/100}%')

#by ReLU
#99.9, steps=32, [, 72.02], linear後のreluのbalancingなし
#99.9, steps=32, [, 10]

#by IFz
#max, steps=32, for文回さずbalancing, [, 15.17]

100%|██████████| 157/157 [00:23<00:00,  6.70it/s]


0.1995559422492981,	15.17%



