<a href="https://colab.research.google.com/github/TK-brsq/Research/blob/main/Simple_Conversion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.25, 0.25, 0.25))])

train = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test, batch_size=64, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:15<00:00, 10948466.77it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


CNN

In [None]:
cnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding='same', bias=False),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, padding='same', bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.AvgPool2d(2),
    nn.Flatten(),
    nn.Dropout(p=0.2),
    nn.Linear(8*8*64, 1024, bias=False),
    nn.ReLU(),
    nn.Dropout(p=0.2),
    nn.Linear(1024, 10, bias=False)
)
#4layer(2conv(3->32->64), 2fc(4096->1024->10), Avgpool), epochs=4, weight_decay=1e-5 acc=[77.32, 71.89]
#上のモデルに次を加える. Linearの前にDropout(p=0.2), scheduler=MultiStep(mile=[4, 7], gamma0.5), epochs=8, acc=[79.29, 75.20]
#さらにbias=False, mile=[4, 6, 7], acc=[87.78, 78.14]

Sequential(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, bias=False)
  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU()
  (7): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Dropout(p=0.2, inplace=False)
  (10): Linear(in_features=4096, out_features=1024, bias=False)
  (11): ReLU()
  (12): Dropout(p=0.2, inplace=False)
  (13): Linear(in_features=1024, out_features=10, bias=False)
)


In [None]:
optimizer = optim.Adam(params=cnn.parameters(), weight_decay=1e-5)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 6, 7], gamma=0.5)
criteria = nn.CrossEntropyLoss()

In [None]:
epochs = 8
for epoch in range(epochs):
    cnn.train()
    trloss = 0
    trcorrect = 0
    for data, target in train_loader:
        optimizer.zero_grad()
        out = cnn(data)
        loss = criteria(out, target)
        loss.backward()
        optimizer.step()

        trloss += loss.item()
        _, pred = torch.max(out, 1)
        trcorrect += (pred == target).sum().item()
    scheduler.step()

    cnn.eval()
    tsloss = 0
    tscorrect = 0
    for data, target in test_loader:
        out = cnn(data)
        loss = criteria(out, target)

        tsloss += loss.item()
        _, pred = out.max(1)
        tscorrect += (pred == target).sum().item()

    N = len(train_loader.dataset)
    n = len(test_loader.dataset)
    print(f'epoch: {epoch+1}\n train : loss={trloss/N:.4f}, acc={trcorrect*100/N:.2f}% \t test : loss={tsloss/n:.4f}, acc={tscorrect*100/n:.2f}%\n')

epoch: 1
 train : loss=0.0202, acc=53.26% 	 test : loss=0.0161, acc=63.36%

epoch: 2
 train : loss=0.0151, acc=65.86% 	 test : loss=0.0136, acc=69.87%

epoch: 3
 train : loss=0.0127, acc=71.20% 	 test : loss=0.0127, acc=71.72%

epoch: 4
 train : loss=0.0112, acc=74.77% 	 test : loss=0.0118, acc=73.89%

epoch: 5
 train : loss=0.0087, acc=80.33% 	 test : loss=0.0106, acc=76.61%

epoch: 6
 train : loss=0.0078, acc=82.35% 	 test : loss=0.0106, acc=76.70%

epoch: 7
 train : loss=0.0064, acc=85.89% 	 test : loss=0.0102, acc=78.25%

epoch: 8
 train : loss=0.0055, acc=87.78% 	 test : loss=0.0102, acc=78.14%



# ここから

SNN

In [3]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.1-py2.py3-none-any.whl.metadata (16 kB)
Collecting nir (from snntorch)
  Downloading nir-1.0.4-py3-none-any.whl.metadata (5.8 kB)
Collecting nirtorch (from snntorch)
  Downloading nirtorch-1.0-py3-none-any.whl.metadata (3.6 kB)
Downloading snntorch-0.9.1-py2.py3-none-any.whl (125 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.3/125.3 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading nir-1.0.4-py3-none-any.whl (18 kB)
Downloading nirtorch-1.0-py3-none-any.whl (13 kB)
Installing collected packages: nir, nirtorch, snntorch
Successfully installed nir-1.0.4 nirtorch-1.0 snntorch-0.9.1


In [4]:
import snntorch as snn
import snntorch.utils as sutils
import snntorch.functional as sF

In [5]:
scnn = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding='same', bias=False),
    nn.BatchNorm2d(32),
    snn.Leaky(beta=1.0, init_hidden=True),
    nn.AvgPool2d(2),
    nn.Conv2d(32, 64, 3, padding='same', bias=False),
    nn.BatchNorm2d(64),
    snn.Leaky(beta=1.0, init_hidden=True),
    nn.AvgPool2d(2), #[batch, channels, 16, 16] -> [batch, channels, 8, 8]
    nn.Flatten(), # -> [batch, channels*8*8]
    nn.Linear(8*8*64, 1024, bias=False),
    snn.Leaky(beta=1.0, init_hidden=True),
    nn.Linear(1024, 10, bias=False), # -> [batch, channels*8*8]
    snn.Leaky(beta=1.0)
)

In [6]:
#torch.save(cnn.state_dict(), 'simple_cnn.pth')
cnn_state_dict = torch.load('simple_cnn.pth', weights_only=True)

conversion

In [7]:
#conversion
scnn_state_dict = {}
scnn_state_dict = {k: v for k, v in cnn_state_dict.items() if k in scnn.state_dict()}
scnn_state_dict['9.weight'] = cnn_state_dict['10.weight']
scnn_state_dict['11.weight'] = cnn_state_dict['13.weight']

scnn.load_state_dict(scnn_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['2.threshold', '2.graded_spikes_factor', '2.reset_mechanism_val', '2.beta', '6.threshold', '6.graded_spikes_factor', '6.reset_mechanism_val', '6.beta', '10.threshold', '10.graded_spikes_factor', '10.reset_mechanism_val', '10.beta', '12.threshold', '12.graded_spikes_factor', '12.reset_mechanism_val', '12.beta'], unexpected_keys=[])

In [8]:
criterion = sF.ce_count_loss()
#acc = sF.accuracy_rate()

In [10]:
batch = 64
epochs = 1
timesteps = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
scnn.to(device)

for epoch in range(epochs):
    scnn.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in tqdm(train_loader):
            data, target = data.to(device), target.to(device)
            sutils.reset(scnn)
            spk_step = []
            for step in range(timesteps):
                spk, mem = scnn(data) #len(spk)= 64(batch)
                spk_step.append(spk)
                del spk, mem
            spk_batch = torch.stack(spk_step) #spk_batch.shape = [timesteps, batch, class]

            loss += criterion(spk_batch, target).item()
            _, pred = spk_batch.sum(dim=0).max(1)
            correct += (pred == target).sum().item()

            del spk_batch, spk_step
            torch.cuda.empty_cache()
        #n = len(test_loader.dataset)
        N = len(train_loader.dataset)
        print(loss/N, correct*100/N)
#timesteps = 4, [68.39 ,62.58]
#timesteps = 16, [74,368, 67.42]
#timesteps = 64, [74.944, 67.73]
#timesteps = 128, [75.018, ]
#timesteps = 256, [75.042]
#cnn = [87, 78]

100%|██████████| 782/782 [06:45<00:00,  1.93it/s]

0.29229616631150246 75.042



