# Neuronal Populations

In [1]:
import torch
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
FULL = "Full"
FCP = "Fixed Coupling Probability"
FNPP = "Fixed Number of Presynaptic Partners"

## LIF Model

In [4]:
class LIF(Behavior):
    """
    The neural dynamics of LIF is defined by:
    tau*dv/dt = v_rest - v + R*I,
    if v >= threshold then v = v_reset.
    Args:
        tau (float): time constant of voltage decay.
        v_rest (float): voltage at rest.
        v_reset (float): value of voltage reset.
        threshold (float): the voltage threshold.
        R (float): the resistance of the membrane potential.
    """


    def initialize(self, neurons):
        """
        Set neuron parameters.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self.add_tag("LIF")
        self.set_parameters_as_variables(neurons)
        neurons.v = neurons.vector(mode="ones") * neurons.v_rest
        neurons.spikes = neurons.vector(mode="zeros")

    def _dv_dt(self, neurons):
        """
        Single step voltage dynamics of simple LIF neurons.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        delta_v = neurons.v_rest - neurons.v
        input_current = neurons.R * neurons.I
        return delta_v + input_current

    def _fire(self, neurons):
        """
        Single step of LIF dynamics.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        neurons.spikes = neurons.v >= neurons.threshold
        neurons.v[neurons.spikes] = neurons.v_reset

    def forward(self, neurons):
        """
        Firing behavior of LIF neurons.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self._fire(neurons)
        neurons.v += self._dv_dt(neurons) / neurons.tau



## Synapse

In [3]:
class Synapse(Behavior):

    def _full_scheme(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode=self.mode)
            s.W = self.coef * s.W + self.scale

    def _fixed_couple_probability(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode=self.mode, density=self.density)
            s.W = self.coef * s.W + self.scale

    def _fixed_pre_partners(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode=self.mode)
            num_post_neurons, num_pre_neurons = s.W.shape[0], s.W.shape[1] 
            s.W = self.coef * s.W + self.scale
            neurons_to_select = num_pre_neurons - self.density
            for neuron_id in range(num_post_neurons):
                mask = torch.randint(0, num_pre_neurons-1, size=(neurons_to_select,))
                s.W[neuron_id, mask] = 0.0


    
    def _set_synapse_weights(self, neurons):
        """
        Set synapse weights.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        if self.scheme == FULL:
            self._full_scheme(neurons)
        elif self.scheme == FCP:
            self._fixed_couple_probability(neurons)
        elif self.scheme == FNPP:
            self._fixed_pre_partners(neurons)
        else:
            raise Exception("Invalide scheme!")
                      

    def initialize(self, neurons):
        """
        Set synapse parameters.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        self.add_tag('Synapse')
        self.scheme = self.parameter('scheme', FULL)
        self.mode = self.parameter('mode', "uniform")
        self.coef = self.parameter('coef', 1.)
        self.scale = self.parameter('scale', 0.0) 
        self.density = self.parameter('density', None)
        self._set_synapse_weights(neurons)
        
    
    
    
    def _get_presynaptic_inputs(self, synapse):
        """
        Calculate presynaptic inputs of population.
        Args:
            synapse (SynapseGroup): the connections between src and dst neurons.
        """
        spikes = synapse.src.spikes.float()
        return torch.matmul(synapse.W, spikes)
    

    def forward(self, neurons):
        """
        Implementation of both Excitatory and Inhibitory synaptic connections.
        Args:
            neurons (NeuronGroup): the post synaptic neural population.
        """
        for s in neurons.afferent_synapses['GLUTAMATE']: 
            neurons.I += self._get_presynaptic_inputs(s)
            

        for s in neurons.afferent_synapses['GABA']:
            neurons.I -= self._get_presynaptic_inputs(s)


## Input

In [None]:
class Input(Behavior):
    @staticmethod
    def _reset_inputs(neurons):
        neurons.I = neurons.vector(mode="zeros")

    def initialize(self, neurons):
        self.add_tag('Input')
        self.input = self.parameter('I', None)
        self._reset_inputs(neurons)


    def forward(self, neurons):
        self._reset_inputs(neurons)
        neurons.I += self.input[neurons.iteration-1]