In [None]:
# Mount Google Drive as local drive

# Load the Drive helper and mount
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')

In [None]:
!ls '/content/drive/My Drive/MNIST'

In [3]:
!pip install numba

In [None]:
# -*- coding: utf-8 -*-
import os
import struct
import numpy as np
import matplotlib.pyplot as plt
import time
from numba import jit


def read(dataset = "training", path = "."):
    """
    Python function for importing the MNIST data set.  It returns an iterator
    of 2-tuples with the first element being the label and the second element
    being a numpy.uint8 2D array of pixel data for the given image.

    """

    if dataset == "training":
        fname_img = os.path.join(path, 'train-images.idx3-ubyte')
        fname_lbl = os.path.join(path, 'train-labels.idx1-ubyte')
    elif dataset == "testing":
        fname_img = os.path.join(path, 't10k-images.idx3-ubyte')
        fname_lbl = os.path.join(path, 't10k-labels.idx1-ubyte')

    # Load everything in some numpy arrays
    with open(fname_lbl, 'rb') as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        lbl = np.fromfile(flbl, dtype=np.int8)

    with open(fname_img, 'rb') as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        img = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols)


    # Create an iterator which returns each image in turn
    return img,lbl

def read_mnist(path='/content/drive/My Drive/MNIST',mode='training'):
    """
    reads mnist dataset and prepares training and testing data
    Inputs: path: where MNIST dataset is stored (download MNIST dataset from http://yann.lecun.com/exdb/mnist/
                 and extract in a folder)
           mode: 'training'/'testing'
    Returns: images: a matrix with n rows (for n samples) and 784 columns (flattened 28*28 MNIST images)
            labels: an array of length n with labels of each sample
    """
    file=read(dataset=mode,path=path)
    img,label=file
    images=np.zeros((img.shape[1]*img.shape[2],img.shape[0]))
    for i in range(len(label)):
        images[:,i]=np.float32(img[i,:,:]).reshape(-1)/255
    labels=label.reshape(-1,1)
    return images,labels

@jit(nopython=True)
def image2spik(img_in):
    """
    converts images to spike trains
    Inputs: a single image of length 784
    Returns: a matrix with 784 rows and number of timesteps (default 351) columns
             It represents 784 spike trains of lenght 350 ms with 1 ms timestep
    """
    timeStepS = 0.001
    durationS = 0.35

    numTrains = img_in.shape[0]

    times = np.arange(0,durationS+timeStepS,timeStepS)
    spikes = np.zeros((numTrains,len(times)),dtype=np.int64)
    for train in range(numTrains):
        np.random.seed(0)
        vt = np.random.random(times.shape[0])
        freq = img_in[train];
        spikes[train, :] = (freq*timeStepS) > vt
    return spikes

@jit(nopython=True)
def updateweights(spike_e,x_pre1,W_xe,Wmax,eta):
    """
  this is where weights are updated based on STDP learning
    """
    for k in range(len(spike_e)):
       W_temp=np.zeros(W_xe.shape[0],dtype=np.float64)
       W_temp = W_xe[:, k] + spike_e[k]*eta*x_pre1*(np.power((Wmax-W_xe[:,k]),1/5))
       W_temp[np.where(W_temp>1)[0]]=1
       W_temp[np.where(W_temp<0)[0]]=0
       W_xe[:,k]=W_temp
    return W_xe


