## Import Necessary Libraries

In [None]:
import torch
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder
from requiredBehaviours import *

## Plotting Functions

In [2]:
def add_rectangles(fig, signal_orders, iter_duration, rest_duration, total_iter):
    for i in range(total_iter):
        color = '#9467bd' if signal_orders[i] else 'orange'
        x_start = i * iter_duration
        x_end = x_start + iter_duration - 2 * rest_duration
        fig.add_vrect(
            x0=f"{x_start}", x1=f"{x_end}",
            y0="0", y1="14",
            fillcolor=color, opacity=0.25,
            layer="below", line_width=0,
        )
    


def plot_network(net, signal_orders, iter_duration, rest_duration, total_iter):
    inh_df = pd.DataFrame({'t':net['spikes.t',0], 'i':net['spikes.i',0]})
    exc_df = pd.DataFrame({'t':net['spikes.t',1], 'i':net['spikes.i',1]})
    exc_dst_1_df = pd.DataFrame({'t':net['spikes.t',2], 'i':net['spikes.i',2]})
    #exc_dst_2_df = pd.DataFrame({'t':net['spikes.t',3], 'i':net['spikes.i',3]})
    inh_fig = px.scatter(inh_df, x='t', y='i', title="Raster Plot of Inhibitory Population")
    exc_fig = px.scatter(exc_df, x='t', y='i', title="Raster Plot of Source Excitatory Population")
    exc_fig_1 = px.scatter(exc_dst_1_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    #exc_fig_2 = px.scatter(exc_dst_2_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    inh_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig_1.update_traces(marker=dict(size=2, opacity=0.5))
    #exc_fig_2.update_traces(marker=dict(size=2, opacity=0.5))
    add_rectangles(exc_fig_1, signal_orders, iter_duration, rest_duration, total_iter)
    #add_rectangles(exc_fig_2, signal_orders, iter_duration, rest_duration, total_iter)
    inh_fig.show()
    exc_fig.show()
    exc_fig_1.show()
    #exc_fig_2.show()

## Define Network Parameters

In [7]:
MEAN, STD, THRESHOLD = 25.0, 6.0, 0.5 
SIGNAL_DURATION = 100
REST = 100
SIGNAL_REPEAT = 5
ITER_NO = 11
TARGET_NEURON_SIZE = 20
TOTAL_DURATION = ((SIGNAL_DURATION + REST) * SIGNAL_REPEAT + REST) * ITER_NO
inputs = InputGenerator(MEAN, STD, THRESHOLD)
INHIBITORY_D = {'D':1.0}
EXCITATORY_D = {'D':0.70}
DST_D = {'D':0.78} 
TRACE_PARAMS = {'tau': 10.}
STDP_PARAMS = {'a_plus': 0.0380, 'a_minus': 0.0330, 'w_max':1.10, 'w_min':0.0, 'enable_soft_bound':True} 
iSTDP_PARAMS = {'lr': 0.000024, 'freq': 5.0}
CLIP_PARAMS = {'w_min': 0.0, 'w_max':1.10}
KWTA_PARAMS = {'k': 5}
INH_ZERO_INPUT = {"I": inputs.get_zero_input(INHIBITORY_NEURON_SIZE, TOTAL_DURATION)}
EXC_RANDOM_INPUT = {"I": inputs.get_random_signals(ITER_NO, SIGNAL_DURATION, SIGNAL_REPEAT, REST, EXCITATORY_NEURON_SIZE, 78)}
TARGET_ZERO_INPUT = {"I": inputs.get_zero_input(TARGET_NEURON_SIZE, TOTAL_DURATION)}

## Network Structure

<div style="text-align: center;"><img src="schematic.png" alt="schematic" width="600"/></div>

## Unsupervised Learning with STDP

In [14]:
net_stdp = Network(settings={"device": "cpu", "dtype":torch.float32})
N_i = NeuronGroup(net=net_stdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_ZERO_INPUT),
        3: Synapse(**INHIBITORY_D),
        5: EventRecorder(['spikes'])
    })

N_e_src = NeuronGroup(net=net_stdp, tag='Excitatory_Population_Source', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**EXCITATORY_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst = NeuronGroup(net=net_stdp, tag='Excitatory_Population_Dest', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: KWTA(**KWTA_PARAMS),
        5: EventRecorder(['spikes'])
    })


In [15]:
connections = {
    "same":{
        "exc":[{
            "src":0,
            "dst":1,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":anti_STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}
simulate = Simulator(net_stdp, [N_e_src, N_e_dst], [N_i], connections=connections,
                      trace_params=TRACE_PARAMS)

net_stdp = simulate.simulate(TOTAL_DURATION)

Network['Network'](Neurons: tensor(145)|3 groups, Synapses: tensor(7625)|4 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(D=1.0,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population_Source', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [26.4481, 19.5639,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 28.0499, 27.4424,  ...,  0.0000, 26.1878, 20.1689],
        ...,
        [ 0.0000,  0.0000,  0.0

In [16]:
plot_network(net_stdp, inputs.signal_orders, TOTAL_DURATION//ITER_NO, REST, ITER_NO)

## Unsupervised Learning with iSTDP

In [8]:
net_istdp = Network(settings={"device": "cpu", "dtype":torch.float32})
N_i_istdp = NeuronGroup(net=net_istdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_ZERO_INPUT),
        3: Synapse(**INHIBITORY_D),
        5: EventRecorder(['spikes'])
    })

N_e_src_istdp = NeuronGroup(net=net_istdp, tag='Excitatory_Population_Source', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**EXCITATORY_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_istdp = NeuronGroup(net=net_istdp, tag='Excitatory_Population_Dest', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: KWTA(**KWTA_PARAMS),
        5: EventRecorder(['spikes'])
    })


In [9]:
connections_istdp = {
    "same":{
        "exc":[{
            "src":0,
            "dst":1,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":iSTDP,
            "learning_params":iSTDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}
simulate_istdp = Simulator(net_istdp, [N_e_src_istdp, N_e_dst_istdp], [N_i_istdp], connections=connections_istdp,
                      trace_params=TRACE_PARAMS)

net_istdp = simulate_istdp.simulate(TOTAL_DURATION)

Network['Network'](Neurons: tensor(145)|3 groups, Synapses: tensor(7625)|4 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(D=1.0,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population_Source', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [26.4481, 19.5639,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 28.0499, 27.4424,  ...,  0.0000, 26.1878, 20.1689],
        ...,
        [ 0.0000,  0.0000,  0.0

In [10]:
plot_network(net_istdp, inputs.signal_orders, TOTAL_DURATION//ITER_NO, REST, ITER_NO)