<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/Reuckauer2017.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Conversion of Continuous-Valued Deep Network to Efficient Event-Driven Network for Image Classification**

# 1. Import

In [None]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.4 nirtorch-1.0 snntorch-0.9.1


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils as utils
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms

from tqdm import tqdm
import numpy as np

import snntorch as snn
import snntorch.functional as sF
import snntorch.utils as sutils

# 2. Data

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])

train = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 68753722.88it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# 3. CNN

In [None]:
cnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1, bias=False),
    nn.ReLU(),
    nn.Conv2d(32, 32, 3, padding=1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, padding=1, bias=False),
    nn.ReLU(),
    nn.Conv2d(64, 64, 3, padding=1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Dropout(p=0.4),
    nn.Linear(4096, 512, bias=False),
    nn.ReLU(),
    nn.Dropout(p=0.4),
    nn.Linear(512, 10, bias=False)
)
#BN=True, p=0.25, miles=[4,6,7,8,9], gamma=0.5, [88.22, 82.86]
#BN=False, p=0.4, decay=1e-4, Exponetial gamma=0.75, [81.113, 77.75]

In [None]:
optimizer = optim.Adam(cnn.parameters(), weight_decay=1e-4)
scheduler = lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.75)
criteria = nn.CrossEntropyLoss()

In [None]:
#activation
activation = {}
def named_hook(name):
    def hook(module, input, output):
        activation[name] = output.detach().cpu().numpy()
    return hook

In [None]:
epochs = 10

for epoch in range(epochs):
    cnn.train()
    tr_loss = 0
    tr_correct = 0
    for data, target in tqdm(train_loader):
        optimizer.zero_grad()
        out = cnn(data)
        loss = criteria(out, target)
        loss.backward()
        optimizer.step()

        tr_loss += loss.item()
        _, pred = out.max(1)
        tr_correct += (pred == target).sum().item()
    scheduler.step()

    #activation archive
    if epoch == epochs - 1:
        for idx, layer in enumerate(cnn):
            if isinstance(layer, nn.ReLU):
                layer.register_forward_hook(named_hook(f'ReLU{idx}'))
        cnn[15].register_forward_hook(named_hook('ReLU15'))
        with torch.no_grad():
            for data, target in train_loader:
                out = cnn(data)

    cnn.eval()
    ts_loss = 0
    ts_correct = 0
    with torch.no_grad():
        for data, target in tqdm(test_loader):
            out = cnn(data)

            ts_loss += criteria(out, target).item()
            _, pred = out.max(1)
            ts_correct += (pred == target).sum().item()

    print(f'\nepoch : {epoch+1}\ntrain = {tr_loss/50000}, {tr_correct/500}%\t test = {ts_loss/10000}, {ts_correct/100}%\n')

100%|██████████| 782/782 [04:18<00:00,  3.02it/s]
100%|██████████| 157/157 [00:20<00:00,  7.77it/s]



epoch : 1
train = 0.0236057322371006, 45.506%	 test = 0.018710120767354965, 57.97%



100%|██████████| 782/782 [04:22<00:00,  2.98it/s]
100%|██████████| 157/157 [00:21<00:00,  7.43it/s]



epoch : 2
train = 0.017440081950426102, 60.842%	 test = 0.015148442205786705, 66.01%



100%|██████████| 782/782 [04:58<00:00,  2.62it/s]
100%|██████████| 157/157 [00:30<00:00,  5.13it/s]



epoch : 3
train = 0.01460567600607872, 67.086%	 test = 0.013816140526533127, 69.84%



100%|██████████| 782/782 [05:16<00:00,  2.47it/s]
100%|██████████| 157/157 [00:33<00:00,  4.76it/s]



epoch : 4
train = 0.012863774104118347, 71.464%	 test = 0.01243299924135208, 72.75%



