diff --git a/docs/tutorials/neurocog/integration.md b/docs/tutorials/neurocog/integration.md
index 7e761208..95320c6f 100644
--- a/docs/tutorials/neurocog/integration.md
+++ b/docs/tutorials/neurocog/integration.md
@@ -1,25 +1,14 @@
# Numerical Integration
-In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers
-useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that
-play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default
-form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more
-accurately (often at the cost of additional computational time).
+In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more accurately (often at the cost of additional computational time).
## Euler Integration
-Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for
-the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command).
+Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command).
-While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish
-to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the
-context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary
-differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is
-also simple -- it is $y(t) = e^{t}$.
+While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is also simple -- it is $y(t) = e^{t}$.
-If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you
-can write the following code to examine how Euler integration approximates the analytical solution (in this example,
-we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)
+If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you can write the following code to examine how Euler integration approximates the analytical solution (in this example, we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)
```python
from jax import numpy as jnp, random, jit, nn
@@ -82,30 +71,13 @@ which should yield you a plot like the one below:
-Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler
-integrator and typically, when constructing your biophysical models, you will need to think about this constant in
-the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells,
-you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix
-your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time
-constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$).
+Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler integrator and typically, when constructing your biophysical models, you will need to think about this constant in the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells, you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$).
## Higher-Order Forms of (Explicit) Integration
-Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a
-second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or
-an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might
-be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more
-nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's
-in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or
-the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2.
-
-To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation
-(thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set
-of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the
-analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code
-like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine
-(`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods
-approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
+Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2.
+
+To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation (thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
```python
from jax import numpy as jnp, random, jit, nn
@@ -176,15 +148,11 @@ which should yield you a plot like the one below:
-As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is
-held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at
-approximating the analytical equation. In the end, the type of numerical integration method employed can matter
-depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like
-in our example above.
-
-[^1]: The format expected by ngc-learn's backend is that the differential equation provides a functional API/form
-like so: for instance `dy/dt = diff_eqn(t, y(t), params)`, representing
-$\frac{\partial \mathbf{y}(t, \text{params})}{\partial t}$, noting that you can name your 3-argument function (and
-its arguments) anything you like. Your function does not need to use all of the arguments (i.e., `t`, `y`, or
-`params`, the last of which is a tuple containing any fixed constants your equation might need) to produce its
-output. Finally, this function should only return the value(s) for `dy/dt` (vectors/matrices of values).
+As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at approximating the analytical equation. In the end, the type of numerical integration method employed can matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like in our example above.
+
+[^1]: The format expected by ngc-learn's backend is that the differential equation
+ provides a functional API/form like so: for instance `dy/dt = diff_eqn(t, y(t), params)`,
+ representing $\frac{\partial \mathbf{y}(t, \text{params})}{\partial t}$,
+ noting that you can name your 3-argument function (and its arguments) anything you like.
+ Your function does not need to use all of the arguments (i.e., `t`, `y`, or `params`, the last of which is a tuple containing any fixed constants your equation might need) to produce its output.
+ Finally, this function should only return the value(s) for `dy/dt` (vectors/matrices of values).
diff --git a/ngclearn/components/neurons/graded/leakyNoiseCell.py b/ngclearn/components/neurons/graded/leakyNoiseCell.py
index 85c4cd03..9ccf4f8e 100755
--- a/ngclearn/components/neurons/graded/leakyNoiseCell.py
+++ b/ngclearn/components/neurons/graded/leakyNoiseCell.py
@@ -22,10 +22,14 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
The specific differential equation that characterizes this cell is (for adjusting x) is:
- | tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_rec)^2) * eps
+ | tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_pre)^2) * eps; and,
+ | r = f(x) + (eps * sigma_post).
| where j_in is the set of incoming input signals
| and j_rec is the set of recurrent input signals
| and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1)
+ | and f(x) is the rectification function
+ | and sigma_pre is the pre-rectification noise applied to membrane x
+ | and sigma_post is the post-rectification noise applied to rates f(x)
| --- Cell Input Compartments: ---
| j_input - input (bottom-up) electrical/stimulus current (takes in external signals)
@@ -33,7 +37,8 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
| --- Cell State Compartments ---
| x - noisy rate activity / current value of state
| --- Cell Output Compartments: ---
- | r - post-rectified activity, i.e., fx(x) = relu(x)
+ | r - post-rectified activity, e.g., fx(x) = relu(x)
+ | r_prime - post-rectified temporal derivative, e.g., dfx(x) = d_relu(x)
Args:
name: the string name of this cell
@@ -53,19 +58,23 @@ class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
:Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of
the cell's evolution at an increase in computational cost (and simulation time)
- sigma_rec: noise scaling factor / standard deviation (Default: 1)
+ sigma_pre: pre-rectification noise scaling factor / standard deviation (Default: 0.1)
+
+ sigma_post: post-rectification noise scaling factor / standard deviation (Default: 0.)
+
+ leak_scale: degree to which membrane leak should be scaled (Default: 1)
"""
- # Define Functions
def __init__(
- self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1.,
- leak_scale=1., shape=None, **kwargs
+ self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_pre=0.1,
+ sigma_post=0.1, leak_scale=1., shape=None, **kwargs
):
super().__init__(name, **kwargs)
self.tau_x = tau_x
- self.sigma_rec = sigma_rec ## a "resistance" scaling factor
+ self.sigma_pre = sigma_pre ## a pre-rectification scaling factor
+ self.sigma_post = sigma_post ## a post-rectification scaling factor
self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
## integration properties
@@ -89,13 +98,17 @@ def __init__(
self.j_input = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
self.j_recurrent = Compartment(restVals, display_name="Recurrent Stimulus Current", units="mA") # electrical current
self.x = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
- self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output
+ self.r = Compartment(restVals, display_name="(Rectified) Rate Activity") # rectified output
+ self.r_prime = Compartment(restVals, display_name="Derivative of rate activity")
@compilable
def advance_state(self, t, dt):
- ### run a step of integration over neuronal dynamics
+ ## run a step of integration over neuronal dynamics
+ ### Note: self.fx is the "rectifier" (rectification function)
+ key, skey = random.split(self.key.get(), 2)
+ eps_pre = random.normal(skey, shape=self.x.get().shape) ## pre-rectifier distributional noise
key, skey = random.split(self.key.get(), 2)
- eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise
+ eps_post = random.normal(skey, shape=self.x.get().shape) ## post-rectifier distributional noise
#x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
_step_fns = {
@@ -104,14 +117,16 @@ def advance_state(self, t, dt):
2: step_rk4,
}
_step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler)
- params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale)
+ params = (self.j_input.get(), self.j_recurrent.get(), eps_pre, self.tau_x, self.sigma_pre, self.leak_scale)
_, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics
- r = self.fx(x) ## calculate rectified / post-activation function value(s)
+ r = self.fx(x) + (eps_post * self.sigma_post) ## calculate (rectified) activity rates; f(x)
+ r_prime = self.dfx(x) ## calculate local deriv of activity rates; f'(x)
## set compartments to next state values in accordance with dynamics
- self.key.set(key)
+ self.key.set(key) ## carry noise key over transition (to next state of component)
self.x.set(x)
self.r.set(r)
+ self.r_prime.set(r_prime)
@compilable
def reset(self):
@@ -123,6 +138,7 @@ def reset(self):
self.j_recurrent.set(restVals)
self.x.set(restVals)
self.r.set(restVals)
+ self.r_prime.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -142,7 +158,7 @@ def help(cls): ## component help function
"n_units": "Number of neuronal cells to model in this layer",
"batch_size": "Batch size dimension of this component",
"tau_x": "State time constant",
- "sigma_rec": "The non-zero degree/scale of noise to inject into this neuron"
+ "sigma_pre": "The non-zero degree/scale of (pre-rectification) noise to inject into this neuron"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py
index f70b0f52..3cf50a22 100755
--- a/ngclearn/components/neurons/graded/rateCell.py
+++ b/ngclearn/components/neurons/graded/rateCell.py
@@ -226,7 +226,7 @@ def advance_state(self, dt):
## self.pressure <-- "top-down" expectation / contextual pressure
## self.current <-- "bottom-up" data-dependent signal
dfx_val = self.dfx(z)
- j = _modulate(j, dfx_val)
+ j = _modulate(j, dfx_val) ## TODO: make this optional (for NGC circuit dynamics)
j = j * self.resist_scale
tmp_z = _run_cell(
dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,
diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py
index 977f2464..d7980c4c 100755
--- a/ngclearn/components/synapses/denseSynapse.py
+++ b/ngclearn/components/synapses/denseSynapse.py
@@ -36,14 +36,20 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
p_conn: probability of a connection existing (default: 1.); setting
this to < 1 and > 0. will result in a sparser synaptic structure
(lower values yield sparse structure)
+
+ mask: if non-None, a (multiplicative) mask is applied to this synaptic weight matrix
"""
def __init__(
- self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
+ self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., mask=None, batch_size=1,
+ **kwargs
):
super().__init__(name, **kwargs)
self.batch_size = batch_size
+ self.mask = 1.
+ if mask is not None:
+ self.mask = mask
## Synapse meta-parameters
self.shape = shape
@@ -79,7 +85,9 @@ def __init__(
@compilable
def advance_state(self):
- self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get())
+ weights = self.weights.get()
+ weights = weights * self.mask
+ self.outputs.set((jnp.matmul(self.inputs.get(), weights) * self.resist_scale) + self.biases.get())
@compilable
def reset(self):
diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py
index f0814443..1f6c9a07 100644
--- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py
+++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py
@@ -86,7 +86,7 @@ def _enforce_constraints(W, w_bound, is_nonnegative=True):
"""
_W = W
if w_bound > 0.:
- if is_nonnegative == True:
+ if is_nonnegative:
_W = jnp.clip(_W, 0., w_bound)
else:
_W = jnp.clip(_W, -w_bound, w_bound)
@@ -173,7 +173,10 @@ def __init__(
prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1.,
p_conn=1., resist_scale=1., batch_size=1, **kwargs
):
- super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs)
+ super().__init__(
+ name, shape=shape, weight_init=weight_init, bias_init=bias_init, resist_scale=resist_scale, p_conn=p_conn,
+ batch_size=batch_size, **kwargs
+ )
if w_decay > 0.:
prior = ('l2', w_decay)
@@ -243,19 +246,20 @@ def calc_update(self):
post = self.post.get()
weights = self.weights.get()
biases = self.biases.get()
- opt_params = self.opt_params.get()
+ #opt_params = self.opt_params.get()
## calculate synaptic update values
dWeights, dBiases = HebbianSynapse._compute_update(
- self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
- pre, post, weights
+ self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght,
+ self.post_wght, pre, post, weights
)
self.dWeights.set(dWeights)
self.dBiases.set(dBiases)
+ #self.opt_params.set(opt_params)
@compilable
- def evolve(self):
+ def evolve(self, dt):
# Get the variables
pre = self.pre.get()
post = self.post.get()
@@ -268,6 +272,7 @@ def evolve(self):
self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
pre, post, weights
)
+
## conduct a step of optimization - get newly evolved synaptic weight value matrix
if self.bias_init != None:
opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases])
diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py
index e5a61eb4..f91dda5b 100755
--- a/ngclearn/utils/metric_utils.py
+++ b/ngclearn/utils/metric_utils.py
@@ -308,7 +308,7 @@ def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False):
nll = jnp.mean(nll)
return nll #tf.reduce_mean(nll)
-@jit
+@partial(jit, static_argnums=[2])
def measure_RMSE(mu, x, preserve_batch=False):
"""
Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each
@@ -328,7 +328,7 @@ def measure_RMSE(mu, x, preserve_batch=False):
mse = measure_MSE(mu, x, preserve_batch=preserve_batch)
return jnp.sqrt(mse) ## sqrt(MSE) is the root-mean-squared-error
-@jit
+@partial(jit, static_argnums=[2])
def measure_MSE(mu, x, preserve_batch=False):
"""
Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch
@@ -352,7 +352,7 @@ def measure_MSE(mu, x, preserve_batch=False):
mse = jnp.mean(mse) # this is proper mse
return mse
-@jit
+@partial(jit, static_argnums=[2])
def measure_MAE(shift, x, preserve_batch=False):
"""
Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch
@@ -376,7 +376,7 @@ def measure_MAE(shift, x, preserve_batch=False):
mae = jnp.mean(mae) # this is proper mae
return mae
-@jit
+@partial(jit, static_argnums=[3])
def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
"""
Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved,