<a href="https://colab.research.google.com/github/Kazuto-Takahashi/Research/blob/main/Spiking_Xception.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [5]:
import torch
import torch.nn as nn

import spikingjelly
from spikingjelly.activation_based import neuron, layer as Snn, functional as SF

from tqdm import tqdm

In [23]:
class SepConv(nn.Module):
    def __init__(self, inc, outc, stride=1):
        super(SepConv, self).__init__()
        self.stride = stride
        self.layer = nn.Sequential(
            Snn.Conv2d(inc, inc, 3, self.stride, 1, groups=inc, bias=False),
            neuron.IFNode(),
            Snn.Conv2d(inc, outc, 1, bias=False)
        )
    def forward(self, x):
        x = self.layer(x)
        return x

class BasicBlock(nn.Module):
    def __init__(self, inc, outc, lif=True):
        super(BasicBlock, self).__init__()
        self.down_sample = True if inc != outc else False
        self.stride = 2 if self.down_sample else 1
        self.conv1x1 = Snn.Conv2d(inc, outc, 2, 2, bias=False)
        layer = []
        layer.append(neuron.IFNode()) if lif else None
        layer.append(SepConv(inc, outc, self.stride))
        layer.append(Snn.BatchNorm2d(outc))
        layer.append(neuron.IFNode())
        layer.append(SepConv(outc, outc))
        layer.append(Snn.BatchNorm2d(outc))
        self.layer = nn.Sequential(*layer)

    def forward(self, x):
        out = self.layer(x)
        if self.down_sample:
            x = self.conv1x1(x)
        out += x
        return out

class S_Xception(nn.Module):
    def __init__(self, T):
        super(S_Xception, self).__init__()
        self.T = T
        self.first = nn.Sequential(
            Snn.Conv2d(3, 32, 3, 2, 1, bias=False),
            Snn.BatchNorm2d(32),
            neuron.IFNode()
        )
        self.block1 = BasicBlock(32, 64, False)
        self.block2 = BasicBlock(64, 64)
        self.block3 = BasicBlock(64, 64)
        self.block4 = BasicBlock(64, 128)
        self.last = nn.Sequential(
            SepConv(128, 256),
            Snn.BatchNorm2d(256),
            neuron.IFNode(),
            Snn.AdaptiveAvgPool2d((1, 1)),
            Snn.Flatten(),
            Snn.Linear(256, 64)
        )
        SF.set_step_mode(self, 'm')
    def forward(self, x):# N, C, H, W -> T, N, D
        SF.reset_net(self)
        x = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        x = self.first(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.last(x)
        return x

In [24]:
#preview
x = torch.rand(4, 3, 32, 32)
model = S_Xception(2)
total_params = sum(p.numel() for p in model.parameters())
print(total_params)
out = model(x)
print(out.shape)

178816
torch.Size([2, 4, 64])
