In [1]:
import numpy as np
import matplotlib
import matplotlib.cm as cmap
import time
import os.path
import scipy 
import pickle
from struct import unpack
from brian2 import *

In [2]:

#------------------------------------------------------------------------------ 
# functions
#------------------------------------------------------------------------------     
def get_labeled_data(picklename, bTrain = True):
    """Read input-vector (image) and target class (label, 0-9) and return
       it as list of tuples.
    """
    if os.path.isfile('%s.pickle' % picklename):
        with open('%s.pickle' % picklename, 'rb') as f:
            data = pickle.load(f)
    else:
        # Open the images with gzip in read binary mode
        if bTrain:
            images = open(MNIST_data_path + 'train-images.idx3-ubyte','rb')
            labels = open(MNIST_data_path + 'train-labels.idx1-ubyte','rb')
        else:
            images = open(MNIST_data_path + 't10k-images.idx3-ubyte','rb')
            labels = open(MNIST_data_path + 't10k-labels.idx1-ubyte','rb')
        # Get metadata for images
        images.read(4)  # skip the magic_number
        number_of_images = unpack('>I', images.read(4))[0]
        rows = unpack('>I', images.read(4))[0]
        cols = unpack('>I', images.read(4))[0]
        # Get metadata for labels
        labels.read(4)  # skip the magic_number
        N = unpack('>I', labels.read(4))[0]
        if number_of_images != N:
            raise Exception('number of labels did not match the number of images')
        # Get the data
        x = np.zeros((N, rows, cols), dtype=np.uint8)  # Initialize numpy array
        y = np.zeros((N, 1), dtype=np.uint8)  # Initialize numpy array
        for i in range(N):
            if i % 1000 == 0:
                print("i: %i" % i)
            x[i] = [[unpack('>B', images.read(1))[0] for unused_col in range(cols)]  for unused_row in range(rows) ]
            y[i] = unpack('>B', labels.read(1))[0]
        data = {'x': x, 'y': y, 'rows': rows, 'cols': cols}
        pickle.dump(data, open("%s.pickle" % picklename, "wb"))
    return data

def get_recognized_number_ranking(assignments, spike_rates):
    summed_rates = [0] * 10
    num_assignments = [0] * 10
    for i in range(10):
        num_assignments[i] = len(np.where(assignments == i)[0])
        if num_assignments[i] > 0:
            summed_rates[i] = np.sum(spike_rates[assignments == i]) / num_assignments[i]
    return np.argsort(summed_rates)[::-1]

def get_new_assignments(result_monitor, input_numbers):
    print (result_monitor.shape)
    assignments = np.ones(n_e) * -1 # initialize them as not assigned
    input_nums = np.asarray(input_numbers)
    maximum_rate = [0] * n_e    
    for j in range(10):
        num_inputs = len(np.where(input_nums == j)[0])
        if num_inputs > 0:
            rate = np.sum(result_monitor[input_nums == j], axis = 0) / num_inputs
        for i in range(n_e):
            if rate[i] > maximum_rate[i]:
                maximum_rate[i] = rate[i]
                assignments[i] = j 
    return assignments

In [3]:
MNIST_data_path = '/home/asif/lava-dl/Tasks/Task_7/Unsupervised_STDP/data/'
data_path = '/home/asif/lava-dl/Tasks/Task_8/Brian2STDPMNIST/activity/'
training_ending = '10000'
testing_ending = '10000'
start_time_training = 0
end_time_training = int(training_ending)
start_time_testing = 0
end_time_testing = int(testing_ending)

n_e = 400
n_input = 784
ending = ''

In [4]:
print('load MNIST')
training = get_labeled_data(MNIST_data_path + 'training')
testing = get_labeled_data(MNIST_data_path + 'testing', bTrain = False)

load MNIST


In [5]:
print ('load results')
training_result_monitor = np.load(data_path + 'resultPopVecs' + training_ending + ending + '.npy')
training_input_numbers = np.load(data_path + 'inputNumbers' + training_ending + '.npy')
testing_result_monitor = np.load(data_path + 'resultPopVecs' + testing_ending + '.npy')
testing_input_numbers = np.load(data_path + 'inputNumbers' + testing_ending + '.npy')
print (training_result_monitor.shape)

