<a href="https://colab.research.google.com/github/TKph/Research/blob/main/SNN_with_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m931.9 kB/s[0m eta [36m0:00:00[0m
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl (18 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.1.0->snntorch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.1.0->snntorch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.1.0->snntorch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.1.0->snntorch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.wh

In [None]:
import numpy as np
import matplotlib as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import snntorch
from snntorch import spikegen
from snntorch import utils
from snntorch import functional as SF

**Data**

In [None]:
import torchvision
from torchvision.datasets import CIFAR100

In [None]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True) #(data,h,w,channel)=(50000,32,32,3)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True)
train_loader = DataLoader(trainset, batch_size=128, shuffle=True)
test_loader = DataLoader(testset, batch_size=128, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:10<00:00, 15849437.31it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
(50000, 32, 32, 3)
Files already downloaded and verified


**Model**

In [None]:
beta = 0.5
spike_grad = snntorch.surrogate.fast_sigmoid(slope=25)

In [None]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 12, 5) #inchannel, outchannel, kernelsize
    self.lif1 = snntorch.Leaky(beta=beta, spike_grad=spike_grad)
    self.conv2 = nn.Conv2d(12, 64, 5)
    self.lif2 = snntorch.Leaky(beta=beta, spike_grad=spike_grad)
    self.fc1 = nn.Linear(64*5*5, 100) #kernel=5, pooling_layer=2
    self.lif3 = snntorch.Leaky(beta=beta, spike_grad=spike_grad)

  def forwarad(self, x):
    mem1 = self.lif1.init_leaky()
    mem2 = self.lif2.init_leaky()
    mem3 = self.lif3.init_leaky()

    cur1 = F.maxpooling(self.conv1(x), 2)
    spk1, mem1 = self.lif1(cur1, mem1)
    cur2 = F.maxpooling(self.conv2(spk1), 2)
    spk2, mem2 = self.lif2(cur2, mem2)
    cur3 = self.fc(spk2.view(batch_size, -1))
    spk3, mem3 = self.lif3(cur3, mem3)
    return spk3, mem3

net = Net()

In [None]:
def forward_timestep(net, num_steps, data):
  mem_rec = list()
  spk_rec = list()
  utils.reset(net)

  for step in range(num_steps):
    spk, mem = net(data)
    spk_rec.append(spk)
    mem_rec.append(mem)

  return torch.stack(spk_rec), torch.stack(mem_rec)

spk, mem = forward_timestep(net, 10, data)

**Loss function, Metrics and Optimizer**

In [None]:
loss_fn = SF.ce_rate_loss()
acc = SF.accuracy_rate(spk, targets)
optimizer = torch.optim.Adam(net.parameters(), ler=1e-2, betas=(0.9, 0.999))

In [None]:
def data_accuracy(loader, net, sum_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    for data, targets in loader:
      spk, _ = forward_timestep(net, sum_steps, data)
      acc += SF.accuracy_rate(spk, targets) * spk.size(1)
      total += spk.size(1)

    return acc / total

**Training**

In [None]:
epochs = 1
loss_hist = list()
acc_hist = list()
counter = 0

for epoch in range(epochs):
  for data, targets in train_loader: