Todo:
Implement nondimensional

Implementing the multi-layer network (where spikes in the 1st layer induce waves in the subsequent layer) on nengo

Once done with the multi-layer networks, start working on implementing the learning rules

In [33]:
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy.spatial.distance import pdist, squareform
from IPython.display import clear_output
%matplotlib inline

import nengo
from nengo.params import Parameter, NumberParam, NdarrayParam
from nengo.neurons import settled_firingrate   

class CustomLIF(nengo.neurons.NeuronType):
    """Spiking version of the leaky integrate-and-fire (LIF) neuron model.

    Parameters
    ----------
    tau_rc : float
        Membrane RC time constant, in seconds. Affects how quickly the membrane
        voltage decays to zero in the absence of input (larger = slower decay).
    tau_ref : float
        Absolute refractory period, in seconds. This is how long the
        membrane voltage is held at zero after a spike.
    min_voltage : float
        Minimum value for the membrane voltage. If ``-np.inf``, the voltage
        is never clipped.
    amplitude : float
        Scaling factor on the neuron output. Corresponds to the relative
        amplitude of the output spikes of the neuron.
    num : int
        Number of neurons in the layer.
    S : ndarray
        Intra-layer adjacency matrix.
    """

    probeable = ("spikes", "voltage", "refractory_time", "threshold")

    min_voltage = NumberParam("min_voltage", high=0)
    tau_v = NumberParam("tau_rc", low=0, low_open=True)
    tau_ref = NumberParam("tau_ref", low=0)
    amplitude = NumberParam("amplitude", low=0, low_open=True)
    num = NumberParam("num")
    Ss = NdarrayParam("Ss")  # adjacency matrix
    piT_th = NumberParam("piT_th")
    piTV_plus = NumberParam("piTV_plus")
    piV_th = NumberParam("piV_th")
    piV_reset = NdarrayParam("piV_reset")  # noise on activity field
    nxs = NdarrayParam("nxs")
    Vt = NumberParam("Vt")
    
    def __init__(
        self, 
        Ss, 
        num, 
        tau_v, 
        piT_th,
        piTV_plus,
        piV_th,
        piV_reset,
        nxs,
        Vt,
        tau_ref=0.002, 
        min_voltage=0, 
        amplitude=1,
    ):
        super().__init__()
        self.tau_v = tau_v
        self.tau_ref = tau_ref
        self.amplitude = amplitude
        self.min_voltage = min_voltage
        self.num = num
        self.Ss = Ss
        self.nxs = nxs
        self.piT_th=piT_th
        self.piTV_plus=piTV_plus
        self.piV_th=piV_th
        self.piV_reset=piV_reset
        self.Vt=Vt

    def gain_bias(self, max_rates, intercepts):
        """Analytically determine gain, bias."""
        gain = np.ones((self.num,))
        bias = np.zeros((self.num,))
        return gain, bias
    

    def max_rates_intercepts(self, gain, bias):
        """Compute the inverse of gain_bias."""
        intercepts = (1 - bias) / gain
        max_rates = 1.0 / (
            self.tau_ref - self.tau_v * np.log1p(1.0 / (gain * (intercepts - 1) - 1))
        )
        if not np.all(np.isfinite(max_rates)):
            warnings.warn(
                "Non-finite values detected in `max_rates`; this "
                "probably means that `gain` was too small."
            )
        return max_rates, intercepts


    def rates(self, x, gain, bias):
        """Always use LIFRate to determine rates."""
        J = self.current(x, gain, bias)
        out = np.zeros_like(J)
        # Use LIFRate's step_math explicitly to ensure rate approximation
        LIFRate.step_math(self, dt=1, J=J, output=out)
        return out


    def step_math(self, dt, J, spiked, voltage, refractory_time, threshold):        
        # reduce all refractory times by dt
        refractory_time -= dt
        
        # step voltage
        U = np.matmul(spiked, Ss)
        eta = 3*np.random.rand(self.num,) / self.Vt
        dV = -1*voltage + U + eta
        voltage[:] += dV * dt
