## Import Necessary Libraries

In [None]:
import torch
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder
from requiredBehaviours import *

## Plotting Functions

In [2]:
def add_rectangles(fig, signal_orders, iter_duration, rest_duration, total_iter):
    for i in range(total_iter):
        color = '#9467bd' if signal_orders[i] else 'orange'
        x_start = i * iter_duration
        x_end = x_start + iter_duration - 2 * rest_duration
        fig.add_vrect(
            x0=f"{x_start}", x1=f"{x_end}",
            y0="0", y1="14",
            fillcolor=color, opacity=0.25,
            layer="below", line_width=0,
        )
    


def plot_network(net, dopamine, signal_orders, iter_duration, rest_duration, total_iter):
    inh_df = pd.DataFrame({'t':net['spikes.t',0], 'i':net['spikes.i',0]})
    exc_df = pd.DataFrame({'t':net['spikes.t',1], 'i':net['spikes.i',1]})
    exc_dst_1_df = pd.DataFrame({'t':net['spikes.t',2], 'i':net['spikes.i',2]})
    exc_dst_2_df = pd.DataFrame({'t':net['spikes.t',3], 'i':net['spikes.i',3]})
    inh_fig = px.scatter(inh_df, x='t', y='i', title="Raster Plot of Inhibitory Population")
    exc_fig = px.scatter(exc_df, x='t', y='i', title="Raster Plot of Source Excitatory Population")
    exc_fig_1 = px.scatter(exc_dst_1_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    exc_fig_2 = px.scatter(exc_dst_2_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    dopamine_fig = px.line(dopamine, title="Dopamine")
    inh_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig_1.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig_2.update_traces(marker=dict(size=2, opacity=0.5))
    dopamine_fig.update_layout(xaxis_title="iteration", yaxis_title="Dopamine")
    add_rectangles(exc_fig_1, signal_orders, iter_duration, rest_duration, total_iter)
    add_rectangles(exc_fig_2, signal_orders, iter_duration, rest_duration, total_iter)
    inh_fig.show()
    exc_fig.show()
    exc_fig_1.show()
    exc_fig_2.show()
    dopamine_fig.show()

## Define Reward/Punishment Function

In [3]:
class Payoff(Behavior):
    """
    Class to define the payoff (reward/punishment) function.
    Args:
        payoff (float): initial payoff value.
        signal_value (float, optional):
    """

    def initialize(self, network):
        """
        Set initial payoff value.

        Args:
            network (Network): Network object.
        """
        self.add_tag("Payoff")
        network.payoff = self.parameter('payoff', 0.0)
        self.signal_value = self.parameter('signal_value', 0.1)
        self.signal_no = self.parameter('signal_no', None)

    def _calculate_reward_punishment(self, ng_0, ng_1):
        spike_diff_prop = (torch.sum(ng_0.spikes) - torch.sum(ng_1.spikes)) / len(ng_0.spikes)
        return self.signal_value * spike_diff_prop

    def _get_reward_punishment(self, network):
        ng0, ng1 = network.find_objects(key="Excitatory_Population_Dest_0")[0], network.find_objects(key="Excitatory_Population_Dest_1")[0]
        if self.signal_no[network.iteration-1] == 0:
            return self._calculate_reward_punishment(ng0, ng1)
        return self._calculate_reward_punishment(ng1, ng0)
        

        

    def forward(self, network):
        """
        define the payoff (reward/punishment) function
        Args:
            network (Network): Network object.
        """
        network.payoff = self._get_reward_punishment(network)
        
       

## Define `Dopamine` Behavior

In [4]:
class Dopamine(Behavior):
    """
    Compute extracellular dopamine concentration.
    Args:
        tau_dopamine (float): Dopamine decay time constant.
        initial_dopamine_concentration (float, optional): Initial dopamine concentration
    """

    def initialize(self, network):
        """
        Set initial dopamine concentration value based on initial payoff value.

        Args:
            network (Network): Network object.
        """
        self.add_tag("Dopamine")

        network.tau_dopamine = self.parameter("tau_dopamine", 1.0)
        self.dopamine_array = self.parameter('array', None)
        network.dopamine_concentration = self.parameter(
            "initial_dopamine_concentration", network.payoff
        )

    def forward(self, network):
        """
        Compute extracellular dopamine concentration at each time step by:

        dd/dt = -d/tau_d + payoff(t).

        Args:
            network (Network): Network object.
        """
        dd_dt = (
            -(network.dopamine_concentration / network.tau_dopamine) + network.payoff
        )
        network.dopamine_concentration += dd_dt
        self.dopamine_array[network.iteration-1] = network.dopamine_concentration

## Define Network Parameters

In [293]:
MEAN, STD, THRESHOLD = 25.0, 6.0, 0.5 
SIGNAL_DURATION = 100
REST = 0
SIGNAL_REPEAT = 20
ITER_NO = 2
TARGET_NEURON_SIZE = 10
TOTAL_DURATION = ((SIGNAL_DURATION + REST) * SIGNAL_REPEAT + REST) * ITER_NO
DOPAMINE_ARRAY = torch.zeros(TOTAL_DURATION)
inputs = InputGenerator(MEAN, STD, THRESHOLD)
INHIBITORY_D = {'D':1.0}
EXCITATORY_D = {'D':0.70}
DST_D = {'D':0.75} 
TRACE_PARAMS = {'tau': 10.}
RSTDP_PARAMS = {'a_plus': 0.110, 'a_minus': 0.0710, 'w_max':2.0, 'w_min':0.0, "initial_c":0.0, 'tau_c':1500., 'enable_soft_bound':True} 
DOPAMINE_PARAMS = {'tau_dopamine': 10., "initial_dopamine_concentration": 5.0, "array":DOPAMINE_ARRAY}
STDP_PARAMS = {'a_plus': 0.0380, 'a_minus': 0.0330, 'w_max':1.0, 'w_min':0.0,'enable_soft_bound':True}
iSTDP_PARAMS = {'lr': 0.01819, 'freq': 10.}
CLIP_PARAMS = {'w_min': 0.0, 'w_max':1.0}
RSTDP_CLIP_PARAMS = {'w_min': 0.0, 'w_max':2.0}
INH_ZERO_INPUT = {"I": inputs.get_zero_input(INHIBITORY_NEURON_SIZE, TOTAL_DURATION)}
EXC_RANDOM_INPUT = {"I": inputs.get_random_signals(ITER_NO, SIGNAL_DURATION, SIGNAL_REPEAT, REST, EXCITATORY_NEURON_SIZE, 42)}
TARGET_ZERO_INPUT = {"I": inputs.get_zero_input(TARGET_NEURON_SIZE, TOTAL_DURATION)}
PAYOFF_PARAMS = {'signal_no': inputs.random_signal_no, 'signal_value':3.5}

## Network Structure

<div style="text-align: center;"><img src="RSTDP_schematic.png" alt="schematic" width="600"/></div>

## R-STDP with STDP

In [204]:
net_stdp = Network(settings={"device": "cpu", "dtype":torch.float32}, behavior={
    1: Payoff(**PAYOFF_PARAMS),
    2: Dopamine(**DOPAMINE_PARAMS)
})


N_i_stdp = NeuronGroup(net=net_stdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_ZERO_INPUT),
        3: Synapse(**INHIBITORY_D),
        5: EventRecorder(['spikes'])
    })