@jit(nopython=True)
def updateparameters(num_neurons,img,spikes_inp,spikes_e,spikes_i,dt,R,Vrest_e,Vrest_i,
                     Vreset_e,Vreset_i,Vthresh_e,Vthresh_i,Vspike,refrac_e,refrac_i,
                     tau_M_e,tau_M_i,tau_ge,tau_gi,Wmax,tau_xpre,eta,alph,mu,xtar,
                     tau_theta,theta_inc,a_xe,a_ei,a_ie,W_xe,W_ei,W_ie,ge_xe,ge_ei,
                     ge_ie,I_xe,I_ei,I_ie,V_e,r_e,theta_e,V_i,r_i,VeArray,ViArray,
                     IxeArray,IeiArray,IieArray,x_pre):
    """
    updates network parameters during training
    """
    WN=0
    marr=np.array([5,3,2,1,1,0])
    flag=0
    while (np.sum(spikes_e,axis=0).max())<5:
        WN+=1
        if WN>10:
            break
        for i in np.arange(1,351):
            spike_e=np.zeros(num_neurons,dtype=np.float64)
            spike_i=np.zeros(num_neurons,dtype=np.float64)
            ge_xe = ge_xe + spikes_inp[i,:];
            I_xe = np.multiply(np.dot(ge_xe,W_xe),a_xe)
            ge_xe = (1-dt/tau_ge)*ge_xe
            I_e =I_xe+ I_ie
            #Neuron update - exc layer.
            V_e = V_e + dt/tau_M_e*(Vrest_e - V_e + R*I_e)
            V_e[np.where(r_e>0)[0]] = Vreset_e
            r_e[np.where(r_e>0)[0]] = r_e[np.where(r_e>0)[0]] - 1

            V_e[np.where(V_e>=(Vthresh_e +theta_e))[0]] =Vspike      #Spike occurs
            r_e[np.where(V_e>=(Vthresh_e +theta_e))[0]] = refrac_e

            spike_e[np.where(V_e>=(Vthresh_e +theta_e))[0]]=1
            spikes_e[i, :] = spike_e

            # SFA - exc layer
            # update and decay of theta.
            theta_e = theta_e + theta_inc*spike_e
            theta_e = (1-dt/tau_theta)*theta_e

            #Synapse update - exc-inh layer.
            ge_ei  =  ge_ei  + spike_e

            I_ei = np.multiply(np.dot(ge_ei,W_ei),a_ei)
            ge_ei = (1-dt/tau_ge)*ge_ei

            # Neuron update - inh layer.
            V_i = V_i + dt/tau_M_i*(Vrest_i - V_i + R*I_ei)
            V_i[np.where(r_i>0)[0]] = Vreset_i
            r_i[np.where(r_i>0)[0]] = r_i[np.where(r_i>0)[0]] - 1;

            V_i[np.where(V_i>=Vthresh_i)[0]] =Vspike       #Spike occurs
            r_i[np.where(V_i>=Vthresh_i)[0]] = refrac_i

            spike_i[np.where((V_i>=Vthresh_i))[0]]=1
            spikes_i[i, :] = spike_i
            ge_ie = ge_ie + spike_i

            I_ie = np.multiply(np.dot(ge_ie,W_ie),(-a_ie))
            ge_ie = (1-dt/tau_ge)*ge_ie
            x_pre = x_pre + dt/tau_xpre*(-x_pre)
            x_pre = x_pre + np.multiply((1-x_pre),spikes_inp[i,:])
            if(np.sum(spike_e)>0):
                x_pre1=(x_pre-xtar)
                W_xe=updateweights(spike_e,x_pre1,W_xe,Wmax,eta)
        flag=np.int64(np.sum(spikes_e,axis=0).max())
        flag1=flag
        if flag1>4:
            flag1=5
            break
        img = img*marr[flag1]*(63.75+32)/63.75
        spikes_inp = image2spik(img)
        spikes_inp = spikes_inp.T
    return spikes_e,spikes_i,\
           W_xe,W_ei,W_ie,\
           ge_xe,ge_ei,ge_ie,\
           I_xe,I_ei,I_ie,V_e,\
           r_e,theta_e,V_i,r_i,\
           VeArray,ViArray,IxeArray,IeiArray,\
           IieArray,x_pre

