%%% Imports

In [1]:

from __future__ import annotations
from time import time
from QCRVPRDataset import QCRVPRDataset, QCRVPRSyncDataset
from VPRNetwork import VPRNetwork
from plotting import plot_confusion_matrix, plot_match_images, plot_gps
from constants import brisbane_event_traverses, qcr_traverses, brisbane_event_traverses_aliases
from utils import get_short_traverse_name

import os, sys
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from matplotlib import animation
from scipy import signal
from datetime import datetime

import numpy as np

import lava.lib.dl.slayer as slayer
from lava.lib.dl.slayer.classifier import Rate

from constants import synced_times

In [2]:
# Training settings
from cgi import test


epochs = 50
batch_size = 10

# Network settings
input_size = 34
threshold = 1

# Data settings
train_traverse = qcr_traverses[15]
test_traverse = qcr_traverses[14]
train_name = get_short_traverse_name(train_traverse)
test_name = get_short_traverse_name(test_traverse)
# num_places = 82
# start_time = 0
# place_gap = 2 #164/num_places # The streams run for approximately 164 seconds 
samples_per_sec = 1000
place_duration = 5 #5
max_spikes = 2762 #2799

use_pre_synced_times = True
incorporate_speed = True # Cant incorporate speed unless using presynced times

match_tolerance = 0

if use_pre_synced_times:
    num_places = len(synced_times[train_traverse])
    print(num_places)

# Sequencer settings
sequence_length = 3

# Plot settings
redo_plot = False
vmin = 0
vmax = 0.19

def transpose( matrix):
    if len(matrix) == 0:
        return []
    return [[matrix[i][j] for i in range(len(matrix))] for j in range(len(matrix[0]))]


# Use GPU
print(torch.cuda.is_available())
device = torch.device('cuda')

#---------------- Create the Network -----------------#

# Create the network
net = VPRNetwork(input_size, num_places, threshold=threshold).to(device)


22
True


In [3]:
# Load the data
if use_pre_synced_times:
    training_set = QCRVPRSyncDataset(train_traverse, train=True, place_duration = place_duration, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size)
    if incorporate_speed:
        testing_set  = QCRVPRSyncDataset(test_traverse, train=False, place_duration = place_duration, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size, training_duration=training_set.training_duration) # training_duration=training_set.training_duration
    else:
        testing_set  = QCRVPRSyncDataset(test_traverse, train=False, place_duration = place_duration, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size)

else: 
    training_set = QCRVPRDataset(train_traverse, train=True, place_duration = place_duration, place_gap=place_gap, num_places=num_places, start_time=start_time,samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size)
    testing_set  = QCRVPRDataset(test_traverse, train=False, place_duration = place_duration, place_gap=place_gap, num_places=num_places, start_time=start_time, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size)

    # if incorporate_speed:
    #     testing_set  = QCRVPRDataset(test_traverse, train=False, place_duration = place_duration, place_gap=place_gap, num_places=num_places, start_time=start_time, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size, relative_place_times=training_set.relative_place_times, training_duration=training_set.training_duration) # training_duration=training_set.training_duration
    # else:
    #     testing_set  = QCRVPRDataset(test_traverse, train=False, place_duration = place_duration, place_gap=place_gap, num_places=num_places, start_time=start_time, samples_per_sec=samples_per_sec, max_spikes=max_spikes, subselect_num=input_size, relative_place_times=training_set.relative_place_times)

