# Implementing Nordland Dataset

In [15]:
import sys
sys.path.append('../')
import os
import torch
from network_parser import parse
from datasets import loadMNIST, loadCIFAR10, loadFashionMNIST, loadNMNIST_Spiking 
import logging
import cnns
import functions.loss_f as loss_f
import numpy as np
from datetime import datetime
import time
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils import clip_grad_value_
import global_v as glv

import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
import argparse

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as P
import yaml

max_accuracy = 0
min_loss = 1000

from tqdm import tqdm, trange

## works locally only

In [2]:
sys.path.append('../../../ActiveAI/pyRC/') # local only! 
import pyRC.datasets.nordland as Nordland

In [3]:
File = '../Networks/Nordland_CNN.yaml'
with open(File) as file:
    params = yaml.full_load(file)
    
params['Network']['data_path'] = '../' + params['Network']['data_path'] # add relative dir path

In [4]:
nBatch       = params['Network']['batch_size']
train_loader = Nordland.get(mode = 'summer', Labels = 'VPR', nImages = 100, nBatch = nBatch, shuffle = False, width=64, height=32)
test_loaderS = Nordland.get(mode = 'summer', Labels = 'VPR', nImages = 100, nBatch = nBatch, shuffle = True , width=64, height=32)
test_loaderF = Nordland.get(mode = 'fall'  , Labels = 'VPR', nImages = 100, nBatch = nBatch, shuffle = True , width=64, height=32)
test_loaderW = Nordland.get(mode = 'winter', Labels = 'VPR', nImages = 100, nBatch = nBatch, shuffle = True , width=64, height=32)

In [5]:
train_loader.dataset[0][0].unsqueeze(0).shape # trick to add a dimension 

torch.Size([1, 32, 64])

In [6]:
dtype = torch.float32

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("selected device: ", device)

glv.init(dtype, device, params['Network']['n_steps'], params['Network']['tau_s'] )

selected device:  cuda


In [7]:
n_steps = params['Network']['n_steps']
n_class = params['Network']['n_class']


In [8]:
net = cnns.Network(params['Network'], params['Layers'], list(train_loader.dataset[0][0].unsqueeze(0).shape)).to(device)
net

Network Structure:
conv_1
[1, 32, 64]
[15, 28, 60]
[15, 1, 5, 5, 1]
-----------------------------------------
pooling_1
[15, 28, 60]
[15, 14, 30]
[1, 1, 2, 2, 1]
-----------------------------------------
conv_2
[15, 14, 30]
[40, 10, 26]
[40, 15, 5, 5, 1]
-----------------------------------------
pooling_2
[40, 10, 26]
[40, 5, 13]
[1, 1, 2, 2, 1]
-----------------------------------------
linear
FC_1
[40, 5, 13]
[300, 1, 1]
[300, 640]
-----------------------------------------
linear
output
[300, 1, 1]
[100, 1, 1]
[100, 300]
-----------------------------------------
-----------------------------------------




Network(
  (my_parameters): ParameterList(
      (0): Parameter containing: [torch.cuda.FloatTensor of size 15x1x5x5x1 (GPU 0)]
      (1): Parameter containing: [torch.cuda.FloatTensor of size 40x15x5x5x1 (GPU 0)]
      (2): Parameter containing: [torch.cuda.FloatTensor of size 300x640 (GPU 0)]
      (3): Parameter containing: [torch.cuda.FloatTensor of size 100x300 (GPU 0)]
  )
)

In [9]:
x, label = next(iter(train_loader))
x = x.unsqueeze(1)
label = label.unsqueeze(1)
print(x.shape, label.shape)
targets = torch.zeros((label.shape[0], n_class, 1, 1, n_steps), dtype=dtype).to(device) 
if len(x.shape) < 5:
    x = x.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
X=x[:,:,:28,:28].to(device).type(dtype)
print(X.shape, targets.shape)

torch.Size([5, 1, 32, 64]) torch.Size([5, 1])
torch.Size([5, 1, 28, 28, 5]) torch.Size([5, 100, 1, 1, 5])


In [10]:
outputs = net(X,0,True)
spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
predicted = np.argmax(spike_counts, axis=1)

In [11]:
if n_steps >= 10:
    desired_spikes = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]).repeat(int(n_steps/10))
else:
    desired_spikes = torch.tensor([0, 1, 1, 1, 1]).repeat(int(n_steps/5))
desired_spikes = desired_spikes.view(1, 1, 1, 1, n_steps).to(device)
print('>> before',desired_spikes)
desired_spikes = loss_f.psp(desired_spikes, params['Network']).view(1, 1, 1, n_steps)
print('>> after',desired_spikes)

>> before tensor([[[[[0, 1, 1, 1, 1]]]]], device='cuda:0')
>> after tensor([[[[0.0000, 0.3333, 0.5556, 0.7037, 0.8025]]]], device='cuda:0')


