## Reward-Modulated Spike-Timing-Dependent Plasticity(R-STDP)

In [1]:
import torch
import pandas as pd
import plotly.express as px
from pymonntorch import Behavior, SynapseGroup, Network, NeuronGroup, Recorder, EventRecorder
from  Behaviour import *

### Implementation

In [2]:
class STDP(Behavior):
    """
    Spike-Timing Dependent Plasticity (STDP) rule for simple connections.
    Note: The implementation uses local variables (spike trace).
    Args:
        a_plus (float): Coefficient for the positive weight change. The default is None.
        a_minus (float): Coefficient for the negative weight change. The default is None.
        w_max (float): maximum value of synaptic strength. The default is None.
        w_min (float): minimum value of synaptic strength. The default is None.
        enable_soft_bound (boolean): if true, soft bound mechanism will be applied. The default is False
    """

    def initialize(self, synapse):
        self.add_tag("STDP")
        self.enable_soft_bound = self.parameter('enable_soft_bound', False)
        self.w_max = self.parameter("w_max", None)
        self.w_min = self.parameter("w_min", None)
        self.a_plus = self.parameter("a_plus", None)
        self.a_minus = self.parameter("a_minus", None)

    def compute_coefs(self, synapse):
        coef_plus = (self.w_max - synapse.W) * self.a_plus if self.enable_soft_bound else self.a_plus
        coef_minus = (synapse.W - self.w_min) * self.a_minus if self.enable_soft_bound else self.a_minus
        return coef_plus, coef_minus

    def compute_dw(self, s):
        coef_plus, coef_minus = self.compute_coefs(s)
        dw_minus = torch.outer(s.dst.trace, s.src.spikes) * coef_minus
        dw_plus = torch.outer(s.dst.spikes, s.src.trace) * coef_plus
        return dw_plus - dw_minus


    def forward(self, synapse):
        synapse.W += self.compute_dw(synapse)


class Anti_STDP(STDP):
    """
    Anti-Hebbian Spike-Timing Dependent Plasticity (STDP) rule for simple connections.
    Note: The implementation uses local variables (spike trace).
    Arguments are the same as the parent class.
    """
    def forward(self, synapse):
        self.compute_coefs(synapse)
        synapse.W += (-1) * self.compute_dw(synapse)


class RSTDP(STDP):

    """
    Reward-modulated Spike-Timing Dependent Plasticity (RSTDP) rule for simple connections.
    Args:
        tau_c (float): Decay factor of c. The Default is None.
        initial_c (float): Initial value of c. The Default is 0.
    """


    def initialize(self, synapse):
        super().initialize(synapse)
        self.tau_c = self.parameter('tau_c', None)
        self.initial_c = self.parameter('initial_c', 0.0)
        synapse.c = torch.ones(*synapse.matrix_dim()) * self.initial_c



    def forward(self, synapse):
        stdp = self.compute_dw(synapse)
        synapse.c += (-synapse.c / self.tau_c) + stdp
        synapse.W += synapse.c * synapse.network.dopamine_concentration


## Payoff Behaviour

In [3]:
class Payoff(Behavior):
    """
    Class to define the payoff (reward/punishment) function.
    Args:
        payoff (float): Initial payoff value.
        signal_value (float, optional): Coefficient.
        signal_no (int): Id of given signals.
    """

    def initialize(self, network):
        """
        Set initial payoff value.

        Args:
            network (Network): Network object.
        """
        self.add_tag("Payoff")
        network.payoff = self.parameter('payoff', 0.0)
        self.signal_value = self.parameter('signal_value', 0.1)
        self.signal_no = self.parameter('signal_no', None)

    def _calculate_reward_punishment(self, ng_0, ng_1):
        spike_diff_prop = (torch.sum(ng_0.spikes) - torch.sum(ng_1.spikes)) / len(ng_0.spikes)
        return self.signal_value * spike_diff_prop

    def _get_reward_punishment(self, network):
        ng0, ng1 = network.find_objects(key="Excitatory_Population_Dest_0")[0], network.find_objects(key="Excitatory_Population_Dest_1")[0]
        if self.signal_no[network.iteration-1] == 0:
            return self._calculate_reward_punishment(ng0, ng1)
        return self._calculate_reward_punishment(ng1, ng0)
        

    def forward(self, network):
        """
        define the payoff (reward/punishment) function
        Args:
            network (Network): Network object.
        """
        network.payoff = self._get_reward_punishment(network)
        
       

## Dopamine

