### Imports

In [None]:

from __future__ import annotations
from time import time
from BrisbaneVPRDatasetTime import BrisbaneVPRDatasetTime
from VPRNetwork import VPRNetwork
from plotting import plot_confusion_matrix, plot_match_images, plot_gps
from constants import brisbane_event_traverses, brisbane_event_traverses_aliases
from utils import get_short_traverse_name

import os, sys
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from matplotlib import animation
from scipy import signal
from datetime import datetime

import numpy as np

import lava.lib.dl.slayer as slayer
from lava.lib.dl.slayer.classifier import Rate


In [None]:
# Training settings
from cgi import test

epochs = 60
batch_size = 10

# Network settings
input_size = 34
threshold = 1.0

# Data settings
train_traverse = brisbane_event_traverses[0]
test_traverse = brisbane_event_traverses[3]
train_name = brisbane_event_traverses_aliases[get_short_traverse_name(train_traverse)]
test_name = brisbane_event_traverses_aliases[get_short_traverse_name(test_traverse)]
num_places = 77
start_dist = 200
place_gap = 100 #164/num_places # The streams run for approximately 164 seconds 
samples_per_sec = 1000
place_duration = 2
max_spikes = None

match_tolerance = 0

# Sequencer settings
sequence_length = 3

# Plot settings
redo_plot = False
vmin = 0
vmax = 0.07

def transpose( matrix):
    if len(matrix) == 0:
        return []
    return [[matrix[i][j] for i in range(len(matrix))] for j in range(len(matrix[0]))]


# Make a folder for the trained network
trained_folder = 'Trained'
os.makedirs(trained_folder, exist_ok=True)

# Use GPU
print(torch.cuda.is_available())
device = torch.device('cuda')

#---------------- Create the Network -----------------#

# Create the network
net = VPRNetwork(input_size, num_places, threshold=threshold).to(device)


In [None]:
# Load the data
training_set = BrisbaneVPRDatasetTime(train_traverse, train=True, place_duration = place_duration, place_gap=place_gap, num_places=num_places, start_dist=start_dist,samples_per_sec=samples_per_sec, max_spikes=max_spikes)
testing_set  = BrisbaneVPRDatasetTime(test_traverse, train=False, place_duration = place_duration, training_locations=training_set.place_locations, place_gap=place_gap, num_places=num_places, start_dist=start_dist, samples_per_sec=samples_per_sec, max_spikes=max_spikes)
            
train_loader = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(dataset=testing_set , batch_size=batch_size, shuffle=True)

In [None]:
time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
results_path = "./../results/" + time_stamp
os.mkdir(results_path)


# Define an optimiser 
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# Training the network
error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction='sum').to(device)

# Create a training assistant object
stats = slayer.utils.LearningStats()
assistant = slayer.utils.Assistant(net, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)


if redo_plot:
    for i, (input, label) in enumerate(train_loader): # training loop
        output = assistant.train(input, label)

else: 
    for epoch in range(epochs):

        for i, (input, label) in enumerate(train_loader): # training loop
            output = assistant.train(input, label)
        print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

        for i, (input, label) in enumerate(test_loader): # training loop
            output = assistant.test(input, label)
        print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

        if epoch%20 == 19: # cleanup display
            print('\r', ' '*len(f'\r[Epoch {epoch:2d}/{epochs}] {stats}'))
            stats_str = str(stats).replace("| ", "\n")
            print(f'[Epoch {epoch:2d}/{epochs}]\n{stats_str}')

        if stats.testing.best_accuracy:
            torch.save(net.state_dict(), results_path + '/network.pt')
        stats.update()
        stats.save(results_path + '/')
        net.grad_flow(results_path + '/')

In [None]:
# import the best network during training 
net.load_state_dict(torch.load(results_path + '/network.pt'))
net.export_hdf5(results_path + '/network.net')

# Get the output for the input to each place
test_loader2  = DataLoader(dataset=testing_set , batch_size=batch_size, shuffle=False)
rate = []
labels = []
for i, (input, label) in enumerate(test_loader2):
    output = net(input.to(device)) # Get network output
    #guesses = assistant.classifier(output).cpu().data.numpy() # get the predictions 
    rate.extend(Rate.confidence(output).cpu().data.numpy()) # Get the firing rates for each place 
    labels.extend(label.cpu().data.numpy()) # Get place labels

