In [None]:
import numpy as np
import matplotlib.cm as cmap
import time
import os.path
import scipy
import pickle as pickle
from struct import unpack
from brian2 import *
import brian2
from brian2tools import *

Functions

In [None]:
def get_matrix_from_file(fileName):
    offset = 4
    if fileName[-4-offset] == 'X':
        n_src = n_input
    else:
        if fileName[-3-offset]=='e':
            n_src = n_e
        else:
            n_src = n_i
    if fileName[-1-offset]=='e':
        n_tgt = n_e
    else:
        n_tgt = n_i
    readout = np.load(fileName)
    print(readout.shape, fileName)
    value_arr = np.zeros((n_src, n_tgt))
    if not readout.shape == (0,):
        value_arr[np.int32(readout[:,0]), np.int32(readout[:,1])] = readout[:,2]
    return value_arr

In [None]:
def update_2d_input_weights(im, fig):
    weights = get_2d_input_weights()
    im.set_array(weights)
    fig.canvas.draw()
    return im

In [None]:
def get_2d_input_weights():
    name = 'XeAe'
    weight_matrix = np.zeros((n_input, n_e))
    n_e_sqrt = int(np.sqrt(n_e))
    n_in_sqrt = int(np.sqrt(n_input))
    num_values_col = n_e_sqrt*n_in_sqrt
    num_values_row = num_values_col
    rearranged_weights = np.zeros((num_values_col, num_values_row))
    connMatrix = np.zeros((n_input, n_e))
    connMatrix[connections[name].i, connections[name].j] = connections[name].w
    weight_matrix = np.copy(connMatrix)
    
    for i in range(n_e_sqrt):
        for j in range(n_e_sqrt):
                rearranged_weights[i*n_in_sqrt : (i+1)*n_in_sqrt, j*n_in_sqrt : (j+1)*n_in_sqrt] = weight_matrix[:, i + j*n_e_sqrt].reshape((n_in_sqrt, n_in_sqrt))
    return rearranged_weights

In [None]:
def plot_2d_input_weights():
    name = 'XeAe'
    weights = get_2d_input_weights()
    fig = figure(fig_num, figsize = (18, 18))
    im2 = imshow(weights, interpolation = "nearest", vmin = 0, vmax = wmax_ee, cmap = cmap.get_cmap('hot_r'))
    colorbar(im2)
    title('weights of connection' + name)
    fig.canvas.draw()
    return im2, fig

In [None]:
def plot_performance(fig_num):
    num_evaluations = int(num_examples/update_interval)
    time_steps = range(0, num_evaluations)
    performance = np.zeros(num_evaluations)
    fig = figure(fig_num, figsize = (5, 5))
    fig_num += 1
    ax = fig.add_subplot(111)
    im2, = ax.plot(time_steps, performance) #my_cmap
    ylim(ymax = 100)
    title('Classification performance')
    fig.canvas.draw()
    return im2, performance, fig_num, fig

In [None]:
def get_new_assignments(result_monitor, input_numbers):
    assignments = np.zeros(n_e)
    input_nums = np.asarray(input_numbers)        # convers list into numpy.ndarray
    maximum_rate = [0] * n_e
    for j in range(10):
        num_assignments = len(np.where(input_nums == j)[0])
        if num_assignments > 0:
            rate = np.sum(result_monitor[input_nums == j], axis = 0) / num_assignments
        for i in range(n_e):
            if rate[i] > maximum_rate[i]:
                maximum_rate[i] = rate[i]
                assignments[i] = j
    return assignments

In [None]:
def save_connections(ending=''):
    print('save connections')
    for connName in save_conns:
        conn = connections[connName]
        connListSparse = zip(conn.i[:], conn.j[:], conn.w[:])
        print(connListSparse)
        np.save(data_path + 'weights/' + connName + ending, connListSparse)

In [None]:
def save_theta(ending=''):
    print('save theta')
    for pop_name in population_names:
        np.save(data_path + 'weights/theta_' + pop_name + ending, neuron_groups[pop_name + 'e'].theta)