train_loader = DataLoader(dataset=training_set, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(dataset=testing_set , batch_size=batch_size, shuffle=True)

Loading training event streams ...


  0%|          | 0/1 [00:00<?, ?it/s]

Duration: 267.00s (which is 51609087 events)
[1648433034.33, 1648433048.06, 1648433063.64, 1648433077.47, 1648433091.34, 1648433105.65, 1648433120.19, 1648433133.24, 1648433144.55, 1648433151.05, 1648433163.63, 1648433177.62, 1648433193.01, 1648433206.73, 1648433220.38, 1648433232.07, 1648433242.7, 1648433253.0, 1648433263.31, 1648433272.6, 1648433282.51, 1648433291.85]


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_stream['x'].loc[small_filt0x] = i
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_stream['y'].loc[small_filt0y] = i


1.936287
1.229313
1.448187
1.965396
1.139994
2.895014
0.543805
0.989688
1.351823
0.905764
4.065003
1.459383
1.543933
1.935831
0.866582
1.54055
1.028766
0.827075
3.939881
1.651236
0.490808
0.148144
average spikes = 2762.0
The number of training substreams is: 22


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chopped_stream['t'] -= chop_start


  0%|          | 0/1 [00:00<?, ?it/s]

Duration: 64.50s (which is 39592357 events)
[1648432912.91, 1648432915.03, 1648432918.1, 1648432921.15, 1648432924.06, 1648432927.01, 1648432930.13, 1648432933.15, 1648432936.1, 1648432939.18, 1648432942.01, 1648432945.04, 1648432948.16, 1648432951.02, 1648432954.11, 1648432957.06, 1648432960.14, 1648432963.08, 1648432966.15, 1648432969.01, 1648432972.12, 1648432975.02]
Place duration 1.2078650594411442
Speed ratio 1


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_stream['x'].loc[small_filt0x] = i
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  event_stream['y'].loc[small_filt0y] = i


0.348232
0.328269
0.48903
0.364701
0.258425
0.505897
0.312104
0.314162
0.678405
0.247867
0.849393
0.303723
0.2881
0.536441
0.304136
0.481545
0.318251
0.191217
0.687149
0.429144
0.553609
0.375824
average spikes = 2762.0
The number of testing substreams is: 22


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chopped_stream['t'] -= chop_start


In [4]:
# Make a folder for all results
time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
results_path = "./../results/" + time_stamp
os.mkdir(results_path)


# Define an optimiser 
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# Training the network
error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction='sum').to(device)

# Create a training assistant object
stats = slayer.utils.LearningStats()
assistant = slayer.utils.Assistant(net, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)


if redo_plot:
    for i, (input, label) in enumerate(train_loader): # training loop
        output = assistant.train(input, label)

else: 
    for epoch in range(epochs):

        for i, (input, label) in enumerate(train_loader): # training loop
            output = assistant.train(input, label)
        print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

        for i, (input, label) in enumerate(test_loader): # training loop
            output = assistant.test(input, label)
        print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')

        if epoch%20 == 19: # cleanup display
            print('\r', ' '*len(f'\r[Epoch {epoch:2d}/{epochs}] {stats}'))
            stats_str = str(stats).replace("| ", "\n")
            print(f'[Epoch {epoch:2d}/{epochs}]\n{stats_str}')

        if stats.testing.best_accuracy:
            torch.save(net.state_dict(), results_path + '/network.pt')
        stats.update()
        stats.save(results_path + '/')
        net.grad_flow(results_path + '/')

[Epoch  0/50] Train loss =     0.52116                        accuracy = 0.04545 | Test  loss =     7.59850                        accuracy = 0.04545
Accuracy was none
[Epoch  1/50] Train loss =     0.36457 (min =     0.52116)    accuracy = 0.13636 (max = 0.04545) | Test  loss =     7.50261 (min =     7.59850)    accuracy = 0.18182 (max = 0.04545)
Accuracy was none
[Epoch  2/50] Train loss =     0.26325 (min =     0.36457)    accuracy = 0.59091 (max = 0.13636) | Test  loss =     2.81389 (min =     7.50261)    accuracy = 0.40909 (max = 0.18182)
Accuracy was none
[Epoch  3/50] Train loss =     0.25944 (min =     0.26325)    accuracy = 0.72727 (max = 0.59091) | Test  loss =     3.09574 (min =     2.81389)    accuracy = 0.40909 (max = 0.40909)
Accuracy was none
[Epoch  4/50] Train loss =     0.22586 (min =     0.25944)    accuracy = 0.86364 (max = 0.72727) | Test  loss =     5.47842 (min =     2.81389)    accuracy = 0.40909 (max = 0.40909)
Accuracy was none
[Epoch  5/50] Train loss =     0

In [5]:
# import the best network during training 
net.load_state_dict(torch.load(results_path + '/network.pt'))
net.export_hdf5(results_path + '/network.net')

# Get the output for the input to each place
test_loader2  = DataLoader(dataset=testing_set , batch_size=batch_size, shuffle=False)
rate = []
labels = []
for i, (input, label) in enumerate(test_loader2):
    output = net(input.to(device)) # Get network output
    #guesses = assistant.classifier(output).cpu().data.numpy() # get the predictions 
    
    rate.extend(Rate.confidence(output).cpu().data.numpy()) # Get the firing rates for each place 
    labels.extend(label.cpu().data.numpy()) # Get place labels


# # Get rates in percentages of total
# for i in range(num_places):
#     sum = np.sum(rate[i])
#     if sum != 0:
#         rate[i] = np.divide(rate[i],sum)