@jit(nopython=True)
def updateparameterstest(num_neurons,img,spikes_inp,spikes_e,spikes_i,dt,R,Vrest_e,Vrest_i,
                         Vreset_e,Vreset_i,Vthresh_e,Vthresh_i,Vspike,refrac_e,refrac_i,
                         tau_M_e,tau_M_i,tau_ge,tau_gi,Wmax,tau_xpre,eta,alph,mu,xtar,
                         tau_theta,theta_inc,a_xe,a_ei,a_ie,W_xe,W_ei,W_ie,ge_xe,ge_ei,
                         ge_ie,I_xe,I_ei,I_ie,V_e,r_e,theta_e,V_i,r_i,VeArray,ViArray,
                         IxeArray,IeiArray,IieArray,x_pre):
    """
    updates network parameters during testing
    """
    WN=0
    marr=np.array([5,3,2,1,1,0])
    flag=0
    while (np.sum(spikes_e,axis=0).max())<5:
        WN+=1
        if WN>30:
            break
        for i in np.arange(1,351):
            spike_e=np.zeros(num_neurons,dtype=np.float64)
            spike_i=np.zeros(num_neurons,dtype=np.float64)
            #Synapse update - Input-exc layer.
            ge_xe = ge_xe + spikes_inp[i,:];
            I_xe = np.multiply(np.dot(ge_xe,W_xe),a_xe)
            ge_xe = (1-dt/tau_ge)*ge_xe
            I_e =I_xe+ I_ie
            #Neuron update - exc layer.
            V_e = V_e + dt/tau_M_e*(Vrest_e - V_e + R*I_e)
            V_e[np.where(r_e>0)[0]] = Vreset_e
            r_e[np.where(r_e>0)[0]] = r_e[np.where(r_e>0)[0]] - 1

            V_e[np.where(V_e>=(Vthresh_e +theta_e))[0]] =Vspike      #Spike occurs
            r_e[np.where(V_e>=(Vthresh_e +theta_e))[0]] = refrac_e

            spike_e[np.where(V_e>=(Vthresh_e +theta_e))[0]]=1
            spikes_e[i, :] = spike_e

            # SFA - exc layer
            # update and decay of theta.
            theta_e = theta_e + theta_inc*spike_e
            theta_e = (1-dt/tau_theta)*theta_e

            #Synapse update - exc-inh layer.
            ge_ei  =  ge_ei  + spike_e

            I_ei = np.multiply(np.dot(ge_ei,W_ei),a_ei)
            ge_ei = (1-dt/tau_ge)*ge_ei

            # Neuron update - inh layer.
            V_i = V_i + dt/tau_M_i*(Vrest_i - V_i + R*I_ei)
            V_i[np.where(r_i>0)[0]] = Vreset_i
            r_i[np.where(r_i>0)[0]] = r_i[np.where(r_i>0)[0]] - 1;

            V_i[np.where(V_i>=Vthresh_i)[0]] =Vspike       #Spike occurs
            r_i[np.where(V_i>=Vthresh_i)[0]] = refrac_i

            spike_i[np.where((V_i>=Vthresh_i))[0]]=1
            spikes_i[i, :] = spike_i

            #Synapse update - inh-exc layer.
            ge_ie = ge_ie + spike_i

            I_ie = np.multiply(np.dot(ge_ie,W_ie),(-a_ie))
            ge_ie = (1-dt/tau_gi)*ge_ie


        flag=np.int64(np.sum(spikes_e,axis=0).max())
        flag1=flag
        if flag1>4:
            flag1=5
            break
        img = img*marr[flag1]*(63.75+32)/63.75
        spikes_inp = image2spik(img)
        spikes_inp = spikes_inp.T
    return spikes_e,spikes_i,\
           W_xe,W_ei,W_ie,\
           ge_xe,ge_ei,ge_ie,\
           I_xe,I_ei,I_ie,V_e,\
           r_e,theta_e,V_i,r_i,\
           VeArray,ViArray,IxeArray,IeiArray,\
           IieArray,x_pre


class SNN:

#     SNN is the main class that creates the spiking neural network.
#        Inputs: number of neurons
#        Return: a SNN class object

#     Methods: SNN class has following methods:

#     generalparameters( ): initializes general parameters. These parameters remain constant during training and testing.
#                           Inputs: input dimension

#     networkparameters( ): initializes network parameters. These parameters change during training and testing.

#     createweights( ): creates synaptic weights.

#     resetparameters( ): resets network parameters after each training/testing sample

#     training( ): trains the network using STDP learning.
#                  Inputs: trainIndex: index of training images (an array of n indexes)
#                          images: MNIST images
#                          labels: MNIST labels
#                          epoch: number of epochs to train the data

#     testing( ): tests the pre-trained network.
#                 Inputs: testIndex: index of training images (an array of n indexes)
#                       testimages: MNIST images
#                       testlabels: MNIST labels
#                 Returns: predictedClass: a matrix with n rows and 2 columns
#                                          where n is number of samples and
#                                          column 1 is predicted labels and
#                                          column 2 is actual labels

