In [2]:
from brian2 import *
from time import clock
from get_MNIST import *
import matplotlib.pyplot as plt

#MNIST_data_path = '../data/'
#training = get_labeled_data(MNIST_data_path + 'training.pickle', bTrain=True)
#X = [training['x'][i, :, :].reshape(784) for i in np.arange(500)]
#y = [training['y'][i] for i in np.arange(500)]

% matplotlib notebook
import matplotlib.patches as mpatches
from sklearn.datasets import load_digits

classes = 3
mnist = load_digits(classes)

# set all samples' intensity to [0 ... 1]
X,y = mnist.data, mnist.target
X /= np.max(X)


# normalizing all input samples: sum(Intensity) per sample = 20
X_norm = np.zeros((len(y), 64))
for i in np.arange(len(y)):
    X_norm[i] = X[i] / np.sum(X[i]) * 20
    
X = X_norm / np.max(X_norm)


start_scope()
#BrianLogger.suppress_hierarchy('brian2.codegen')
prefs.codegen.target = "numpy"

# function to visualise the connectivity of Synapses S
def visualise_connectivity(S):
    Ns = len(S.source)
    Nt = len(S.target)
    figure(figsize=(10, 4))
    subplot(121)
    plot(zeros(Ns), arange(Ns), 'ok', ms=10)
    plot(ones(Nt), arange(Nt), 'ok', ms=10)
    for i, j in zip(S.i, S.j):
        plot([0, 1], [i, j], '-k')
    xticks([0, 1], ['Source', 'Target'])
    ylabel('Neuron index')
    xlim(-0.1, 1.1)
    ylim(-1, max(Ns, Nt))
    subplot(122)
    plot(S.i, S.j, 'ok')
    xlim(-1, Ns)
    ylim(-1, Nt)
    xlabel('Source neuron index')
    ylabel('Target neuron index')

In [5]:
time_step = 1*ms #timestep for simulation

n_input_width = 8 #
n_input_height = 8 #
n_hidden_width = 8 #
n_hidden_height = 8 #
lr=0.1 # learning rate for forward connections
lr_neg=1.0 # learning rate for negative intrinsic connections
runtime = 100*ms # simulation time per one image sample
tau = 10*ms # tau for neuron's voltage 'v'
tau_I = 15*ms # tau for neuron's current 'I'
threshold = 3.5 # theshold for neuron's voltage to generate a spike
c = 1. # coefficient to calculate difference between rate and average rate: difference = rate - c * average_rate
coef = 1. #actually don't remember about this parameter, but if it is equal to one, it doesn't change anything
alpha = 30*ms #tau for computing exponential mean of spike-array to calculate rate 'a'
beta= 500*ms #tau for computing exponential mean of spike-array to calculate average_rate 'theta'

In [6]:
# equations for input poisson neurons (added rates variable for each neuron, no voltage and current)
eqs_input_neuron = '''
da/dt = -a/alpha : 1
dtheta/dt = -theta/beta : 1
diff = a - c*theta : 1
rates : Hz
train : 1
'''
# equations for neurons of hidden layer
eqs_neuron = '''
dv/dt = (V0-v+I)/tau : 1 (unless refractory)
dI/dt = -I/(tau_I) : 1
da/dt = -a/alpha : 1
dtheta/dt = -theta/beta : 1
V0 :1
diff = a - c*theta : 1
train : 1
'''

# each forward connection synapse has it's' own weight 'w'
eqs_syn = '''
w : 1
'''

# each negative connection synapse has it's own weight 'w'
eqs_syn_neg = '''
w_neg: 1
'''

# equations that describe changes if presynaptic spike of the forward synapse occures
eqs_pre = '''
I_post += w
a_pre += 1*ms/alpha
theta_pre += 1*ms/beta
'''
# equations that describe changes if postsynaptic spike of the forward synapse occures
eqs_post = '''
w = clip(w + train_pre * (-0.001 + lr*diff_pre), 0, wmax)
'''