In [None]:
def get_labeled_data(picklename, bTrain = True):
    """Read input-vector (image) and target class (label, 0-9) and return
       it as list of tuples.
    """
    if os.path.isfile('%s.pickle' % picklename):
        data = pickle.load(open('%s.pickle' % picklename, 'rb'))
    else:
        # Open the images with gzip in read binary mode
        if bTrain:
            images = open(MNIST_data_path + 'train-images.idx3-ubyte','rb')
            labels = open(MNIST_data_path + 'train-labels.idx1-ubyte','rb')
        else:
            images = open(MNIST_data_path + 't10k-images.idx3-ubyte','rb')
            labels = open(MNIST_data_path + 't10k-labels.idx1-ubyte','rb')
        # Get metadata for images
        images.read(4)  # skip the magic_number
        number_of_images = unpack('>I', images.read(4))[0]
        rows = unpack('>I', images.read(4))[0]
        cols = unpack('>I', images.read(4))[0]
        # Get metadata for labels
        labels.read(4)  # skip the magic_number
        N = unpack('>I', labels.read(4))[0]

        if number_of_images != N:
            raise Exception('number of labels did not match the number of images')
        # Get the data
        x = np.zeros((N, rows, cols), dtype=np.uint8)  # Initialize numpy array
        y = np.zeros((N, 1), dtype=np.uint8)  # Initialize numpy array
        for i in range(N):
            if i % 1000 == 0:
                print("i: %i" % i)
            x[i] = [[unpack('>B', images.read(1))[0] for unused_col in range(cols)]  for unused_row in range(rows) ]
            y[i] = unpack('>B', labels.read(1))[0]

        data = {'x': x, 'y': y, 'rows': rows, 'cols': cols}
        pickle.dump(data, open("%s.pickle" % picklename, "wb"), encoding='bytes')
    return data

In [None]:
def normalize_weights():
    for connName in connections:
        if connName[1] == 'e' and connName[3] == 'e':
            len_source = len(connections[connName].source)
            len_target = len(connections[connName].target)
            connection = np.zeros((len_source, len_target))
            connection[connections[connName].i, connections[connName].j] = connections[connName].w
            temp_conn = np.copy(connection)
            colSums = np.sum(temp_conn, axis = 0)
            colFactors = weight['ee_input']/colSums
            for j in range(n_e):#
                temp_conn[:,j] *= colFactors[j]
            connections[connName].w = temp_conn[connections[connName].i, connections[connName].j]

In [None]:
def get_recognized_number_ranking(assignments, spike_rates):
    summed_rates = [0] * 10
    num_assignments = [0] * 10
    for i in range(10):
        num_assignments[i] = len(np.where(assignments == i)[0])
        if num_assignments[i] > 0:
            summed_rates[i] = np.sum(spike_rates[assignments == i]) / num_assignments[i]
    return np.argsort(summed_rates)[::-1]

CODE

In [None]:
MNIST_data_path = './MINST/'

np.random.seed(0)

data_path = './'

In [None]:
start = time.time()
training = get_labeled_data(MNIST_data_path + 'training')
end = time.time()
print('time needed to load training set:', end - start)

start = time.time()
testing = get_labeled_data(MNIST_data_path + 'testing', bTrain = False)
end = time.time()
print('time needed to load test set:', end - start)

In [None]:
#This is for training

weight_path = './test_random/'
num_examples = 60000 * 3                  # ???? why 3 times, need to redure 60000 to lower
do_plot_performance = True
record_spikes = True                      # explicit True
ee_STDP_on = True

n_input = 784
n_e = 400
n_i = n_e

single_example_time = 0.35 * second #
resting_time = 0.15 * second

v_rest_e = -65. * mV     #
v_rest_i = -60. * mV     #
v_reset_e = -65. * mV    #
v_reset_i = -45. * mV    #
v_thresh_e = -52. * mV   #
v_thresh_i = -40. * mV   #
refrac_e = 5. * ms     #
refrac_i = 2. * ms     #

weight = {}
delay = {}

# no idea
weight['ee_input'] = 78.
delay['ee_input'] = (0*ms, 10*ms) #min and max delay
delay['ei_input'] = (0*ms, 5*ms)                 # is not necessary

input_intensity = 2.
start_input_intensity = input_intensity   

tc_pre_ee = 20*ms
tc_post_1_ee = 20*ms
tc_post_2_ee = 40*ms
nu_ee_pre =  0.0001      # learning rate
nu_ee_post = 0.01       # learning rate
wmax_ee = 1.0

save_conns = ['XeAe']
population_names = ['A']

In [None]:
# no idea with number of exaples

# if num_examples <= 10000:
#     update_interval = num_examples
#     weight_update_interval = 20
# else:
#     update_interval = 10000
#     weight_update_interval = 100
# if num_examples <= 60000:
#     save_connections_interval = 10000
# else:
#     save_connections_interval = 10000
#     update_interval = 10000

weight_update_interval = 100
save_connections_interval = 10000
update_interval = 10000

Neuron Group

In [None]:
neuron_eqs_e = '''
        dv/dt = ((v_rest_e - v) + g_e*(-v) + g_i*(-100.*mV - v) ) / (100*ms)  : volt (unless refractory)
        dg_e/dt = -g_e/(1.0*ms)                                    : 1
        dg_i/dt = -g_i/(2.0*ms)                                    : 1
        dtheta/dt = -theta/(1e7*ms)                                : volt
        dtimer/dt = 0.1                                            : second
'''