#         print('final voltage = {}'.format(voltage))
                
        # step threshold voltage (theta)
        dTh = self.piT_th*(self.piV_th-threshold)*(1-spiked)+self.piTV_plus*spiked
        threshold[:] += dTh * dt
#         print('threshold = {}'.format(threshold))
        
        # determine which neurons spiked (set them to 1/dt, else 0)
        spiked_mask = voltage > threshold
        spiked[:] = spiked_mask * (self.amplitude)
#         print('spiked mask= {}'.format(spiked_mask))
        
        # Visualization of Wave
        plt.scatter(nx[:,0],nx[:,1], color = 'k')
        plt.title('Ret-Wave')
        fired = np.argwhere(spiked)
        plt.scatter(nx[fired,0],nx[fired,1], color = 'r')

        plt.show()
        clear_output(wait=True)
        

        # set spiked voltages to v_reset, refractory times to tau_ref, and
        # rectify negative voltages to a floor of min_voltage
        voltage[voltage < self.min_voltage] = self.min_voltage 
        voltage[spiked_mask] = self.piV_reset[spiked_mask]
        refractory_time[spiked_mask] = self.tau_ref

In [34]:
from nengo.builder.operator import Operator

class SimCustomLIF(Operator):
    """Set a neuron model output for the given input current.

    Implements ``neurons.step_math(dt, J, output, *states)``.

    Parameters
    ----------
    neurons : NeuronType
        The `.NeuronType`, which defines a ``step_math`` function.
    J : Signal
        The input current.
    output : Signal
        The neuron output signal that will be set.
    states : list, optional
        A list of additional neuron state signals set by ``step_math``.
    tag : str, optional
        A label associated with the operator, for debugging purposes.

    Attributes
    ----------
    J : Signal
        The input current.
    neurons : NeuronType
        The `.NeuronType`, which defines a ``step_math`` function.
    output : Signal
        The neuron output signal that will be set.
    states : list
        A list of additional neuron state signals set by ``step_math``.
    tag : str or None
        A label associated with the operator, for debugging purposes.

    Notes
    -----
    1. sets ``[output] + states``
    2. incs ``[]``
    3. reads ``[J]``
    4. updates ``[]``
    """
    
    def __init__(self, neurons, J, output, states=None, tag=None):
        super().__init__(tag=tag)
        self.neurons = neurons

        self.sets = [output] + ([] if states is None else states)
        self.incs = []
        self.reads = [J]
        self.updates = []

    @property
    def J(self):
        return self.reads[0]

    @property
    def output(self):
        return self.sets[0]

    @property
    def states(self):
        return self.sets[1:]

    def _descstr(self):
        return "%s, %s, %s" % (self.neurons, self.J, self.output)

    def make_step(self, signals, dt, rng):
        J = signals[self.J]
        output = signals[self.output]
        states = [signals[state] for state in self.states]

        def step_simcustomlif():
            self.neurons.step_math(dt, J, output, *states)

        return step_simcustomlif

In [35]:
from nengo.builder import Builder
from nengo.builder.operator import Copy
from nengo.builder.signal import Signal
from nengo.rc import rc