# equations that describe changes if postsynaptic spike of the forward synapse occures
eqs_last_post = '''
w = clip(w + train_pre * (-0.001 + lr*diff_pre), 0, wmax)
a_post += 1*ms/alpha
theta_post += 1*ms/beta
'''

#w = clip(w +(-lr*log(1+exp(-3*(w+0.1)*diff_pre))+0.05)*(wmax-w), 0, wmax)

# equations that describe changes if presunaptic spike of the negative synapse occures
eqs_neg_pre = '''
I_post -= w_neg
w_neg = clip(w_neg + train_pre * (lr_n*diff_post), 0, wmax)
'''

In [26]:
class SpikingNN():
    def __init__(self, n_input_height = 8, n_input_width = 8, n_hidden_width = n_input_width, n_hidden_height = n_input_height, 
                 lr=0.1, lr_neg=0.1, runtime = 100*ms,
                 tau = 10*ms, tau_I = 15*ms, threshold = 3.5, c = 1., coef = 1., alpha = 60*ms, 
                 beta= 500*ms, eqs_input_neuron = eqs_input_neuron, eqs_neuron = eqs_neuron, 
                 eqs_syn = eqs_syn, eqs_syn_neg = eqs_syn_neg, eqs_pre = eqs_pre, eqs_post = eqs_post, eqs_neg_pre = eqs_neg_pre):
        
        # set all parameters to the object variables
        self.n_input_width = n_input_width
        self.n_input_height = n_input_height
        self.n_hidden_width = n_hidden_width
        self.n_hidden_height = n_hidden_height
        self.lr = lr
        self.lr_neg = lr_neg
        self.threshold = threshold
        self.c = c
        
        self.eqs_input_neuron = eqs_input_neuron
        self.eqs_neuron = eqs_neuron
        self.eqs_syn = eqs_syn
        self.eqs_syn_neg = eqs_syn_neg
        self.eqs_pre = eqs_pre
        self.eqs_post = eqs_post
        self.eqs_neg_pre = eqs_neg_pre
        
        self.wmax = 1.
        self.coef = coef
        self.runtime = runtime*self.coef
        self.tau = tau*self.coef
        self.tau_I = tau_I*self.coef
        self.alpha = alpha*self.coef
        self.beta = beta*self.coef
        
        # array of predictions
        self.pred = []
        
        
    def architecture(self):
        
        # Create Neurons and Connections between them.
        # P -> InputConnection -> G -> OutputConnection -> H (N = num classes) -> PruneConnection -> R  
        #                         |
        #                 NegativeConnection
        #                         |
        #                         G
        # R is needed for H's a and theta calculations
        
        P = NeuronGroup(self.n_input_width * self.n_input_height, self.eqs_input_neuron, method=linear, 
                        threshold='rand()<rates*dt', dt = time_step)
        P.a = np.random.random(self.n_input**2)
        P.theta = np.random.random(self.n_input**2)
        
        # train is parameter that allows synaptic weights to change (if train==1) or disables weight change (if train==0)
        P.train = np.ones_like(P.train)
        
        G = NeuronGroup(self.n_hidden**2, self.eqs_neuron, method=linear, threshold='v > hold', reset='v=0', refractory=2*ms, 
                        dt=time_step)
        G.train = np.ones_like(G.train)
        InputConnection = Synapses(P, G, self.eqs_syn, on_pre=self.eqs_pre, on_post=self.eqs_post, dt=time_step)
        
        InputConnection.connect()
        InputConnection.w = np.random.random(self.n_input*self.n_input*self.n_hidden*self.n_hidden)
        
        NegativeConnection = Synapses(G, G, self.eqs_syn_neg, on_pre=self.eqs_neg_pre, dt=time_step)
        NegativeConnection.connect(condition='i != j')
        NegativeConnection.w_neg = np.ones((self.n_hidden*self.n_hidden-1)*self.n_hidden*self.n_hidden) * 0.5
        
        H = NeuronGroup(classes, self.eqs_neuron, method=linear, threshold='v > 3.5', reset='v=0', refractory=2*ms, dt=time_step)
        H.train = np.ones_like(H.train)
        
        OutputConnection = Synapses(G, H, self.eqs_syn, on_pre=self.eqs_pre, on_post=self.eqs_post, dt=time_step)
        OutputConnection.connect()
        
        R = NeuronGroup(1, self.eqs_neuron, method=linear, threshold='v > 3.5', reset='v=0', refractory=2*ms, dt=time_step)
        R.train = np.ones_like(R.train)
        
        PruneConnection = Synapses(H, R, self.eqs_syn, on_pre=self.eqs_pre, on_post=self.eqs_post, dt=time_step)
        PruneConnection.connect()
        
        return P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection
    
    def pp(self, s=0, d=0, w = 1, h = 1):
        return [mpatches.Rectangle((i+d*(self.n_input+1), j+s*(self.n_input+1)), 1, 1) 
                for j in np.arange(h-1, -1, -1) for i in np.arange(w)]

    def create_canvas(self):
        back_color = "yellow"
        width, height = self.n_input*self.n_hidden + self.n_hidden-1, self.n_input*self.n_hidden + self.n_hidden-1

        self.fig, (self.ax, self.ax2, self.ax3) = plt.subplots(1, 3, figsize=(15,5))
        self.ax.set(xlim=[0, width], ylim=[0, height]) # Or use "ax.axis([x0,x1,y0,y1])"
        self.ax2.set(xlim=[0, self.n_hidden], ylim=[0, self.n_hidden]) # Or use "ax.axis([x0,x1,y0,y1])"
        self.ax3.set(xlim=[0, self.n_input], ylim=[0, self.n_input]) # Or use "ax.axis([x0,x1,y0,y1])"
        self.fig.canvas.draw()

        self.patches = []
        self.patches2 = []
        self.patches3 = []
        self.patches = [self.pp((self.n_hidden-1-s),d, self.n_input, self.n_input) for s in np.arange(self.n_hidden) 
                        for d in np.arange(self.n_hidden)]
        self.patches2 = self.pp(0,0, self.n_hidden, self.n_hidden)
        self.patches3 = self.pp(0,0, self.n_input, self.n_input)
        
        for s in np.arange(len(self.patches)):
            for j in np.arange(len(self.patches[0])):
                self.ax.add_artist(self.patches[s][j])
                
        for j in np.arange(len(self.patches2)):
            self.ax2.add_artist(self.patches2[j])
            
        for j in np.arange(len(self.patches3)):
            self.ax3.add_artist(self.patches3[j])
        
    def update(self, P, G, InputConnection, NegativeConnection,  H, OutputConnection, R, PruneConnection, k, rate):
        
        lr = self.lr
        lr_n = self.lr_neg
        hold = self.threshold
        c = self.c
        tau = self.tau
        tau_I = self.tau_I
        alpha = self.alpha 
        beta = self.beta
        wmax = self.wmax
        
        if k % 1 == 0:
            print k

        P.rates = 250 * self.coef * rate * Hz
        run(self.runtime)
        rate_z = np.zeros(n_input*n_input)
        P.rates = rate_z*Hz
        run(self.runtime*0.3)

        colors = np.zeros((self.n_hidden*self.n_hidden, self.n_input*self.n_input))
        for s in np.arange(len(self.patches)):
            for j in np.arange(len(self.patches[0])):
                colors[s][j]=InputConnection.w[s+j*self.n_hidden**2]
                self.patches[s][j].set_facecolor(np.array([colors[s][j], colors[s][j], colors[s][j]]))
        
        colors_neg = np.zeros((self.n_hidden*self.n_hidden))
        
        for j in np.arange(self.n_hidden**2):        
            colors_neg[j]=NegativeConnection.w_neg[j*(self.n_hidden**2-1)]
            self.patches2[j].set_facecolor(np.array([colors_neg[j], colors_neg[j], colors_neg[j]]))
            
        for j in np.arange(len(self.patches3)):
            self.patches3[j].set_facecolor(np.array([rate[j], rate[j], rate[j]]))

        self.fig.canvas.blit(self.ax.bbox)
        self.fig.canvas.blit(self.ax2.bbox)
        self.fig.canvas.blit(self.ax3.bbox)
        
        self.fig.canvas.draw()
    
    def train(self, X, y, epochs_train, epochs_test, show = True, classifier = False, train = True, 
              pretrain = False, cheat = False):
        
        if cheat:
            P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection = self.architecture()
            
            dig_y = y[:self.n_hidden**2]
            dig_X = X[:self.n_hidden**2]
            for s in np.arange(self.n_hidden**2):
                for j in np.arange(self.n_input**2):
                    InputConnection.w[s+j*self.n_hidden**2] = dig_X[s][j]
            
            NegativeConnection.w_neg = np.ones_like(NegativeConnection.w_neg)
            if show:
                plt.show()
            
            self.P = P
            self.G = G
            self.InputConnection = InputConnection
            self.NegativeConnection = NegativeConnection
            self.OutputConnection = OutputConnection
            self.H = H
            self.R = R
            self.PruneConnection = PruneConnection
            
            colors = np.zeros((self.n_hidden*self.n_hidden, self.n_input*self.n_input))
            for s in np.arange(len(self.patches)):
                for j in np.arange(len(self.patches[0])):
                    colors[s][j]=InputConnection.w[s+j*self.n_hidden**2]
                    self.patches[s][j].set_facecolor(np.array([colors[s][j], colors[s][j], colors[s][j]]))

            self.fig.canvas.blit(self.ax.bbox)
        
            self.fig.canvas.draw()
            
        if train:
            if pretrain:
                P = self.P
                G = self.G
                InputConnection = self.InputConnection
                NegativeConnection = self.NegativeConnection
                OutputConnection = self.OutputConnection
                H = self.H
                R = self.R
                PruneConnection = self.PruneConnection
                
                P.train = np.ones_like(P.train)
                G.train = np.ones_like(G.train)
                H.train = np.ones_like(H.train)
                R.train = np.ones_like(R.train)
            
            else:
                P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection = self.architecture()

            k = 0

            if show:
                plt.show()

            while k < epochs_train:
                rate =  X[k%(X.shape[0])].reshape(self.n_input**2)
                self.update(P, G, InputConnection, NegativeConnection,  H, OutputConnection, R, PruneConnection, k, rate)
                k += 1

            self.P = P
            self.G = G
            self.InputConnection = InputConnection
            self.NegativeConnection = NegativeConnection
            self.OutputConnection = OutputConnection
            self.H = H
            self.R = R
            self.PruneConnection = PruneConnection
        
        if classifier:
            P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection = self.create_classifier(
                P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection, X, y)

            SpikeM = SpikeMonitor(H)
            self.SpikeMonitor = SpikeM
            StateM = StateMonitor(H, ['v', 'I', 'a'], record=True)
            self.StateMonitor = StateM
            StateMG = StateMonitor(G, ['v', 'I', 'a'], record=True)
            self.StateMonitorG = StateMG

            lr = self.lr
            lr_n = self.lr_neg
            hold = self.threshold
            c = self.c
            tau = self.tau
            tau_I = self.tau_I
            alpha = self.alpha 
            beta = self.beta
            wmax = self.wmax

            k=0
            print 'testing'
            while k < epochs_test:

                rate =  X[(k)%(X.shape[0])].reshape(self.n_input**2)
                P.rates = 250 * self.coef * rate * Hz
                start = clock()
                run(self.runtime)
                end = clock()
                
                self.pred.append(np.argmax(self.StateMonitor.a[:,-1]))
                accuracy = self.accuracy(y[:len(self.pred)], self.pred)
                
                if k % 1 == 0:
                    print k, accuracy, '%'
                    
                print end - start, 'runtime for 100*ms classification test'
                
                P.rates = np.zeros_like(P.rates)
                
                G.v = np.zeros_like(G.v)
                G.I = np.zeros_like(G.I)
                
                H.a = np.zeros_like(H.a)
                H.theta = np.zeros_like(H.theta)
                #H.diff = np.zeros_like(H.diff)
                H.v = np.zeros_like(H.v)
                H.I = np.zeros_like(H.I)
                # run(self.runtime)
                k += 1
                
        self.P = P
        self.G = G
        self.InputConnection = InputConnection
        self.NegativeConnection = NegativeConnection
        self.OutputConnection = OutputConnection
        self.H = H
        self.R = R
        self.PruneConnection = PruneConnection
            
    def create_classifier(self, P, G, InputConnection, NegativeConnection, H,
                         OutputConnection, R, PruneConnection, X, y):
        
        xx = []

        for j in np.arange(np.max(y)+1):
            xs = []
            for i in np.arange(len(y)):
                if y[i] == j:
                    xs.append(X[i])
            xx.append(xs)
            
        
        x_mean = []
        for i in np.arange(len(xx)):
            x_mean.append(np.mean(np.array(xx[i]), axis=0))

        colors = np.zeros((self.n_hidden*self.n_hidden, self.n_input*self.n_input))
        for s in np.arange(len(self.patches)):
            for j in np.arange(len(self.patches[0])):
                colors[s][j]=net.InputConnection.w[s+j*self.n_hidden**2]
                
        scals = []

        for k in np.arange(np.max(y)+1):
            scal = []
            for i in np.arange(self.n_hidden*self.n_hidden):
                svertka = []
                for j in np.arange(self.n_input*self.n_input):
                    svertka.append(colors[i][j]*x_mean[k][j])
                svertka = np.array(svertka)
                scal.append(np.sum(svertka))
            scals.append(scal)
            

        scals = np.array(scals)
        
        targets = np.argmax(scals, axis = 0)
        self.targets = targets
        
        P.train = np.zeros_like(P.train)
        G.train = np.zeros_like(G.train)
        H.train = np.zeros_like(H.train)
        R.train = np.zeros_like(R.train)
        OutputConnection.w = np.zeros_like(OutputConnection.w)
        
        #CHEEEEAAAAT
        #targets = y
        #self.targets = targets
        for s in np.arange(self.n_hidden**2):
            target = targets[s]
            OutputConnection.w[int(target + s * (np.max(y)+1))] = 1.0
            
        return P, G, InputConnection, NegativeConnection, H, OutputConnection, R, PruneConnection
        
    def accuracy(self, y, pred):
        accuracy = 0

        for a, b in zip(y, pred):
            if a == b:
                accuracy += 1

        accuracy /= float(len(pred))
        return round(accuracy*100, 2)
        

    
        