neuron_eqs_i = '''
        dv/dt = ((v_rest_i - v) +  g_e*(-v) + g_i*(-85.*mV - v)) / (10*ms)  : volt (unless refractory)
        dg_e/dt = -g_e/(1.0*ms)                                    : 1
        dg_i/dt = -g_i/(2.0*ms)                                    : 1
'''

tc_theta = 1e7 * ms
theta_plus_e = 0.05 * mV

v_thresh_e_str = '(v > (theta - 20.0*mV + v_thresh_e)) and (timer > refrac_e)'
v_thresh_i_str = 'v > v_thresh_i'

scr_e = 'v = v_reset_e; theta += theta_plus_e; timer = 0*ms'
v_reset_i_str = 'v=v_reset_i'

neuron_groups = {}

neuron_groups['Ae'] = NeuronGroup(n_e, neuron_eqs_e, threshold= v_thresh_e_str, refractory= refrac_e, reset= scr_e, method='euler')
neuron_groups['Ai'] = NeuronGroup(n_i, neuron_eqs_i, threshold= v_thresh_i_str, refractory= refrac_i, reset= v_reset_i_str, method='euler')

neuron_groups['Ae'].v = v_rest_e - 40. * mV
neuron_groups['Ai'].v = v_rest_i - 40. * mV

neuron_groups['Ae'].theta = np.ones((n_e)) * 20.0*mV

create network population and recurrent connections

In [None]:
connections = {}
rate_monitors = {}
spike_monitors = {}
spike_counters = {}
input_groups = {}
result_monitor = np.zeros((update_interval,n_e))    # (10000, 400)

eqs_stdp_ee = '''
dpre/dt   =   -pre/(tc_pre_ee)         : 1 (event-driven)
dpost1/dt  = -post1/(tc_post_1_ee)     : 1 (event-driven)
dpost2/dt  = -post2/(tc_post_2_ee)     : 1 (event-driven)
'''
eqs_stdp_pre_ee = 'pre = 1.; w = clip(w + nu_ee_pre * post1, 0, wmax_ee)'
eqs_stdp_post_ee = 'post1 = 1.; post2 = 1.; w = clip(w + nu_ee_post * pre * post2, 0, wmax_ee)'

# AeAi
weightMatrix = get_matrix_from_file(weight_path + '../test_random/AeAi.npy')
model = 'w : 1'
pre = ''
post = ''

connections['AeAi'] = Synapses(neuron_groups['Ae'], neuron_groups['Ai'], model=model, on_pre=pre, on_post=post)
connections['AeAi'].connect(True)   # all-to-all connection
connections['AeAi'].w = weightMatrix[connections['AeAi'].i, connections['AeAi'].j]

# AiAe
weightMatrix = get_matrix_from_file(weight_path + '../test_random/AiAe.npy')
model = 'w : 1'
pre = ''
post = ''

connections['AiAe'] = Synapses(neuron_groups['Ai'], neuron_groups['Ae'], model=model, on_pre=pre, on_post=post)
connections['AiAe'].connect(True)   # all-to-all connection
connections['AiAe'].w = weightMatrix[connections['AiAe'].i, connections['AiAe'].j]

In [None]:
print('create monitors for A')
rate_monitors['Ae'] = PopulationRateMonitor(neuron_groups['Ae'])
rate_monitors['Ai'] = PopulationRateMonitor(neuron_groups['Ai'])

spike_counters['Ae'] = SpikeMonitor(neuron_groups['Ae'])

spike_monitors['Ae'] = SpikeMonitor(neuron_groups['Ae'])
spike_monitors['Ai'] = SpikeMonitor(neuron_groups['Ai'])

create input population and connections from input populations

In [None]:
input_groups['Xe'] = PoissonGroup(n_input, 0*Hz)
rate_monitors['Xe'] = PopulationRateMonitor(input_groups['Xe'])

In [None]:
print('create connections between X and A')

weightMatrix = get_matrix_from_file('./test_random/XeAe.npy')
model = 'w : 1'
pre = ''
post = ''
model += eqs_stdp_ee
pre += eqs_stdp_pre_ee
post = eqs_stdp_post_ee

connections['XeAe'] = Synapses(input_groups['Xe'], neuron_groups['Ae'], model=model, on_pre=pre, on_post=post)
minDelay = delay['ee_input'][0]
maxDelay = delay['ee_input'][1]
deltaDelay = maxDelay - minDelay

# TODO: test this  ?????????????
connections['XeAe'].connect(True) # all-to-all connection
connections['XeAe'].delay = 'minDelay + rand() * deltaDelay'
connections['XeAe'].w = weightMatrix[connections['XeAe'].i, connections['XeAe'].j]

In [None]:
net = Network()
for obj_list in [neuron_groups, input_groups, connections, rate_monitors, spike_monitors, spike_counters]:
    print(obj_list)
    for key in obj_list:
        print(key)
        net.add(obj_list[key])