@Builder.register(CustomLIF)
def build_customlif(model, neuron_type, neurons):
    """Builds a `.LIF` object into a model.

    In addition to adding a `.SimNeurons` operator, this build function sets up
    signals to track the voltage and refractory times for each neuron.

    Parameters
    ----------
    model : Model
        The model to build into.
    neuron_type : CustomLIF
        Neuron type to build.
    neuron : Neurons
        The neuron population object corresponding to the neuron type.

    Notes
    -----
    Does not modify ``model.params[]`` and can therefore be called
    more than once with the same `.LIF` instance.
    """

    model.sig[neurons]["voltage"] = Signal(
        shape=neurons.size_in, name="%s.voltage" % neurons
    )
    model.sig[neurons]["refractory_time"] = Signal(
        shape=neurons.size_in, name="%s.refractory_time" % neurons
    )
    model.sig[neurons]["threshold"] = Signal(
        shape=neurons.size_in, name= "%s.threshold" % neurons
    )
    model.add_op(
        SimCustomLIF(
            neurons=neuron_type,
            J=model.sig[neurons]["in"],
            output=model.sig[neurons]["out"],
            states=[
                model.sig[neurons]["voltage"],
                model.sig[neurons]["refractory_time"],
                model.sig[neurons]["threshold"],
            ],
        )
    )

In [36]:
from nengo.utils.matplotlib import rasterplot
# import nengo_loihi

# System Constants
num = 1632

v_reset = 0+0.1*(np.random.randn(num,))**2 # Noise on activity field
# Spatial Parameters
sqR = 28
nx = sqR*np.random.rand(num,2)
# Adjacency kernel
ri, ro, lam, ai, ao = 3, 6, 10, 30, 10
D = squareform(pdist(nx))
S = ai * (D < 3) - (ao * (D > 6) * np.exp(-D/lam))
S = S - np.diag(np.diag(S))
# Dynamics parameters
tau_v, tau_th, th_plus, v_th = 1, 30, 9, 1
# Nondimensional parameters
Vt, Tt, Lt = ao, tau_v, sqR
# Nondimensional groups
piTV_plus = th_plus * Tt / Vt
piT_th = Tt / tau_th 

piV_th = v_th / Vt
piV_reset = v_reset / Vt;
piV_ai = ai / Vt

piL_ri = ri / Lt
piL_ro = ro / Lt
piL_lam = lam / Lt

# Non-dimensional Spatial Adjacency matrix
nxs = nx / Lt;
Ds = squareform(pdist(nxs))     # Non-dimensional distance matrix
Ss = piV_ai *(Ds < piL_ri) - (Ds > piL_ro) * np.exp(-Ds / piL_lam); 
Ss = Ss - np.diag(np.diag(Ss));      # Adjacency matrix between Neurons in Retina

model = nengo.Network(label='2D Representation', seed=10)
process = nengo.processes.WhiteNoise(
    dist=nengo.dists.Gaussian(0, .01), seed=1)
with model:
    a = nengo.Ensemble(num, dimensions=2, neuron_type=CustomLIF(Ss, num, tau_v, piT_th, piTV_plus, piV_th,
                                                                piV_reset,nxs, Vt))
    b = nengo.Ensemble(num, dimensions=2, neuron_type=CustomLIF(S, num, tau_v, piT_th, piTV_plus, piV_th,
                                                                piV_reset,nxs, Vt))
#     conn = nengo.Connection(a, b, learning_rule_type=nengo.PES(learning_rate=2e-4))
    spikes_probe = nengo.Probe(a.neurons, 'spikes')
    voltage_probe = nengo.Probe(a.neurons, 'voltage')
    threshold_probe = nengo.Probe(a.neurons, 'threshold')
    b_spikes = nengo.Probe(b.neurons, 'spikes')
    b_voltage = nengo.Probe(b.neurons, 'voltage')
with nengo.Simulator(model) as sim:
    sim.run(.2)
    
# Custom LIF neurons
# plt.figure(figsize=(12, 6))
# plt.plot(sim.trange(), sim.data[voltage_probe])
# plt.xlabel('time [s]')
# plt.ylabel('voltage')

# plt.figure(figsize=(12, 6))
# plt.plot(sim.trange(), sim.data[threshold_probe])
# plt.xlabel('time [s]')
# plt.ylabel('threshold voltage')

# plt.figure(figsize=(12, 6))
# rasterplot(sim.trange(), sim.data[spikes_probe])
# plt.xlabel('time [s]')
# plt.ylabel('Neuron number')

KeyboardInterrupt: 