In [4]:
class Dopamine(Behavior):
    """
    Compute extracellular dopamine concentration.
    Args:
        tau_dopamine (float): Dopamine decay time constant.
        initial_dopamine_concentration (float, optional): Initial dopamine concentration.
        array (tensor, optional): Monitors dopamine level.
    """

    def initialize(self, network):
        """
        Set initial dopamine concentration value based on initial payoff value.

        Args:
            network (Network): Network object.
        """
        self.add_tag("Dopamine")

        network.tau_dopamine = self.parameter("tau_dopamine", 1.0)
        self.dopamine_array = self.parameter('array', None)
        network.dopamine_concentration = self.parameter(
            "initial_dopamine_concentration", network.payoff
        )

    def forward(self, network):
        """
        Compute extracellular dopamine concentration at each time step by:

        dd/dt = -d/tau_d + payoff(t).

        Args:
            network (Network): Network object.
        """
        dd_dt = (
            -(network.dopamine_concentration / network.tau_dopamine) + network.payoff
        )
        network.dopamine_concentration += dd_dt
        self.dopamine_array[network.iteration-1] = network.dopamine_concentration

## Plotting Functions

In [5]:
def add_rectangles(fig, signal_orders, iter_duration, rest_duration, total_iter):
    for i in range(total_iter):
        color = '#9467bd' if signal_orders[i] else 'orange'
        x_start = i * iter_duration
        x_end = x_start + iter_duration - 2 * rest_duration
        fig.add_vrect(
            x0=f"{x_start}", x1=f"{x_end}",
            y0="0", y1="14",
            fillcolor=color, opacity=0.25,
            layer="below", line_width=0,
        )
    