In [None]:
previous_spike_count = np.zeros(n_e)           # (400,)
assignments = np.zeros(n_e)                   #  (400,)
input_numbers = [0] * num_examples               # array 180000 long
outputNumbers = np.zeros((num_examples, 10))    # (180000, 10)

In [None]:
fig_num = 1

input_weight_monitor, fig_weights = plot_2d_input_weights()
fig_num += 1
    
performance_monitor, performance, fig_num, fig_performance = plot_performance(fig_num)

input_groups['Xe'].rates = 0 * Hz  

net.run(0*second)

In [None]:
j = 0
while j < (int(num_examples)):
    normalize_weights()
    spike_rates = training['x'][j%60000,:,:].reshape((n_input)) / 8. *  input_intensity  
    # (28,28) => (784,)      divide by 4.  0>0Hz and 255>64.925Hz     # input_intensity is 2 ???
    
    input_groups['Xe'].rates = spike_rates * Hz
    print('run example number:', j+1, 'of', int(num_examples))
    
    net.run(single_example_time, report='text')   # 0.35 s

    if j % update_interval == 0 and j > 0:    #update_interval is 10000
        assignments = get_new_assignments(result_monitor[:], input_numbers[j-update_interval : j])    #no idea on function
        
    if j % weight_update_interval == 0:                                    # weight_update_interval 100
        update_2d_input_weights(input_weight_monitor, fig_weights)
        
    if j % save_connections_interval == 0 and j > 0 and not test_mode:    # save_connections_interval = 10000
        save_connections(str(j))
        save_theta(str(j))

    current_spike_count = np.asarray(spike_counters['Ae'].count[:]) - previous_spike_count
    previous_spike_count = np.copy(spike_counters['Ae'].count[:])
    
    if np.sum(current_spike_count) < 5:   # if less than 5 spikes within 350ms, fire rate is increased by 32Hz
        input_intensity += 1
        input_groups['Xe'].rates = 0 * Hz
        net.run(resting_time)                              # 150 ms
    
    else:                                                   
        result_monitor[j%update_interval,:] = current_spike_count    # update_interval = 10000
        input_numbers[j] = training['y'][j%60000][0]
        
        outputNumbers[j,:] = get_recognized_number_ranking(assignments, result_monitor[j%update_interval,:])
        
        if j % 100 == 0 and j > 0:
            print('runs done:', j, 'of', int(num_examples))
        if j % update_interval == 0 and j > 0:
            if do_plot_performance:
                unused, performance = update_performance_plot(performance_monitor, performance, j, fig_performance)
                print('Classification performance', performance[:(j/float(update_interval))+1])
        for i,name in enumerate(input_population_names):
            input_groups[name+'e'].rates = 0 * Hz
        net.run(resting_time)
        input_intensity = start_input_intensity
        j += 1

In [None]:
print('save results')
if not test_mode:
    save_theta()
if not test_mode:
    save_connections()
else:
    np.save(data_path + 'activity/resultPopVecs' + str(num_examples), result_monitor)
    np.save(data_path + 'activity/inputNumbers' + str(num_examples), input_numbers)

In [None]:
if rate_monitors:
    b2.figure(fig_num)
    fig_num += 1
    for i, name in enumerate(rate_monitors):
        b2.subplot(len(rate_monitors), 1, 1+i)
        b2.plot(rate_monitors[name].t/b2.second, rate_monitors[name].rate, '.')
        b2.title('Rates of population ' + name)

if spike_monitors:
    b2.figure(fig_num)
    fig_num += 1
    for i, name in enumerate(spike_monitors):
        b2.subplot(len(spike_monitors), 1, 1+i)
        b2.plot(spike_monitors[name].t/b2.ms, spike_monitors[name].i, '.')
        b2.title('Spikes of population ' + name)

if spike_counters:
    b2.figure(fig_num)
    fig_num += 1
    b2.plot(spike_monitors['Ae'].count[:])
    b2.title('Spike count of population Ae')





plot_2d_input_weights()

plt.figure(5)

subplot(3,1,1)

brian_plot(connections['XeAe'].w)
subplot(3,1,2)

brian_plot(connections['AeAi'].w)

subplot(3,1,3)

brian_plot(connections['AiAe'].w)


plt.figure(6)

subplot(3,1,1)

brian_plot(connections['XeAe'].delay)
subplot(3,1,2)

brian_plot(connections['AeAi'].delay)

subplot(3,1,3)

brian_plot(connections['AiAe'].delay)


b2.ioff()
b2.show()


In [None]:
spike_rates = result_monitor[0%update_interval,:]

In [None]:
spike_rates.shape

In [None]:
num_assignments = [0] * 10

In [None]:
np.sum(spike_rates[assignments == 0]) / num_assignments[0]

In [None]:
(spike_rates[assignments == 0]).shape