100%|██████████| 782/782 [05:14<00:00,  2.48it/s]
100%|██████████| 157/157 [00:33<00:00,  4.67it/s]



epoch : 5
train = 0.011505562909245492, 74.238%	 test = 0.011733125907182694, 74.45%



100%|██████████| 782/782 [05:32<00:00,  2.35it/s]
100%|██████████| 157/157 [00:40<00:00,  3.89it/s]



epoch : 6
train = 0.010594038138389588, 76.318%	 test = 0.011203959861397743, 74.81%



100%|██████████| 782/782 [06:52<00:00,  1.90it/s]
100%|██████████| 157/157 [00:44<00:00,  3.54it/s]



epoch : 7
train = 0.009824658742547036, 78.078%	 test = 0.010647248828411102, 76.41%



100%|██████████| 782/782 [07:35<00:00,  1.72it/s]
100%|██████████| 157/157 [00:43<00:00,  3.59it/s]



epoch : 8
train = 0.00921361991584301, 79.55%	 test = 0.0105768824249506, 76.54%



100%|██████████| 782/782 [08:17<00:00,  1.57it/s]
100%|██████████| 157/157 [00:45<00:00,  3.46it/s]



epoch : 9
train = 0.00885140183210373, 80.374%	 test = 0.01034541222155094, 77.34%



100%|██████████| 782/782 [08:56<00:00,  1.46it/s]
100%|██████████| 157/157 [00:42<00:00,  3.65it/s]


epoch : 10
train = 0.008457711308896542, 81.116%	 test = 0.010303056800365447, 77.75%






In [None]:
torch.save(cnn.state_dict(), 'reuckauer2017.pth')
np.save('activation.npy', activation)

# 4. SNN and Conversion