#print(rate)
# # Get rates in percentages of total
# for i in range(num_places):
#     sum = np.sum(rate[i])
#     if sum != 0:
#         rate[i] = np.divide(rate[i],sum)

# Make confusion matrix annotations
accuracy = 0
matches = []
annotations = [['' for i in range(num_places)] for j in range(num_places)]
for qryIndex in range(num_places):
    max_idx = np.argmax(rate[qryIndex])
    matches.append(max_idx)
    annotations[max_idx][qryIndex] = 'x'
    if abs(max_idx - qryIndex) <= match_tolerance:
        accuracy += 1


#--------------- Apply a sequencer ------------------#
I = np.identity(sequence_length)
conv = signal.convolve2d(rate, I, mode='same')/sequence_length
print(np.shape(conv))


# Make confusion matrix annotations
accuracy_s = 0
matches_with_seq = []
annotations_s = [['' for i in range(num_places)] for j in range(num_places)]
for qryIndex in range(num_places):
    max_idx = np.argmax(conv[qryIndex])
    matches_with_seq.append(max_idx)
    annotations_s[max_idx][qryIndex] = 'x'
    if abs(max_idx - qryIndex) <= match_tolerance:
        accuracy_s += 1



In [None]:

#--------------- Save Results ------------------#

# Make new folder for results
from tracemalloc import start
from plotting import plot_gps

temp_rate = torch.tensor(rate).reshape(torch.tensor(rate).shape[0], -1)
temp_conv = torch.tensor(conv).reshape(torch.tensor(conv).shape[0], -1)
confidence = temp_rate / (torch.sum(temp_rate, dim=1, keepdim=True) + 1e-6)
confidence_seq = temp_conv / (torch.sum(temp_conv, dim=1, keepdim=True) + 1e-6)


# Save query and match images 
images_path = results_path + "/matched_images"
images_path_seq = results_path + "/matched_images_seq"
os.mkdir(images_path)
os.mkdir(images_path_seq)

#print('help')
plot_match_images(images_path + "/", matches, training_set.place_images, testing_set.place_images)
plot_match_images(images_path_seq + "/", matches_with_seq, training_set.place_images, testing_set.place_images)

# Save the confusion matrices
confusion_path = results_path + "/confusion_matrices"
os.mkdir(confusion_path)
confidence = transpose(confidence)
confidence_seq = transpose(confidence_seq)
output_path = confusion_path + "/confusion_matrix" 
output_path_s = confusion_path + "/confusion_matrix_seq" 
plot_confusion_matrix(confidence, labels, annotations, output_path, vmin, vmax)
plot_confusion_matrix(confidence_seq, labels, annotations_s, output_path_s, vmin, vmax)

# Save GPS map'
gps_path = results_path + "/gps_locations"
plot_gps(gps_path, training_set.place_locations, testing_set.place_locations)

# Save test settings and accuracy
accuracy = accuracy/num_places
accuracy_s = accuracy_s/num_places


In [None]:

log_path = results_path + "/log.txt"
log_string  = """
---- DATA SETTINGS ---- 
Training datasets = {}
Testing datasets = {} 
# Places = {}
Start dist = {} [m]
Place gap = {} [m]
Place duration = {} [s]
Samples per second = {} 
Max spikes per sample = {}
---- NETWORK SETTINGS ---- 
Input size = {}x{} 
Threshold = {}
---- TRAINING SETTINGS ---- 
Epochs = {}
Batch size = {}
---- SEQUENCER SETTINGS ----
Sequence lengtth = {} 
---- RESULTS ----
Match tolerance = {}
Accuracy = {}
Accuracy (sequencer) = {}
""".format(train_name,test_name,num_places,start_dist,place_gap,place_duration,
                samples_per_sec,max_spikes,input_size,input_size,threshold,epochs,
                batch_size,sequence_length,match_tolerance,accuracy,accuracy_s)
f = open(log_path,'w')
f.write(log_string)
f.close()

print("The accuracy of the network is: " + str(accuracy))
print("The accuracy with a sequencer is: " + str(accuracy_s))