#     plotweights( ): plots the receptive field of the neurons and saves it as a figure.
#                     Inputs: name: name of the plot

#     saveweights( ): saves the trained network weights and neuron labels
#                     Inputs: filename: name of the file to save the data. the filename should have .npz extension

#     loadweights( ): loads the trained network weights and neuron labels
#                     Inputs: filename: name of the file where the data is saved after training


    def __init__(self,num_neurons):
       self.neuron=np.zeros((10,num_neurons),dtype=np.float64)
       self.neuronLabels=np.zeros((1,num_neurons),dtype=np.float64)
       self.num_neurons=num_neurons
       self.general_parameters=dict()
       self.network_parameters=dict()


    def generalparameters(self,Ninp):
        self.general_parameters['Ninp'] = Ninp
        self.general_parameters['N_e'] = self.num_neurons
        self.general_parameters['N_i'] = self.num_neurons

        ## LIF parameters
        self.general_parameters['dt'] = np.float64(0.001)
        self.general_parameters['R']=np.float64(10e6)
        self.general_parameters['Vrest_e'] = np.float64(-65e-3)
        self.general_parameters['Vrest_i'] = np.float64(-60e-3)
        self.general_parameters['Vreset_e'] = np.float64(-65e-3)
        self.general_parameters['Vreset_i'] = np.float64(-45e-3)
        self.general_parameters['Vthresh_e'] = np.float64(-52e-3)
        self.general_parameters['Vthresh_i'] = np.float64(-40e-3)
        self.general_parameters['Vspike'] = np.float64(20e-3); # set by me.
        self.general_parameters['refrac_e'] = np.float64(5e-3)
        self.general_parameters['refrac_i'] = np.float64(2e-3)
        self.general_parameters['tau_M_e'] = np.float64(100e-3)
        self.general_parameters['tau_M_i'] = np.float64(10e-3)

        ## synapse parameters
        self.general_parameters['tau_ge'] = np.float64(5e-3)
        self.general_parameters['tau_gi'] = np.float64(20e-3)

        ## STDP parameters
        self.general_parameters['Wmax'] = np.float64(1)
        self.general_parameters['tau_xpre'] = np.float64(5e-3)
        self.general_parameters['eta'] = np.float64(0.01)
        self.general_parameters['alph']=np.float64(0.01)
        self.general_parameters['mu'] = np.float64(0.2)
        self.general_parameters['xtar'] = np.float64(0.4)

        ## sfa parameters
        self.general_parameters['tau_theta'] = np.float64(1e7*1e-3)
        self.general_parameters['theta_inc'] = np.float64(0.05e-3)

        # current amplitudes ## set by me.
        self.general_parameters['a_xe'] = np.float64(0.05e-9) # Adjust according to input.
        self.general_parameters['a_ei'] = np.float64(25e-9) # This current is somewhat big.
        self.general_parameters['a_ie'] = np.float64(100e-9) # This current is set to much bigger than other currents.


    def createweights(self):
        np.random.seed(0)
        self.W_xe = np.random.random((self.general_parameters['Ninp'], self.general_parameters['N_e'])).astype(np.float64)*0.2 + 0.4; #randomly initialized to small weights around 0.5.
        # Exc-Inh weights - This is in the case that N_e = N_i;
        self.W_ei = np.eye(self.general_parameters['N_e'], dtype=np.float64)  #id matrix
        self.W_ie = 1-np.eye(self.general_parameters['N_e'],dtype=np.float64); # id complementary matrix

    def saveweights(self,filename='/content/drive/My Drive/MNIST/saved_values.npz'):
        np.savez(filename,W_xe=self.W_xe,neuronLabels=self.neuronLabels)

    def loadweights(self,filename='/content/drive/My Drive/MNIST/saved_values.npz'):
        values = np.load(filename)
        self.W_xe=values['W_xe']
        self.neuronLabels=values['neuronLabels']



    def networkparameters(self):
        # These are the parameters unique to the network.
        self.network_parameters['N_inp'] = self.general_parameters['Ninp']
        self.network_parameters['N_e'] = self.num_neurons
        self.network_parameters['N_i'] = self.num_neurons


        ## Vectors of neuron parameters.
        self.network_parameters['ge_xe'] = np.zeros((self.network_parameters['N_inp']),dtype=np.float64)
        self.network_parameters['ge_ei'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['ge_ie'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)

        self.network_parameters['spikes_e'] = np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['spikes_i'] = np.zeros((351, self.network_parameters['N_i']),dtype=np.float64)

        self.network_parameters['I_xe'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['I_ei'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)
        self.network_parameters['I_ie'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['V_e'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)+self.general_parameters['Vrest_e']
        self.network_parameters['r_e'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['theta_e'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['V_i'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64) +self.general_parameters['Vrest_i']
        self.network_parameters['r_i'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)

        ## vectors of output record
        self.network_parameters['VeArray'] = np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)+self.general_parameters['Vrest_e']
        self.network_parameters['ViArray'] = np.zeros((351, self.network_parameters['N_i']),dtype=np.float64)+self.general_parameters['Vrest_i']
        self.network_parameters['IxeArray'] = np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['IeiArray'] =np.zeros((351, self.network_parameters['N_i']),dtype=np.float64)
        self.network_parameters['IieArray'] = np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)

        ## STDP parameters
        self.network_parameters['x_pre'] = np.zeros((self.network_parameters['N_inp']),dtype=np.float64)


    def resetparameters(self):
        ## Vectors of neuron parameters.
        self.network_parameters['ge_xe'] = np.zeros((self.general_parameters['Ninp']),dtype=np.float64)
        self.network_parameters['ge_ei'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['ge_ie'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)

        self.network_parameters['I_xe'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['I_ei'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)
        self.network_parameters['I_ie'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['V_e'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)+self.general_parameters['Vrest_e']
        self.network_parameters['r_e'] = np.zeros((self.network_parameters['N_e']),dtype=np.float64)
        self.network_parameters['V_i'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64) +self.general_parameters['Vrest_i']
        self.network_parameters['r_i'] = np.zeros((self.network_parameters['N_i']),dtype=np.float64)



    def training(self, trainIndex, images, labels, current_outer_epoch): 
        self.neuron = np.zeros((10, self.num_neurons), dtype=np.int64)
        # pred1 = np.zeros((trainIndex.shape[0], 10), dtype=np.int64)
        # predictedClass1 = np.zeros((trainIndex.shape[0], 2), dtype=np.int64)

        print(f'Starting training pass for Epoch {current_outer_epoch}...')

        trainCount = 0
        self.neuronLabels = np.zeros(self.num_neurons, dtype=np.int64)
        trainIndexes = trainIndex
        start = time.time()

        for c in trainIndexes:
            trainCount += 1
            if np.remainder(trainCount, 10) == 0:
                end = time.time()
                print('Epoch:', current_outer_epoch, 'done:', trainCount, '/', len(trainIndex), 'time elapsed: %0.2f seconds' % (end - start))

            self.resetparameters()
            spikes_e = np.zeros((351, self.network_parameters['N_e']), dtype=np.float64)
            spikes_i = np.zeros((351, self.network_parameters['N_e']), dtype=np.float64)
            img = images[c, :]
            spikes_inp = image2spik(img)
            spikes_inp = spikes_inp.T

            spikes_e, spikes_i, \
            self.W_xe, self.W_ei, self.W_ie, \
            self.network_parameters['ge_xe'], self.network_parameters['ge_ei'], self.network_parameters['ge_ie'], \
            self.network_parameters['I_xe'], self.network_parameters['I_ei'], self.network_parameters['I_ie'], \
            self.network_parameters['V_e'], self.network_parameters['r_e'], self.network_parameters['theta_e'], \
            self.network_parameters['V_i'], \
            self.network_parameters['r_i'], self.network_parameters['VeArray'], self.network_parameters['ViArray'], \
            self.network_parameters['IxeArray'], \
            self.network_parameters['IeiArray'], self.network_parameters['IieArray'], self.network_parameters['x_pre'] \
                = updateparameters(self.num_neurons, img, spikes_inp, spikes_e, spikes_i, \
                                   self.general_parameters['dt'], self.general_parameters['R'], \
                                   self.general_parameters['Vrest_e'], self.general_parameters['Vrest_i'], \
                                   self.general_parameters['Vreset_e'], \
                                   self.general_parameters['Vreset_i'], self.general_parameters['Vthresh_e'], \
                                   self.general_parameters['Vthresh_i'], self.general_parameters['Vspike'], \
                                   self.general_parameters['refrac_e'], self.general_parameters['refrac_i'], \
                                   self.general_parameters['tau_M_e'], self.general_parameters['tau_M_i'], \
                                   self.general_parameters['tau_ge'], self.general_parameters['tau_gi'], \
                                   self.general_parameters['Wmax'], self.general_parameters['tau_xpre'], \
                                   self.general_parameters['eta'], self.general_parameters['alph'], \
                                   self.general_parameters['mu'], self.general_parameters['xtar'], \
                                   self.general_parameters['tau_theta'], self.general_parameters['theta_inc'], \
                                   self.general_parameters['a_xe'], self.general_parameters['a_ei'], \
                                   self.general_parameters['a_ie'], self.W_xe, self.W_ei, self.W_ie, \
                                   self.network_parameters['ge_xe'], self.network_parameters['ge_ei'], \
                                   self.network_parameters['ge_ie'], \
                                   self.network_parameters['I_xe'], self.network_parameters['I_ei'], \
                                   self.network_parameters['I_ie'], \
                                   self.network_parameters['V_e'], self.network_parameters['r_e'], \
                                   self.network_parameters['theta_e'], self.network_parameters['V_i'], \
                                   self.network_parameters['r_i'], self.network_parameters['VeArray'], \
                                   self.network_parameters['ViArray'], self.network_parameters['IxeArray'], \
                                   self.network_parameters['IeiArray'], self.network_parameters['IieArray'], \
                                   self.network_parameters['x_pre'])

            spikeRate = (spikes_e.sum(axis=0)) / 0.35
            for numNeuron in list(np.where(spikes_e.sum(axis=0))[0]):
                self.neuron[labels[c], numNeuron] = self.neuron[labels[c], numNeuron] + spikeRate[numNeuron]

        for nl in range(self.num_neurons):
            idx = self.neuron[:, nl].argmax()
            self.neuronLabels[nl] = idx


    def testing(self,testIndexes,testimages,testlabels):
        self.neuron = np.zeros((10, self.num_neurons),dtype=np.int64)
        pred = np.zeros((testIndexes.shape[0], 10),dtype=np.int64);
        predictedClass = np.zeros((testIndexes.shape[0], 2),dtype=np.int64);
        testCount=0
        print('Testing.....')
        start=time.time()
        for c in testIndexes:
            testCount+=1
            if np.remainder(testCount,10)==0:
              end = time.time()
              print('done:',testCount,'/',len(testIndexes),'time elapsed: %0.2f seconds'%(end - start),)
            self.resetparameters()
            spikes_e=np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)
            spikes_i=np.zeros((351, self.network_parameters['N_e']),dtype=np.float64)
            img=testimages[c,:]
            spikes_inp = image2spik(img)
            spikes_inp = spikes_inp.T
            spikes_e,spikes_i,\
            self.W_xe,self.W_ei,self.W_ie,\
            self.network_parameters['ge_xe'],self.network_parameters['ge_ei'],self.network_parameters['ge_ie'],\
            self.network_parameters['I_xe'],self.network_parameters['I_ei'],self.network_parameters['I_ie'],\
            self.network_parameters['V_e'],self.network_parameters['r_e'],self.network_parameters['theta_e'],self.network_parameters['V_i'],\
            self.network_parameters['r_i'],self.network_parameters['VeArray'],self.network_parameters['ViArray'],self.network_parameters['IxeArray'],\
            self.network_parameters['IeiArray'],self.network_parameters['IieArray'],self.network_parameters['x_pre']\
            =updateparameterstest(self.num_neurons,img,spikes_inp,spikes_e,spikes_i,\
            self.general_parameters['dt'],self.general_parameters['R'],self.general_parameters['Vrest_e'],self.general_parameters['Vrest_i'],self.general_parameters['Vreset_e'],\
            self.general_parameters['Vreset_i'],self.general_parameters['Vthresh_e'],self.general_parameters['Vthresh_i'],self.general_parameters['Vspike'],\
            self.general_parameters['refrac_e'],self.general_parameters['refrac_i'],self.general_parameters['tau_M_e'],self.general_parameters['tau_M_i'],\
            self.general_parameters['tau_ge'],self.general_parameters['tau_gi'],self.general_parameters['Wmax'],self.general_parameters['tau_xpre'],\
            self.general_parameters['eta'], self.general_parameters['alph'],self.general_parameters['mu'],self.general_parameters['xtar'],\
            self.general_parameters['tau_theta'],self.general_parameters['theta_inc'],self.general_parameters['a_xe'],self.general_parameters['a_ei'],\
            self.general_parameters['a_ie'],self.W_xe,self.W_ei,self.W_ie,\
            self.network_parameters['ge_xe'],self.network_parameters['ge_ei'],self.network_parameters['ge_ie'],\
            self.network_parameters['I_xe'],self.network_parameters['I_ei'],self.network_parameters['I_ie'],\
            self.network_parameters['V_e'],self.network_parameters['r_e'],self.network_parameters['theta_e'],self.network_parameters['V_i'],\
            self.network_parameters['r_i'],self.network_parameters['VeArray'],self.network_parameters['ViArray'],self.network_parameters['IxeArray'],\
            self.network_parameters['IeiArray'],self.network_parameters['IieArray'],self.network_parameters['x_pre'])

            spikeRate = (spikes_e.sum(axis=0))/0.35
            for i in range(10):
                for j in list(np.where(self.neuronLabels==i)[0]):

                    pred[testCount-1,i] = pred[testCount-1,i] + spikeRate[j]

            maxIdx = pred[testCount-1,:].argmax()
            predictedClass[testCount-1,:] = [maxIdx,testlabels[c][0]]

        return predictedClass

    def plotweights(self, name='receptivefields.png'):
        plt.ioff() 

        num_neurons = self.num_neurons
        if num_neurons == 0:
            print("Warning: No neurons to plot receptive fields for.")
            return

        n_cols = int(np.ceil(np.sqrt(num_neurons)))
        n_rows = int(np.ceil(num_neurons / n_cols))

        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 1.5, n_rows * 1.5))
        axes = axes.flatten()

        for k in range(num_neurons):
            axes[k].imshow(self.W_xe[:, k].reshape(28, 28), cmap=plt.get_cmap('Greys'))
            axes[k].axis('off')

        for k in range(num_neurons, len(axes)):
            fig.delaxes(axes[k])

        plt.tight_layout()
        plt.savefig(name)
        plt.close(fig)

def main():
    output_file_path = "/content/drive/My Drive/MNIST/Output.txt"
    open(output_file_path, "w").close()
    print(f"Cleared previous content in {output_file_path}")

    images,labels=read_mnist(path='/content/drive/My Drive/MNIST',mode='training')
    testimages,testlabels=read_mnist(path='/content/drive/My Drive/MNIST',mode='testing')
    images=images.T
    testimages=testimages.T
    my_network=SNN(200) 
    my_network.generalparameters(784)
    my_network.createweights()
    my_network.networkparameters()

    train_samples=500
    test_samples=1000
    epochs=10

    np.random.seed(34)
    trainIndex=np.random.choice(len(labels),(train_samples,))
    testIndex=np.random.choice(len(testlabels),(test_samples,))
    prev = 0

    for e in range(epochs): # 'e' is your current epoch number
        print(f'--- Starting Epoch {e} ---')
        my_network.training(trainIndex, images, labels, e)

        predictedClass=my_network.testing(trainIndex,images,labels)
        acc1=(np.int64(predictedClass[:,0]==predictedClass[:,1]).mean())*100
        print('Training accuracy=',acc1)

        my_network.saveweights(filename='/content/drive/My Drive/MNIST/saved_values.npz')
        my_network.loadweights(filename='/content/drive/My Drive/MNIST/saved_values.npz')

        predictedClass=my_network.testing(testIndex,testimages,testlabels)
        acc=(np.int64(predictedClass[:,0]==predictedClass[:,1]).mean())*100
        print('Test Accuracy=',acc)

        with open(output_file_path, "a") as text_file:
            text_file.write(f"Epoch {e}: Accuracy = {acc}\n")

        name='/content/drive/My Drive/MNIST/receptivefieldsN150TS2000E10TA.png'
        my_network.plotweights(name=name)

    print("\nTraining and testing complete. All epoch accuracies logged in Output.txt.")


if __name__ == '__main__':
    main()