# Install SlayerPytorch on Colab
After the installations the runtime needs to be restarted. 
```
exit()
```
Will restart the runtime without deleting files. The runtime will automatically start. And if you press "run all" the run is not interrupted and works till the end.

In [None]:
!git clone https://github.com/bamsumit/slayerPytorch
!pip install ninja
exit()

In [None]:
%cd slayerPytorch/
!python setup.py install
exit()

Test to verify if everything went well with the installation

In [None]:
%cd slayerPytorch/test/
!python -m  unittest

# Get the Dataset
Be careful to change the path according to your setup

In [None]:
!unzip '/content/drive/My Drive/IBM_Gestures.zip' -d /content/slayerPytorch/exampleLoihi/03_IBMGesture/
%cd ./exampleLoihi/03_IBMGesture/

#Configure the SNN

In [None]:
import sys, os
CURRENT_TEST_DIR = os.getcwd()
sys.path.append(CURRENT_TEST_DIR + "/../../src")

from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import HTML
import torch
from torch.utils.data import Dataset, DataLoader
import slayerSNN as snn
#from learningStats import learningStats
from IPython.display import HTML
import time
import shutil

def save_ckp(state, is_best_loss, is_best_acc, checkpoint_dir, best_model_dir):
    f_path = checkpoint_dir+'checkpoint.pt'
    torch.save(state, f_path)
    if is_best_loss:
        best_fpath = best_model_dir+'best_model.pt'
        shutil.copyfile(f_path, best_fpath)
    if is_best_acc:
        acc_fpath =  best_model_dir+'best_acc.pt'
        shutil.copyfile(f_path, acc_fpath)
        
def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']

actionName = [
    'hand_clapping',
    'right_hand_wave',
    'left_hand_wave',
    'right_arm_clockwise',
    'right_arm_counter_clockwise',
    'left_arm_clockwise', 
    'left_arm_counter_clockwise',
    'arm_roll',
    'air_drums',
    'air_guitar',
    'other_gestures',
]

# Define dataset module
class IBMGestureDataset(Dataset):
    def __init__(self, datasetPath, sampleFile, samplingTime, sampleLength):
        self.path = datasetPath 
        self.samples = np.loadtxt(sampleFile).astype('int')
        self.samplingTime = samplingTime
        self.nTimeBins    = int(sampleLength / samplingTime)

    def __getitem__(self, index):
        # Read inoput and label
        inputIndex  = self.samples[index, 0]
        classLabel  = self.samples[index, 1]
        # Read input spike
        inputSpikes = snn.io.readNpSpikes(
                        self.path + str(inputIndex.item()) + '.npy'
                        ).toSpikeTensor(torch.zeros((2,128,128,self.nTimeBins)),
                        samplingTime=self.samplingTime)
        # Create one-hot encoded desired matrix
        desiredClass = torch.zeros((11, 1, 1, 1))
        desiredClass[classLabel,...] = 1
        
        return inputSpikes, desiredClass, classLabel

    def __len__(self):
        return self.samples.shape[0]
		
# Define the network
class Network(torch.nn.Module):
    def __init__(self, netParams):
        super(Network, self).__init__()
        # initialize slayer
        slayer = snn.loihi(netParams['neuron'], netParams['simulation'])
        self.slayer = slayer
        # define network functions
        self.conv1 = slayer.conv(2, 16, 5, padding=2, weightScale=10)
        self.conv2 = slayer.conv(16, 32, 3, padding=1, weightScale=50)
        self.pool1 = slayer.pool(4)
        self.pool2 = slayer.pool(2)
        self.pool3 = slayer.pool(2)
        self.fc1   = slayer.dense((8*8*32), 512)
        self.fc2   = slayer.dense(512, 11)
        self.drop  = slayer.dropout(0.1)

    def forward(self, spikeInput):
        spike = self.slayer.spikeLoihi(self.pool1(spikeInput )) # 32, 32, 2
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.conv1(spike)) # 32, 32, 16
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.pool2(spike)) # 16, 16, 16
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.conv2(spike)) # 16, 16, 32
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.pool3(spike)) #  8,  8, 32
        spike = spike.reshape((spike.shape[0], -1, 1, 1, spike.shape[-1]))
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.drop(spike)
        spike = self.slayer.spikeLoihi(self.fc1  (spike)) # 512
        spike = self.slayer.delayShift(spike, 1)
        
        spike = self.slayer.spikeLoihi(self.fc2  (spike)) # 11
        spike = self.slayer.delayShift(spike, 1)
        
        return spike



