In [1]:
import torch
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder

  from .autonotebook import tqdm as notebook_tqdm


## Plotting Functions

In [2]:
def count_spikes(net):
    inh_spikes = torch.unique(net['spikes.i',0], return_counts=True)[1]
    exc_spikes = torch.unique(net['spikes.i',1], return_counts=True)[1]
    return inh_spikes, exc_spikes

def get_population_activity(net):
    fire_rate_time = net['spikes.t',1].unique(return_counts=True)
    return fire_rate_time

def plot_balanced_network(net, iter):
    inh_spike_count, exc_spike_count = count_spikes(net) 
    pop_activity_x, pop_activity_y = get_population_activity(net)
    inh_fig = px.scatter(x=net['spikes.t',0], y=net['spikes.i',0], title="Raster Plot of Inhibitory Population")
    exc_fig = px.scatter(x=net['spikes.t',1], y=net['spikes.i',1], title="Raster Plot of Excitatory Population")
    inh_spike_fig = px.bar(y=inh_spike_count/iter, title="Bar Plot of Inhibitory Neurons Fire Rate")
    exc_spike_fig = px.bar(y=exc_spike_count/iter, title=f"Bar Plot of Excitatory Neurons Fire Rate")
    inh_fig.update_layout(xaxis_title="iteration", yaxis_title="index")
    exc_fig.update_layout(xaxis_title="iteration", yaxis_title="index")
    inh_spike_fig.update_layout(xaxis_title="index", yaxis_title="fire rate")
    exc_spike_fig.update_layout(xaxis_title="index", yaxis_title="fire rate")
    inh_fig.update_traces(marker=dict(size=3, opacity=0.5))
    exc_fig.update_traces(marker=dict(size=3, opacity=0.5))
    pop_activity_fig = px.line(x=pop_activity_x, y=pop_activity_y, title="Number of Neurons that fired a spike in each iteration")
    pop_activity_fig.update_layout(xaxis_title="iteration", yaxis_title="number of neurons that fired")
    inh_fig.show()
    exc_fig.show()
    inh_spike_fig.show()
    exc_spike_fig.show()
    pop_activity_fig.show()

## Connection Schemes

In [3]:
FULL = "Full"
FCP = "Fixed Coupling Probability"
FNPP = "Fixed Number of Presynaptic Partners"

## LIF Model

In [4]:
class LIF(Behavior):
    """
    The neural dynamics of LIF is defined by:
    tau*dv/dt = v_rest - v + R*I,
    if v >= threshold then v = v_reset.
    Args:
        tau (float): time constant of voltage decay.
        v_rest (float): voltage at rest.
        v_reset (float): value of voltage reset.
        threshold (float): the voltage threshold.
        R (float): the resistance of the membrane potential.
    """


    def initialize(self, neurons):
        """
        Set neuron parameters.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self.add_tag("LIF")
        self.set_parameters_as_variables(neurons)
        neurons.v = neurons.vector(mode="ones") * neurons.v_rest
        neurons.spikes = neurons.vector(mode="zeros")

    def _dv_dt(self, neurons):
        """
        Single step voltage dynamics of simple LIF neurons.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        delta_v = neurons.v_rest - neurons.v
        input_current = neurons.R * neurons.I
        return delta_v + input_current

    def _fire(self, neurons):
        """
        Single step of LIF dynamics.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        neurons.spikes = neurons.v >= neurons.threshold
        neurons.v[neurons.spikes] = neurons.v_reset

    def forward(self, neurons):
        """
        Firing behavior of LIF neurons.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self._fire(neurons)
        neurons.v += self._dv_dt(neurons) / neurons.tau



## Synapse

