In [1]:
import torch
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder
from requiredBehaviours import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def count_spikes(net):
    inh_spikes = torch.unique(net['spikes.i',0], return_counts=True)[1]
    exc_spikes = torch.unique(net['spikes.i',1], return_counts=True)[1]
    return inh_spikes, exc_spikes

def plot_balanced_network(net, iter):
    inh_spike_count, exc_spike_count = count_spikes(net) 
    inh_fig = px.scatter(x=net['spikes.t',0], y=net['spikes.i',0], title="Raster Plot of Inhibitory Population")
    exc_fig = px.scatter(x=net['spikes.t',1], y=net['spikes.i',1], title="Raster Plot of Excitatory Population")
    inh_spike_fig = px.bar(y=inh_spike_count/iter, title="Bar Plot of Inhibitory Neurons Fire Rate")
    exc_spike_fig = px.bar(y=exc_spike_count/iter, title=f"Bar Plot of Excitatory Neurons Fire Rate")
    inh_fig.update_layout(xaxis_title="iteration", yaxis_title="index")
    exc_fig.update_layout(xaxis_title="iteration", yaxis_title="index")
    inh_spike_fig.update_layout(xaxis_title="index", yaxis_title="fire rate")
    exc_spike_fig.update_layout(xaxis_title="index", yaxis_title="fire rate")
    inh_fig.update_traces(marker=dict(size=3, opacity=0.5))
    exc_fig.update_traces(marker=dict(size=3, opacity=0.5))
    inh_fig.show()
    exc_fig.show()
    inh_spike_fig.show()
    exc_spike_fig.show()


In [3]:
MEAN, STD, THRESHOLD = 60.0, 20.0, 0.5
net_stdp = Network(settings={"device": "cpu", "dtype":torch.float32})
net_istdp = Network(settings={"device": "cpu", "dtype":torch.float32})
inputs = InputGenerator(MEAN, STD, THRESHOLD)
NET_PARAMS = {'D':0.3}
TRACE_PARAMS = {'tau': 10.}
STDP_PARAMS = {'a_plus': 0.01, 'a_minus': 0.01} 
iSTDP_PARAMS = {'lr': 0.01, 'freq': 5.}
CLIP_PARAMS = {'w_min': 0.0, 'w_max':float('inf')}
INH_RANDOM_INPUT = {"I": inputs.get_random_input(INHIBITORY_NEURON_SIZE, ITER, 78)}
EXC_RANDOM_INPUT = {"I": inputs.get_random_input(EXCITATORY_NEURON_SIZE, ITER, 42)}

In [4]:
N_i = NeuronGroup(net=net_stdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_RANDOM_INPUT),
        3: Synapse(**NET_PARAMS),
        4: Recorder(['I']),
        5: EventRecorder(['spikes'])
    })

N_e = NeuronGroup(net=net_stdp, tag='Excitatory_Population', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**NET_PARAMS),
        4: EventRecorder(['spikes'])
    })

In [5]:
balanced_connections = {
    "same":{
        "exc":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":anti_STDP,
            "learning_params":STDP_PARAMS
        }]
    }
}
simulate = Simulator(net_stdp, [N_e], [N_i], connections=balanced_connections,
                      trace_params=TRACE_PARAMS, clip_params=CLIP_PARAMS)

net_stdp = simulate.simulate(ITER)

Network['Network'](Neurons: tensor(125)|2 groups, Synapses: tensor(15625)|4 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000,  0.0000,  ..., 47.0101,  0.0000,  0.0000],
        [ 0.0000,  0.0000, 61.5113,  ..., 75.7066,  0.0000, 56.3861],
        [88.7338, 37.2081,  0.0000,  ...,  0.0000, 67.9018, 75.9950],
        ...,
        [ 0.0000, 37.5195,  0.0000,  ..., 42.5209,  0.0000, 17.8967],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, 65.5090,  ...,  0.0000,  0.0000,  0.0000]]),)3:Synapse(D=0.3,)4:Recorder(arg_0=['I'],)5:EventRecorder(arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[78.2707, 63.1605,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 62.3728,  0.0000,  ...,  0.0000, 49.369

In [6]:
plot_balanced_network(net_stdp, ITER)

In [7]:
N_i_istdp = NeuronGroup(net=net_istdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_RANDOM_INPUT),
        3: Synapse(**NET_PARAMS),
        4: Recorder(['I']),
        5: EventRecorder(['spikes'])
    })

N_e_istdp = NeuronGroup(net=net_istdp, tag='Excitatory_Population', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**NET_PARAMS),
        4: EventRecorder(['spikes'])
    })

In [8]:
balanced_connections = {
    "same":{
        "exc":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":iSTDP,
            "learning_params":iSTDP_PARAMS
        }]
    }
}
simulate = Simulator(net_istdp, [N_e_istdp], [N_i_istdp], connections=balanced_connections,
                      trace_params=TRACE_PARAMS, clip_params=CLIP_PARAMS)

net_istdp = simulate.simulate(ITER)

Network['Network'](Neurons: tensor(125)|2 groups, Synapses: tensor(15625)|4 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000,  0.0000,  ..., 47.0101,  0.0000,  0.0000],
        [ 0.0000,  0.0000, 61.5113,  ..., 75.7066,  0.0000, 56.3861],
        [88.7338, 37.2081,  0.0000,  ...,  0.0000, 67.9018, 75.9950],
        ...,
        [ 0.0000, 37.5195,  0.0000,  ..., 42.5209,  0.0000, 17.8967],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, 65.5090,  ...,  0.0000,  0.0000,  0.0000]]),)3:Synapse(D=0.3,)4:Recorder(arg_0=['I'],)5:EventRecorder(arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[78.2707, 63.1605,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000, 62.3728,  0.0000,  ...,  0.0000, 49.369

In [9]:
plot_balanced_network(net_istdp, ITER)

In [10]:
fire_rate_time = net_istdp['spikes.t',1].unique(return_counts=True)
fig = px.line(x=fire_rate_time[0], y=fire_rate_time[1], title="Number of Neurons that fired a spike in each iteration")
fig.update_layout(xaxis_title="iteration", yaxis_title="number of neurons that fired")
fig.show()