## Install Necessary Libraries

In [None]:
!pip install pymonntorch



## Define Behaviors

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import pandas as pd
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder
import copy


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

EXC_CONFIG = {
    "v_reset" : -65.,
    "threshold" : 30.,
    "a": 0.02,
    "b": 0.2,
    "d": 8.,
    "u": -8.
}

INH_CONFIG = {
    "v_reset" : -65.,
    "threshold" : 30.,
    "a": 0.1,
    "b": 0.2,
    "d": 2.,
    "u": -8.
}


STDP_TRACE_PARAMS = {'tau_plus': 20., "tau_minus":20.}
iSTDP_TRACE_PARAMS = {'tau_plus': 20., "tau_minus":20.}




"""
Implementation of Izhikevich neuron model.
"""


class Izhikevich(Behavior):

  """
    The neural dynamics of Izhikevich model is defined by:
    dv/dt = 0.04 * v**2 + 5v + 140 - u + I,
    du/dt = a(bv - u)
    if v >= threshold then v = v_reset and u = u + d.
    Args:
        v_reset (float): value of voltage reset.
        threshold (float): the voltage threshold.
        a (float): Time scale of the recovery variable.
        b (float):Sensivity of the recovery variable to the subthreshold fluctuations of v.
        d (float): after-spike reset value of u

    """

  def initialize(self, neurons):
      super().initialize(neurons)
      self.add_tag("Izhikevich")
      self.set_parameters_as_variables(neurons)
      neurons.a *= neurons.vector(mode='ones')
      neurons.b *= neurons.vector(mode='ones')
      neurons.d *= neurons.vector(mode='ones')
      neurons.v = neurons.vector(mode="ones") * neurons.v_reset
      neurons.membrane_recovery = neurons.vector(mode="ones") * neurons.u
      neurons.spikes = neurons.vector(mode="zeros")



  def _dv_dt(self, neurons):
      return 0.04 * neurons.v**2 + 5*neurons.v + 140 - neurons.membrane_recovery + neurons.I


  def _du_dt(self, neurons):
      return neurons.a * (neurons.b * neurons.v - neurons.membrane_recovery)



  def _fire(self, neurons):
      neurons.spikes = neurons.v >= neurons.threshold
      neurons.v[neurons.spikes] = neurons.v_reset
      neurons.membrane_recovery[neurons.spikes] += neurons.d[neurons.spikes]



  def forward(self, neurons):
      self._fire(neurons)
      neurons.v += self._dv_dt(neurons)
      neurons.membrane_recovery += self._du_dt(neurons)