# Make confusion matrix annotations
accuracy = 0
matches = []
annotations = [['' for i in range(num_places)] for j in range(num_places)]
for qryIndex in range(num_places):
    max_idx = np.argmax(rate[qryIndex])
    matches.append(max_idx)
    annotations[max_idx][qryIndex] = 'x'
    if abs(max_idx - qryIndex) <= match_tolerance:
        accuracy += 1


#--------------- Apply a sequencer ------------------#
I = np.identity(sequence_length)
conv = signal.convolve2d(rate, I, mode='same')
print(np.shape(conv))

# Make confusion matrix annotations
accuracy_s = 0
matches_with_seq = []
annotations_s = [['' for i in range(num_places)] for j in range(num_places)]
for qryIndex in range(num_places):
    max_idx = np.argmax(conv[qryIndex])
    matches_with_seq.append(max_idx)
    annotations_s[max_idx][qryIndex] = 'x'
    if abs(max_idx - qryIndex) <= match_tolerance:
        accuracy_s += 1



(22, 22)


In [6]:

#--------------- Save Results ------------------#

# Make new folder for results
from tracemalloc import start
from plotting import plot_gps




# # Save query and match images 
# images_path = results_path + "/matched_images"
# images_path_seq = results_path + "/matched_images_seq"
# os.mkdir(images_path)
# os.mkdir(images_path_seq)

# # print('help')
# plot_match_images(images_path + "/", matches, training_set.place_images, testing_set.place_images)
# plot_match_images(images_path_seq + "/", matches_with_seq, training_set.place_images, testing_set.place_images)

#Save the confusion matrices
confusion_path = results_path + "/confusion_matrices"
os.mkdir(confusion_path)
rate = transpose(rate)
conv = transpose(conv)
output_path = confusion_path + "/confusion_matrix" 
output_path_s = confusion_path + "/confusion_matrix_seq" 
plot_confusion_matrix(rate, labels, annotations, output_path, vmin, vmax)
plot_confusion_matrix(conv, labels, annotations_s, output_path_s, vmin, vmax)

# Save GPS map'
# gps_path = results_path + "/gps_locations"
# plot_gps(gps_path, training_set.place_locations, testing_set.place_locations)

# Save test settings and accuracy
accuracy = accuracy/num_places
accuracy_s = accuracy_s/num_places


In [7]:

log_path = results_path + "/log.txt"

if use_pre_synced_times:
    log_string  = """
    ---- DATA SETTINGS ---- 
    Training datasets = {}
    Testing datasets = {} 
    # Places = {}
    Using presynced times = True
    Place duration = {} [s]
    Samples per second = {} 
    Max spikes per sample = {}
    Incorporate speed = {}
    ---- NETWORK SETTINGS ---- 
    Input size = {}x{} 
    Threshold = {}
    ---- TRAINING SETTINGS ---- 
    Epochs = {}
    Batch size = {}
    ---- SEQUENCER SETTINGS ----
    Sequence lengtth = {} 
    ---- RESULTS ----
    Match tolerance = {}
    Accuracy = {}
    Accuracy (sequencer) = {}
    """.format(train_name,test_name, num_places, place_duration,
                    samples_per_sec,max_spikes, incorporate_speed, input_size,input_size,threshold,epochs,
                    batch_size,sequence_length,match_tolerance,accuracy,accuracy_s)

else: 
    log_string  = """
    ---- DATA SETTINGS ---- 
    Training datasets = {}
    Testing datasets = {} 
    # Places = {}
    Start time = {} [s]
    Place gap = {} [s]
    Place duration = {} [s]
    Samples per second = {} 
    Max spikes per sample = {}
    Incorporate speed = {}
    ---- NETWORK SETTINGS ---- 
    Input size = {}x{} 
    Threshold = {}
    ---- TRAINING SETTINGS ---- 
    Epochs = {}
    Batch size = {}
    ---- SEQUENCER SETTINGS ----
    Sequence lengtth = {} 
    ---- RESULTS ----
    Match tolerance = {}
    Accuracy = {}
    Accuracy (sequencer) = {}
    """.format(train_name,test_name,num_places,start_time,place_gap,place_duration,
                    samples_per_sec,max_spikes, incorporate_speed, input_size,input_size,threshold,epochs,
                    batch_size,sequence_length,match_tolerance,accuracy,accuracy_s)
f = open(log_path,'w')
f.write(log_string)
f.close()

print("The accuracy of the network is: " + str(accuracy))
print("The accuracy with a sequencer is: " + str(accuracy_s))

The accuracy of the network is: 0.5909090909090909
The accuracy with a sequencer is: 0.9545454545454546
