<a href="https://colab.research.google.com/github/Kazuto-Takahashi/Research/blob/main/Spiking_Xception.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install spikingjelly

Collecting spikingjelly
  Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl.metadata (15 kB)
Downloading spikingjelly-0.0.0.0.14-py3-none-any.whl (437 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m437.6/437.6 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: spikingjelly
Successfully installed spikingjelly-0.0.0.0.14


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler as lrs
from torch.utils.data import DataLoader, ConcatDataset
import spikingjelly
from spikingjelly.activation_based import neuron, layer as snn, functional as SF

import torchvision
from torchvision import datasets
from torchvision.transforms import v2 as TF

from tqdm import tqdm

# Model

In [None]:
class SepConv(nn.Module):
    def __init__(self, inc, outc, stride=1):
        super(SepConv, self).__init__()
        self.stride = stride
        self.layer = nn.Sequential(
            snn.Conv2d(inc, inc, 3, self.stride, 1, groups=inc, bias=False),
            neuron.IFNode(),
            snn.Conv2d(inc, outc, 1, bias=False)
        )
    def forward(self, x):
        x = self.layer(x)
        return x

class BasicBlock(nn.Module):
    def __init__(self, inc, outc, lif=True):
        super(BasicBlock, self).__init__()

        self.down_sample = True if inc != outc else False
        self.stride = 2 if self.down_sample else 1
        self.conv1x1 = snn.Conv2d(inc, outc, 2, 2, bias=False)

        layer = []
        layer.append(neuron.IFNode()) if lif else None
        layer.append(SepConv(inc, outc, self.stride))
        layer.append(snn.BatchNorm2d(outc))
        layer.append(neuron.IFNode())
        layer.append(SepConv(outc, outc))
        layer.append(snn.BatchNorm2d(outc))
        self.layer = nn.Sequential(*layer)

    def forward(self, x):
        out = self.layer(x)
        if self.down_sample:
            x = self.conv1x1(x)
        out += x
        return out

class S_Xception(nn.Module):
    def __init__(self, T=4):
        super(S_Xception, self).__init__()
        self.T = T
        self.first = nn.Sequential(
            snn.Conv2d(3, 32, 3, 2, 1, bias=False),
            snn.BatchNorm2d(32),
            neuron.IFNode()
        )
        self.block1 = BasicBlock(32, 64, False)
        self.block2 = BasicBlock(64, 64)
        self.block3 = BasicBlock(64, 64)
        self.block4 = BasicBlock(64, 128)
        self.last = nn.Sequential(
            SepConv(128, 256),
            snn.BatchNorm2d(256),
            neuron.IFNode(),
            snn.AdaptiveAvgPool2d((1, 1)),
            snn.Flatten(),
            snn.Linear(256, 10)
        )
        SF.set_step_mode(self, 'm')

    def forward(self, x):# N, C, H, W -> T, N, D
        SF.reset_net(self)
        x = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        x = self.first(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.last(x)
        return x.mean(0)

In [None]:
model = S_Xception(4)
params = sum(p.numel() for p in model.parameters())
print(params)

# Utils

In [None]:
class DataAugmentation:
    def __init__(self):
        self.device = device
        color_jitter = TF.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.tf = TF.Compose([
            TF.RandomResizedCrop(32, (0.36, 1)),
            TF.RandomHorizontalFlip(p=0.5),
            TF.RandomApply([color_jitter], p=0.8),
            TF.RandomGrayscale(p=0.2),
            TF.ToImage(),
            TF.ToDtype(torch.float32, scale=True)
        ])

    def __call__(self, x):
        return self.tf(x), self.tf(x)

In [None]:
def get_loader(data='cifar10', split='train', batch_size=128, DA=False):
    tf = DataAugmentation() if DA else TF.Compose([TF.ToImage(), TF.ToDtype(torch.float32, scale=True)])
    if data == 'cifar10':
        match split:
            case 'train':
                data = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
            case 'test':
                data = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
            case 'all':
                train = datasets.CIFAR10('./data', train=True, transform=tf, download=True)
                test = datasets.CIFAR10('./data', train=False, transform=tf, download=True)
                data = ConcatDataset([train, test])
    elif data == 'stl10':
        match split:
            case 'train':
                data = datasets.STL10('./data', split='train', transform=tf, download=True)
            case 'test':
                data = datasets.STL10('./data', split='test', transform=tf, download=True)
            case 'all':
                data = datasets.STL10('./data', split='unlabeled', transform=tf, download=True)
    else:
        print(f'{data} is not supported >_<. cifar10 or stl10 is supported')
    loader = DataLoader(data, batch_size, shuffle=True, drop_last=True, num_workers=2)
    return loader

In [None]:
def train_(loader, model, optimizer, scheduler, criterion, device):
    running_loss = 0
    correct = 0
    model.train()
    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        correct += (out.argmax(1) == target).sum().item()
    scheduler.step()
    return running_loss, correct

# Training

In [None]:
#instance
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#------------------------#
loader = get_loader('cifar10', split='train', batch_size=128, DA=False)
N = len(loader.dataset)
#------------------------#
model = S_Xception(4).to(device)
#------------------------#
optimizer = optim.SGD(model.parameters(), lr=0.2)
scheduler1 = lrs.CosineAnnealingLR(optimizer, T_max=8, eta_min=0.1, total_iters=8)
#------------------------#
criterion = nn.CrossEntropyLoss()
'''
wandb.login()
run = wandb.init(
    project = 'name',
    config = {
        'Architecture': 'x',
        'optim': 'Adam(1e-3)',
        'sche1': 'x',
        'sche2': 'x',
        'sche': 'x',
        'criterion': 'x',
        'Data': 'Cifar10',
        'else': 'x'
    }
)
'''
#train
start_epoch = 0
epochs = 8
for epoch in range(start_epoch, epochs):
    loss, correct = train_(loader, model, optimizer, scheduler1, criterion, device)
    print(f'Epoch: {epoch} | loss: {loss} | acc: {correct*100/N}%')

#save_checkpoint('!!!', simclr, optimizer, scheduler)

Files already downloaded and verified


100%|██████████| 390/390 [00:32<00:00, 12.07it/s]


Epoch: 0 | loss: 732.2418007850647 | acc: 28.482%


100%|██████████| 390/390 [00:31<00:00, 12.34it/s]


Epoch: 1 | loss: 611.673143029213 | acc: 41.548%


100%|██████████| 390/390 [00:31<00:00, 12.23it/s]


Epoch: 2 | loss: 556.4308239221573 | acc: 47.2%


100%|██████████| 390/390 [00:32<00:00, 12.07it/s]


Epoch: 3 | loss: 519.8156929016113 | acc: 50.92%


100%|██████████| 390/390 [00:31<00:00, 12.37it/s]


Epoch: 4 | loss: 490.36688554286957 | acc: 54.042%


100%|██████████| 390/390 [00:31<00:00, 12.34it/s]


Epoch: 5 | loss: 463.4905755519867 | acc: 56.954%


100%|██████████| 390/390 [00:32<00:00, 12.15it/s]


Epoch: 6 | loss: 444.6516178846359 | acc: 58.702%


100%|██████████| 390/390 [00:31<00:00, 12.33it/s]

Epoch: 7 | loss: 425.8461194038391 | acc: 60.96%





In [None]:
#SResNet14(4) | Cifar10 | sf=default | lr=2e-3 | MultiStep([4, 6], 0.75) | 1min/epoch | 0 1077, 49.232% -> 1 757, 65.62% -->> 4 392, 82.526% -->> 9 150, 93.22% testは79.65%
#SResNet14_group2(4) downsample(stride=2) | Cifar10 | lr=default | 40s/epoch | 33% -->> 10ep, 69.9%
# 上との対照実験group=1 | 33% -->> 73%
# 上との対照実験stride=3 | 30% -->> 73%
# 上との対照実験channel半分 | 29% -->> 63%

#11/11
# ResNet()