Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 16 additions & 48 deletions docs/tutorials/neurocog/integration.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
# Numerical Integration

In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers
useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that
play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default
form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more
accurately (often at the cost of additional computational time).
In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more accurately (often at the cost of additional computational time).

## Euler Integration

Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for
the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command).
Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command).

While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish
to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the
context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary
differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is
also simple -- it is $y(t) = e^{t}$.
While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is also simple -- it is $y(t) = e^{t}$.

If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you
can write the following code to examine how Euler integration approximates the analytical solution (in this example,
we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)
If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you can write the following code to examine how Euler integration approximates the analytical solution (in this example, we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)

```python
from jax import numpy as jnp, random, jit, nn
Expand Down Expand Up @@ -82,30 +71,13 @@ which should yield you a plot like the one below:

<img src="../../images/tutorials/neurocog/euler_integration.jpg" width="500" />

Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler
integrator and typically, when constructing your biophysical models, you will need to think about this constant in
the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells,
you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix
your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time
constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$).
Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler integrator and typically, when constructing your biophysical models, you will need to think about this constant in the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells, you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$).

## Higher-Order Forms of (Explicit) Integration

Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a
second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or
an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might
be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more
nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's
in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or
the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2.

To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation
(thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set
of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the
analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code
like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine
(`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods
approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2.

