In [1]:
import numpy as np

# NumPy 호환성 패치 (NumPy 1.20+ 대응)
if not hasattr(np, 'int'):
    np.int = int
if not hasattr(np, 'float'):
    np.float = float
if not hasattr(np, 'bool'):
    np.bool = bool
if not hasattr(np, 'complex'):
    np.complex = complex
if not hasattr(np, 'object'):
    np.object = object

  if not hasattr(np, 'object'):


In [2]:
import torch
import torch.nn as nn
import spikingjelly
from spikingjelly.activation_based import neuron, functional, surrogate, layer, encoding


In [3]:
nn.Sequential( # stack layer
    nn.Flatten(),
    nn.Linear(28 * 28, 10, bias=False), # 학습 X, (input 수, output 수) => (input, 10)
    nn.Softmax()
    )

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=False)
  (2): Softmax(dim=None)
)

In [4]:
tau = 2.0
net = nn.Sequential(
    layer.Flatten(),
    layer.Linear(28 * 28, 10, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
    )

In [5]:
lr = 1e-3

# Use Adam optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# Use PoissonEncoder
encoder = encoding.PoissonEncoder()

mnist

In [6]:
import os
import time
import argparse
import sys
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

In [7]:
class SNN(nn.Module):
    def __init__(self, tau):
        super().__init__()

        self.layer = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        )

    def forward(self, x: torch.Tensor):
        return self.layer(x)

In [8]:
parser = argparse.ArgumentParser(description="LIF MNIST Training")
parser.add_argument("-T", default=100, type=int, help="simulating time-steps")
parser.add_argument("-device", default="cuda:0", help="device")
parser.add_argument("-b", default=64, type=int, help="batch size")
parser.add_argument(
        "-epochs",
        default=100,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
parser.add_argument(
        "-j",
        default=4,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 4)",
    )
parser.add_argument("-data-dir", type=str, help="root dir of MNIST dataset")
parser.add_argument(
        "-out-dir",
        type=str,
        default="./logs",
        help="root dir for saving logs and checkpoint",
    )
parser.add_argument("-resume", type=str, help="resume from the checkpoint path")
parser.add_argument(
        "-amp", action="store_true", help="automatic mixed precision training"
    )
parser.add_argument(
        "-opt",
        type=str,
        choices=["sgd", "adam"],
        default="adam",
        help="use which optimizer. SGD or Adam",
    )
parser.add_argument("-momentum", default=0.9, type=float, help="momentum for SGD")
parser.add_argument("-lr", default=1e-3, type=float, help="learning rate")
parser.add_argument(
        "-tau", default=2.0, type=float, help="parameter tau of LIF neuron"
    )

args = parser.parse_args([]) # 오류 : Jupyter Notebook에서 argparse 사용 시, 빈 리스트를 전달하여 기본값을 사용하도록 설정
print(args)

Namespace(T=100, device='cuda:0', b=64, epochs=100, j=4, data_dir=None, out_dir='./logs', resume=None, amp=False, opt='adam', momentum=0.9, lr=0.001, tau=2.0)


In [9]:
net = SNN(tau=args.tau)
print(net)

SNN(
  (layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1, step_mode=s)
    (1): Linear(in_features=784, out_features=10, bias=False)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)


In [10]:
net.to(args.device)

SNN(
  (layer): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1, step_mode=s)
    (1): Linear(in_features=784, out_features=10, bias=False)
    (2): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)

In [None]:
train_dataset = torchvision.datasets.MNIST(
        root=args.data_dir if args.data_dir else './data',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True,
    )
test_dataset = torchvision.datasets.MNIST(
        root=args.data_dir if args.data_dir else './data',
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True,
    )

train_data_loader = data.DataLoader(
        dataset=train_dataset, # 어느 데이터 불러올지
        batch_size=args.b,
        shuffle=True, # epoch마다 데이터 섞기
        drop_last=True, # 나누어떨어지지 않는 마지막 batch 버리기
        num_workers=args.j,# subprocess 수
        pin_memory=True, # cpu 공간 고정. gpu로 옮길 때 빠르게 옮길 수 있도록 도와줌
    )
test_data_loader = data.DataLoader(
        dataset=test_dataset,
        batch_size=args.b,
        shuffle=False,
        drop_last=False,
        num_workers=args.j,
        pin_memory=True,
    )

In [15]:
scaler = None

In [None]:
if args.amp: # gpu 속도 향상
    scaler = amp.GradScaler()

In [17]:
start_epoch = 0
max_test_acc = -1

In [18]:
optimizer = None
if args.opt == "sgd":
        optimizer = torch.optim.SGD(
            net.parameters(), lr=args.lr, momentum=args.momentum
        )
elif args.opt == "adam":
        optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
        raise NotImplementedError(args.opt)

In [None]:
if args.resume: # 학습 재개. 체크포인트 데이터 불러오기
        checkpoint = torch.load(args.resume, map_location="cpu")
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        start_epoch = checkpoint["epoch"] + 1
        max_test_acc = checkpoint["max_test_acc"]

In [None]:
out_dir = os.path.join(args.out_dir, f"T{args.T}_b{args.b}_{args.opt}_lr{args.lr}")
''' 파일 경로 합치기. 
붙일 경로가 /로 시작하면 앞에 경로 버리기
/가 없거나 앞의 경로에 붙어있다면 그냥 경로 잇기'''

In [None]:
if args.amp:
    out_dir += "_amp" # 경로 이름 끝에 _amp 붙이기

In [None]:
if not os.path.exists(out_dir): # out_dir이 존재하지 않으면 만들기
    os.makedirs(out_dir)
    print(f"Mkdir {out_dir}.")

Mkdir ./logs/T100_b64_adam_lr0.001.


In [23]:
with open(os.path.join(out_dir, "args.txt"), "w", encoding="utf-8") as args_txt:
    args_txt.write(str(args))

In [None]:
writer = SummaryWriter(out_dir, purge_step=start_epoch) # 로그 저장하는 파일 생성
with open(os.path.join(out_dir, "args.txt"), "w", encoding="utf-8") as args_txt:
        args_txt.write(str(args))
        args_txt.write("\n")
        args_txt.write(" ".join(sys.argv))

In [25]:
encoder = encoding.PoissonEncoder()