In [1]:
from DataLoader import MyOwnDataloader
from pycocotools.coco import COCO

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

In [2]:
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [13]:

class Net(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, timesteps: int):
        super(Net, self).__init__()
        self.timesteps = timesteps

        self.conv1 = nn.Conv2d(in_channels, 64,
                               kernel_size=7,
                               padding=3,
                               # no bias because it is not bio-plausible (and hard to impl in neuromorphic hardware)
                               bias=True,
                               dilation=1,
                               stride=2)
        self.spike1 = snn.Leaky(
            beta=0.5, spike_grad=surrogate.fast_sigmoid(slope=25), init_hidden=False)

        # residual block 2
        self.conv2 = nn.Conv2d(64, 64,
                               kernel_size=3,
                               padding=1,
                               # no bias because it is not bio-plausible (and hard to impl in neuromorphic hardware)
                               bias=True,
                               stride=2)
        self.spike2 = snn.Leaky(
            beta=0.5, spike_grad=surrogate.fast_sigmoid(slope=25), init_hidden=False)

        self.conv3 = nn.Conv2d(64, 128,
                               kernel_size=3,
                               padding=1,
                               bias=True,
                               stride=2)
        self.spike3 = snn.Leaky(
            beta=0.5, spike_grad=surrogate.fast_sigmoid(slope=25), init_hidden=False)

        # residual block 3
        self.conv4 = nn.Conv2d(128, 256,
                               kernel_size=3,
                               padding=1,
                               # no bias because it is not bio-plausible (and hard to impl in neuromorphic hardware)
                               bias=True,
                               stride=2)
        self.spike4 = snn.Leaky(
            beta=0.5, spike_grad=surrogate.fast_sigmoid(slope=25), init_hidden=False)

        self.conv5 = nn.Conv2d(256, 512,
                               kernel_size=3,
                               padding=1,
                               bias=True,
                               stride=2)
        self.spike5 = snn.Leaky(
            beta=0.5, spike_grad=surrogate.fast_sigmoid(slope=25), init_hidden=False)

        # classifying layers
        self.avg_pool = nn.AdaptiveAvgPool2d((512, 10))

        self.flat = nn.Flatten()
        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512, out_channels, bias=True)
        self.fc_spike = snn.Leaky(beta=0.5, spike_grad=surrogate.fast_sigmoid(
            slope=25), init_hidden=False, output=True)

        # self.final = nn.Linear(128, out_channels, bias=True)

    def forward(self, inputs):
        # resets every LIF neurons
        mem_spike1 = self.spike1.init_leaky()
        mem_spike2 = self.spike2.init_leaky()
        mem_spike3 = self.spike3.init_leaky()
        mem_spike4 = self.spike4.init_leaky()
        mem_spike5 = self.spike5.init_leaky()

        mem_fc_spike = self.fc_spike.init_leaky()

        # mem accumulator to get the prediction
        accumulator = []

        for k in range(self.timesteps):
            x = inputs[k, :, :, :]
            x = F.max_pool2d(self.conv1(x), 2)
            x, mem_spike1 = self.spike1(x, mem_spike1)

            x = F.max_pool2d(self.conv2(x), 2)
            x, mem_spike2 = self.spike2(x, mem_spike2)

            x = F.max_pool2d(self.conv3(x), 2)
            x, mem_spike3 = self.spike3(x, mem_spike3)

            x = F.max_pool2d(self.conv4(x), 2)
            x, mem_spike4 = self.spike4(x, mem_spike4)

            x = self.conv5(x)
            x, mem_spike5 = self.spike5(x, mem_spike5)

            x = self.avg_pool(x)

            # classifier
            # x = self.flat(x)
            x = self.dropout(x)
            x = self.fc(x)
            x, mem_fc_spike = self.fc_spike(x, mem_fc_spike)

            x = self.final(x)

            accumulator.append(mem_fc_spike)

        return accumulator


In [14]:
dataDir='/media/gamedisk/COCO_dataset/'
val='val2017'
train = 'train2017'

val_annFile='{}/annotations/instances_{}.json'.format(dataDir,val)
train_annFile='{}/annotations/instances_{}.json'.format(dataDir,train) 
# Batch size
batch_size = 16



classes = {
    "bird": 1,
    "cat": 2,
    "dog": 3,
    "horse": 4,
    "sheep": 5,
    "cow": 6,
    "elephant": 7,
    "bear": 8,
    "zebra": 9,
    "giraffe": 10
}


coco = COCO(val_annFile)
val_loader = MyOwnDataloader(dataDir = dataDir, dataType = val,
                     annFile = val_annFile, classes = classes, train_batch_size=batch_size, classifier=True)
valid_dl = val_loader.concat_datasets()


# coco = COCO(train_annFile)
# train_loader = MyOwnDataloader(dataDir = dataDir, dataType = train,
#                      annFile = train_annFile, classes = classes, train_batch_size=batch_size)
# train_dl = val_loader.concat_datasets()



loading annotations into memory...
Done (t=0.22s)
creating index...
index created!
loading annotations into memory...
Done (t=0.38s)
creating index...
index created!


In [15]:
net = Net(3,10, 100)
net = net.to(device)

In [16]:
data, targets = next(iter(valid_dl))


In [18]:
loss_fn = SF.ce_rate_loss()
net.train()
for data, target in tqdm(valid_dl):
    label= int(target[0]['labels'][0])
    imgs = list(img.to(device) for img in data)
    imgs = torch.stack(imgs)
    spk_rec, mem_rec = net(imgs)
    loss_val = loss_fn(spk_rec, target)
    print(loss_val)



  0%|          | 0/69 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (512x10 and 512x10)

In [None]:
def batch_accuracy(train_loader, net, num_steps):
  with torch.no_grad():
    total = 0
    acc = 0
    net.eval()

    train_loader = iter(train_loader)
    for data, targets in train_loader:
      data = data.to(device)
      targets = targets.to(device)
      spk_rec, _ = forward_pass(net, num_steps, data)

      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
      total += spk_rec.size(1)

  return acc/total

In [None]:
test_acc = batch_accuracy(test_loader, net, num_steps)


In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 10
test_acc_hist = []

# training loop
for epoch in range(num_epochs):

    avg_loss = backprop.BPTT(net, train_loader, optimizer=optimizer, criterion=loss_fn,
                            num_steps=num_steps, time_var=False, device=device)

    print(f"Epoch {epoch}, Train Loss: {avg_loss.item():.2f}")

    # Test set accuracy
    test_acc = batch_accuracy(train_loader, net, num_steps)
    test_acc_hist.append(test_acc)

    print(f"Epoch {epoch}, Test Acc: {test_acc * 100:.2f}%\n")