# Convolutional SNN
### **Classifying Fashion-MNIST with Convolutional SNN**

This tutorial goes over how to train a convolutional spiking neural network (CSNN) on the Fashion-MNIST dataset and deploy on HiAER Spike using our conversion pipline.

### **Define a CSNN**
To build a CSNN with PyTorch, we can use snnTorch, SpikingJelly or other deep learning frameworks that are based on PyTorch. Currently, our conversion pipline supports snnTorch and SpikingJelly. In this tutorial, we will be using SpikingJelly.

Install the PyPi distribution of SpikingJelly

Import necessary libraries from SpikingJelly and PyTorch

In [1]:
from spikingjelly.activation_based import neuron, functional, surrogate, layer
import torch 
import os
import numpy as np
from skimage import io, transform
import tifffile as tiff
from torchvision.transforms import v2
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

### **Model Architecture**
Using SpikingJelly, we can define a CSNN with the architecture of 8C3-BN-6272FC10
- 8C3: a 3x3 convolutional kernel with 8 channels
- BN: batch normalization layer 
- 6272FC10: the fully connected output layer 
 
#### **Surrogate Function**
SpikingJelly and snnTorch both use backpropagation through time to train the spiking neural networks. However, because of the non-differentiability of spikes, surrogate gradients are used in place of the Heaviside function in the backward pass

In [2]:
class model(nn.Module): 
    def __init__(self, channels=8): 
        super().__init__()
        #first block
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 2, kernel_size=1, padding=1, bias=False)
        self.lif1 = neuron.IFNode(surrogate_function=surrogate.ATan())
        self.conv2 = nn.Conv2d(in_channels = 2, out_channels = 12, kernel_size=3, padding=1, bias=False)
        self.lif2 = neuron.IFNode(surrogate_function=surrogate.ATan())
        self.conv3 = nn.Conv2d(in_channels = 12, out_channels = 12, kernel_size=3, padding=1, bias=False)
        self.lif3 = neuron.IFNode(surrogate_function=surrogate.ATan())
        
        #downsample
        self.conv4 = nn.Conv2d(in_channels = 12,out_channels = 24, kernel_size=3, stride=2, padding=1, bias=False)
        #upsample
        self.conv5 = nn.ConvTranspose2d(in_channels = 24,out_channels = 12, kernel_size=2, stride=2, bias=False)
        
        self.conv6 = nn.Conv2d(in_channels = 24,out_channels = 12, kernel_size = 3, padding=1, bias=False)
        self.lif4 = neuron.IFNode(surrogate_function=surrogate.ATan())
        self.conv7 = nn.Conv2d(in_channels = 12,out_channels = 12, kernel_size = 3, padding=1, bias=False)
        self.lif5 = neuron.IFNode(surrogate_function=surrogate.ATan())
        self.conv8 = nn.Conv2d(in_channels=12, out_channels= 2, kernel_size=1, bias=False)
        self.lif6 = neuron.IFNode(surrogate_function=surrogate.ATan())
        
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.lif1(x)
        x = self.conv2(x)
        x = self.lif2(x)
        x = self.conv3(x)
        x = self.lif3(x)
        #crop
        y = v2.CenterCrop(size=54)(x)
        #downsample
        x = self.conv4(x)
        x = v2.CenterCrop(size=27 )(x)
        z = self.conv5(x)
        #print(y.size())
        #print(z.size())
        x = torch.concatenate([y,z], axis=1)
        #print(x.size())
        x = self.conv6(x)
        x = self.lif4(x)
        x = self.conv7(x)
        x = self.lif5(x)
        x = self.conv8(x)
        x = self.lif6(x)
        
        return x

In [3]:
#Initiate the Network
net = model()

### **Setting up the MNIST Dataset**