In [12]:
# Network Config
network_config = params['Network']
n_steps        = network_config['n_steps']
n_class        = network_config['n_class']
batch_size     = network_config['batch_size']
# Training Functions
err  = loss_f.SpikeLoss(network_config).to(device)
opti = torch.optim.AdamW(net.get_parameters(), lr=network_config['lr'], betas=(0.9, 0.999))

In [23]:
train_loss = 0
correct = 0
total = 0
total_accuracy = 0

In [24]:
for epoch in range(10):
    for _ in tqdm(range(100), desc='accuracy so far:' + str(total_accuracy)):
        start_time = time.time()
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            targets = torch.zeros((labels.shape[0], n_class, 1, 1, n_steps), dtype=dtype).to(device) 
            
            # begin offline
            # this is the case for each item in the for loop, can be done offline, functionally, to speed up!
            if len(inputs.shape) < 5: 
                inputs = inputs.unsqueeze(1).unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
                inputs = inputs[:,:,:28,:28,:]
                
            labels = labels.to(device)
            inputs = inputs.to(device)
            inputs.type(dtype)
            # end offline 
            
            outputs = net.forward(inputs, epoch, True)
            
            # begin function
            # seems systematic enough, can be done in a function
            if network_config['loss'] == "count":
                # set target signal
                desired_count = network_config['desired_count']
                undesired_count = network_config['undesired_count']
                targets = torch.ones((outputs.shape[0], outputs.shape[1], 1, 1), dtype=dtype).to(device) * undesired_count
                for i in range(len(labels)):
                    targets[i, labels[i], ...] = desired_count
                loss = err.spike_count(outputs, targets, network_config, layers_config[list(layers_config.keys())[-1]])
            elif network_config['loss'] == "kernel":
                targets.zero_()
                for i in range(len(labels)):
                    targets[i, labels[i], ...] = desired_spikes
                loss = err.spike_kernel(outputs, targets, network_config)
            elif network_config['loss'] == "softmax":
                # set target signal
                loss = err.spike_soft_max(outputs, labels)
            else:
                raise Exception('Unrecognized loss function.')
            # end function
        
            opti.zero_grad()
            loss.backward()
            clip_grad_norm_(net.get_parameters(), 1) # what's this for?
            opti.step()
            net.weight_clipper()
            
            # begin argmax 
            # PyTorch has argmax function, re-write this and clean up squeezes!
            spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
            predicted = np.argmax(spike_counts, axis=1)
            # end argmax
            
            # It is not efficient to return to item at each epoch, do we need that?
            # Don't think when using W&B at least.
            train_loss += torch.sum(loss).item()
            labels = labels.cpu().numpy()
            total += len(labels)
            correct += (predicted == labels).sum().item()
        
        total_accuracy = correct / total
        total_loss = train_loss / total
    end_time = time.time() - start_time
#     print('>> result for one set of epochs: {:.3}, time it takes {:.2}s'.format(total_accuracy, end_time))

accuracy so far:0: 100%|██████████| 100/100 [00:50<00:00,  1.99it/s]
accuracy so far:0.6885: 100%|██████████| 100/100 [00:49<00:00,  2.00it/s]
accuracy so far:0.7564: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s]
accuracy so far:0.8048666666666666: 100%|██████████| 100/100 [00:46<00:00,  2.13it/s]
accuracy so far:0.844075: 100%|██████████| 100/100 [00:49<00:00,  2.01it/s]
accuracy so far:0.8703: 100%|██████████| 100/100 [00:50<00:00,  1.97it/s]
accuracy so far:0.8908666666666667: 100%|██████████| 100/100 [00:49<00:00,  2.03it/s]
accuracy so far:0.9059: 100%|██████████| 100/100 [00:49<00:00,  2.03it/s]
accuracy so far:0.9174625: 100%|██████████| 100/100 [00:49<00:00,  2.04it/s]
accuracy so far:0.9265222222222222: 100%|██████████| 100/100 [00:48<00:00,  2.06it/s]


In [25]:
total_accuracy

0.93375

# Testing

In [32]:
with torch.no_grad():
    for tests in [test_loaderS, test_loaderF, test_loaderW]: # test for summer, fall, winter
        total = 0
        correct = 0
        for batch_idx, (inputs, labels) in enumerate(tests):
            targets = torch.zeros((labels.shape[0], n_class, 1, 1, n_steps), dtype=dtype).to(device) 
            if len(inputs.shape) < 5: 
                inputs = inputs.unsqueeze(1).unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
                inputs = inputs[:,:,:28,:28,:]
            labels = labels.to(device)
            inputs = inputs.to(device)
            inputs.type(dtype)
            outputs = net.forward(inputs, epoch, False)
            spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
            predicted = np.argmax(spike_counts, axis=1)
            labels = labels.cpu().numpy()
            total += len(labels)
            correct += (predicted == labels).sum().item()
            total_accuracy = correct / total
        print(total_accuracy)

1.0
0.56
0.08