def plot_network(net, dopamine, signal_orders, iter_duration, rest_duration, total_iter):
    inh_df = pd.DataFrame({'t':net['spikes.t',0], 'i':net['spikes.i',0]})
    exc_df = pd.DataFrame({'t':net['spikes.t',1], 'i':net['spikes.i',1]})
    exc_dst_1_df = pd.DataFrame({'t':net['spikes.t',2], 'i':net['spikes.i',2]})
    exc_dst_2_df = pd.DataFrame({'t':net['spikes.t',3], 'i':net['spikes.i',3]})
    inh_fig = px.scatter(inh_df, x='t', y='i', title="Raster Plot of Inhibitory Population")
    exc_fig = px.scatter(exc_df, x='t', y='i', title="Raster Plot of Source Excitatory Population")
    exc_fig_1 = px.scatter(exc_dst_1_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    exc_fig_2 = px.scatter(exc_dst_2_df, x='t', y='i', title="Raster Plot of Destination Excitatory Population")
    dopamine_fig = px.line(dopamine, title="Dopamine")
    inh_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig_1.update_traces(marker=dict(size=2, opacity=0.5))
    exc_fig_2.update_traces(marker=dict(size=2, opacity=0.5))
    dopamine_fig.update_layout(xaxis_title="iteration", yaxis_title="Dopamine")
    add_rectangles(exc_fig_1, signal_orders, iter_duration, rest_duration, total_iter)
    add_rectangles(exc_fig_2, signal_orders, iter_duration, rest_duration, total_iter)
    inh_fig.show()
    exc_fig.show()
    exc_fig_1.show()
    exc_fig_2.show()
    dopamine_fig.show()

## Network Structure

<div style="text-align: center;"><img src="schematic_RSTDP.png" alt="schematic" width="600"/></div>

## Define Network

In [10]:
MEAN, STD, THRESHOLD = 25.0, 6.0, 0.5 
SIGNAL_DURATION = 100
REST = 0
SIGNAL_REPEAT = 25
ITER_NO = 2
TARGET_NEURON_SIZE = 10
TOTAL_DURATION = ((SIGNAL_DURATION + REST) * SIGNAL_REPEAT + REST) * ITER_NO
DOPAMINE_ARRAY = torch.zeros(TOTAL_DURATION)
inputs = InputGenerator(MEAN, STD, THRESHOLD)
INHIBITORY_D = {'D':1.0}
EXCITATORY_D = {'D':0.70}
DST_D = {'D':0.75} 
TRACE_PARAMS = {'tau': 10.}
RSTDP_PARAMS = {'a_plus': 0.110, 'a_minus': 0.0710, 'w_max':2.0, 'w_min':0.0, "initial_c":0.0, 'tau_c':1500., 'enable_soft_bound':True} 
DOPAMINE_PARAMS = {'tau_dopamine': 10., "initial_dopamine_concentration": 5.0, "array":DOPAMINE_ARRAY}
STDP_PARAMS = {'a_plus': 0.0380, 'a_minus': 0.0330, 'w_max':1.0, 'w_min':0.0,'enable_soft_bound':True}
CLIP_PARAMS = {'w_min': 0.0, 'w_max':1.0}
RSTDP_CLIP_PARAMS = {'w_min': 0.0, 'w_max':2.0}
INH_ZERO_INPUT = {"I": inputs.get_zero_input(INHIBITORY_NEURON_SIZE, TOTAL_DURATION)}
EXC_RANDOM_INPUT = {"I": inputs.get_random_signals(ITER_NO, SIGNAL_DURATION, SIGNAL_REPEAT, REST, EXCITATORY_NEURON_SIZE, 42)}
TARGET_ZERO_INPUT = {"I": inputs.get_zero_input(TARGET_NEURON_SIZE, TOTAL_DURATION)}
PAYOFF_PARAMS = {'signal_no': inputs.random_signal_no, 'signal_value':3.5}

In [11]:
net = Network(settings={"device": "cpu", "dtype":torch.float32}, behavior={
    1: Payoff(**PAYOFF_PARAMS),
    2: Dopamine(**DOPAMINE_PARAMS)
})


N_i_stdp = NeuronGroup(net=net, tag='Inhibitory_Population', size=INHIBITORY_NEURON_SIZE, behavior={
        1: LIF(**INH_CONFIG),
        2: Input(**INH_ZERO_INPUT),
        3: Synapse(**INHIBITORY_D),
        5: EventRecorder(['spikes'])
    })


N_e_src_stdp = NeuronGroup(net=net, tag='Excitatory_Population_Source', size=EXCITATORY_NEURON_SIZE, behavior={
        1: LIF(**EXC_CONFIG),
        2: Input(**EXC_RANDOM_INPUT),
        3: Synapse(**EXCITATORY_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_1_stdp = NeuronGroup(net=net, tag='Excitatory_Population_Dest_0', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**EXC_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })

N_e_dst_2_stdp = NeuronGroup(net=net, tag='Excitatory_Population_Dest_1', size=TARGET_NEURON_SIZE, behavior={
        1: LIF(**EXC_CONFIG),
        2: Input(**TARGET_ZERO_INPUT),
        3: Synapse(**DST_D),
        4: EventRecorder(['spikes'])
    })


## Set Connections and Simulate

In [12]:
connections_stdp = {
    "same":{
        "exc":[{
            "src":0,
            "dst":1,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }, {
            "src":0,
            "dst":2,
            "learning_rule":RSTDP,
            "learning_params":RSTDP_PARAMS,
            "clip_params": RSTDP_CLIP_PARAMS
        }],
        "inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    },
    "different":{
        "exc_inh":[{
            "src":0,
            "dst":0,
            "learning_rule":STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }],
        "inh_exc":[{
            "src":0,
            "dst":0,
            "learning_rule":Anti_STDP,
            "learning_params":STDP_PARAMS,
            "clip_params": CLIP_PARAMS
        }]
    }
}
simulate = Simulator(net, [N_e_src_stdp, N_e_dst_1_stdp, N_e_dst_2_stdp], [N_i_stdp], connections=connections_stdp,
                      trace_params=TRACE_PARAMS)

net = simulate.simulate(TOTAL_DURATION)

Network['Network'](Neurons: tensor(145)|4 groups, Synapses: tensor(7625)|5 groups){1:Payoff(signal_no=tensor([1., 1., 1.,  ..., 0., 0., 0.]),signal_value=3.5,)2:Dopamine(tau_dopamine=10.0,initial_dopamine_concentration=5.0,array=tensor([0., 0., 0.,  ..., 0., 0., 0.]),)}
NeuronGroup['Inhibitory_Population', 'NeuronGroup', 'ng'](25){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]),)3:Synapse(D=1.0,)5:EventRecorder(variables=None,gap_width=0,max_length=None,auto_annotate=True,tag=None,arg_0=['spikes'],)}
NeuronGroup['Excitatory_Population_Source', 'NeuronGroup', 'ng'](100){1:LIF(v_reset=-65.0,v_rest=-65.0,tau=10.0,R=2.0,threshold=-55.0,)2:Input(I=tensor([[ 0.0000,  0.0000, 21.3009,  ..., 24.8847, 25.1241, 23.2734]

## Check The Results

In [14]:
plot_network(net, DOPAMINE_ARRAY, inputs.signal_orders, TOTAL_DURATION//ITER_NO, REST, ITER_NO)

## Reference

### E. M. Izhikevich, “Solving the Distal Reward Problem through Link-age of STDP and Dopamine Signaling,” Cerebral Cortex, vol.17,pp.2443–2452, 01 2007