In [4]:
class EMDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir, transform=None, test=False):
        """
        Arguments:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform        
        if test:
            self.img_trn = tiff.imread(os.path.join(root_dir,'test-volume.tif'))
            self.msk_trn = tiff.imread(os.path.join(root_dir,'test-labels.tif'))
        else:
            self.img_trn = tiff.imread(os.path.join(root_dir,'train-volume.tif'))
            self.msk_trn = tiff.imread(os.path.join(root_dir,'train-labels.tif'))
        expand = np.zeros((len(self.msk_trn),2,512,512))
        for idx,mask in enumerate(self.msk_trn):
            expand[idx,0,:,:] = mask==0
            expand[idx,1,:,:] = mask==255
        self.msk_trn = expand
        self.msk_trn.astype(float)
        print(np.shape(self.msk_trn))

    def __len__(self):
        return len(self.img_trn)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.img_trn[idx]
        mask = self.msk_trn[idx]
        #extract the channels
        
        
        sample = {'image': image, 'mask': mask}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [5]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))
        msk = transform.resize(mask,  (2, new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        #landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'mask': mask}


class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = torch.randint(low = 0, high = h - new_h + 1, size=(1,1))
        left = torch.randint(low = 0, high= w - new_w + 1, size=(1,1))

        image = image[top: top + new_h,
                      left: left + new_w]

        mask = mask[:,top: top + new_h,
                      left: left + new_w]

        return {'image': image, 'mask': mask}
    
class flipHorizontal(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self):
        return

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        
        # Random horizontal flipping
        rand_int = torch.randint(low = 0, high = 1, size=(1,1))
        if rand_int > 0.5:
            image = F.hflip(image)
            mask = F.hflip(mask)

        return {'image': image, 'mask': mask}
    
class flipVertical(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self):
        return

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        
        # Random horizontal flipping
        rand_int = torch.randint(low = 0, high = 1, size=(1,1))
        if rand_int > 0.5:
            image = F.vflip(image)
            mask = F.vflip(mask)

        return {'image': image, 'mask': mask}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C x H x W
        image = image[None,:,:]
        #image = image[None,:,:]
        #image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'mask': torch.from_numpy(mask)}

In [6]:
em_dataset = EMDataset(root_dir='/Users/gweneverefrank/code/hs_api/examples/ISBI-2012-challenge',  transform=transforms.Compose([
                                               Rescale(192),
                                               RandomCrop(64),
                                               ToTensor(),
                                               flipHorizontal(),
                                               flipVertical()
                                           ]))

em_dataset_test = EMDataset(root_dir='/Users/gweneverefrank/code/hs_api/examples/ISBI-2012-challenge',  transform=transforms.Compose([
                                               Rescale(192),
                                               RandomCrop(64),
                                               ToTensor()
                                           ]),test=True)


for i, sample in enumerate(em_dataset):
    print(i, sample['image'].shape, sample['mask'].shape)

(30, 2, 512, 512)
(30, 2, 512, 512)
0 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
1 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
2 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
3 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
4 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
5 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
6 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
7 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
8 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
9 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
10 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
11 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
12 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
13 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
14 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
15 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
16 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
17 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
18 torch.Size([1, 64, 64]) torch.Size([2, 64, 64])
19 to

In [7]:
train_loader = DataLoader(em_dataset, batch_size=1,
                        shuffle=True, num_workers=0)

test_loader = DataLoader(em_dataset_test, batch_size=1,
                        shuffle=True, num_workers=0)


for i, sample in enumerate(train_loader):
    print(i, sample['image'].shape, sample['mask'].shape)
    print('type: '+str(sample['image'].type())+' '+str(sample['mask'].type()))

0 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
1 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
2 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
3 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
4 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
5 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
6 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
7 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
8 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor
9 torch.Size([1, 1, 64, 64]) torch.Size([1, 2, 64, 64])
type: torch.DoubleTensor torch.DoubleTensor


### **Training the SNN**
Since we are using a static image dataset, we will first encode the image into spikes using the rate encoding function from spikingjelly. With rate encoding, the input feature determines the firing frequency and the neuron that fries the most is selected as the predicted class.  

In [8]:
from spikingjelly.activation_based import encoding
import time

In [9]:
#Setting up the encoder and the time steps
encoder = encoding.PoissonEncoder()
num_steps = 40

#Define training parameters
epochs = 40
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

#Copy netowrk to device 
net.to(device)

#Define optimizer, scheduler and the loss function
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
loss_fun = torch.nn.BCEWithLogitsLoss()

In [10]:
for epoch in range(epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for sample in train_loader:
        img = sample['image']
        mask = sample['mask']
        optimizer.zero_grad()
        img = img.to(device)
        mask = mask.to(device)
        #label_onehot = torch.nn.functional.one_hot(label, 10).float()
        out_fr = 0.
        for t in range(num_steps):
            #print(img.size())
            encoded_img = encoder(img)
            #print(encoded_img.size())
            out_fr += net(encoded_img.float())
        out_fr = out_fr/num_steps
        #print('outputsize')
        #print(out_fr.size())
        #out_fr = out_fr[None, :, :, :]
        out_fr = F.interpolate(out_fr, (64,64))
        #out_fr = torch.softmax(out_fr,dim=1).float()
        loss = loss_fun(out_fr.float(), mask.float())
        loss.backward()
        optimizer.step()

        train_samples += 1
        train_loss += loss.item()
        class_labels = mask.argmax(1)
        output_labels = out_fr.argmax(1)
        train_acc += (output_labels == class_labels).float().sum().item() / (64**2)

        #reset the membrane protential after each input image
        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples
    print('finished a bout')
    
    lr_scheduler.step()
        
    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0

    with torch.no_grad():
        for sample in test_loader:
            img = sample['image']
            mask = sample['mask']
            img = img.to(device)
            mask = mask.to(device)
            out_fr = 0.   
            for t in range(num_steps):
                #print(img.size())
                encoded_img = encoder(img)
                #print(encoded_img.size())
                out_fr += net(encoded_img.float())
            out_fr = out_fr/num_steps
            #print(out_fr.size())
            #out_fr = out_fr[None, :, :, :]
            out_fr = F.interpolate(out_fr, (64,64))
            loss = loss_fun(out_fr, mask)

            
            
            test_samples += 1
            test_loss += loss.item()
            test_class_labels = mask.argmax(1)
            test_output_labels = out_fr.argmax(1)
            test_acc += (test_output_labels == test_class_labels).float().sum().item() / (64**2)
            functional.reset_net(net)


    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples

    print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
    print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')

finished a bout
epoch = 0, train_loss = 0.6876, train_acc = 0.4770, test_loss = 0.6661, test_acc = 0.7501
train speed = 5.9364 images/s, test speed = 12.9118 images/s
finished a bout
epoch = 1, train_loss = 0.6461, train_acc = 0.7494, test_loss = 0.6376, test_acc = 0.7345
train speed = 5.9349 images/s, test speed = 12.8414 images/s
finished a bout
epoch = 2, train_loss = 0.6119, train_acc = 0.7980, test_loss = 0.6202, test_acc = 0.7762
train speed = 5.9337 images/s, test speed = 12.8454 images/s
finished a bout
epoch = 3, train_loss = 0.6286, train_acc = 0.7536, test_loss = 0.6318, test_acc = 0.7449
train speed = 5.9708 images/s, test speed = 12.6790 images/s
finished a bout
epoch = 4, train_loss = 0.6148, train_acc = 0.7818, test_loss = 0.6132, test_acc = 0.7850
train speed = 5.8946 images/s, test speed = 12.8548 images/s
finished a bout
epoch = 5, train_loss = 0.6240, train_acc = 0.7619, test_loss = 0.6354, test_acc = 0.7370
train speed = 5.8511 images/s, test speed = 12.9549 images/

### **Converting the trained SNN to HiAER Spike Format**

In [12]:
from hs_api.converter import CRI_Converter, Quantize_Network, BN_Folder
from hs_api.api import CRI_network
# import hs_bridge #Uncomment when running on FPGA

#Fold the BN layer 
bn = BN_Folder() 
net_bn = bn.fold(net)

#Weight, Bias Quantization 
qn = Quantize_Network() 
net_quan = qn.quantize(net_bn)

#Set the parameters for conversion
input_layer = 0 #first pytorch layer that acts as synapses
output_layer = 4 #last pytorch layer that acts as synapses
input_shape = (1, 28, 28)
backend = 'spikingjelly'
v_threshold = qn.v_threshold

cn = CRI_Converter(num_steps = num_steps, 
                   input_layer = input_layer, 
                   output_layer = output_layer, 
                   input_shape = input_shape,
                   backend=backend,
                   v_threshold = v_threshold)
cn.layer_converter(net_quan)

TypeError: Quantize_Network.__init__() missing 1 required positional argument: 'w_alpha'

### **Initiate the HiAER Spike SNN**

In [None]:
config = {}
config['neuron_type'] = "I&F"
config['global_neuron_params'] = {}
config['global_neuron_params']['v_thr'] = int(quan_fun.v_threshold)
    
# #Uncomment this to create a network running on the FPGA
# hardwareNetwork = CRI_network(dict(cri_convert.axon_dict),
#                               connections=dict(cri_convert.neuron_dict),
#                               config=config,target='CRI', 
#                               outputs = cri_convert.output_neurons,
#                               coreID=1)

softwareNetwork = CRI_network(dict(cri_convert.axon_dict),
                              connections=dict(cri_convert.neuron_dict),
                              config=config,target='simpleSim', 
                              outputs = cri_convert.output_neurons,
                              coreID=1)

### **Deploying the SNN on HiAER Spike**

run_sw and run_hw are two helper functions for running the spiking neural network 

In [None]:
def Run_sw(self,inputList,softwareNetwork):
    predictions = []
    total_time_cri = 0
    #each image
    for currInput in tqdm(inputList):
        #reset the membrane potential to zero
        softwareNetwork.simpleSim.initialize_sim_vars(len(self.neuron_dict))
        spikeRate = [0]*10
        #each time step
        for slice in currInput:
            start_time = time.time()
            swSpike = softwareNetwork.step(slice, membranePotential=False)

            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            for spike in swSpike:
                spikeIdx = int(spike) - self.bias_start_idx 
                try: 
                    if spikeIdx >= 0: 
                        spikeRate[spikeIdx] += 1 
                except:
                    print("SpikeIdx: ", spikeIdx,"\n SpikeRate:",spikeRate)
        predictions.append(spikeRate.index(max(spikeRate)))
    print(f"Total simulation execution time: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
def run_CRI_hw(self,inputList,hardwareNetwork):
    predictions = []
    #each image
    total_time_cri = 0
    for currInput in tqdm(inputList):
        #initiate the hardware for each image
        hs_bridge.FPGA_Execution.fpga_controller.clear(len(self.neuron_dict), False, 0)  ##Num_neurons, simDump, coreOverride
        spikeRate = [0]*10
        #each time step
        for slice in tqdm(currInput):
            start_time = time.time()
            hwSpike, latency, hbmAcc = hardwareNetwork.step(slice, membranePotential=False)
            print(f'hwSpike: {hwSpike}\n. latency : {latency}\n. hbmAcc:{hbmAcc}')
            end_time = time.time()
            total_time_cri = total_time_cri + end_time-start_time
            for spike in hwSpike:
                # print(int(spike))
                spikeIdx = int(spike) - self.bias_start_idx 
                try: 
                    if spikeIdx >= 0: 
                        spikeRate[spikeIdx] += 1 
                except:
                    print("SpikeIdx: ", spikeIdx,"\n SpikeRate:",spikeRate)
        predictions.append(spikeRate.index(max(spikeRate))) 
    print(f"Total execution time CRIFPGA: {total_time_cri:.5f} s")
    return(predictions)

In [None]:
cri_convert.bias_start_idx = int(cri_convert.output_neurons[0])
loss_fun = nn.MSELoss()
start_time = time.time()
test_loss = 0
test_acc = 0
test_samples = 0
num_batches = 0

RUN_HARDWARE = False #Set to True if running on FPGA

for img, label in tqdm(test_loader):
    cri_input = cri_convert.input_converter(img)
    output = None
    if RUN_HARDWARE:
        output = torch.tensor(run_CRI_hw(cri_input,hardwareNetwork), dtype=float)
    else:
        output = torch.tensor(run_CRI_sw(cri_input,softwareNetwork), dtype=float)
    loss = loss_fun(output, label)
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (output == label).float().sum().item()
    num_batches += 1
test_time = time.time()
test_speed = test_samples / (test_time - start_time)
test_loss /= test_samples
test_acc /= test_samples

print(f'test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}')
print(f'test speed ={test_speed: .4f} images/s')