In [6]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron
from spikingjelly.activation_based.model import train_classify

In [3]:
T = 8
N = 1

def element_wise_add(x, y):
    return x + y

net = layer.ElementWiseRecurrentContainer(neuron.IFNode(), element_wise_add)
print(net)
x = torch.zeros([T, N])
x[0] = 1.5
for t in range(T):
    print(t, f'x[t]={x[t]}, s[t]={net(x[t])}')

functional.reset_net(net)

ElementWiseRecurrentContainer(
  element-wise function=<function element_wise_add at 0x7f1dd62af5b0>, step_mode=s
  (sub_module): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
)
0 x[t]=tensor([1.5000]), s[t]=tensor([1.])
1 x[t]=tensor([0.]), s[t]=tensor([1.])
2 x[t]=tensor([0.]), s[t]=tensor([1.])
3 x[t]=tensor([0.]), s[t]=tensor([1.])
4 x[t]=tensor([0.]), s[t]=tensor([1.])
5 x[t]=tensor([0.]), s[t]=tensor([1.])
6 x[t]=tensor([0.]), s[t]=tensor([1.])
7 x[t]=tensor([0.]), s[t]=tensor([1.])


In [5]:
stateful_conv = nn.Sequential(
    layer.Conv2d(3, 16, kernel_size=3, padding=1, stride=1),
    layer.SynapseFilter(tau=100.)
)

In [1]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet

s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)

print(f's_resnet18={s_resnet18}')

with torch.no_grad():
    T = 4
    N = 1
    x_seq = torch.rand([T, N, 3, 224, 224])
    functional.set_step_mode(s_resnet18, 'm')
    y_seq = s_resnet18(x_seq)
    print(f'y_seq.shape={y_seq.shape}')
    functional.reset_net(s_resnet18)

s_resnet18=SpikingResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
  (sn1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
    (surrogate_function): ATan(alpha=2.0, spiking=True)
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bi

In [7]:
class MyTrainer(train_classify.Trainer):
    def set_optimizer(self, args, parameters):
        opt_name = args.opt.lower()
        if opt_name.startswith("adamax"):
            optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
            return optimizer
        else:
            return super().set_optimizer(args, parameters)