In [None]:
if __name__ == '__main__':
	netParams = snn.params('network.yaml')
	
	# Define the cuda device to run the code on.
	device = torch.device('cuda')
	# deviceIds = [2, 3]

	# Create network instance.
	net = Network(netParams).to(device)
	# net = torch.nn.DataParallel(Network(netParams).to(device), device_ids=deviceIds)

	# Create snn loss instance.
	error = snn.loss(netParams, snn.loihi).to(device)

	# Define optimizer module.
	# optimizer = torch.optim.Adam(net.parameters(), lr = 0.01, amsgrad = True)
	optimizer = snn.utils.optim.Nadam(net.parameters(), lr = 0.01, amsgrad = True)

	#Dataset and dataLoader instances.
	trainingSet = IBMGestureDataset(datasetPath ="Train/", 
									sampleFile  = "Train/train.txt",
									samplingTime= 1.0,
									sampleLength= 1450)
 
	trainLoader = DataLoader(dataset=trainingSet, batch_size=4, shuffle=True, num_workers=1)

	testingSet = IBMGestureDataset(datasetPath ="Test/", 
									sampleFile  = "Test/test.txt",
									samplingTime= 1.0,
									sampleLength= 1450)
	testLoader = DataLoader(dataset=testingSet, batch_size=4, shuffle=True, num_workers=1)
  
	i=0

	# Learning stats instance.
	stats = snn.utils.stats()
 
  start_epoch = 0

# Visualize the input spikes (first five samples).

In [None]:


input, target, label = trainingSet[i]
print(actionName[label])
#snn.io.showTD(snn.io.spikeArrayToEvent(input.reshape((2, 128, 128, -1)).cpu().data.numpy()))		
anim = snn.io.animTD(snn.io.spikeArrayToEvent(input.reshape((2, 128, 128, -1)).cpu().data.numpy()))
HTML(anim.to_jshtml())


#Checkpoint loading
If you wish to resume the training from a previous state.

In [None]:

ckp_path = "../../../drive/My Drive/checkpoint/checkpoint.pt"
net, optimizer, start_epoch = load_ckp(ckp_path, net, optimizer)


#Training loop
Train the SNN and saves a checkpoint at every epoch

In [None]:
for epoch in range(start_epoch,start_epoch+100):
	tSt = datetime.now()

	# Training loop.
	for i, (input, target, label) in enumerate(trainLoader, 0):
		net.train()

		# Move the input and target to correct GPU.
		input  = input.to(device)
		target = target.to(device) 

		# Forward pass of the network.
		output = net.forward(input)

		# Gather the training stats.
		stats.training.correctSamples += torch.sum( snn.predict.getClass(output) == label ).data.item()
		stats.training.numSamples     += len(label)
		# Calculate loss.
		loss = error.numSpikes(output, target)
		# Reset gradients to zero.
		optimizer.zero_grad()
		# Backward pass of the network.
		loss.backward()
		# Update weights.
		optimizer.step()
		# Gather training loss stats.
		stats.training.lossSum += loss.cpu().data.item()
		# Display training stats.
		stats.print(epoch, i, (datetime.now() - tSt).total_seconds())
  # Testing loop.
  # Same steps as Training loops except loss backpropagation and weight update.
	for i, (input, target, label) in enumerate(testLoader, 0):
		net.eval()
		with torch.no_grad():
		  input = input.to(device)
		  target = target.to(device)
		output = net.forward(input)
		stats.testing.correctSamples += torch.sum( snn.predict.getClass(output) == label ).data.item()
		stats.testing.numSamples     += len(label)
		loss = error.numSpikes(output, target)
		stats.testing.lossSum += loss.cpu().data.item()
		stats.print(epoch,i)
	stats.update()
	checkpoint={
				'epoch': epoch+1,
				'state_dict':net.state_dict(),
				'optimizer': optimizer.state_dict()}
	save_ckp(checkpoint,stats.training.bestLoss, stats.testing.bestAccuracy,'../../../drive/My Drive/checkpoint/', '../../../drive/My Drive/checkpoint/')