To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation (thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below:

```python
from jax import numpy as jnp, random, jit, nn
Expand Down Expand Up @@ -176,15 +148,11 @@ which should yield you a plot like the one below:

<img src="../../images/tutorials/neurocog/ode_method_comparison.jpg" width="500" />

As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is
held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at
approximating the analytical equation. In the end, the type of numerical integration method employed can matter
depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like
in our example above.

[^1]: The format expected by ngc-learn's backend is that the differential equation provides a functional API/form
like so: for instance `dy/dt = diff_eqn(t, y(t), params)`, representing
$\frac{\partial \mathbf{y}(t, \text{params})}{\partial t}$, noting that you can name your 3-argument function (and
its arguments) anything you like. Your function does not need to use all of the arguments (i.e., `t`, `y`, or
`params`, the last of which is a tuple containing any fixed constants your equation might need) to produce its
output. Finally, this function should only return the value(s) for `dy/dt` (vectors/matrices of values).
As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at approximating the analytical equation. In the end, the type of numerical integration method employed can matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like in our example above.

[^1]: The format expected by ngc-learn's backend is that the differential equation
provides a functional API/form like so: for instance `dy/dt = diff_eqn(t, y(t), params)`,
representing $\frac{\partial \mathbf{y}(t, \text{params})}{\partial t}$,
noting that you can name your 3-argument function (and its arguments) anything you like.
Your function does not need to use all of the arguments (i.e., `t`, `y`, or `params`, the last of which is a tuple containing any fixed constants your equation might need) to produce its output.
Finally, this function should only return the value(s) for `dy/dt` (vectors/matrices of values).
44 changes: 30 additions & 14 deletions ngclearn/components/neurons/graded/leakyNoiseCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,23 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell

The specific differential equation that characterizes this cell is (for adjusting x) is:

| tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_rec)^2) * eps
| tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_pre)^2) * eps; and,
| r = f(x) + (eps * sigma_post).
| where j_in is the set of incoming input signals
| and j_rec is the set of recurrent input signals
| and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1)
| and f(x) is the rectification function
| and sigma_pre is the pre-rectification noise applied to membrane x
| and sigma_post is the post-rectification noise applied to rates f(x)

| --- Cell Input Compartments: ---
| j_input - input (bottom-up) electrical/stimulus current (takes in external signals)
| j_recurrent - recurrent electrical/stimulus pressure
| --- Cell State Compartments ---
| x - noisy rate activity / current value of state
| --- Cell Output Compartments: ---
| r - post-rectified activity, i.e., fx(x) = relu(x)
| r - post-rectified activity, e.g., fx(x) = relu(x)
| r_prime - post-rectified temporal derivative, e.g., dfx(x) = d_relu(x)

Args:
name: the string name of this cell
Expand All @@ -53,19 +58,23 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
:Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of
the cell's evolution at an increase in computational cost (and simulation time)

sigma_rec: noise scaling factor / standard deviation (Default: 1)
sigma_pre: pre-rectification noise scaling factor / standard deviation (Default: 0.1)

sigma_post: post-rectification noise scaling factor / standard deviation (Default: 0.)

leak_scale: degree to which membrane leak should be scaled (Default: 1)
"""

# Define Functions
def __init__(
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1.,
leak_scale=1., shape=None, **kwargs
self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1,
sigma_post=0.1, leak_scale=1., shape=None, **kwargs
):
super().__init__(name, **kwargs)


self.tau_x = tau_x
self.sigma_rec = sigma_rec ## a "resistance" scaling factor
self.sigma_pre = sigma_pre ## a pre-rectification scaling factor
self.sigma_post = sigma_post ## a post-rectification scaling factor
self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)

## integration properties
Expand All @@ -89,13 +98,17 @@ def __init__(
self.j_input = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
self.j_recurrent = Compartment(restVals, display_name="Recurrent Stimulus Current", units="mA") # electrical current
self.x = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output
self.r = Compartment(restVals, display_name="(Rectified) Rate Activity") # rectified output
self.r_prime = Compartment(restVals, display_name="Derivative of rate activity")

@compilable
def advance_state(self, t, dt):
### run a step of integration over neuronal dynamics
## run a step of integration over neuronal dynamics
### Note: self.fx is the "rectifier" (rectification function)
key, skey = random.split(self.key.get(), 2)
eps_pre = random.normal(skey, shape=self.x.get().shape) ## pre-rectifier distributional noise
key, skey = random.split(self.key.get(), 2)
eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise
eps_post = random.normal(skey, shape=self.x.get().shape) ## post-rectifier distributional noise

#x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
_step_fns = {
Expand All @@ -104,14 +117,16 @@ def advance_state(self, t, dt):
2: step_rk4,
}
_step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler)
params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale)
params = (self.j_input.get(), self.j_recurrent.get(), eps_pre, self.tau_x, self.sigma_pre, self.leak_scale)
_, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics
r = self.fx(x) ## calculate rectified / post-activation function value(s)
r = self.fx(x) + (eps_post * self.sigma_post) ## calculate (rectified) activity rates; f(x)
r_prime = self.dfx(x) ## calculate local deriv of activity rates; f'(x)

## set compartments to next state values in accordance with dynamics
self.key.set(key)
self.key.set(key) ## carry noise key over transition (to next state of component)
self.x.set(x)
self.r.set(r)
self.r_prime.set(r_prime)

@compilable
def reset(self):
Expand All @@ -123,6 +138,7 @@ def reset(self):
self.j_recurrent.set(restVals)
self.x.set(restVals)
self.r.set(restVals)
self.r_prime.set(restVals)

@classmethod
def help(cls): ## component help function
Expand All @@ -142,7 +158,7 @@ def help(cls): ## component help function
"n_units": "Number of neuronal cells to model in this layer",
"batch_size": "Batch size dimension of this component",
"tau_x": "State time constant",
"sigma_rec": "The non-zero degree/scale of noise to inject into this neuron"
"sigma_pre": "The non-zero degree/scale of (pre-rectification) noise to inject into this neuron"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
Expand Down
2 changes: 1 addition & 1 deletion ngclearn/components/neurons/graded/rateCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def advance_state(self, dt):
## self.pressure <-- "top-down" expectation / contextual pressure
## self.current <-- "bottom-up" data-dependent signal
dfx_val = self.dfx(z)
j = _modulate(j, dfx_val)
j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics)
j = j * self.resist_scale
tmp_z = _run_cell(
dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,
Expand Down
12 changes: 10 additions & 2 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,20 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
p_conn: probability of a connection existing (default: 1.); setting
this to < 1 and > 0. will result in a sparser synaptic structure
(lower values yield sparse structure)

mask: if non-None, a (multiplicative) mask is applied to this synaptic weight matrix
"""

def __init__(
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., mask=None, batch_size=1,
**kwargs
):
super().__init__(name, **kwargs)

self.batch_size = batch_size
self.mask = 1.
if mask is not None:
self.mask = mask

## Synapse meta-parameters
self.shape = shape
Expand Down Expand Up @@ -79,7 +85,9 @@ def __init__(

@compilable
def advance_state(self):
self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get())
weights = self.weights.get()
weights = weights * self.mask
self.outputs.set((jnp.matmul(self.inputs.get(), weights) * self.resist_scale) + self.biases.get())

@compilable
def reset(self):
Expand Down
Loading
Loading