In [None]:
%matplotlib notebook
from brian2 import *
from matplotlib.pyplot import *
import numpy as np
import random as rp
import time
import copy
from data_generator import *

In [None]:
start_scope()
eqs = '''
dv/dt = (I-v)/tau : 1
I : 1
tau : second
th : 1
'''

taupre = taupost = 100*ms
wmax = 2
wmin = 0.001

Apre = 0.005
Apost = -Apre*taupre/taupost*1.05

S = NeuronGroup(256, eqs, threshold='v>0.5', reset='v = 0', method='exact')
E = NeuronGroup(64, eqs, threshold='v>1', reset='v = 0', method='exact')
I = NeuronGroup(64, eqs, threshold='v>1', reset='v = 0', method='exact')

StE= Synapses(S, E, 'w : 1', on_pre='v_post += w')
StE.connect(p=1)

EtI= Synapses(E, I,'''
             w : 1
             dapre/dt = -apre/taupre : 1 (event-driven)
             dapost/dt = -apost/taupost : 1 (event-driven)
             ''',
             on_pre='''
             v_post += w
             apre += Apre
             w = clip(w+apost, wmin, wmax)
             ''',
             on_post='''
             apost += Apost
             w = clip(w+apre, wmin, wmax)
             ''')
EtI.connect(condition='i==j')

ItE= Synapses(I, E, 'w : 1', on_pre='v_post = 0')
ItE.connect(condition='i!=j')

S.tau=50*ms
E.tau=100*ms
I.tau=100*ms

E.I=0
I.I=0

StE.w = 0.01*rand(len(StE.w))
EtI.w = 1.5*rand(len(EtI.w))
ItE.w = 0.20*rand(len(ItE.w))

M = StateMonitor(S, 'v', record=True)
spikemonS = SpikeMonitor(S)
spikemonE = SpikeMonitor(E)
spikemonI = SpikeMonitor(I)

net = Network(S, E, I, StE, EtI, ItE, spikemonE, spikemonI, spikemonS)
print("done")

In [None]:
def run_data(data):
    for j in range(len(data)):
        S.I=[f[0]*3for f in data[j]]
        net.run(33*ms)

def train(epochs, samples_per_epoch):
    left = 1
    right = 0
    
    start = time.time()
    for e in range(epochs):
        data, answers = data_generator(samples_per_epoch, pixels=256, objects=1, label_delay=0, 
                                       noise=0.00, left=left, right=right, loop_around=True)
        run_data(data)

        S.I=0
        net.run(150*ms)
        
        left = not left
        right = not right
        
    end = time.time()
    print(end - start,"s")

# Train the network
train(40, 15)

figure(2)
plot(EtI.w, range(len(EtI.w)), '.k')

figure(7)
plot(ItE.w, range(len(ItE.w)), '.k')

#figure(3)
#plot(spikemonS.t/ms, spikemonS.i, '.k')
    
figure(5)
plot(spikemonE.t/ms, spikemonE.i, '.k')

figure(6)
plot(spikemonI.t/ms, spikemonI.i, '.k')

In [None]:
# Turn off learning
Apre = 0
Apost = 0

# Monitor left neurons
spikemon_left = SpikeMonitor(E)
net.add(spikemon_left)

data, answers = data_generator(100, objects=1, label_delay=0, noise=0.00, left=1, right=0, loop_around=True)
figure(10)
imgplot = plt.imshow(data)
plt.show()

run_data(data)

left_t = copy.deepcopy(spikemon_left.t[0:-1])
left_i = copy.deepcopy(spikemon_left.i[0:-1])
net.run(150*ms)

# Monitor right neurons
spikemon_right = SpikeMonitor(E)
net.add(spikemon_right)

data, answers = data_generator(100, objects=1, label_delay=0, noise=0.00, left=0, right=1, loop_around=True)
figure(11)
imgplot = plt.imshow(data)
plt.show()
    
run_data(data)

right_t = copy.deepcopy(spikemon_right.t[0:-1])
right_i = copy.deepcopy(spikemon_right.i[0:-1])

In [None]:
# Label neurons depending on difference in spiking rates
def label_neurons(n, left_i, right_i):    
    left_neurons = np.zeros(n)
    right_neurons = np.zeros(n)
    
    for i in left_i:
        left_neurons[i] += 1
        
    for i in right_i:
        right_neurons[i] += 1
        
    diff = left_neurons - right_neurons 
    left = []
    neutral = []
    right = []
    for i, val in enumerate(diff):
        if val > 0:
            left.append(i)
        elif val < 0:
            right.append(i)
        else:
            neutral.append(i)
    return left, right, neutral

left_neurons, right_neurons, neutral_neurons = label_neurons(64, left_i, right_i)
print("Left neurons", left_neurons)
print("Right neurons", right_neurons)
print("Useless neurons", neutral_neurons)


In [None]:
# Add the output layer
OUT = NeuronGroup(2, eqs, threshold='v>1', reset='v = 0', method='exact')
OUT.tau=100*ms

EtOUT= Synapses(E, OUT, 'w : 1', on_pre='v_post += w')

# Connect left_neurons to one output and right_neurons to the other
EtOUT.connect(i=left_neurons, j=0)
EtOUT.connect(i=right_neurons, j=1)
EtOUT.w = 0.6
net.add(OUT, EtOUT)

In [None]:
# Run test
Apre = 0
Apost = 0
net.run(150*ms)

spikemonOUT = SpikeMonitor(OUT)
net.add(spikemonOUT)

data, answers = data_generator(100, objects=1, label_delay=0, noise=0.00, left=1, right=0, loop_around=True)
run_data(data)

net.run(150*ms)
S.I=0

data, answers = data_generator(100, objects=1, label_delay=0, noise=0.00, left=0, right=1, loop_around=True)
run_data(data)

figure(20)
plot(spikemonOUT.t/ms, spikemonOUT.i, '.k')

l_spike = 0
r_spike = 0
for spike in spikemonOUT.i[0:-1]:
    if spike == 0:
        l_spike += 1
    else:
        r_spike += 1
print(l_spike, r_spike)