In [297]:
from brian2 import *

In [298]:
####################################################################################################
# Simulation parameters
dt = 1e-4                    # Timestep
t = np.arange(0,200000)*dt    # Simulation time array t
num_receptors = 101          # Number of receptor neurons
speed = 80                   # Speed of bright spot, number of receptor neurons passed per second
####################################################################################################

plot_t = []                  # Time of spikes
plot_n = []                  # Receptor neuron id's
currtime = 0                 # Time when a spike is fired
neuronid = 0                 # Neuron that fires a spike
stepdir = 1                  # Is the bright spot moving towards the right (1) or left (-1)
timestep = 1.0/(speed-1)     # Time between spikes generated by nearby receptor neurons

# Make placeholders for spike time arrays to be used for SNN simulation
spikes = []
for k in range(num_receptors):
    spikes.append([])

# Generate and plot spikes from receptor neurons
while currtime < t[-1]:
    plot_t.append(currtime)
    plot_n.append(neuronid)
    spikes[neuronid].append(currtime)
    if neuronid == 0 and stepdir == -1:
        stepdir = 1
    elif neuronid == (num_receptors-1) and stepdir == 1:
        stepdir = -1
    neuronid += stepdir
    currtime += timestep

In [299]:
def initNetwork():
    start_scope() 

    spikeArr = plot_n
    timeArr = plot_t*second

    inputSize = num_receptors

    generator = SpikeGeneratorGroup(inputSize, spikeArr, timeArr)

    # generatorMonitor = SpikeMonitor(generator)

    tau = 20*ms
    tau_syn = 2*ms
    I_weight = 3*nA
    threshold='v > -55*mV'
    v_reset='v = -80*mV'
    v_rest = -70*mV
    R = 100*Mohm
    equ = '''
    dv/dt = -(v - v_rest)/tau + R*I_syn/tau : volt
    dI_syn/dt = -I_syn/tau_syn : ampere
    '''

    relay = NeuronGroup(inputSize-1, equ, threshold=threshold, reset=v_reset, method=exact)
    relay.v = -70*mV
    inpToRel = Synapses(generator, relay, on_pre='I_syn += I_weight')
    inpToRel.connect(condition='j+1 == i')

    relay1 = NeuronGroup(inputSize-1, equ, threshold=threshold, reset=v_reset, method=exact)
    relay1.v = -70*mV
    inpToRel1 = Synapses(generator, relay1, on_pre='I_syn += I_weight')
    inpToRel1.connect(condition='j == i')

    inhib = NeuronGroup(inputSize, equ, threshold=threshold, reset=v_reset, method=exact)
    inhib.v = -70*mV
    inpToIn = Synapses(generator, inhib, on_pre='I_syn += I_weight')
    inpToIn.connect(condition='j == i')

    inToRelay = Synapses(inhib, relay, on_pre='I_syn -= I_weight')
    inToRelay.connect(condition='j == i')

    inToRelay1 = Synapses(inhib, relay1, on_pre='I_syn -= I_weight')
    inToRelay1.connect(condition='j+1 == i')

    output = NeuronGroup(2, equ, threshold=threshold, reset=v_reset, method=exact)
    output.v = -70*mV
    relayToOut = Synapses(relay, output, on_pre='I_syn += I_weight/3')
    relayToOut.connect(condition='j == 0')

    relay1ToOut = Synapses(relay1, output, on_pre='I_syn += I_weight/3')
    relay1ToOut.connect(condition='j == 1')

    # outputMonitor = SpikeMonitor(output)
    network = Network(generator, relay, inpToRel, relay1, inpToRel1, inhib, inpToIn, inToRelay, inToRelay1, output, relayToOut, relay1ToOut)
    return (network, output, generator)
    
    

In [300]:
(network, output, generator) = initNetwork()
network.store("init")

In [301]:
def runNet(network, output, generator, spikeArr, timeArr, runTime = 1000*ms):
    network.restore("init")
    generator.set_spikes(spikeArr, timeArr)
    outputMonitor = SpikeMonitor(output)
    network.add(outputMonitor)
    network.run(runTime)

    rightSpikes = 0
    leftSpikes = 0
    for s in outputMonitor.i:
        if s == 1:
            rightSpikes +=1
        else:
            leftSpikes +=1
    network.remove(outputMonitor)
    print("Right direction:", rightSpikes, "Left directions:", leftSpikes)

In [307]:
# spikeArr = [4,3,2,1,0] #Right to left
# timeArr = [10,20,30,40,50]*ms
# spikeArr = [0,1,2,3,4] #Left to right
# timeArr = [100,120,130,140,150]*ms

spikeArr = plot_n
timeArr = plot_t*second

runNet(network, output, generator, spikeArr, timeArr)

Right direction: 19 Left directions: 0