In [30]:
lr_neg = 0.01
lr = 0.01
c = 1.
holds = np.arange(1, 25, 5)
for h in holds:
    net = SpikingNN(n_hidden=3, lr = lr, lr_neg=lr_neg, c=c, alpha=30*ms, threshold=h, runtime=100*ms)
    print 'hold', h
            
    net.create_canvas()

    net.train(X, y, epochs_train=50, epochs_test=15, train=True, classifier=True, pretrain=False, cheat=False)

hold 1


<IPython.core.display.Javascript object>



0
1
2
3
4
5
6
7
8
9
10
11
12
13


KeyboardInterrupt: 

In [28]:
figure(999)
subplot(211)
#for i in np.arange(len(net.StateMonitorG.I)):
for i in np.arange(5):
    plt.plot(net.StateMonitorG.t/ms, net.StateMonitorG.I[i], label = str(i))
#plt.plot(net.StateMonitor.t/ms, net.StateMonitor.a[4] - net.StateMonitor.a[1], label = str(i))

subplot(212)
for i in np.arange(5):
    plt.plot(net.StateMonitorG.t/ms, net.StateMonitorG.a[i], label = str(i))

for k in np.arange(0, 100*15, 100):
    axvline(k, ls='--', c='r', lw=3)
plt.xlabel('Time (ms)')
plt.ylabel('Neuron index')
plt.legend(loc='best')
plt.show()