"""
Implementation of Synapses.
"""
class Synapse(Behavior):


    def _initialize_delays(self, neurons):
        for s in neurons.afferent_synapses['All']:
            s.D = torch.randint(low=self.min_delay, high=self.max_delay, size=s.matrix_dim()).to(DEVICE)



    def _set_synapse_weights(self, neurons):
        """
        Set synapse weights.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        for s in neurons.afferent_synapses['All']:
            s.W = s.matrix(mode='uniform', density=neurons.D) * self.coef



    def initialize(self, neurons):
        """
        Set synapse parameters.
        Args:
            neurons (NeuronGroup): the neural population.
        """
        super().initialize(neurons)
        self.add_tag('Synapse')
        self.set_parameters_as_variables(neurons)
        self.coef = self.parameter('coef', 1.)
        self._set_synapse_weights(neurons)
        self.has_delay = self.parameter("has_delay", False)
        if self.has_delay:
            self.min_delay = self.parameter('min_delay', 1)
            self.max_delay = self.parameter('max_delay', 20)
            self._initialize_delays(neurons)




    def _get_presynaptic_inputs(self, synapse):
        """
        Calculate presynaptic inputs of population.
        Args:
            synapse (SynapseGroup): the connections between src and dst neurons.
        """
        if self.has_delay:
            spikes = synapse.src.axon.get_spikes(synapse.D)
            return torch.sum(synapse.W * spikes, dim=1)
        else:
            spikes = synapse.src.spikes.float()
            return torch.matmul(synapse.W, spikes)


    def forward(self, neurons):
        """
        Implementation of both Excitatory and Inhibitory synaptic connections.
        Args:
            neurons (NeuronGroup): the post synaptic neural population.
        """
        for s in neurons.afferent_synapses.get('GLUTAMATE', []):
            neurons.I += self._get_presynaptic_inputs(s)


        for s in neurons.afferent_synapses.get('GABA', []):
            neurons.I -= self._get_presynaptic_inputs(s)




class Input(Behavior):
    @staticmethod
    def _reset_inputs(neurons):
        neurons.I = neurons.vector(mode="zeros")

    def initialize(self, neurons):

        super().initialize(neurons)
        self.add_tag('Input')
        self.input = self.parameter('I', None)
        self._reset_inputs(neurons)


    def forward(self, neurons):
        self._reset_inputs(neurons)
        neurons.I += next(self.input)






class Axon(Behavior):

  def initialize(self, neurons):
      super().initialize(neurons)
      self.add_tag("Axon")
      self.max_delay = self.parameter('max_delay', None)
      self.spike_history = neurons.vector_buffer(self.max_delay, dtype=torch.bool)
      neurons.axon = self


  def get_spikes(self, delay):
      return self.spike_history.gather(0, delay)

  def forward(self, neurons):
      self.spike_history = neurons.buffer_roll(
            mat=self.spike_history, new=neurons.spikes
        )



"""
Implementation of synaptic trace.
"""

class Trace(Behavior):

    def initialize(self, synapse):
        """
        Set trace parameters.
        Args:
            synapse (SynapseGroup): the connections between src and dst neurons .
        """
        super().initialize(synapse)
        self.add_tag("Trace")
        self.set_parameters_as_variables(synapse)
        synapse.src_trace = synapse.src.vector(mode="zeros")
        synapse.dst_trace = synapse.dst.vector(mode="zeros")

    @staticmethod
    def _get_trace_change(trace, tau, n):
        """
        trace variables of both pre- and post-synaptic neurons are modified through time with:
        dx/dt = -x/tau + neurons.spikes,
        Args:
            n (NeuronGroup): NeuronGroup that is involved in s
            s (SynapseGroup): the connection between src and dst neurons.
        """
        d_trace = -1 * trace/tau + n.spikes
        return d_trace

    def _update_spike_trace(self, synapse):
        """
        Single step of spike trace dynamics.
        Args:
            synapse (SynapseGroup): the connection between src and dst neurons.
        """
        synapse.src_trace += self._get_trace_change(synapse.src_trace, synapse.tau_plus, synapse.src)
        synapse.dst_trace += self._get_trace_change(synapse.dst_trace, synapse.tau_minus, synapse.dst)


    def forward(self, synapse):
        self._update_spike_trace(synapse)




class STDP(Behavior):
    """
    Spike-Timing Dependent Plasticity (STDP) rule for simple connections.
    Note: The implementation uses local variables (spike trace).
    Args:
        a_plus (float): Coefficient for the positive weight change. The default is None.
        a_minus (float): Coefficient for the negative weight change. The default is None.
    """

    def initialize(self, synapse):
        super().initialize(synapse)
        self.add_tag("STDP")
        self.enable_soft_bound = self.parameter('enable_soft_bound', False)
        self.w_max = self.parameter("w_max", None)
        self.w_min = self.parameter("w_min", None)
        self.a_plus = self.parameter("a_plus", None)
        self.a_minus = self.parameter("a_minus", None)

    def compute_coefs(self, synapse):
        coef_plus = (self.w_max - synapse.W) * self.a_plus if self.enable_soft_bound else self.a_plus
        coef_minus = (synapse.W - self.w_min) * self.a_minus if self.enable_soft_bound else self.a_minus
        return coef_plus, coef_minus

    def compute_dw(self, s):
        coef_plus, coef_minus = self.compute_coefs(s)
        dw_minus = torch.outer(s.dst_trace, s.src.spikes) * coef_minus
        dw_plus = torch.outer(s.dst.spikes, s.src_trace) * coef_plus
        return dw_plus - dw_minus


    def forward(self, synapse):
        synapse.W += self.compute_dw(synapse)


class Anti_STDP(STDP):
    def forward(self, synapse):
        self.compute_coefs(synapse)
        synapse.W += (-1) * self.compute_dw(synapse)




"""
Reward-modulated Spike-Timing Dependent Plasticity (RSTDP) rule for simple connections.
"""

class RSTDP(STDP):


    def initialize(self, synapse):
        super().initialize(synapse)
        self.tau_c = self.parameter('tau_c', None)
        self.initial_c = self.parameter('initial_c', 0.0)
        synapse.c = torch.ones(*synapse.matrix_dim()) * self.initial_c



    def forward(self, synapse):
        stdp = self.compute_dw(synapse)
        synapse.c += (-synapse.c / self.tau_c) + stdp
        synapse.W += synapse.c * synapse.network.dopamine_concentration





"""
Implementation of inhibitory Spike-Time Dependent Plasticity (iSTDP).
"""


class iSTDP(Behavior):
    """
    the synaptic weights are updated for every pre- and post-synaptic event such that:
        W_ij = W_ij + lr*(x_i - alpha) for pre-synaptic spikes,
        W_ij = W_ij + lr * x_j  for post-synaptic spikes
    where lr is the learning rate, x_i and x_j are post- and pre-synaptic trace, alpha is the depression factor determined by:
        alpha = 2 * freq * tau
    the freq parameter acts as a target firing rate. The learning rule implements a form of homeostatic plasticity
    that stabilizes the postsynaptic firing rate.
    Args:
        lr (float): Learning rate.
        freq (float): Constant that determines post-synaptic firing rate.
    """

    def calculate_alpha(self, synapse):
        return 2 * self.freq * synapse.tau_plus / 1000

    def initialize(self, synapse):
        super().initialize(synapse)
        self.add_tag("iSTDP")
        self.lr = self.parameter('lr', None)
        self.freq = self.parameter('freq', None)
        self.alpha = self.calculate_alpha(synapse)

    def get_weight_changes(self, s):
        pre_spike_changes = self.lr * (torch.outer(s.dst_trace - self.alpha, s.src.spikes))
        post_spike_changes = self.lr * torch.outer(s.dst.spikes, s.src_trace)
        return pre_spike_changes + post_spike_changes


    def forward(self, synapse):
        synapse.W += self.get_weight_changes(synapse)






class WeightClip(Behavior):

    def initialize(self, synapse):
        super().initialize(synapse)
        self.w_min = self.parameter('w_min', 0)
        self.w_max = self.parameter('w_max', 1)
        assert 0 <= self.w_min < self.w_max, "Invalid weight range!"

    def forward(self, synapse):
        synapse.W = torch.clip(synapse.W, self.w_min, self.w_max)





class KWTA(Behavior):

    """
    KWTA behavior of spiking neurons:
    if v >= threshold then v = v_reset and all other spiked neurons are inhibited.
    """

    def initialize(self, neurons):
        super().initialize(neurons)
        self.k = self.parameter('k', None)


    def forward(self, neurons):
        will_spike = neurons.v >= neurons.threshold
        will_spike_v = will_spike * (neurons.v - neurons.threshold)

        if torch.sum(will_spike) <= self.k:
            return

        k_values, k_indices = torch.topk(will_spike_v, self.k)
        min_value = k_values.min()
        neurons.v[will_spike_v < min_value] = neurons.v_reset





class SpikeRecorder(Behavior):

    def initialize(self, neurons):
        self.add_tag('Recorder')
        self.start_time = self.parameter('start_time', None)
        self.finish_time = self.parameter('finish_time', None)
        self.spikes_array = self.parameter('array', None)

    def forward(self, neurons):
        if (self.start_time <= neurons.iteration) and (self.finish_time > neurons.iteration):
            index = neurons.iteration % self.start_time
            self.spikes_array[:, index] = neurons.spikes




class InputGenerator:


    def __init__(self, mu, std, threshold, device):
        self.mean = mu
        self.std = std
        self.threshold = threshold
        self.device = device
        self.signals = None


    def get_random_input(self, population_size, duration, seed):
      torch.manual_seed(seed)
      for _ in range(duration):
          enable_input = torch.randint(low=0, high=population_size, size=(1,))
          input = torch.zeros(population_size)
          input[enable_input] = 20.0
          yield input.to(self.device)



    def get_zero_input(self, population_size, duration):
      for _ in range(duration):
          yield torch.zeros(population_size).to(self.device)






class Simulator:

    def __init__(self, net, excitatory_pops:list, inhibitory_pops:list, connections:dict,
                                trace_params_stdp:dict, trace_params_istdp:dict):

        self.excitatory_pops = excitatory_pops
        self.inhibitory_pops = inhibitory_pops
        self.connections = connections
        self.trace_params_stdp = trace_params_stdp
        self.trace_params_istdp = trace_params_istdp
        self.net = net

    def add_coonections(self, src_populations:list, dst_populations:list, connection_maps:list,
                        connection_tag:str):
        for connection in connection_maps:
            src_pop = src_populations[connection['src']]
            dst_pop = dst_populations[connection['dst']]
            learning_rule, learning_params = connection['learning_rule'], connection['learning_params']
            clip_params = connection['clip_params']
            if learning_rule is None:
                behavior = {}
            else:
                trace = self.trace_params_stdp if isinstance(learning_rule, STDP) else self.trace_params_istdp
                behavior={
                    1: Trace(**trace),
                    2: learning_rule(**learning_params) ,
                    3: WeightClip(**clip_params)
                }
            SynapseGroup(net=self.net, src=src_pop, dst=dst_pop, tag=connection_tag, behavior=behavior)


    def set_coonections(self):
        self.add_coonections(self.inhibitory_pops, self.inhibitory_pops, self.connections['same']['inh'], 'GABA')
        self.add_coonections(self.excitatory_pops, self.excitatory_pops, self.connections['same']['exc'], 'GLUTAMATE')
        self.add_coonections(self.excitatory_pops, self.inhibitory_pops, self.connections['different']['exc_inh'], 'GLUTAMATE')
        self.add_coonections(self.inhibitory_pops, self.excitatory_pops, self.connections['different']['inh_exc'], 'GABA')


    def simulate(self, iter):
        self.set_coonections()
        self.net.initialize()
        self.net.simulate_iterations(iter)
        return self.net





## Plotting Functions

In [None]:
def rasterPlot(exc_records, inh_records):
    all_records = torch.cat((exc_records, inh_records), dim=0)
    y, x = torch.where(all_records)
    raster_fig = px.scatter(x=x, y=y, title="Raster Plot of Neurons")
    raster_fig.show(renderer='colab')

## Network Configuration

In [None]:
MEAN, STD, THRESHOLD = 10., 5., 0.2
TOTAL_DURATION = 5000
TOTAL_NEURON_SIZE = 1000
EXC_NEURON_SIZE = int(TOTAL_NEURON_SIZE * 0.8)
INH_NEURON_SIZE = int(TOTAL_NEURON_SIZE * 0.2)
MAX_DELAY = 20
MIN_DELAY = 1
AXON = {"max_delay": MAX_DELAY}
INHIBITORY_SYNAPSE = {'D':0.125, 'has_delay':True, 'min_delay':MIN_DELAY, 'max_delay':MAX_DELAY}
EXCITATORY_SYNAPSE = {'D':0.1, 'has_delay':True, 'min_delay':MIN_DELAY, 'max_delay':MAX_DELAY}
STDP_PARAMS = {'a_plus': 0.1, 'a_minus': 0.12, 'w_max':2.0, 'w_min':0.0,'enable_soft_bound':False}
iSTDP_PARAMS = {'lr': 0.020, 'freq': 10.}
CLIP_PARAMS = {'w_min': 0.0, 'w_max':2.0}
inputs =  InputGenerator(MEAN, STD, THRESHOLD, DEVICE)
INH_ZERO_INPUT = {"I": inputs.get_zero_input(INH_NEURON_SIZE, TOTAL_DURATION)}
EXC_RANDOM_INPUT = {"I": inputs.get_random_input(EXC_NEURON_SIZE, TOTAL_DURATION, 42)}

In [None]:
EXC_1, INH_1 = torch.zeros((EXC_NEURON_SIZE, 1000)), torch.zeros((INH_NEURON_SIZE, 1000))
EXC_100, INH_100 = torch.zeros((EXC_NEURON_SIZE, 1000)), torch.zeros((INH_NEURON_SIZE, 1000))
EXC_3600, INH_3600 = torch.zeros((EXC_NEURON_SIZE, 1000)), torch.zeros((INH_NEURON_SIZE, 1000))

In [None]:
RECORDER_1_EXC = {'start_time': 1 * 1000, 'finish_time': 2 * 1000, 'array': EXC_1}
RECORDER_1_INH = {'start_time': 1 * 1000, 'finish_time': 2 * 1000, 'array': INH_1}

## Network Structure

In [None]:
connections = {
    "same":{
        "exc":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh":[]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":iSTDP,
            "learning_params":iSTDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}


In [None]:
net = Network(device=DEVICE, synapse_mode='DxS')

N_i = NeuronGroup(net=net, tag='Inhibitory_Population', size=INH_NEURON_SIZE, behavior={
      1: Izhikevich(**INH_CONFIG),
      2: Axon(**AXON),
      3: Input(**INH_ZERO_INPUT),
      4: Synapse(**INHIBITORY_SYNAPSE),
      5: SpikeRecorder(**RECORDER_1_INH)
})

N_e = NeuronGroup(net=net, tag='Excitatory_Population', size=EXC_NEURON_SIZE, behavior={
      1: Izhikevich(**EXC_CONFIG),
      2: Axon(**AXON),
      3: Input(**EXC_RANDOM_INPUT),
      4: Synapse(**EXCITATORY_SYNAPSE),
      5: SpikeRecorder(**RECORDER_1_EXC)
})

simulate_net = Simulator(net, [N_e], [N_i], connections=connections,
                      trace_params_stdp=STDP_TRACE_PARAMS, trace_params_istdp=iSTDP_TRACE_PARAMS)
net = simulate_net.simulate(TOTAL_DURATION)

Network['Network_1', 'Network'](Neurons: tensor(1000)|2 groups, Synapses: tensor(960000)|3 groups){}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](200){1:Izhikevich(v_reset=-65.0,threshold=30.0,a=0.1,b=0.2,d=2.0,u=-8.0,)2:Axon(max_delay=20,)3:Input(I=<generator object InputGenerator.get_zero_input at 0x7aafa1007290>,)4:Synapse(D=0.125,has_delay=True,min_delay=1,max_delay=20,)5:SpikeRecorder(start_time=1000,finish_time=2000,array=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)}
NeuronGroup['Excitatory_Population', 'NeuronGroup', 'ng'](800){1:Izhikevich(v_reset=-65.0,threshold=30.0,a=0.02,b=0.2,d=8.0,u=-8.0,)2:Axon(max_delay=20,)3:Input(I=<generator object InputGenerator.get_random_input at 0x7aafa1007370>,)4:Synapse(D=0.1,has_delay=True,min_delay=1,max_delay=20,)5:Spik

In [None]:
rasterPlot(EXC_1, INH_1)