#Install SlayerPytorch

In [None]:
!git clone https://www.github.com/bamsumit/slayerPytorch
!pip install ninja
exit()

In [None]:
%cd slayerPytorch/
!python setup.py install
exit()

#SNN configuration

In [None]:
import sys, os
CURRENT_TEST_DIR = os.getcwd()
sys.path.append(CURRENT_TEST_DIR + "/../../src")

from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import torch
from torch.utils.data import Dataset, DataLoader
import slayerSNN as snn
#from learningStats import learningStats
from IPython.display import HTML
import time
import shutil

def save_ckp(state, is_best_loss, is_best_acc, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir+'checkpoint.pt'
    torch.save(state, f_path)
    if is_best_loss:
        best_fpath = best_model_dir+'best_model.pt'
        shutil.copyfile(f_path, best_fpath)
    if is_best_acc:
        acc_fpath =  best_model_dir+'best_acc.pt'
        shutil.copyfile(f_path, acc_fpath)
        
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

actionName = [
    'hand_clapping',
    'right_hand_wave',
    'left_hand_wave',
    'right_arm_clockwise',
    'right_arm_counter_clockwise',
    'left_arm_clockwise', 
    'left_arm_counter_clockwise',
    'arm_roll',
    'air_drums',
    'air_guitar',
    'other_gestures',
]


# Define dataset module
class IBMGestureDataset(Dataset):
    def __init__(self, datasetPath, sampleFile, samplingTime, sampleLength):
        self.path = datasetPath 
        self.samples = np.loadtxt(sampleFile).astype('int')
        self.samplingTime = samplingTime
        self.nTimeBins    = int(sampleLength / samplingTime)

    def __getitem__(self, index):
        # Read inoput and label
        inputIndex  = self.samples[index, 0]
        classLabel  = self.samples[index, 1]
        # Read input spike
        inputSpikes = snn.io.readNpSpikes(
                        self.path + str(inputIndex.item()) + '.npy'
                        ).toSpikeTensor(torch.zeros((2,128,128,self.nTimeBins)),
                        samplingTime=self.samplingTime)
        # Create one-hot encoded desired matrix
        desiredClass = torch.zeros((11, 1, 1, 1))
        desiredClass[classLabel,...] = 1
        
        return inputSpikes, desiredClass, classLabel

    def __len__(self):
        return self.samples.shape[0]
		
# Define the network
class Network(torch.nn.Module):
    def __init__(self, netParams):
        super(Network, self).__init__()
        # initialize slayer
        slayer = snn.loihi(netParams['neuron'], netParams['simulation'])
        self.slayer = slayer
        # define network functions
        self.conv1 = slayer.conv(2, 16, 5, padding=2, weightScale=10)
        self.conv2 = slayer.conv(16, 32, 3, padding=1, weightScale=50)
        self.pool1 = slayer.pool(4)
        self.pool2 = slayer.pool(2)
        self.pool3 = slayer.pool(2)
        self.fc1   = slayer.dense((8*8*32), 512)
        self.fc2   = slayer.dense(512, 11)
        self.drop  = slayer.dropout(0.1)

    def forward(self, spikeInput):
        spike = self.slayer.spikeLoihi(self.pool1(spikeInput )) # 32, 32, 2
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.conv1(spike)) # 32, 32, 16
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.pool2(spike)) # 16, 16, 16
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.conv2(spike)) # 16, 16, 32
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.pool3(spike)) #  8,  8, 32
        spike = spike.reshape((spike.shape[0], -1, 1, 1, spike.shape[-1]))
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.fc1  (spike)) # 512
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.fc2  (spike)) # 11
        spike = self.slayer.delayShift(spike, 1)
        
        return spike