<IPython.core.display.Javascript object>

AttributeError: SpikingNN instance has no attribute 'StateMonitorG'

In [10]:
from sklearn.metrics import confusion_matrix
import itertools

acc = net.accuracy(y[:len(net.pred)], net.pred)
print acc

cm = confusion_matrix(y[:len(net.pred)],net.pred)
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
figure(9999)
imshow(cm, cmap=plt.cm.Blues, interpolation='nearest')

thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    plt.text(j, i, round(cm[i, j], 2), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")
        
plt.ylabel('True label')
plt.xlabel('Predicted label')

colorbar()
show()

20.0


<IPython.core.display.Javascript object>

In [11]:
print len(net.patches)
print len(net.patches[0])

9
64


In [12]:
it = net.SpikeMonitor.it
window = 60
rates_list = []

for i,times in zip(it[0],it[1]):
    
    rates = np.zeros(3)
    for j in np.arange(len(rates)):
        if i == j:
            for tt in it[1]:
                d = times/ms - tt/ms
                if d >= 0:
                    if d < window:
                        rates[j] += 1
                else:
                    break
    rates_list.append(rates)
             
 

In [13]:
print y[:10]
figure(1000)
plt.plot(net.SpikeMonitor.t/ms, rl[0])
for k in np.arange(0, 2000, 200):
    axvline(k, ls='--', c='r', lw=3)
plt.xlabel('Time (ms)')
plt.ylabel('Neuron index')
plt.show()

[0 1 2 3 4 0 1 2 3 4]


<IPython.core.display.Javascript object>

NameError: name 'rl' is not defined

In [16]:
xx = []

for j in np.arange(np.max(y)+1):
    xs = []
    for i in np.arange(len(y)):
        if y[i] == j:
            xs.append(X[i])
    xx.append(xs)

In [17]:
x_mean = []
for i in np.arange(len(xx)):
    x_mean.append(np.mean(np.array(xx[i]), axis=0))
print len(x_mean), len(x_mean[0])

5 64


In [18]:
colors = np.zeros((net.n_hidden*net.n_hidden, net.n_input*net.n_input))
for s in np.arange(len(net.patches)):
    for j in np.arange(len(net.patches[0])):
        colors[s][j]=net.InputConnection.w[s+j*net.n_input**2]


IndexError: index 576 is out of bounds for axis 0 with size 576

In [19]:
scals = []

for k in np.arange(np.max(y)+1):
    scal = []
    for i in np.arange(n_hidden*n_hidden):
        svertka = []
        for j in np.arange(n_input*n_input):
            svertka.append(colors[i][j]*x_mean[k][j])
        svertka = np.array(svertka)
        scal.append(np.sum(svertka))
    scals.append(scal)




IndexError: index 9 is out of bounds for axis 0 with size 9

In [20]:
figure(222)
subplot(141)
imshow(np.array(x_mean[0]).reshape(8, 8), cmap ='bone', interpolation='None')
plt.colorbar()

subplot(142)
imshow(np.array(x_mean[1]).reshape(8, 8), cmap ='bone', interpolation='None')
plt.colorbar()

subplot(143)
imshow(np.array(x_mean[2]).reshape(8, 8), cmap ='bone', interpolation='None')
plt.colorbar()

subplot(144)
imshow(np.array(x_mean[3]).reshape(8, 8), cmap ='bone', interpolation='None')
plt.colorbar()

show()

<IPython.core.display.Javascript object>

In [21]:
scals = np.array(scals)
scals.shape
print np.argmax(scals, axis = 0).reshape(8,8)

print np.round(np.amax(scals, axis = 0).reshape(8,8), 2)

ValueError: attempt to get argmax of an empty sequence