In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing

In [2]:
class CSNN(nn.Module):
    def __init__(self, T: int, channels: int, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
        layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False), # 28*28
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 14 * 14

        layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
        layer.BatchNorm2d(channels),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        layer.MaxPool2d(2, 2),  # 7 * 7

        layer.Flatten(), # 7*7 -> 49*1
        layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False), # 7*7 -> 4*4
        neuron.IFNode(surrogate_function=surrogate.ATan()),

        layer.Linear(channels * 4 * 4, 10, bias=False),
        neuron.IFNode(surrogate_function=surrogate.ATan()),
        )
        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr

In [None]:
def main():
    parser = argparse.ArgumentParser(description="LIF FMNIST Training")
    parser.add_argument("--data_dir", type=str, default="data/FMIST")
    parser.add_argument("--device", type=str, default=('cuda' if torch.cuda.is_available() else 'cpu'))
    parser.add_argument("-T", default=1, type=int, help="simulating time-steps")
    parser.add_argument("-b", default=64, type=int, help="batch size")
    parser.add_argument(
        "-epochs",
        default=3,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
    parser.add_argument(
        "-j",
        default=4,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
    parser.add_argument(
        "-out-dir",
        type=str,
        default="./logs",
        help="root dir for saving logs and checkpoint",
    )
    parser.add_argument("-resume", type=str, help="resume from the checkpoint path")
    parser.add_argument(
        "-amp", action="store_true", help="automatic mixed precision training"
    )
    parser.add_argument(
        "-opt",
        type=str,
        choices=["sgd", "adam"],
        default="adam",
        help="use which optimizer. SGD or Adam",
    )
    parser.add_argument("-momentum", default=0.9, type=float, help="momentum for SGD")
    parser.add_argument("-lr", default=1e-3, type=float, help="learning rate")
    parser.add_argument(
        "-tau", default=2.0, type=float, help="parameter tau of LIF neuron"
    )
    args, _ = parser.parse_known_args()
    net = CSNN(T=args.T, channels=32, use_cupy=False)
    net.to(args.device)

    train_dataset = torchvision.datasets.FashionMNIST(
        root=args.data_dir,
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    test_dataset = torchvision.datasets.FashionMNIST(
        root=args.data_dir,    
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True)

    train_data_loader = data.DataLoader(
        dataset=train_dataset,
        batch_size=args.b,
        shuffle=True, # 매 epoch 시작마다 데이터 섞기
        drop_last=True, # 남는 자투리 데이터 버린다
        num_workers=args.j,
        pin_memory=True, # 메모리 고정
    )
    test_data_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=args.b,
        shuffle=False,
        drop_last=False,
        num_workers=args.j,
        pin_memory=True,
    )


if __name__ == "__main__":
    main()

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FMIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:07<00:00, 3.50MB/s]


Extracting data/FMIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FMIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FMIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 105kB/s]


Extracting data/FMIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FMIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FMIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:02<00:00, 2.01MB/s]


Extracting data/FMIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FMIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FMIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 11.9MB/s]


Extracting data/FMIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FMIST/FashionMNIST/raw