In [None]:
scnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1, bias=False),
    snn.Leaky(1,init_hidden=True),
    nn.Conv2d(32, 32, 3, padding=1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, padding=1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.Conv2d(64, 64, 3, padding=1, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Linear(4096, 512, bias=False),
    snn.Leaky(1, init_hidden=True),
    nn.Linear(512, 10, bias=False),
    snn.Leaky(1)
)

In [None]:
cnn_sd = torch.load('reuckauer2017.pth', weights_only=True)

scnn_sd = {}
scnn_sd['0.weight'] = cnn_sd['0.weight']
scnn_sd['2.weight'] = cnn_sd['2.weight']
scnn_sd['5.weight'] = cnn_sd['5.weight']
scnn_sd['7.weight'] = cnn_sd['7.weight']
scnn_sd['11.weight'] = cnn_sd['12.weight']
scnn_sd['13.weight'] = cnn_sd['15.weight']

scnn.load_state_dict(scnn_sd, strict=False)

_IncompatibleKeys(missing_keys=['1.threshold', '1.graded_spikes_factor', '1.reset_mechanism_val', '1.beta', '3.threshold', '3.graded_spikes_factor', '3.reset_mechanism_val', '3.beta', '6.threshold', '6.graded_spikes_factor', '6.reset_mechanism_val', '6.beta', '8.threshold', '8.graded_spikes_factor', '8.reset_mechanism_val', '8.beta', '12.threshold', '12.graded_spikes_factor', '12.reset_mechanism_val', '12.beta', '14.threshold', '14.graded_spikes_factor', '14.reset_mechanism_val', '14.beta'], unexpected_keys=[])

# 5. Threshold Balancing or Weight Normalization

In [None]:
activation_obj = np.load('activation.npy', allow_pickle=True)
activation = activation_obj.item()

thresholds = []
for k, v in activation.items():
    p999 = np.percentile(v, q=99.9)
    thresholds.append(p999)

In [None]:
scnn[1].threshold = torch.tensor(thresholds[1])
scnn[3].threshold = torch.tensor(thresholds[2]/thresholds[1])
scnn[6].threshold = torch.tensor(thresholds[3]/thresholds[2])
scnn[8].threshold = torch.tensor(thresholds[4]/thresholds[3])
scnn[12].threshold = torch.tensor(thresholds[5]/thresholds[4])
scnn[14].threshold = torch.tensor(thresholds[0]/thresholds[5])

# 6. Result

In [None]:
criteria = sF.ce_count_loss()

In [None]:
epochs = 1
timesteps = 64
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#scnn.to(device)

scnn.eval()
with torch.no_grad():
    loss = 0
    correct = 0
    for data, target in tqdm(test_loader):
        #data, target = data.to(device), target.to(device)
        sutils.reset(scnn)

        spk_step = []
        for step in range(timesteps):
            spk, mem = scnn(data)
            spk_step.append(spk)
        spk_batch = torch.stack(spk_step)

        loss += criteria(spk_batch, target).item()
        _, pred = spk_batch.sum(0).max(1)
        correct += (pred == target).sum().item()

        del spk_batch, spk_step, data, target

    print(f'\n{loss/10000},\t{correct/100}%')

#steps=32, BN=False, WN=False, [, 28.38]
#steps=32, BN=False, TB=True(data-based by ReLU), [, 77.75]->[, 75.0]
#steps=64, same, [, 77.75]->[, 77.2]

100%|██████████| 157/157 [34:42<00:00, 13.26s/it]


0.016365691620111464,	77.2%





# 7. Firing rate

In [None]:
spk_count = dict()
def spike_hook(name, init_hidden):
    def hook(module, input, output):
        if name not in spk_count:
            if init_hidden:
                spk_count[name] = torch.zeros_like(output)
            else: spk_count[name] = torch.zeros_like(output[0])
        if init_hidden:
            spk_count[name] += output
        else: spk_count[name] += output[0]
    return hook

for idx, layer in enumerate(scnn):
    if isinstance(layer, snn.Leaky):
        init_hidden = layer.init_hidden
        layer.register_forward_hook(spike_hook(f'IF{idx}', init_hidden))

In [None]:
data, target = next(iter(test_loader))

timesteps = 8
with torch.no_grad():
    for step in range(timesteps):
        spk, mem = scnn(data)

In [None]:
for k, v in spk_count.items():
    print(k, v.shape)

IF1 torch.Size([64, 32, 32, 32])
IF3 torch.Size([64, 32, 32, 32])
IF6 torch.Size([64, 64, 16, 16])
IF8 torch.Size([64, 64, 16, 16])
IF12 torch.Size([64, 512])
IF14 torch.Size([64, 10])


In [None]:
spk_layer = torch.sum(spk_count['IF12'], 1) / 512
rate = torch.mean(spk_layer)
print(spk_layer, rate) #timestepsで割るのを忘れずに
#1neuronあたり0.01, dataによって大きく異なる, 0も多い, steps=8

tensor([0.0156, 0.0430, 0.0352, 0.0859, 0.0000, 0.1328, 0.0820, 0.0469, 0.0273,
        0.1211, 0.1719, 0.3242, 0.1172, 0.1250, 0.0156, 0.0000, 0.3867, 0.2188,
        0.0469, 0.0977, 0.0742, 0.0195, 0.2422, 0.0000, 0.0039, 0.0000, 0.0000,
        0.3477, 0.0195, 0.0000, 0.0000, 0.0000, 0.0586, 0.2461, 0.0000, 0.0000,
        0.0039, 0.0000, 0.1055, 0.0117, 0.1953, 0.0000, 0.3438, 0.0156, 0.0352,
        0.0078, 0.0977, 0.0000, 0.0352, 0.2227, 0.0078, 0.2422, 0.0547, 0.0352,
        0.0352, 0.0039, 0.0820, 0.0859, 0.0000, 0.0508, 0.0000, 0.3008, 0.0469,
        0.0273]) tensor(0.0805)


In [None]:
print(spk_count['IF14']) #timestepsで割るのを忘れずに

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0