In [5]:
class Synapse(Behavior):

    def _full_scheme(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode=self.mode)
            s.W = self.coef * s.W + self.scale

    def _fixed_couple_probability(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode=self.mode, density=self.density)
            s.W = self.coef * s.W + self.scale

    def _fixed_pre_partners(self, neurons):
        self._full_scheme(neurons)
        for s in neurons.afferent_synapses['All']:
            num_post_neurons, num_pre_neurons = s.W.shape[0], s.W.shape[1] 
            neurons_to_select = int(num_pre_neurons * (1 - self.density))
            for neuron_id in range(num_post_neurons):
                mask = torch.randint(0, num_pre_neurons-1, size=(neurons_to_select,))
                s.W[neuron_id, mask] = 0.0


    
    def _set_synapse_weights(self, neurons):
        """
        Set synapse weights.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        if self.scheme == FULL:
            self._full_scheme(neurons)
        elif self.scheme == FCP:
            self._fixed_couple_probability(neurons)
        elif self.scheme == FNPP:
            self._fixed_pre_partners(neurons)
        else:
            raise Exception("Invalide scheme!")
                      

    def initialize(self, neurons):
        """
        Set synapse parameters.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self.add_tag('Synapse')
        self.scheme = self.parameter('scheme', FULL)
        self.mode = self.parameter('mode', "uniform")
        self.coef = self.parameter('coef', 1.)
        self.scale = self.parameter('scale', 0.0) 
        self.density = self.parameter('density', None)
        self._set_synapse_weights(neurons)
        
    
    
    
    def _get_presynaptic_inputs(self, synapse):
        """
        Calculate presynaptic inputs of population.
        Args:
            synapse (SynapseGroup): the connections between src and dst neurons.
        """
        spikes = synapse.src.spikes.float()
        return torch.matmul(synapse.W, spikes)
    

    def forward(self, neurons):
        """
        Implementation of both Excitatory and Inhibitory synaptic connections.
        Args:
            neurons (NeuronGroup): the post synaptic neural population.
        """
        for s in neurons.afferent_synapses['GLUTAMATE']: 
            neurons.I += self._get_presynaptic_inputs(s)
            

        for s in neurons.afferent_synapses['GABA']:
            neurons.I -= self._get_presynaptic_inputs(s)


## Input

In [6]:
class Input(Behavior):
    @staticmethod
    def _reset_inputs(neurons):
        neurons.I = neurons.vector(mode="zeros")

    def initialize(self, neurons):
        self.add_tag('Input')
        self.input = self.parameter('I', None)
        self._reset_inputs(neurons)


    def forward(self, neurons):
        self._reset_inputs(neurons)
        neurons.I += self.input[neurons.iteration-1]

## Set Parameters

In [60]:
torch.manual_seed(42)
net = Network(settings={"device": "cpu", "dtype":torch.float32})
TOTAL_NEURON_SIZE = 100
INHIBITORY_NEURON_SIZE = int(0.2 * TOTAL_NEURON_SIZE)
EXCITATORY_NEURON_SIZE = int(0.8 * TOTAL_NEURON_SIZE)
ITER = 1000
MEAN, STD = 9., 3.
EXCITATORY_INPUT = {"I": torch.normal(MEAN, STD, size=(ITER, EXCITATORY_NEURON_SIZE))}
INHIBITORY_INPUT = {"I": torch.zeros(ITER, INHIBITORY_NEURON_SIZE)}
EXC_CONFIG = {
    "v_reset" : -65.0,
    "v_rest": -65.0,
    "tau" : 10.,
    "R" : 2.,
    "threshold" : -55.,
}
INH_CONFIG = {
    "v_reset" : -65.0,
    "v_rest": -65.0,
    "tau" : 10.,
    "R" : 2.,
    "threshold" : -55.,
}

In [61]:
SYNAPSE_CONFIG = {
    'scheme': FNPP,
    'mode': "normal(mean=6, std=1)",
    'density': 0.1
}

In [62]:

N_i = NeuronGroup(net=net, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INHIBITORY_INPUT),
        3: Synapse(**SYNAPSE_CONFIG),
        5: EventRecorder(['spikes'])
    })

N_e = NeuronGroup(net=net, tag='Excitatory_Population', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**EXC_CONFIG),
        2: Input(**EXCITATORY_INPUT),
        3: Synapse(**SYNAPSE_CONFIG),
        4: EventRecorder(['spikes'])
    })

In [63]:
SynapseGroup(net=net, src=N_e, dst=N_i, tag="GLUTAMATE")
SynapseGroup(net=net, src=N_e, dst=N_e, tag="GLUTAMATE")
SynapseGroup(net=net, src=N_i, dst=N_e, tag="GABA")
SynapseGroup(net=net, src=N_i, dst=N_i, tag="GABA")

SynapseGroup['GABA', 'SynapseGroup', 'syn', 'Inhibitory_Population => Inhibitory_Population'](D20xS20){}

In [64]:
net.initialize()
net.simulate_iterations(ITER)

Network['Network'](Neurons: tensor(100)|2 groups, Synapses: tensor(10000)|4 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](20){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(scheme=Fixed Number of Presynaptic Partners,mode=normal(mean=6, std=1),density=0.1,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population', 'NeuronGroup', 'ng'](80){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[14.7807, 13.4619, 11.7022,  ...,  7.0748,  2.3808,  6.7476],
        [ 9.0326,  7.9838,  4.9780,  ...,  7.7474,  5.4136, 11.4370],
        [ 3.2983,  9.6857,  9.0746,  ...,  9

181.90383911132812

In [65]:
plot_balanced_network(net, ITER)