load results
(10000, 400)


In [6]:
print('get assignments')
test_results = np.zeros((10, end_time_testing-start_time_testing))
test_results_max = np.zeros((10, end_time_testing-start_time_testing))
test_results_top = np.zeros((10, end_time_testing-start_time_testing))
test_results_fixed = np.zeros((10, end_time_testing-start_time_testing))
assignments = get_new_assignments(training_result_monitor[start_time_training:end_time_training], 
                                  training_input_numbers[start_time_training:end_time_training])

get assignments
(10000, 400)


In [8]:
print(assignments)
counter = 0 
num_tests = end_time_testing / 10000
sum_accurracy = [0] * int(num_tests)
while (counter < num_tests):
    end_time = min(end_time_testing, 10000*(counter+1))
    start_time = 10000*counter
    test_results = np.zeros((10, end_time-start_time))
    print( 'calculate accuracy for sum')
    for i in range(end_time - start_time):
        test_results[:,i] = get_recognized_number_ranking(assignments, 
                                                          testing_result_monitor[i+start_time,:])
    difference = test_results[0,:] - testing_input_numbers[start_time:end_time]
    correct = len(np.where(difference == 0)[0])
    incorrect = np.where(difference != 0)[0]
    sum_accurracy[counter] = correct/float(end_time-start_time) * 100
    print ('Sum response - accuracy: ', sum_accurracy[counter], ' num incorrect: ', len(incorrect))
    counter += 1
print ('Sum response - accuracy --> mean: ', np.mean(sum_accurracy),  '--> standard deviation: ', np.std(sum_accurracy))


show()

[7. 3. 2. 2. 1. 5. 6. 2. 5. 0. 0. 3. 1. 4. 0. 2. 4. 7. 6. 3. 0. 0. 6. 8.
 8. 2. 0. 2. 3. 1. 6. 6. 7. 5. 2. 9. 6. 0. 2. 3. 7. 2. 0. 3. 4. 3. 6. 2.
 2. 2. 2. 8. 6. 6. 2. 6. 4. 7. 4. 0. 0. 8. 9. 8. 8. 2. 8. 6. 8. 9. 2. 3.
 9. 2. 8. 1. 6. 7. 7. 6. 3. 5. 9. 8. 9. 6. 3. 6. 0. 4. 4. 1. 6. 2. 4. 6.
 9. 3. 0. 5. 7. 0. 3. 3. 8. 3. 8. 1. 6. 8. 9. 9. 3. 2. 6. 5. 0. 3. 7. 1.
 5. 6. 0. 0. 9. 2. 2. 5. 8. 3. 7. 8. 0. 6. 9. 0. 3. 9. 4. 4. 0. 4. 9. 4.
 4. 8. 4. 0. 9. 3. 2. 7. 3. 3. 0. 9. 9. 0. 7. 0. 9. 6. 5. 0. 9. 8. 1. 4.
 7. 7. 3. 2. 7. 8. 7. 1. 2. 7. 2. 4. 8. 3. 0. 1. 7. 9. 5. 8. 4. 5. 7. 6.
 7. 2. 6. 5. 2. 9. 0. 6. 0. 0. 9. 2. 8. 3. 0. 3. 0. 4. 4. 0. 8. 6. 6. 3.
 9. 0. 8. 7. 8. 5. 2. 0. 4. 6. 9. 9. 2. 3. 9. 8. 6. 1. 3. 5. 6. 6. 9. 8.
 5. 3. 3. 2. 1. 9. 3. 6. 5. 2. 0. 8. 8. 9. 5. 0. 8. 3. 3. 1. 5. 1. 5. 0.
 8. 2. 5. 9. 9. 4. 0. 2. 1. 3. 9. 6. 8. 5. 7. 7. 2. 4. 6. 0. 0. 3. 4. 2.
 5. 8. 5. 6. 2. 8. 0. 0. 2. 7. 4. 2. 3. 4. 8. 8. 0. 3. 0. 9. 5. 7. 5. 5.
 2. 6. 5. 0. 2. 8. 2. 4. 7. 1. 7. 8. 9. 4. 6. 6. 7.