N_e_src_stdp = NeuronGroup(net=net_stdp, tag='Excitatory_Population_Source', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**EXCITATORY_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_1_stdp = NeuronGroup(net=net_stdp, tag='Excitatory_Population_Dest_0', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_2_stdp = NeuronGroup(net=net_stdp, tag='Excitatory_Population_Dest_1', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })


In [205]:
connections_stdp = {
    "same":{
        "exc":[{
            "src":0,
            "dst":1,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }, {
            "src":0,
            "dst":2,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":anti_STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}
simulate_stdp = Simulator(net_stdp, [N_e_src_stdp, N_e_dst_1_stdp, N_e_dst_2_stdp], [N_i_stdp], connections=connections_stdp,
                      trace_params=TRACE_PARAMS)

net_stdp = simulate_stdp.simulate(TOTAL_DURATION)

Network['Network'](Neurons: tensor(145)|4 groups, Synapses: tensor(7625)|5 groups){1:Payoff(signal_no=tensor([1., 1., 1.,  ..., 0., 0., 0.]),signal_value=3.5,)2:Dopamine(tau_dopamine=10.0,initial_dopamine_concentration=5.0,array=tensor([0., 0., 0.,  ..., 0., 0., 0.]),)}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(D=1.0,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population_Source', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000, 21.3009,  ..., 24.8847, 25.1241, 23.2734]

In [206]:
plot_network(net_stdp, DOPAMINE_ARRAY, inputs.signal_orders, TOTAL_DURATION//ITER_NO, REST, ITER_NO)

## R-STDP with iSTDP

In [294]:
net_istdp = Network(settings={"device": "cpu", "dtype":torch.float32}, behavior={
    1: Payoff(**PAYOFF_PARAMS),
    2: Dopamine(**DOPAMINE_PARAMS)
})


N_i = NeuronGroup(net=net_istdp, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_ZERO_INPUT),
        3: Synapse(**INHIBITORY_D),
        5: EventRecorder(['spikes'])
    })


N_e_src = NeuronGroup(net=net_istdp, tag='Excitatory_Population_Source', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**SRC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**EXCITATORY_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_1 = NeuronGroup(net=net_istdp, tag='Excitatory_Population_Dest_0', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_2 = NeuronGroup(net=net_istdp, tag='Excitatory_Population_Dest_1', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**DST_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })


In [295]:
connections_istdp = {
    "same":{
        "exc":[{
            "src":0,
            "dst":1,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }, {
            "src":0,
            "dst":2,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":iSTDP,
            "learning_params":iSTDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}
simulate_istdp = Simulator(net_istdp, [N_e_src, N_e_dst_1, N_e_dst_2], [N_i], connections=connections_istdp,
                      trace_params=TRACE_PARAMS)

net_istdp = simulate_istdp.simulate(TOTAL_DURATION)

Network['Network'](Neurons: tensor(145)|4 groups, Synapses: tensor(7625)|5 groups){1:Payoff(signal_no=tensor([1., 1., 1.,  ..., 0., 0., 0.]),signal_value=3.5,)2:Dopamine(tau_dopamine=10.0,initial_dopamine_concentration=5.0,array=tensor([0., 0., 0.,  ..., 0., 0., 0.]),)}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(D=1.0,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population_Source', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000, 21.3009,  ..., 24.8847, 25.1241, 23.2734]

In [296]:
plot_network(net_istdp, DOPAMINE_ARRAY, inputs.signal_orders, TOTAL_DURATION//ITER_NO, REST, ITER_NO)