<a href="https://colab.research.google.com/github/Kazuto-Takahashi/Research/blob/main/IFneuron.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm

import math

In [None]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.4 nirtorch-1.0 snntorch-0.9.1


In [None]:
import snntorch as snn
import snntorch.functional as SF

# Neuron Definition

In [None]:
class surrogate_activation(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input, thresh):
    ctx.save_for_backward(input)
    return torch.where(input < thresh, 0., 1.)

  @staticmethod
  def backward(ctx, grad_input):
    input, = ctx.saved_tensors
    grad = 1 / (25 * torch.abs(input) + 1.0)
    return grad_input * grad, None #このNoneはforward(thresh)に対応

In [None]:
class IF(nn.Module):
  def __init__(self, thresh = 1.0, leak = 1.0):
    super(IF, self).__init__()
    self.mem = None
    self.thresh = thresh
    self.leak = leak

  def forward(self, x):
    if self.mem is None:
      self.mem = torch.zeros_like(x)

    self.mem = self.mem + x
    spike = surrogate_activation.apply(self.mem, self.thresh)
    self.mem = self.mem - spike
    return spike

# preview

In [None]:
data = torch.tensor([0.3, 0.4, 0.5], requires_grad=True)

if_neuron = IF()

spike_pot = []
for i in range(4):
  out = if_neuron(data)
  spike_pot.append(out)

[tensor([0., 0., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>)]
[tensor([0., 0., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([0., 0., 1.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>)]
[tensor([0., 0., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([0., 0., 1.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([0., 1., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>)]
[tensor([0., 0., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([0., 0., 1.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([0., 1., 0.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>), tensor([1., 0., 1.], dtype=torch.float64,
       grad_fn=<surrogate_activationBackward>)]


In [None]:
data1 = torch.tensor([0.3, 0.4, 0.5], requires_grad=True)

if_neuron = IF()

out = if_neuron(data1)
loss = out.sum()
loss.backward()

print(data1.grad)

tensor([0.1176, 0.0909, 0.0741])


# MNIST

In [None]:
from torchvision import datasets, transforms

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.25,))])

train = datasets.MNIST('./data', train=True, download=True, transform=transform)
test = datasets.MNIST('./data', train=False, download=True, transform=transform)

tr_loader = DataLoader(train, batch_size=64, shuffle=True)
ts_loader = DataLoader(test, batch_size=64, shuffle=False)

In [None]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.flatten = nn.Flatten()
    self.linear1 = nn.Linear(784, 784)
    self.act1 = IF()
    self.linear2 = nn.Linear(784, 10)
    self.act2 = IF()

  def forward(self, x):
    x = self.flatten(x)
    x = self.linear1(x)
    x = self.act1(x)
    x = self.linear2(x)
    x = self.act2(x)
    return x

In [None]:
model = Model()
optimizer = optim.Adam(model.parameters())
criterion = SF.ce_count_loss()

In [None]:
model = Model()

epochs = 1
steps = 1

for epoch in range(epochs):
  model.train()
  for idx, (data, target) in tqdm(enumerate(tr_loader)):
    optimizer.zero_grad()
    out_rec = []
    for step in range(steps):
      out = model(data)
      out_rec.append(out)
    out_stc = torch.stack(out_rec)
    loss = criterion(out_stc, target)
    loss.backward(retain_graph=True)
    optimizer.step()

    if idx % 50 == 0:
      print(f'idx: {idx}, loss:{loss.item()}')
  print(f'idx: {idx}, Loss: {loss.item()}')

4it [00:00, 35.25it/s]

idx: 0, loss:2.3025851249694824


52it [00:04,  7.30it/s]

idx: 50, loss:2.288417339324951


101it [00:16,  3.60it/s]

idx: 100, loss:2.3633053302764893


122it [00:23,  5.26it/s]


KeyboardInterrupt: 