if __name__ == '__main__':
  netParams = snn.params('slayerPytorch/exampleLoihi/03_IBMGesture/network.yaml')

  # Define the cuda device to run the code on.
  device = torch.device('cuda')
  # deviceIds = [2, 3]

  # Create network instance.
  net = Network(netParams).to(device)
  # net = torch.nn.DataParallel(Network(netParams).to(device), device_ids=deviceIds)

  # Create snn loss instance.
  error = snn.loss(netParams, snn.loihi).to(device)

  # Define optimizer module.
  # optimizer = torch.optim.Adam(net.parameters(), lr = 0.01, amsgrad = True)
  optimizer = snn.utils.optim.Nadam(net.parameters(), lr = 0.01, amsgrad = True)
  testingSet = IBMGestureDataset("drive/My Drive/Test/",
                                "drive/My Drive/Test/test.txt",
                                1.0,
                                1450)
  testLoader = DataLoader(dataset=testingSet, batch_size=1, shuffle=False, num_workers=1)

  ckp_path = "drive/My Drive/best/check.pt"
  net, optimizer, start_epoch = load_ckp(ckp_path, net, optimizer)
  # Learning stats instance.
  stats = snn.utils.stats() 

#Uniform noise injection 

In [None]:
def update(x,y,t,A,s):
  for i in range(x-s,x+s+1):
     for j in range(y-s,y+s+1):
        if not(i==x and j==y) and i>0 and j>0 and i<127 and j<127:
	        A[i][j] = t

np.random.seed(7)
mask = np.where(np.random.uniform(size=(1,2,128,128,1450))<0.01,1,0)
np.random.seed(1)
uniform = np.random.uniform(size=(1,2,128,128,1450))
noise= torch.Tensor(mask*abs(uniform))
amp=[1 ,0.85, 0.70, 0.55, 0.40, 0.25, 0.10]
for a in amp:
  net.eval()
  stats.testing.reset()
  tSt = datetime.now()
  print(tSt,a)
  # Testing loop.
  for i, (input, target, label) in enumerate(testLoader, 0):
    input= torch.clamp(input+a*noise,0,1)
    #print(i)
    input=torch.reshape(input,(1450,128,128,2))
    index=np.nonzero(input)
    temp= np.zeros((128,128))
    for i in index.numpy():
      update(i[1],i[2],i[0],temp)
      if i[0]-temp[i[1]][i[2]]>=5:
        input[i[0],i[1],i[2],i[3]]= 0.
    with torch.no_grad():
      input  = input.to(device)
      target = target.to(device)
    input=torch.reshape(input,(1,2,128,128,1450))
    output = net.forward(input)
    # Stats updating		
    stats.testing.correctSamples += torch.sum(snn.predict.getClass(output) == label).data.item()
    stats.testing.numSamples     += len(label)
    loss = error.numSpikes(output, target)
    stats.testing.lossSum += loss.cpu().data.item()
  tEnd=datetime.now()
  print(tEnd)
  print(a,stats.testing.accuracy())

#Normal noise injection 

In [None]:
def update(x,y,t,A,s):
  for i in range(x-s,x+s+1):
     for j in range(y-s,y+s+1):
        if not(i==x and j==y) and i>0 and j>0 and i<127 and j<127:
	        A[i][j] = t

np.random.seed(7)
mask = np.where(np.random.uniform(size=(1,2,128,128,1450))<0.01,1,0)
np.random.seed(1)
normal = np.random.normal(size=(1,2,128,128,1450))
noise= torch.Tensor(mask*abs(normal))
amp=[1 ,0.85, 0.70, 0.55, 0.40, 0.25, 0.10]
for a in amp:
  net.eval()
  stats.testing.reset()
  tSt = datetime.now()
  print(tSt,a)
  # Testing loop.
  for i, (input, target, label) in enumerate(testLoader, 0):
    input= torch.clamp(input+a*noise,0,1)
    #print(i)
    input=torch.reshape(input,(1450,128,128,2))
    index=np.nonzero(input)
    temp= np.zeros((128,128))
    for i in index.numpy():
      update(i[1],i[2],i[0],temp)
      if i[0]-temp[i[1]][i[2]]>=5:
        input[i[0],i[1],i[2],i[3]]= 0.
    with torch.no_grad():
      input  = input.to(device)
      target = target.to(device)
    input=torch.reshape(input,(1,2,128,128,1450))
    output = net.forward(input)
    # Stats updating		
    stats.testing.correctSamples += torch.sum(snn.predict.getClass(output) == label).data.item()
    stats.testing.numSamples     += len(label)
    loss = error.numSpikes(output, target)
    stats.testing.lossSum += loss.cpu().data.item()
  tEnd=datetime.now()
  print(tEnd)
  print(a,stats.testing.accuracy())