diff --git a/CITATION.cff b/CITATION.cff
index e7243de9..ee939cc7 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -6,10 +6,12 @@ authors:
orcid: https://orcid.org/0000-0002-2590-1310
- family-names: Gebhardt
given-names: William
+ orcid: https://orcid.org/0009-0008-7456-6556
- family-names: Mali
given-names: Ankur
+ orcid: https://orcid.org/0000-0001-5813-3584
title: "ngc-learn"
-version: 1.0.0
+version: 3.0.0
identifiers:
- type: doi
value: 10.5281/zenodo.6605728
diff --git a/README.md b/README.md
index 26d37e56..accc2543 100644
--- a/README.md
+++ b/README.md
@@ -32,7 +32,7 @@ ngc-learn requires:
1) Python (>=3.10)
2) NumPy (>=1.22.0)
3) SciPy (>=1.7.0)
-4) ngcsimlib (>=1.0.1), (visit official page here)
+4) ngcsimlib (>=3.0.0), (visit official page here)
5) JAX (>=0.4.28) (to enable GPU use, make sure to install one of the CUDA variants)
---
-ngc-learn 2.0.3 and later require Python 3.10 or newer as well as ngcsimlib >=1.0.1.
+ngc-learn 3.0.0 and later require Python 3.10 or newer as well as ngcsimlib >=3.0.0.
ngc-learn's plotting capabilities (routines within `ngclearn.utils.viz`) require
Matplotlib (>=3.8.0) and imageio (>=2.31.5) and both plotting and density estimation
tools (routines within ``ngclearn.utils.density``) will require Scikit-learn (>=0.24.2).
@@ -75,7 +75,7 @@ Python 3.11.4 (main, MONTH DAY YEAR, TIME) [GCC XX.X.X] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import ngclearn
>>> ngclearn.__version__
-'2.0.3'
+'3.0.0'
```
Note: For access to the previous Tensorflow-2 version of ngc-learn (of
@@ -122,7 +122,7 @@ $ python install -e .
**Version:**
-2.0.3
+3.0.0
Author:
Alexander G. Ororbia II
diff --git a/docs/images/museum/harmonium/rbm_arch.jpg b/docs/images/museum/harmonium/rbm_arch.jpg
new file mode 100755
index 00000000..9da2f79c
Binary files /dev/null and b/docs/images/museum/harmonium/rbm_arch.jpg differ
diff --git a/docs/images/museum/harmonium/rbm_recon.jpg b/docs/images/museum/harmonium/rbm_recon.jpg
new file mode 100644
index 00000000..4bbce883
Binary files /dev/null and b/docs/images/museum/harmonium/rbm_recon.jpg differ
diff --git a/docs/images/museum/harmonium/receptive_fields.jpg b/docs/images/museum/harmonium/receptive_fields.jpg
new file mode 100644
index 00000000..e0e5465a
Binary files /dev/null and b/docs/images/museum/harmonium/receptive_fields.jpg differ
diff --git a/docs/images/museum/harmonium/samples_0.jpg b/docs/images/museum/harmonium/samples_0.jpg
new file mode 100755
index 00000000..0055fd0b
Binary files /dev/null and b/docs/images/museum/harmonium/samples_0.jpg differ
diff --git a/docs/images/museum/harmonium/samples_1.jpg b/docs/images/museum/harmonium/samples_1.jpg
new file mode 100755
index 00000000..f9310647
Binary files /dev/null and b/docs/images/museum/harmonium/samples_1.jpg differ
diff --git a/docs/images/museum/harmonium/samples_2.jpg b/docs/images/museum/harmonium/samples_2.jpg
new file mode 100755
index 00000000..0f132374
Binary files /dev/null and b/docs/images/museum/harmonium/samples_2.jpg differ
diff --git a/docs/images/tutorials/neurocog/gmm_fit.jpg b/docs/images/tutorials/neurocog/gmm_fit.jpg
new file mode 100644
index 00000000..64a51715
Binary files /dev/null and b/docs/images/tutorials/neurocog/gmm_fit.jpg differ
diff --git a/docs/images/tutorials/neurocog/gmm_samples.jpg b/docs/images/tutorials/neurocog/gmm_samples.jpg
new file mode 100644
index 00000000..3264dbde
Binary files /dev/null and b/docs/images/tutorials/neurocog/gmm_samples.jpg differ
diff --git a/docs/index.rst b/docs/index.rst
index 969753d7..1b710677 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -5,10 +5,7 @@
Welcome to ngc-learn's documentation!
=====================================
-**ngc-learn** is a Python library for building, simulating, and analyzing
-biomimetic computational models, arbitrary predictive processing/coding models,
-and spiking neural networks. This toolkit is built on top of
-`JAX `_ and is distributed under the 3-Clause BSD license.
+**ngc-learn** is a Python library for building, simulating, and analyzing biomimetic and NeuroAI computational models, arbitrary predictive processing/coding models, spiking neural networks, and general dynamical systems. This toolkit is built on top of `JAX `_ and is distributed under the 3-Clause BSD license.
.. toctree::
:maxdepth: 1
@@ -23,6 +20,7 @@ and spiking neural networks. This toolkit is built on top of
tutorials/intro
tutorials/theory
+ tutorials/configuration/index
tutorials/index
tutorials/neurocog/index
@@ -52,9 +50,10 @@ and spiking neural networks. This toolkit is built on top of
.. toctree::
:maxdepth: 1
- :caption: Papers that use NGC-Learn
+ :caption: NGC-Learn Papers & Media
ngclearn_papers
+ ngclearn_talks
Indices and tables
==================
diff --git a/docs/installation.md b/docs/installation.md
index 64bcc5c1..01474aa1 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,65 +1,41 @@
# Installation
-**ngc-learn** officially supports Linux on Python 3. It can be run with or
-without a GPU.
+**ngc-learn** officially supports Linux on Python 3. It can be run with or without a GPU.
-Setup: ngc-learn,
-in its entirety (including its supporting utilities),
-requires that you ensure that you have installed the following base dependencies in
-your system. Note that this library was developed and tested on Ubuntu 22.04 (and earlier versions on 18.04/20.04).
-Specifically, ngc-learn requires:
+Setup: NGC-Learn, in its entirety (including its supporting utility sub-packages), requires that you ensure that you have installed the following base dependencies in your system. Note that this library was developed and tested on Ubuntu 22.04 (with much earlier versions on Ubuntu 18.04/20.04).
+Specifically, NGC-Learn requires:
* Python (>=3.10)
-* ngcsimlib (>=1.0.1), (official page)
+* ngcsimlib (>=3.0.0), (official page)
* NumPy (>=1.22.0)
* SciPy (>=1.7.0)
* JAX (>= 0.4.28; and jaxlib>=0.4.28)
* Matplotlib (>=3.8.0), (for `ngclearn.utils.viz`)
* Scikit-learn (>=1.6.1), (for `ngclearn.utils.patch_utils` and `ngclearn.utils.density`)
-Note that the above requirements are taken care of if one installs ngc-learn
-through either `pip`. One can either install the CPU version of ngc-learn (if no JAX is
-pre-installed or only the CPU version of JAX is installed currently) via
+Note that the above requirements are taken care of if one installs NGC-Learn through either `pip`. One can either install the CPU version of NGC-Learn (if no JAX is pre-installed or only the CPU version of JAX is currently installed) via:
```console
$ pip install ngclearn
```
-or install the GPU version of ngc-learn by first installing the
-CUDA 12
-version of JAX before running the above pip command.
+or install the GPU version of NGC-Learn by first installing the CUDA 12 version of JAX before running the above pip command.
-Alternatively, one may locally, step-by-step (see below), install and setup
-ngc-learn from source after pulling from the repo.
+Alternatively, one may locally, step-by-step (see below), install and setup NGC-Learn from source after pulling from the repo.
-Note that installing the official pip package without any form of JAX installed
-on your system will default to downloading the CPU version of ngc-learn (see
-below for installing the GPU version).
+Note that installing the official pip package without any form of JAX installed on your system will default to downloading the CPU version of NGC-Learn (see below for installing the GPU version).
## Install from Source
-0. Install ngc-sim-lib first (as an editable install); visit the repo
-https://github.com/NACLab/ngc-sim-lib for details.
+1. Install NGC-Sim-Lib first (as an editable install); visit the repo https://github.com/NACLab/ngc-sim-lib for details.
-1. Clone the ngc-learn repository:
+2. Clone the NGC-Learn repository:
```console
$ git clone https://github.com/NACLab/ngc-learn.git
$ cd ngc-learn
```
-2. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending
- on your system setup. Follow the
- installation instructions
- on the official JAX page to properly install the CUDA 11 or 12 version.
+3. (Optional; only for GPU version) Install JAX for either CUDA 12 , depending on your system setup. Follow the installation instructions on the official JAX page to properly install the CUDA 11 or 12 version.
-
-
-3. Install the ngc-learn package via:
+4. Install the NGC-Learn package via:
```console
$ pip install .
```
@@ -68,22 +44,21 @@ or, to install as an editable install for development, run:
$ pip install -e .
```
-If the installation was successful, you should see the following if you test
-it against your Python interpreter, i.e., run the $ python command
-and complete the following sequence of steps as depicted in the screenshot below:
-
+If the installation was successful, you should see the following if you test it against your Python interpreter, i.e., run the $ python command and complete the following sequence of steps as depicted in the screenshot below:
```console
Python 3.11.4 (main, MONTH DAY YEAR, TIME) [GCC XX.X.X] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import ngclearn
>>> ngclearn.__version__
-'2.0.3'
+'3.0.0'
```
+
diff --git a/docs/modeling/neurons.md b/docs/modeling/neurons.md
index 4babf8f7..36f7c2b1 100644
--- a/docs/modeling/neurons.md
+++ b/docs/modeling/neurons.md
@@ -86,6 +86,22 @@ and `dmu` is the first derivative with respect to the mean parameter.
:noindex:
```
+#### Bernoulli Error Cell
+
+This cell is (currently) fixed to be a (factorized) multivariate Bernoulli cell.
+Concretely, this cell implements compartments/mechanics to facilitate Bernoulli
+log likelihood error calculations.
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.BernoulliErrorCell
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
## Spiking Neurons
These neuronal cells exhibit dynamics that involve emission of discrete action
@@ -117,10 +133,42 @@ negative pressure on the membrane potential values at `t`).
:noindex:
```
+### The IF (Integrate-and-Fire) Cell
+
+This cell (the simple "integrator") models dynamics over the voltage `v`. Note that `thr` is used as the membrane potential threshold and no adaptive threshold mechanics are implemented for this cell model.
+(This cell is primarily a faster, convenience formulation that omits the leak element of the LIF.)
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.IFCell
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
+### The Winner-Take-All (WTAS) Cell
+
+This cell models dynamics over the voltage `v` as a simple instantaneous
+softmax function of the electrical current input, where only a single
+spike, which wins the competition across the group of neuronal units
+within this component, emits a pulse/spike.
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.WTASCell
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
### The LIF (Leaky Integrate-and-Fire) Cell
This cell (the "leaky integrator") models dynamics over the voltage `v`
-and threshold shift `thrTheta` (a homeostatic variable). Note that `thr`
+and threshold shift `thr_theta` (a homeostatic variable). Note that `thr`
is used as a baseline level for the membrane potential threshold while
`thrTheta` is treated as a form of short-term plasticity (full
threshold is: `thr + thrTheta(t)`).
diff --git a/docs/modeling/synapses.md b/docs/modeling/synapses.md
index 470446e9..6b881f0c 100644
--- a/docs/modeling/synapses.md
+++ b/docs/modeling/synapses.md
@@ -1,17 +1,7 @@
# Synapses
-The synapse is a key building block for connecting/wiring together the various
-component cells that one would use for characterizing a biomimetic neural system.
-These particular objects are meant to perform, per simulated time step, a
-specific type of transformation -- such as a linear transform or a
-convolution -- utilizing their underlying synaptic parameters.
-Most times, a synaptic cable will be represented by a set of matrices (or filters)
-that are used to conduct a projection of an input signal (a value presented to a
-pre-synaptic/input compartment) resulting in an output signal (a value that
-appears within one of its post-synaptic compartments). Notably, a synapse component is
-typically associated with a local plasticity rule, e.g., a Hebbian-type
-update, that either is triggered online, i.e., at some or all simulation time
-steps, or by integrating a differential equation, e.g., via eligibility traces.
+The synapse is a key building block for connecting/wiring together the various component cells that one would use for characterizing a biomimetic neural system. These particular objects are meant to perform, per simulated time step, a specific type of transformation -- such as a linear transform or a convolution -- utilizing their underlying synaptic parameters. Most times, a synaptic cable will be represented by a set of matrices (or filters) that are used to conduct a projection of an input signal (a value presented to a pre-synaptic/input compartment) resulting in an output signal (a value that appears within one of its post-synaptic compartments). There are three general groupings of synaptic components in ngc-learn: 1) non-plastic static synapses (only perform fixed transformations of input signals); 2) non-plastic dynamic synapses (perform time-varying, input-dependent transformations on input signals); and 3) plastic synapses that carry out long-term evolution.
+Notably, plastic synapse components are typically associated with a local plasticity rule, e.g., a Hebbian-type update, that either is triggered online, i.e., at some or all simulation time steps, or by integrating a differential equation, e.g., via eligibility traces.
## Non-Plastic Synapse Types
@@ -74,6 +64,20 @@ This (chemical) synapse performs a linear transform of its input signals. Note t
:noindex:
```
+### Double-Exponential Synapse
+
+This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a doubleexpoential / difference of two exponentials kernel.
+
+```{eval-rst}
+.. autoclass:: ngclearn.components.DoubleExpSynapse
+ :noindex:
+
+ .. automethod:: advance_state
+ :noindex:
+ .. automethod:: reset
+ :noindex:
+```
+
### Alpha Synapse
This (chemical) synapse performs a linear transform of its input signals. Note that this synapse is "dynamic" in the sense that its efficacies are a function of their pre-synaptic inputs; there is no inherent form of long-term plasticity in this base implementation. Synaptic strength values can be viewed as being filtered/smoothened through a kernel that models more realistic rise and fall times of synaptic conductance..
diff --git a/docs/museum/event_stdp_patches.md b/docs/museum/event_stdp_patches.md
new file mode 100644
index 00000000..f0759e27
--- /dev/null
+++ b/docs/museum/event_stdp_patches.md
@@ -0,0 +1,13 @@
+# Event-based Spike-Timing-Dependent Plasticity (Tavanaei et al.; 2018)
+
+In this exhibit, we create, simulate, and visualize the internally acquired receptive fields of the spiking neural
+network (SNN) trained via event-based spike-timing-dependent plasticity (EV-STDP) over image patches. This
+reproduces the SNN model originally proposed in (Tavanaei et al., 2018) [1].
+
+The model code for this exhibit can be found
+[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/evstdp_patches).
+
+
+## References
+[1] Tavanaei, Amirhossein, Timothée Masquelier, and Anthony Maida. "Representation learning using event-based
+STDP." Neural Networks 105 (2018): 294-303.
\ No newline at end of file
diff --git a/docs/museum/harmonium.md b/docs/museum/harmonium.md
new file mode 100644
index 00000000..a197d409
--- /dev/null
+++ b/docs/museum/harmonium.md
@@ -0,0 +1,361 @@
+# Harmoniums and Contrastive Divergence (Hinton; 1999)
+
+In NGC-Learn, it is possible to construct other forms of learning from the very base learning/plasticity components
+already in-built into the base library. Notably, a class of learning and inference systems that adapt through a process
+known as contrastive Hebbian learning (CHL) can be constructed and simulated with ngc-learn.
+
+In this walkthrough, we will design a simple Harmonium, also known as the restricted Boltzmann machine (RBM). We will
+specifically focus on learning its synaptic connections with an algorithmic recipe known as contrastive divergence (CD),
+which can be considered to be a stochastic form of CHL. After going through this exhibit, you will:
+
+1. Learn how to construct an `NGCGraph` that emulates the structure of an RBM and adapt the NGC settling process to
+ calculate approximate synaptic weight gradients in accordance to contrastive divergence.
+2. Simulate fantasized image samples using the block Gibbs sampler implicitly defined by the negative phase graph.
+
+Note that the folders of interest to this walkthrough are:
++ `ngc-museum/exhibits/harmonium/`: this contains the necessary simulation scripts (which can be found
+ [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/harmonium));
++ `ngc-museum/data/mnist/`: this contains the zipped copy of the MNIST digit image arrays
+
+## On the Harmonium Probabilistic Graphical Model
+
+A harmonium is a generative model implemented as a stochastic, two-layer neural system (a type of probabilistic graphic
+model; PGM) that attempts to learn a probability distribution over sensory input $\mathbf{x}$, i.e., the goal of a
+harmonium is to learn $p(\mathbf{x})$, the underlying probability/likelihood of a given (training) dataset.
+Fundamentally, the approach to estimating $p(\mathbf{x})$ that carried out by a harmonium is through the optimization
+of an energy function $E(\mathbf{x})$ (a concept motivated by statistical mechanics), where the system searches for an
+internal configuration, i.e., the values of its synapses, that assigns low energy (values) to sample patterns that come
+from the true data distribution $p(\mathbf{x})$ and high energy (values) to patterns that do not (or those that do not
+come from the training dataset).
+
+```{eval-rst}
+.. table::
+ :align: center
+
+ +-----------------------------------------------------------------+
+ | .. image:: ../images/museum/harmonium/rbm_arch.jpg |
+ | :scale: 65% |
+ | :align: center |
+ +-----------------------------------------------------------------+
+```
+
+The most common, simplest harmonium is one where input nodes (one per dimension of the data observation space) are
+modeled as binary/Boolean sensors -- or "visible units" $\mathbf{z}^0$ (observed variables) that are clamped to actual
+data patterns -- connected to a layer of (stochastic) binary latent feature detectors -- or "hidden units"
+$\mathbf{z}^1$ (unobserved or latent variables). Notably, the synaptic connections between the latent and visible units
+are symmetric. Furthermore, as a result of a key restriction imposed on the harmonium's network structure, i.e., no
+lateral connections between the neurons/units within $\mathbf{z}^0$ as well as those within $\mathbf{z}^1$, computing
+the latent and visible states is as straightforward as the following:
+
+$$
+p(\mathbf{z}^1 | \mathbf{z}^0) &= sigmoid(\mathbf{W} \cdot \mathbf{z}^0 + \mathbf{b}),
+\; \mathbf{z}^1 \sim p(\mathbf{z}^1 | \mathbf{z}^0) \\
+p(\mathbf{z}^0 | \mathbf{z}^1) &= sigmoid(\mathbf{W}^T \cdot \mathbf{z}^1 + \mathbf{c}),
+\; \mathbf{z}^0 \sim p(\mathbf{z}^0 | \mathbf{z}^1)
+$$
+
+where $\mathbf{c}$ is the visible bias vector, $\mathbf{b}$ is the latent bias vector,
+and $\mathbf{W}$ is the synaptic weight matrix that connects $\mathbf{z}^0$ to $\mathbf{z}^1$ (and its transpose
+$\mathbf{W}^T$ is used to make predictions of the input itself). Note that $\cdot$ means matrix/vector multiplication
+and $\sim$ denotes that we would sample from a probability (vector). In the above harmonium's case, samples will be
+drawn treating the conditionals such as $p(\mathbf{z}^1 | \mathbf{z}^0)$ as multivariate Bernoulli distributions.
+$\mathbf{z}^0$ would typically be clamped/set to the actual sensory input data $\mathbf{x}$.
+
+The energy function of the harmonium's joint configuration $(\mathbf{z}^0,\mathbf{z}^1)$ (similar to that of a Hopfield
+network) is specified as follows:
+
+$$
+E(\mathbf{z}^0,\mathbf{z}^1) = -\sum_i \mathbf{c}_i \mathbf{z}^0_i -
+\sum_j \mathbf{b}_j \mathbf{z}^1_j - \sum_i \sum_j \mathbf{z}^0_i \mathbf{W}_{ij} \mathbf{z}^1_j .
+$$
+
+Notice that, in the equation above, we sum over vector dimension indices, e.g., $\mathbf{z}^0_i$ retrieves the $i$th
+scalar element of (vector) $\mathbf{z}^0$ while $\mathbf{W}_{ij}$ retrieves the scalar element at position $(i,j)$
+within matrix $\mathbf{W}$. With this energy function, one can write out the probability that a harmonium PGM assigns
+to a data point as:
+
+$$
+p(\mathbf{z}^0 = \mathbf{x}) = \frac{1}{Z} \exp( -E(\mathbf{z}^0,\mathbf{z}^1) )
+$$
+
+where $Z$ is the normalizing constant (or, in statistical mechanics, the partition function) needed to obtain
+proper probability values[^1].
+When one works through the derivation of the gradient of the log probability $\log p(\mathbf{x})$ with respect to the
+synapses such as $\mathbf{W}$, they get a (contrastive) Hebbian-like update rule as follows:
+
+$$
+\Delta \mathbf{W}_{ij} = <\mathbf{z}^0_i \mathbf{z}^1_j>_{data} - <\mathbf{z}^0_i \mathbf{z}^1_j>_{model}
+$$
+
+where the angle brackets $< >$ tell us that we need to take the expectation of the values within the brackets under a
+certain distribution (such as the data distribution denoted by the subscript $data$). The above rule can also be
+considered to be a stochastic form of a general recipe known as contrastive Hebbian learning (CHL) [4].
+
+Technically, to compute the update above, obtaining the first term
+$<\mathbf{z}^0_i \mathbf{z}^1_j>_{data}$ is easy since we only need to take the product of a data point and its
+corresponding latent state under the harmonium. However, obtaining the second term
+$<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ is very costly, since we would need to
+initialize the value of $\mathbf{z}^0$ to a random initial state and then run a (block) Gibbs sampler for many
+iterations to accurately approximate the second term. Fortunately, it was shown in work such as [3], that learning a
+harmonium is still possible by replacing the term $<\mathbf{z}^0_i \mathbf{z}^1_j>_{model}$ with
+$<\mathbf{z}^0_i \mathbf{z}^1_j>_{recon}$, which is simply computed by using the first term's latent state
+$\mathbf{z}^1$ to reconstruct the input and then using this reconstruction once more in order to obtain its
+corresponding binary latent state. This is known as "contrastive divergence" (CD-1), and, although this approximation
+has been shown to not actual follow the gradient of any known objective function, it works well in practice when
+learning a harmonium-based generative model. Finally, the vectorized form of the CD-1 update is:
+
+$$
+\Delta \mathbf{W} = \Big[ (\mathbf{z}^0_{pos})^T \cdot \mathbf{z}^1_{pos} \Big] - \Big[ (\mathbf{z}^0_{neg})^T \cdot \mathbf{z}^1_{neg} \Big]
+$$
+
+where the first term (in brackets) is labeled as the "positive phase" (or the positive, data-dependent statistics --
+where $\mathbf{z}^0_{pos}$ denotes the positive phase sample of $\mathbf{z}^0$) while the second term is labeled as the
+"negative phase" (or the negative, data-independent statistics -- where $\mathbf{z}^0_{neg}$ denotes the negative phase
+sample of $\mathbf{z}^0$). Note that simpler rules of a similar form can be worked out for the latent/visible bias
+vectors as well.
+
+In NGC-Learn, to simulate the above harmonium PGM and its CD-1 update, we will model the positive and negative phases
+as simulated co-models, each responsible for producing the relevant statistics that we will require in order to adjust
+synapses. Additionally, we will find that we can further re-purpose the created co-models to construct a block Gibbs
+sampler for confabulating "fantasized"
+data patterns from a harmonium that has been fit to data.
+
+
+## Boltzmann Machines: Positive and Negative Co-Models
+
+We begin by first specifying the structure of the harmonium system that we would like to simulate. In NGC shorthand,
+the above positive and negative phase graphs would simply be (under one complete generative model):
+
+```
+z0 -(z0-z1)-> z1
+z1 -(z1-z0) -> z0
+Note: z1-z0 = (z0-z1)^T (transpose-tied synapses)
+```
+
+In order to construct the desired harmonium, particularly the structure needed to implement CD-1, we will need to break
+up the model into its key "phases", i.e., a positive phase and a negative phase. We will model each phase as its own
+simulated nodes-and-cables structure within one single model context, allowing us to craft a general approach that
+permits a CD-based learning. Notably, we will use the negative-phase co-model to emulate the crucial MCMC sampling step
+to synthesize data from the trained RBM.
+
+Building the positive phase of our harmonium can be done as follows:
+
+```python
+with Context("Circuit") as self.circuit:
+ ## set up positive-phase graph
+ self.z0 = BernoulliStochasticCell("z0", n_units=obs_dim, is_stoch=False)
+ self.z1 = BernoulliStochasticCell("z1", n_units=hid_dim, key=subkeys[0])
+
+ self.W1 = HebbianSynapse(
+ "W1", shape=(obs_dim, hid_dim), eta=0., weight_init=dist.gaussian(mean=0., std=sigma),
+ bias_init=dist.constant(value=0.), w_bound=0., optim_type="sgd", sign_value=1., key=subkeys[1]
+ )
+ ## wire up z0 to z1 via synaptic project W1
+ self.z0.s >> self.W1.inputs
+ self.W1.outputs >> self.z1.inputs
+```
+
+To gather the rest of the statistics that we require, we will need to build the negative phase of our model (which is
+responsible for "dreaming up" or "confabulating" data samples from its internal model of the world). Constructing the
+negative-phase co-model, under the same model `Context` above can be done as follows:
+
+```python
+ ## set up negative-phase graph
+ self.z0neg = BernoulliStochasticCell("z0neg", n_units=obs_dim, key=subkeys[3])
+ self.z1neg = BernoulliStochasticCell("z1neg", n_units=hid_dim, key=subkeys[4])
+
+ self.E1 = DenseSynapse( ## E1 = W1.T
+ "E1", shape=(hid_dim, obs_dim), weight_init=dist.gaussian(mean=0., std=sigma),
+ bias_init=dist.constant(value=0.), resist_scale=1., key=subkeys[2]
+ )
+ self.E1.weights.set(self.W1.weights.get().T)
+ self.V1 = HebbianSynapse( ## V1 = W1
+ "V1", shape=(obs_dim, hid_dim), eta=0., weight_init=dist.gaussian(mean=0., std=sigma),
+ bias_init=None, w_bound=0., optim_type="sgd", sign_value=1., key=subkeys[1]
+ )
+ self.V1.weights.set(self.W1.weights.get())
+ self.V1.biases.set(self.W1.biases.get())
+
+ ## wire up z1 to z0(neg) via E1=(W1)^T, and z0(neg) to z1(neg) via V1=W1
+ self.z1.s >> self.E1.inputs
+ self.E1.outputs >> self.z0neg.inputs
+ self.z0neg.p >> self.V1.inputs ## drive hiddens by probs of visibles
+ self.V1.outputs >> self.z1neg.inputs
+```
+
+The above chunk of code effectively sets up the propagation of information from the latent neurons `z1` back down to
+`z0` (obtaining the negative phase values of `z0`, i.e., `z0neg`) and then the propagation of the reconstructed values
+back up to `z1` one last time (obtaining the negative phase values of `z1`, i.e., `z0neg`).
+
+To build a CHL-based form of plasticity, allowing us to build the CD-1 learning process, we will then need to wire up a
+set of 2-factor Hebbian rules like so:
+
+```python
+ ## set up contrastive Hebbian learning rule (pos-stats - neg-stats)
+ self.z0.s >> self.W1.pre ## positive-phase pre-synaptic term
+ self.z1.p >> self.W1.post ## positive-phase post-synaptic term
+ self.z0neg.p >> self.V1.pre ## negative-phase pre-synaptic term
+ self.z1neg.p >> self.V1.post ## negative-phase pre-synaptic term
+```
+
+the results of these two Hebbian rules are then used in an exhibit-specific function (`_update_via_CHL()`) written in
+the [`Harmonium` class](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py).
+While we observe that our "negative phase" co-model allows us to emulate the CD learning recipe[^2], technically, the
+negative phase of a harmonium should be run for a very high value of steps (approaching infinity) in order to obtain a
+proper sample from the PGM's equilibrium/steady state distribution. However, this would be extremely costly to simulate
+and, as early studies [3] observed, often only a few or even a single step of this Markov chain proved to work quite
+well, approximating the contrastive divergence objective (the learning algorithm's namesake) instead of direct
+maximum likelihood.
+
+Note that the full code, containing the snippets above, can be found in the Model Museum `Harmonium` model structure
+class. One could further generalize our CD-1 framework to variations, such as "persistent" CD (where we, instead of
+running `z1` back down through `E1` synapses, we inject random noise instead (to sample the harmonium's latent prior),
+or even an algorithm known as parallel tempering, where we would maintain multiple co-models and draw samples from
+all of them to obtain negative-phase statistics.
+
+Finally, within the `Harmonium` class, we have written a routine for drawing samples from the model directly, i.e., we
+implement a block Gibbs sampler in order synthesize data from the RBM's current set of parameters.
+
+## Using the Harmonium to Dream Up Handwritten Digits
+
+We finally take the harmonium that we have constructed above and fit it to some MNIST digits. Specifically, we will
+leverage the [Harmonium](https://github.com/NACLab/ngc-museum/blob/v3/exhibits/harmonium/harmonium.py), model in the Model Museum since it implements all of the above core mechanisms (and
+more) internally. In the script `sim_harmonium.py`, you will find a general training that will fit our harmonium to
+the MNIST database (unzip the file `mnist.zip` in the `ngc-museum/exhibits/data/` directory if you have not already)
+by cycling through it several times, saving the final
+(best) resulting to disk within the `exp/` sub-directory. Go ahead and execute the training process as follows:
+
+```console
+$ python sim_harmonium.py
+```
+
+which will fit/adapt your harmonium to MNIST. This should produce per-training iteration output, printed to I/O,
+similar to the following:
+
+```console
+--- Initial RBM Synaptic Stats ---
+W1: min -0.0494 ; max 0.0445 mu -0.0000 ; norm 4.4734
+b1: min -4.0000 ; max -4.0000 mu -4.0000 ; norm 64.0000
+c0: min -11.6114 ; max 0.0635 mu -3.8398 ; norm 135.2238
+-1| Test: err(X) = 54.3889
+0| Test: |d.E(X)| = 16.8070 err(X) = 46.8236; Train: err(X) = 52.7418
+1| Test: |d.E(X)| = 27.1183 err(X) = 36.8690; Train: err(X) = 41.3630
+2| Test: |d.E(X)| = 13.7855 err(X) = 31.8582; Train: err(X) = 34.5511
+3| Test: |d.E(X)| = 9.0927 err(X) = 28.6253; Train: err(X) = 30.4615
+4| Test: |d.E(X)| = 5.8375 err(X) = 26.2317; Train: err(X) = 27.6882
+5| Test: |d.E(X)| = 5.3187 err(X) = 24.3207; Train: err(X) = 25.5485
+6| Test: |d.E(X)| = 3.7614 err(X) = 22.8012; Train: err(X) = 23.8361
+7| Test: |d.E(X)| = 2.2589 err(X) = 21.6163; Train: err(X) = 22.4523
+8| Test: |d.E(X)| = 3.2040 err(X) = 20.5934; Train: err(X) = 21.3355
+9| Test: |d.E(X)| = 2.4215 err(X) = 19.7679; Train: err(X) = 20.4297
+10| Test: |d.E(X)| = 1.5725 err(X) = 19.0672; Train: err(X) = 19.6835
+11| Test: |d.E(X)| = 0.5418 err(X) = 18.4881; Train: err(X) = 19.0372
+...
+
+...
+91| Test: |d.E(X)| = 0.4870 err(X) = 11.0443; Train: err(X) = 10.9832
+92| Test: |d.E(X)| = 0.0390 err(X) = 11.0118; Train: err(X) = 10.9820
+93| Test: |d.E(X)| = 0.5127 err(X) = 11.0013; Train: err(X) = 10.9586
+94| Test: |d.E(X)| = 1.9180 err(X) = 10.9874; Train: err(X) = 10.9312
+95| Test: |d.E(X)| = 0.0258 err(X) = 10.9906; Train: err(X) = 10.9274
+96| Test: |d.E(X)| = 0.4760 err(X) = 10.9712; Train: err(X) = 10.8940
+97| Test: |d.E(X)| = 0.6038 err(X) = 10.9589; Train: err(X) = 10.8960
+98| Test: |d.E(X)| = 0.2870 err(X) = 10.9563; Train: err(X) = 10.8727
+99| Test: |d.E(X)| = 1.6622 err(X) = 10.9347; Train: err(X) = 10.8671
+--- Final RBM Synaptic Stats ---
+W1: min -1.8648 ; max 1.3757 mu -0.0012 ; norm 70.6230
+b1: min -7.5815 ; max 0.2337 mu -2.3395 ; norm 53.3993
+c0: min -11.6316 ; max -2.4227 mu -5.3259 ; norm 161.5646
+```
+
+You will find, after the training script has finished executing, several outputs in the `exp/filters/` model
+sub-directory that is created for you. Concretely, you will find a grid-plot of the (first `100` of the) harmonium's
+acquired filters (or "receptive fields"), much as we did for the sparse coding exhibit, that will look similar to
+the following:
+
+
+
+Interestingly enough, we see that our harmonium/RBM has extracted what appears to be rough stroke features, which is
+what it uses when sampling its binary latent feature detectors to compose final synthesized image patterns (each
+binary feature detector serves as Boolean function that emits a decision of `1` if the filter is to be used and a `0`
+if not). In particular, we remark notice that the filters that our harmonium has acquired are a bit more prominent due
+to the fact our exhibit employs some weight decay (specifically, Gaussian/L2 decay -- with intensity
+`l2_lambda=0.01` -- to the `W1` synaptic matrix of our RBM).
+Weight decay of this form is particularly useful to not only mitigate against the harmonium overfitting to its training
+data but also to ensure that the Markov chain inherent to its negative-phase mixes more effectively [5] (which ensures
+better-quality samples from the block Gibbs sampler, which we will use next).
+
+Finally, you will also find in the `exp/filters/` model sub-folder another grid-plot containing some (about `100`) of
+the RBM's reconstructions of held-out development data. This plot should look similar to the one below:
+
+
+
+### Sampling the Harmonium
+
+Once the training process has completed, you can then run the following to sample from trained model using block Gibbs
+sampling:
+
+```console
+$ python sample_harmonium.py
+```
+
+which will take your trained harmonium's negative-phase co-model and use it to synthesize some digit patterns. You
+should see inside the `exp/samples/` sub-directory three sample-image grids (i.e., `samples_0.jpg`, `samples_1.jpg`,
+and `samples_2.jpg`) similar to what is shown below:
+
+```{eval-rst}
+.. image:: ../images/museum/harmonium/samples_0.jpg
+ :width: 30%
+.. image:: ../images/museum/harmonium/samples_1.jpg
+ :width: 30%
+.. image:: ../images/museum/harmonium/samples_2.jpg
+ :width: 30%
+```
+
+Furthermore, you will see three corresponding GIFs that have been generated for you that visualize how each of the
+three simulated sampling Markov chains change with time (i.e., these are the files: `markov_chain_0.gif`,
+`markov_chain_1.gif`, and `markov_chain_2.gif`).
+
+
+
+It is important to understand that the three grids of samples shown above come from particular points in the block
+Gibbs sampling process.
+(Note that one reads these sample grid plots left-column to right-column, and top-row to bottom-row; this way of
+reading the plot follows the ordering of samples extracted from the specific Markov chain sequence.)
+Note that, although each chain is run for many total steps, the `sample_harmonium.py` script "thins" out each Markov
+chain by only pulling out a fantasized pattern every `20` steps (further "burning" in each chain before collecting
+samples). Each chain is merely initialized with random Bernoulli noise. Note that higher-quality samples can be
+obtained if one modifies the earlier harmonium to learn with persistent CD or parallel tempering.
+
+### Final Notes
+
+The harmonium that we have built in this exhibit is a classical Bernoulli harmonium/RBM, which is a neural PGM that
+assumes that the input data features are binary in nature. If one wants to model data that is continuous/real-valued,
+then the harmonium model above would need to be extended to utilize visible units that follow a continuous
+distribution; for instance, if one modeled a multivariate Gaussian distribution, this would yield a Gaussian restricted
+Boltzmann machine (GRBM).
+
+
+## References
+[1] Smolensky, P. "Information Processing in Dynamical Systems: Foundations of Harmony Theory" (Chapter 6). Parallel
+distributed processing: explorations in the microstructure of cognition 1 (1986).
+[2] Hinton, Geoffrey. Products of Experts. International conference on artificial neural networks (1999).
+[3] Hinton, Geoffrey E. "Training products of experts by maximizing contrastive likelihood." Technical Report, Gatsby
+computational neuroscience unit (1999).
+[4] Movellan, Javier R. "Contrastive Hebbian learning in the continuous Hopfield model." Connectionist models. Morgan
+Kaufmann, 1991. 10-17.
+[5] Hinton, Geoffrey E. "A practical guide to training restricted Boltzmann machines." Neural networks: Tricks of the
+trade. Springer, Berlin, Heidelberg, 2012. 599-619.
+
+
+[^1]: In fact, it is intractable to compute the partition function $Z$ for any reasonably-sized harmonium; fortunately,
+we will not need to calculate $Z$ in order to learn and sample from a Harmonium.
+[^2]: In general, CD-1 means contrastive divergence where the negative phase is only run for one single step, i.e.,
+`K=1`. The more general form of CD is known as CD-K, the K-step CD algorithm where `K > 1`. (Sometimes, CD-1 is just
+referred to as just "CD".)
diff --git a/docs/museum/index.rst b/docs/museum/index.rst
index 25f75dfc..a0b58557 100644
--- a/docs/museum/index.rst
+++ b/docs/museum/index.rst
@@ -5,17 +5,26 @@
Model Exhibits
==============
-Models are presented in ngc-learn's model museum in the form of "exhibits",
-which are effectively model-specific walkthroughs and analyses, based on the
-relevant, referenced publicly available ngc-learn simulation code.
+Models are presented in ngc-learn's model museum in the form of "exhibits", which are effectively model-specific
+walkthroughs and analyses, based on the relevant, referenced publicly available ngc-learn simulation code. (Note that
+there are more model exhibits in the actual `museum repository `_ than the number
+of detailed walkthroughs presented in the table of contents below.)
.. toctree::
:maxdepth: 1
- :caption: Neuromimetic Models
+ :caption: Neuroscience Models
- pcn_discrim
sparse_coding
+ pc_rao_ballard1999
snn_dc
+ event_stdp_patches
+ rl_snn
+
+.. toctree::
+ :maxdepth: 1
+ :caption: NeuroAI / Neuro-mimetic Models
+
+ pcn_discrim
snn_bfa
+ harmonium
sindy
- rl_snn
diff --git a/docs/museum/model_museum.md b/docs/museum/model_museum.md
index bd154524..c03cb2d8 100644
--- a/docs/museum/model_museum.md
+++ b/docs/museum/model_museum.md
@@ -1,20 +1,9 @@
# The Model Museum
-There is an ever-growing galaxy of neurobiological models and credit assignment
-processes [1, 2]. One of ngc-learn's aims, in the spirit of scientific
-reproducibility, is to capture a snapshot of as many of these
-biomimetic/neuro-mimetic models as possible, in the form of a digital "museum".
-This museum is further designed with the notion of exhibits and exhibitors,
-aiding to facilitate credit assignment and respectful citation to the ideas and
-the work of those that have helped to lay the foundations for the progress
-observed today. Recently, we have separated out the model museum into its own
-particular maintained repository called
-[ngc-museum](https://github.com/NACLab/ngc-museum), where you can find and
-access/run historical models and agents built with ngc-learn to perform
-different experimental tasks.
+There is an ever-growing galaxy of neurobiological models and credit assignment processes [1, 2]. One of ngc-learn's aims, in the spirit of scientific reproducibility, is to capture a snapshot of as many of these biomimetic/neuro-mimetic models as possible, in the form of a digital "museum".
+This museum is further designed with the notion of exhibits and exhibitors, aiding to facilitate credit assignment and respectful citation to the ideas and the work of those that have helped to lay the foundations for the progress observed today. Recently, we have separated out the model museum into its own particular maintained repository called [ngc-museum](https://github.com/NACLab/ngc-museum), where you can find and access/run historical models and agents built with ngc-learn to perform different experimental tasks.
-Please refer to the [table of contents](../museum/index.rst) for walkthroughs on
-using and running various historical models in the museum.
+Please refer to the [table of contents](../museum/index.rst) for walkthroughs and guidance on using and running various historical model exhibits in the museum.
## References
[1] Ororbia, Alexander G. "Brain-inspired machine intelligence: A survey
diff --git a/docs/museum/pc_rao_ballard1999.md b/docs/museum/pc_rao_ballard1999.md
new file mode 100644
index 00000000..cfa6b921
--- /dev/null
+++ b/docs/museum/pc_rao_ballard1999.md
@@ -0,0 +1,12 @@
+# Hierarchical Predictive Coding (Rao & Ballard; 1999)
+
+In this exhibit, we create, simulate, and visualize the internally acquired receptive fields of the predictive coding
+model originally proposed in (Rao & Ballard, 1999) [1].
+
+The model code for this exhibit can be found
+[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/pc_recon).
+
+
+## References
+[1] Rao, Rajesh PN, and Dana H. Ballard. "Predictive coding in the visual cortex: a functional interpretation of
+some extra-classical receptive-field effects." Nature neuroscience 2.1 (1999): 79-87.
\ No newline at end of file
diff --git a/docs/museum/pcn_discrim.md b/docs/museum/pcn_discrim.md
index 85cc3756..66d9642b 100644
--- a/docs/museum/pcn_discrim.md
+++ b/docs/museum/pcn_discrim.md
@@ -1,4 +1,4 @@
-# Discriminative Predictive Coding
+# Discriminative Predictive Coding (Whittington & Bogacz; 2017)
In this exhibit, we will see how a classifier can be created based on
predictive coding. This exhibit model effectively reproduces some of the results
diff --git a/docs/museum/rl_snn.md b/docs/museum/rl_snn.md
index 24d11412..6e1c2876 100644
--- a/docs/museum/rl_snn.md
+++ b/docs/museum/rl_snn.md
@@ -1,13 +1,10 @@
-# Reinforcement Learning through a Spiking Controller
+# Reinforcement Learning through a Spiking Controller (Chevtchenko et al.; 2020)
-In this exhibit, we will see how to construct a simple biophysical model for
-reinforcement learning with a spiking neural network and modulated
-spike-timing-dependent plasticity.
-This model incorporates a mechanisms from several different models, including
-the constrained RL-centric SNN of [1] as well as the simplifications
-made with respect to the model of [2]. The model code for this
-exhibit can be found
-[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn).
+In this exhibit, we will see how to construct a simple biophysical model for reinforcement learning with a spiking
+neural network and modulated spike-timing-dependent plasticity.
+This model incorporates a mechanisms from several different models, including the constrained RL-centric SNN of
+[1] as well as some simplifications of the structures used within the SNN of [2]. The model code for this
+exhibit can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/rl_snn).
## Modeling Operant Conditioning through Modulation
@@ -123,10 +120,7 @@ RL-SNN model:
## References
-[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse
-and delayed rewards with a multilayer spiking neural network." 2020 International
-Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
-[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit
-recognition using spike-timing-dependent plasticity." Frontiers in computational
-neuroscience 9 (2015): 99.
-
+[1] Chevtchenko, Sérgio F., and Teresa B. Ludermir. "Learning from sparse and delayed rewards with a multilayer
+spiking neural network." 2020 International Joint Conference on Neural Networks (IJCNN). IEEE, 2020.
+[2] Diehl, Peter U., and Matthew Cook. "Unsupervised learning of digit recognition using spike-timing-dependent
+plasticity." Frontiers in computational neuroscience 9 (2015): 99.
diff --git a/docs/museum/sindy.md b/docs/museum/sindy.md
index 04426d70..9245382d 100644
--- a/docs/museum/sindy.md
+++ b/docs/museum/sindy.md
@@ -1,14 +1,4 @@
-
-
-# Sparse Identification of Non-linear Dynamical Systems (SINDy)
+# Sparse Identification of Non-linear Dynamical Systems (SINDy; Brunton et al.; 2016)
In this section, we will study, create, simulate, and visualize a model known as the sparse identification of non-linear dynamical systems (SINDy) [1], implementing it in NGC-Learn and JAX. After going through this demonstration, you will:
diff --git a/docs/museum/snn_bfa.md b/docs/museum/snn_bfa.md
index 3e62dec1..2a83795b 100644
--- a/docs/museum/snn_bfa.md
+++ b/docs/museum/snn_bfa.md
@@ -1,8 +1,8 @@
-# Spiking Neural Networks: Learning with Broadcast Feedback Alignment
+# Spiking Neural Networks: Learning with Broadcast Feedback Alignment (Samadi et al.; 2017)
In this exhibit, we will see how one can train a spiking neural network model
using surrogate functions and a credit assignment scheme called broadcast
-feedback alignment (BFA) [1].
+feedback alignment (BFA) [1].
This exhibit model effectively reproduces some of the results
reported (Samadi et al., 2017) [1]. The model code for this
exhibit can be found
diff --git a/docs/museum/snn_dc.md b/docs/museum/snn_dc.md
index c42c1a45..4bb3b32a 100755
--- a/docs/museum/snn_dc.md
+++ b/docs/museum/snn_dc.md
@@ -1,4 +1,4 @@
-# The Diehl and Cook Spiking Neuronal Network
+# The Diehl and Cook Spiking Neuronal Network (Diehl & Cook; 2015)
In this exhibit, we will see how a spiking neural network model that adapts
its synaptic efficacies via spike-timing-dependent plasticity can be created.
@@ -313,24 +313,8 @@ neuroscience 9 (2015): 99.
[^1]: Note that the `LIFCell` is not the same as ngc-learn's
-[sLIFCell](ngclearn.components.neurons.spiking.sLIFCell), which is a particular
-cell that simplifies the spiking dynamics greatly and is not meant to operate
-in the negative milliVolt range like the `LIFCell` does.
-[^2]: While both forms of modeling electrical current are easily doable in
- ngc-learn, the `DC_SNN` exhibit model opts for the second approach for simplicity
- and additional simulation speed.
-[^3]: Trace components have also been used in the `DC_SNN` exhibit model, specifically
-those built with the [variable trace](ngclearn.components.other.varTrace) component.
-Note that the variable trace effectively applies a low-pass filter iteratively
-to the spikes produced by a spike train.
-[^4]: In the NAC group's
-experience, observing the mean and Frobenius norm of synaptic values can be a
-useful starting point for determining unhealthy behavior or some degenerate cases
-in the context of spiking neural network credit assignment.
-[^5]: To load in the exact synaptic efficacies we obtained to get the images
-above, you can unzip the folder `dcsnn_syn.zip`, which contains all of the
-model's numpy array values, and simply copy all of the compressed numpy arrays
-into your `exp/snn_stdp/custom/` folder, which is where ngc-learn/ngc-sim-lib
-look for pre-trained value arrays when loading in a previously constructed model.
-Once you do this, running `analyze_dcsnn.py` with the same arguments as above
-should produce plots/images much like those in this walkthrough.
+[sLIFCell](ngclearn.components.neurons.spiking.sLIFCell), which is a particular cell that simplifies the spiking dynamics greatly and is not meant to operate in the negative milliVolt range like the `LIFCell` does.
+[^2]: While both forms of modeling electrical current are easily doable in NGC-Learn, the `DC_SNN` exhibit model opts for the second approach for simplicity and additional simulation speed.
+[^3]: Trace components have also been used in the `DC_SNN` exhibit model, specifically those built with the [variable trace](ngclearn.components.other.varTrace) component. Note that the variable trace effectively applies a low-pass filter iteratively to the spikes produced by a spike train.
+[^4]: In the NAC group's experience, observing the mean and Frobenius norm of synaptic values can be a useful starting point for determining unhealthy behavior or some degenerate cases in the context of spiking neural network credit assignment.
+[^5]: To load in the exact synaptic efficacies we obtained to get the images above, you can unzip the folder `dcsnn_syn.zip`, which contains all of the model's numpy array values, and simply copy all of the compressed numpy arrays into your `exp/snn_stdp/custom/` folder, which is where ngc-learn/ngc-sim-lib look for pre-trained value arrays when loading in a previously constructed model. Once you do this, running `analyze_dcsnn.py` with the same arguments as above should produce plots/images much like those in this walkthrough.
diff --git a/docs/museum/sparse_coding.md b/docs/museum/sparse_coding.md
index d36dbf50..3802a713 100755
--- a/docs/museum/sparse_coding.md
+++ b/docs/museum/sparse_coding.md
@@ -1,86 +1,59 @@
-# Sparse Coding and Iterative Thresholding
+# Sparse Coding and Iterative Thresholding (Olshausen & Field; 1996)
-In this exhibit, we create, simulate, and visualize the
-internally acquired filters/atoms of variants of a sparse coding system based
-on the classical model proposed by (Olshausen & Field, 1996) [1].
+In this exhibit, we create, simulate, and visualize the internally acquired filters/atoms of variants of a sparse coding system based on the classical model proposed by (Olshausen & Field, 1996) [1].
After going through this demonstration, you will:
-1. Learn how to build a 2-layer sparse coding model of natural image patterns,
-using the original dataset used in [1].
-2. Visualize the acquired filters of the learned dictionary models and examine
-the results of imposing a kurtotic prior as well as a thresholding function
-over latent codes.
+1. Learn how to build a 2-layer sparse coding model of natural image patterns, using the original dataset used in [1].
+2. Visualize the acquired filters of the learned dictionary models and examine the results of imposing a kurtotic prior as well as a thresholding function over latent codes.
-The model code for this
-exhibit can be found
-[here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/olshausen_sc).
+The model code for this exhibit can be found [here](https://github.com/NACLab/ngc-museum/tree/main/exhibits/olshausen_sc).
-Note: You will need to unzip the data arrays in `exhibits/data/natural_scenes.zip`
-to the folder `exhibits/data/` to work through this exhibit.
+Note: You will need to unzip the data arrays in `exhibits/data/natural_scenes.zip` to the folder `exhibits/data/` to work through this exhibit.
## On Dictionary Learning
-Dictionary learning poses a very interesting question for statistical learning:
-can we extract "feature detectors" from a given database (or collection of patterns)
-such that only a few of these detectors play a role in reconstructing any given,
-original pattern/data point?
-The aim of dictionary learning is to acquire or learn a matrix, also called the
-"dictionary", which is meant to contain "atoms" or basic elements inside this dictionary
-(such as simple fundamental features such as the basic strokes/curves/edges
-that compose handwritten digits or characters). Several atoms (or rows of this
-matrix) inside the dictionary can then be linearly combined to reconstruct a
-given input signal or pattern. A sparse dictionary model is able to reconstruct
-input patterns with as few of these atoms as possible. Typical sparse dictionary
-or coding models work with an over-complete spanning set, or, in other words,
-a latent dimensionality (which one could think of as the number of neurons
-in a single latent state node of an ngc-learn system) that is greater than the
-dimensionality of the input itself.
-
-From a neurobiological standpoint, sparse coding emulates a fundamental property
-of neural populations -- the activities among a neural population are sparse where,
-within a period of time, the number of total active neurons (those that are firing)
-is smaller than the total number of neurons in the population itself. When sensory
-inputs are encoded within this population, different subsets (which might overlap) of
-neurons activate to represent different inputs (one way to view this is that they
-"fight" or compete for the right to activate in response to different stimuli).
-Classically, it was shown in [1] that a sparse coding model trained on natural
-image patches learned within its dictionary non-orthogonal filters that resembled
-receptive fields of simple-cells (found in the visual cortex).
+Dictionary learning poses a very interesting question for statistical learning: can we extract "feature detectors" from a given database (or collection of patterns) such that only a few of these detectors play a role in reconstructing any given, original pattern/data point?
+The aim of dictionary learning is to acquire or learn a matrix, also called the "dictionary", which is meant to contain "atoms" or basic elements inside this dictionary (such as simple fundamental features such as the basic strokes/curves/edges that compose handwritten digits or characters). Several atoms (or rows of this matrix) inside the dictionary can then be linearly combined to reconstruct a given input signal or pattern. A sparse dictionary model is able to reconstruct input patterns with as few of these atoms as possible. Typical sparse dictionary or coding models work with an over-complete spanning set, or, in other words, a latent dimensionality (which one could think of as the number of neurons in a single latent state node of an ngc-learn system) that is greater than the dimensionality of the input itself.
+
+From a neurobiological standpoint, sparse coding emulates a fundamental property of neural populations -- the activities among a neural population are sparse where, within a period of time, the number of total active neurons (those that are firing) is smaller than the total number of neurons in the population itself. When sensory inputs are encoded within this population, different subsets (which might overlap) of neurons activate to represent different inputs (one way to view this is that they "fight" or compete for the right to activate in response to different stimuli).
+Classically, it was shown in [1] that a sparse coding model trained on natural image patches learned within its dictionary non-orthogonal filters that resembled receptive fields of simple-cells (found in the visual cortex).
## Constructing a Sparse Coding System
-To build a sparse coding model, we can manually craft a model using ngc-learn's
-nodes-and-cables system. First, we specify the underlying generative model we
-aim to emulate. Formally, we seek to optimize a set of latent codes according
-to the following differential equation:
+To build a sparse coding model, we can manually craft a model using ngc-learn's nodes-and-cables system. First, we specify the underlying generative model we aim to emulate. Formally, we seek to optimize a set of latent codes according to the following differential equation:
$$
\tau_m \frac{\partial \mathbf{z}_t}{\partial t} =
\big(\mathbf{W}^T \cdot \mathbf{e}(t) \big) + \lambda \Omega\big(\mathbf{z}(t)\big)
$$
-where $\tau_m$ is the latent code time constant and the error neurons $\mathbf{e}(t)$
-at the sensory input layer made at time $t$ are specified as:
+where the above is also referred to as the E-step (since the optimization carried out for most sparse coding models is done within the framework of expectation-maximization -- E-step refers to updates to the latent variables whereas M-step refers to updates to synaptic/dictionary parameters) and $\tau_m$ is the latent code time constant and the error neurons $\mathbf{e}(t)$ at the sensory input layer made at time $t$ are specified as:
$$
\mathbf{e}(t) = -\big(\mathbf{o}_t - (\mathbf{W} \cdot \mathbf{z}(t)) \big)
$$
-where we see that we aim to learn a two-layer generative system that specifically
-imposes a prior distribution `p(z)` over the latent feature detectors (via the
-constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope
-to extract in node `z`. Note that this two-layer model (or single latent-variable layer
-model) could either be the linear generative model from [1] or one similar to the
-model learned through ISTA [2] if a (soft) thresholding function is used instead.
+where we see that we aim to learn a two-layer generative system that specifically imposes a prior distribution `p(z)` over the latent feature detectors (via the constraint function $ \Omega\big(\mathbf{z}(t)\big) $ ) that we hope to extract in node `z`. Note that this two-layer model (or single latent-variable layer model) could either be the linear generative model from [1] or one similar to the model learned through ISTA [2] if a (soft) thresholding function is used instead.
+
+Furthermore, the synaptic weight updates (the M-step) to our sparse coding model generally adhere to the following differential equation:
+
+$$
+\tau_m \frac{\partial \mathbf{W}}{\partial t} = -\mathbf{W} + \big(\mathbf{e}(t) \cdot (\mathbf{z}(t))^T \big)
+$$
-Constructing the above system for (Olshausen & Field, 1996) is done, much
-like we do in the `SparseCoding` agent constructor in the model museum exhibit
-code, as follows:
+Constructing the above system for (Olshausen & Field, 1996) is done, much like we do in the `SparseCoding` agent constructor in the model museum exhibit code, as follows:
```python
-from ngcsimlib.context import Context
-from ngclearn.components import GaussianErrorCell as ErrorCell, RateCell, HebbianSynapse, StaticSynapse
+from ngclearn.utils.io_utils import makedir
+from ngclearn.utils.viz.synapse_plot import visualize
+from jax import numpy as jnp, random, jit
+from ngclearn import Context, MethodProcess, JointProcess
+from ngclearn.components.neurons.graded.rateCell import RateCell
+from ngclearn.components.synapses.denseSynapse import DenseSynapse
+from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
+from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell as ErrorCell
from ngclearn.utils.model_utils import normalize_matrix
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
in_dim = # ... dimension of patch data ...
hid_dim = # ... number of atoms in the dictionary matrix
@@ -88,174 +61,127 @@ dt = 1. # ms
T = 300 # ms # (OR) number of E-steps to take during inference
# ---- build a sparse coding linear generative model with a Cauchy prior ----
with Context("Circuit") as circuit:
- z1 = RateCell("z1", n_units=hid_dim, tau_m=20., act_fx="identity",
- prior=("cauchy", 0.14), integration_type="euler")
+ z1 = RateCell(
+ "z1", n_units=hid_dim, tau_m=20, act_fx="identity", prior=("cauchy", 0.14), integration_type="euler"
+ )
e0 = ErrorCell("e0", n_units=in_dim)
- W1 = HebbianSynapse("W1", shape=(hid_dim, in_dim),
- eta=1e-2, wInit=("fan_in_gaussian", 0., 1.),
- bInit=None, w_bound=0., optim_type="sgd", signVal=-1.)
- E1 = StaticSynapse("E1", shape=(in_dim, hid_dim),
- wInit=("uniform", -0.2, 0.2), Rscale=1.)
+ W1 = HebbianSynapse(
+ "W1", shape=(hid_dim, in_dim), eta=1e-2, weight_init=dist.fan_in_gaussian(), bias_init=None, w_bound=0., optim_type="sgd", sign_value=-1.
+ )
+ E1 = DenseSynapse( ## E1 = (W1)^T
+ "E1", shape=(in_dim, hid_dim), weight_init=dist.uniform(-0.2, 0.2),
+ resist_scale=1.
+ )
+ E1.weights.set(W1.weights.get().T)
+
## wire z1.zF to e0.mu via W1
- W1.inputs << z1.zF
- e0.mu << W1.outputs
- ## wire e0.dmu to z1.j
- E1.inputs << e0.dmu
- z1.j << E1.outputs
- ## Setup W1 for its 2-factor Hebbian update
- W1.pre << z1.zF
- W1.post << e0.dmu
-
- reset_cmd, reset_args = circuit.compile_by_key(
- W1, E1, z1, e0,
- compile_key="reset")
- advance_cmd, advance_args = circuit.compile_by_key(
- W1, E1, z1, e0,
- compile_key="advance_state")
- evolve_cmd, evolve_args = circuit.compile_by_key(W1, compile_key="evolve")
+ z1.zF >> W1.inputs
+ W1.outputs >> e0.mu
+ ## wire e0.dmu back up to z1.j via E1 (for E-step)
+ e0.dmu >> E1.inputs
+ E1.outputs >> z1.j
+
+ ## Setup W1 for its 2-factor Hebbian update (for M-step)
+ z1.zF >> W1.pre
+ e0.dmu >> W1.post
+
+ ## Inference process
+ advance = (MethodProcess(name="advance")
+ >> W1.advance_state
+ >> E1.advance_state
+ >> z1.advance_state
+ >> e0.advance_state)
+ ## Reset-to-baseline process
+ reset = (MethodProcess(name="reset")
+ >> W1.reset
+ >> E1.reset
+ >> z1.reset
+ >> e0.reset)
+ ## Learning process
+ evolve = (MethodProcess(name="evolve")
+ >> W1.evolve)
```
-Notice that, in our model `circuit`, we have taken care to set the `.param_axis`
-variable to be equal to `1` -- this will, whenever we call `apply_constraints()`,
-tell the NGC system to normalize the Euclidean norm of the columns
-of the dictionary matrix to be equal to a value of one. This is a particularly
-important constraint to apply to sparse coding models as this prevents the
-trivial solution of simply growing out
-the magnitude of the dictionary synapses to solve the underlying constrained
-optimization problem (and, in general, constraining the rows or
-columns of generative models helps to facilitate a more stable training process).
-This norm constraint is configured in the agent constructor's dynamic
-compile function, specifically in the snippet below:
+There is one important co-routine we also need to make sure we include for our sparse coding `circuit` that needs to happen after each update to the synapses -- synaptic weight normalization. Specifically, we want to normalize the Euclidean norm of the columns of the dictionary matrix to be equal to a value of one.
+
+This is a particularly important constraint to apply to sparse coding models as this prevents the trivial solution of simply growing out the magnitude of the dictionary synapses to solve the underlying constrained optimization problem (and, in general, constraining the rows or columns of generative models helps to facilitate a more stable training process). This norm constraint can be simply written as below:
```python
-@Context.dynamicCommand
def norm():
- W1.weights.set(normalize_matrix(W1.weights.value, 1., order=2, axis=1))
+ W1.weights.set(normalize_matrix(W1.weights.get(), 1., order=2, axis=1))
```
-To build the version of our model (the ISTA model) using a thresholding function,
-instead of using a factorial prior over the latents, we can write the following:
+To build the version of our model (the ISTA model) using a thresholding function, instead of using a factorial prior over the latents, we can write the following:
```python
# ---- build a sparse coding generative model w/ a thresholding function ----
with Context("Circuit") as circuit:
- z1 = RateCell("z1", n_units=hid_dim, tau_m=20., act_fx="identity",
- threshold=("soft_threshold", 5e-3), integration_type="euler")
+ z1 = RateCell(
+ "z1", n_units=hid_dim, tau_m=20, act_fx="identity", threshold=("soft_threshold", 5e-3), integration_type="euler"
+ )
e0 = ErrorCell("e0", n_units=in_dim)
- W1 = HebbianSynapse("W1", shape=(hid_dim, in_dim),
- eta=1e-2, wInit=("fan_in_gaussian", 0., 1.),
- bInit=None, w_bound=0., optim_type="sgd", signVal=-1.)
- E1 = StaticSynapse("E1", shape=(in_dim, hid_dim),
- wInit=("uniform", -0.2, 0.2), Rscale=1.)
+ W1 = HebbianSynapse(
+ "W1", shape=(hid_dim, in_dim), eta=1e-2, weight_init=dist.fan_in_gaussian(), bias_init=None, w_bound=0., optim_type="sgd", sign_value=-1.
+ )
+ E1 = DenseSynapse(
+ "E1", shape=(in_dim, hid_dim), weight_init=dist.uniform(-0.2, 0.2),
+ resist_scale=1.
+ )
+ E1.weights.set(W1.weights.get().T)
## ...rest of the code is the same as the Cauchy prior model...
```
-Note that the above two models are built and configured for you in the
-Model Museum, in the `museum/exhibits/olshausen_sc/sparse_coding.py`
-agent constructor, which internally implements the model contexts depicted above
-as well as the necessary task-specific functions needed to reproduce the
-correct experimental setup (these get compiled in the constructor's
-`dynamic()` method. For both the Cauchy prior model of [1]
-and the iterative thresholding model of [2], we track, in the
-training script `train_patch_sc.py`, various dictionary synaptic
-statistics and a measurement of the model reconstruction loss. The
-reconstruction loss is a key part of the objective that both models
-optimize, i.e., both SC models effectively optimize an
-energy function that is a sum of its reconstruction error of its sensory
-input and the sparsity of its single latent state layer `z1`).
+Note that the above two models are built and configured for you in the Model Museum, in the `museum/exhibits/olshausen_sc/sparse_coding.py` agent constructor, which internally implements the model contexts depicted above as well as the necessary task-specific functions needed to reproduce the correct experimental setup (these get compiled in the constructor's `dynamic()` method. For both the Cauchy prior model of [1] and the iterative thresholding model of [2], we track, in the training script `train_patch_sc.py`, various dictionary synaptic statistics and a measurement of the model reconstruction loss. The reconstruction loss is a key part of the objective that both models optimize, i.e., both SC models effectively optimize an energy function that is a sum of its reconstruction error of its sensory input and the sparsity of its single latent state layer `z1`).
## Learning Latent Feature Detectors
-We will now simulate the learning of feature detectors using the two
-sparse coding models specified above. The code provided in
-`train_patch_sc.py` will execute a simulation of the above
-two models on the natural images found in `exhibits/data/natural_scenes.zip`),
-which is a dataset composed of several images of the American Northwest.
-
-First, navigate to the `exhibits/` directory to access the example/demonstration
-code and further enter the `exhibits/data/` sub-folder. Unzip the file
-`natural_scenes.zip` to create one more sub-folder that contains two numpy arrays,
-the first labeled `natural_scenes/raw_dataX.npy` and another labeled as
-`natural_scenes/dataX.npy`. The first one contains the original, `512 x 512` raw pixel
-image arrays (flattened) while the second contains the pre-processed, whitened/normalized
-(and flattened) image data arrays (these are the pre-processed image patterns used
-in [1]). You will, in this demonstration, only be working with `natural_scenes/dataX.npy`.
+We will now simulate the learning of feature detectors using the two sparse coding models specified above. The code provided in `train_patch_sc.py` will execute a simulation of the above two models on the natural images found in `exhibits/data/natural_scenes.zip`), which is a dataset composed of several images of the American Northwest.
+
+First, navigate to the `exhibits/` directory to access the example/demonstration code and further enter the `exhibits/data/` sub-folder. Unzip the file `natural_scenes.zip` to create one more sub-folder that contains two numpy arrays, the first labeled `natural_scenes/raw_dataX.npy` and another labeled as `natural_scenes/dataX.npy`. The first one contains the original, `512 x 512` raw pixel image arrays (flattened) while the second contains the pre-processed, whitened/normalized (and flattened) image data arrays (these are the pre-processed image patterns used in [1]). You will, in this demonstration, only be working with `natural_scenes/dataX.npy`.
Two (raw) images sampled from the original dataset (`raw_dataX.npy`) are shown below:
| | |
|---|---|
|  |  |
-With the data unpacked and ready, we can now run the training process in
-the model exhibit by either executing its Python simulation script like so:
+With the data unpacked and ready, we can now run the training process in the model exhibit by either executing its Python simulation script like so:
```console
$ python train_patch_sc.py --dataX="$DATA_DIR/dataX.npy" \
--n_iter=200 --model_type="sc_cauchy"
```
-or simply running the convenience Bash script `$ ./sim.sh` (which cleans up the model
-experimental output folder each time you call the training script in order
-to reduce memory clutter on your system). Running either the Python or Bash
-script will then train a sparse coding model with a Cauchy prior on `16 x 16`
-pixel patches from the natural image dataset in [1].[^1] After the simulation
-terminates, i.e., once `200` iterations/passes through the data have been made,
-you will notice in the `exp/filters/` sub-directory a visual plot
-of your trained model's filters which should look like the one below:
+or simply running the convenience Bash script `$ ./sim.sh` (which cleans up the model experimental output folder each time you call the training script in order to reduce memory clutter on your system). Running either the Python or Bash script will then train a sparse coding model with a Cauchy prior on `16 x 16` pixel patches from the natural image dataset in [1].[^1] After the simulation terminates, i.e., once `200` iterations/passes through the data have been made, you will notice in the `exp/filters/` sub-directory a visual plot of your trained model's filters which should look like the one below:
-If you modify either the Bash script or Python script call to use
-with a different model argument like so:
+If you modify either the Bash script or Python script call to use with a different model argument like so:
```console
$ python train_patch_sc.py --dataX="$DATA_DIR/dataX.npy" \
--n_iter=200 --model_type="sc_ista"
```
-you will now train your sparse coding using a latent soft-thresholding function
-(emulating ISTA). After this simulated training process ends, you should see
-in your `exp/filters/` sub-directory a filter plot like the one below:
+you will now train your sparse coding using a latent soft-thresholding function (emulating ISTA). After this simulated training process ends, you should see in your `exp/filters/` sub-directory a filter plot like the one below:
-The filter plots, notably, visually indicate that the dictionary atoms in both
-sparse coding systems learned to function as edge detectors, each tuned to
-a particular position, orientation, and frequency. These learned feature detectors,
-as discussed in [1], appear to behave similar to the primary visual area (V1)
-neurons of the cerebral cortex in the brain. In the end, even though the edge
-detectors learned by both our models qualitatively appear to be similar,
-we should note that the latent codes (when inferring them given sensory input)
-for the model that used the thresholding function will ultimately sparser
-(given the direct clamping to zero values it imposes mathematically).
-Furthermore, the filters for the model with thresholding appear to smoother
-and with fewer occurrences of less-than-useful slots than the Cauchy model
-(or filters that did not appear to extract any particularly interpretable
-features).
+The filter plots, notably, visually indicate that the dictionary atoms in both sparse coding systems learned to function as edge detectors, each tuned to a particular position, orientation, and frequency. These learned feature detectors, as discussed in [1], appear to behave similar to the primary visual area (V1) neurons of the cerebral cortex in the brain. In the end, even though the edge detectors learned by both our models qualitatively appear to be similar, we should note that the latent codes (when inferring them given sensory input) for the model that used the thresholding function will ultimately sparser (given the direct clamping to zero values it imposes mathematically).
+Furthermore, the filters for the model with thresholding appear to smoother and with fewer occurrences of less-than-useful slots than the Cauchy model (or filters that did not appear to extract any particularly interpretable features).
### Computing Hardware Note:
-This tutorial was tested and run on an `Ubuntu 22.04.2 LTS` operating system
-using an `NVIDIA GeForce RTX 2070` GPU with `CUDA Version: 12.1`
-(`Driver Version: 530.41.03`). Note that the times reported in any tutorial
-screenshot/console snippets were produced on this system.
+This tutorial was tested and run on an `Ubuntu 22.04.2 LTS` operating system using an `NVIDIA GeForce RTX 2070` GPU with `CUDA Version: 12.1` (`Driver Version: 530.41.03`). Note that the times reported in any tutorial screenshot/console snippets were produced on this system.
## References
[1] Olshausen, B., Field, D. Emergence of simple-cell receptive field properties
diff --git a/docs/ngclearn_papers.md b/docs/ngclearn_papers.md
index 637f8a76..04460db5 100644
--- a/docs/ngclearn_papers.md
+++ b/docs/ngclearn_papers.md
@@ -1,31 +1,19 @@
-# List of Talks Related to NGC-Learn/Sim-Lib
-
-The following is a list of talks, including conference/symposium tech-talks,
-given about NGC-Learn and/or NGC-Sim-Lib:
-
-1. Gebhardt, W. "An introduction to ngc-learn: A computational neuroscience library." 4th Applied Active Inference Symposium (2023). URL: https://www.youtube.com/watch?v=ynekj8F4zeY
-
# List of Papers and Publications
-The following is a list of current papers that use ngc-learn (this list will be
-actively updated as we discover others that use ngc-learn):
+The following is a list of current papers that use NGC-Learn (this list will be actively updated as we discover others that use NGC-Learn):
-1. Ororbia, A., and Kifer, D. "The neural coding framework for learning generative models". Nature Communications 13, 2064 (2022).
+1. Ororbia, A., and Kifer, D. The neural coding framework for learning generative models. Nature Communications 13, 2064 (2022).
-2. Ororbia, A., and Mali, A. "Backprop-free reinforcement learning with active neural generative coding". Proceedings of the AAAI Conference on Artificial intelligence (2022).
+2. Ororbia, A., and Mali, A. Backprop-free reinforcement learning with active neural generative coding. Proceedings of the AAAI Conference on Artificial intelligence (2022).
-3. Ororbia, A. "Spiking neural predictive coding for continual learning from data streams." arXiv preprint arXiv:1908.08655 (2019).
+3. Ororbia, A. "Spiking neural predictive coding for continual learning from data streams." Neurocomputing 544: 126292 (2022).
-4. Ororbia, A, and Kelly, M. Alex. "CogNGen: Constructing the kernel of a hyperdimensional predictive processing cognitive architecture."
-Proceedings of the Annual Meeting of the Cognitive Science Society (CogSci), Volume 44 (2022).
+4. Ororbia, A, and Kelly, M. Alex. "CogNGen: constructing the kernel of a hyperdimensional predictive processing cognitive architecture." Proceedings of the Annual Meeting of the Cognitive Science Society (CogSci), Volume 44 (2022).
-5. Ororbia, A., and Kelly, M. Alex. "“Learning using a hyperdimensional predictive processing cognitive architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022).
+5. Ororbia, A., and Kelly, M. Alex. "Learning using a hyperdimensional predictive processing cognitive architecture." 15th International Conference on Artificial General Intelligence (AGI) (2022).
-6. Ororbia, A., Mali, A., Kifer, D., and Giles, C. L. "Lifelong neural predictive coding: Learning cumulatively online without
-forgetting." Thirty-sixth Conference on Neural Information Processing Systems (NeurIPS) (2022).
+6. Ororbia, A., Mali, A., Kifer, D., & Giles, C. L. "Lifelong neural predictive coding: Learning cumulatively online without forgetting." Thirty-sixth Conference on Neural Information Processing Systems (NeurIPS) (2022).
-7. Gebhardt, W., and Ororbia, A. "Time-integrated spike-timing-dependent-plasticity." arXiv preprint arXiv:2407.10028 (2024).
+7. Ororbia, A., Friston, K., Rao, Rajesh P. N. "Meta-representational predictive coding: Biomimetic self-supervised learning." arXiv preprint arXiv:2503.21796 (2025).
-Note: Please let us know if your work uses ngc-learn so we can update this page to accurately track
-ngc-learn's use and include your work in the accumulating body of work in predictive processing
-and/or brain-inspired computational modeling.
+Note: Please let us know if your work uses NGC-Learn so we can update this page to accurately track NGC-Learn's use and include your work in the space of computational neuroscience, NeuroAI, and/or brain-inspired computational modeling.
diff --git a/docs/ngclearn_talks.md b/docs/ngclearn_talks.md
new file mode 100644
index 00000000..421da729
--- /dev/null
+++ b/docs/ngclearn_talks.md
@@ -0,0 +1,13 @@
+# Talks and Media Related to NGC-Learn
+
+The following is a list of talks and any media related to NGC-Learn:
+
+1. "NGC-Learn V3: A Fast, Modular, Computational Neuroscience Library". William Gebhardt (NAC Lab). Link (Youtube Video) (2025)
+
+2. "An Introduction to NGC-Learn: A Computational Neuroscience Library". William Gebhardt (NAC Lab). Link (Youtube Video) (2024)[^1]
+
+Keep an eye out on the [NAC Lab Youtube channel](https://www.youtube.com/@TheNACLab/featured) for additional future videos related to tutorials, educational material, and research related to NGC-Learn.
+
+Note: Please let us know if you give any tutorials/talks on work related that makes use of NGC-Learn so we can update this page to make these useful educational materials available to a wider audience.
+
+[^1]: Note that this talk is related to NGC-Learn (v2)/NGC-Sim-Lib (v1).
\ No newline at end of file
diff --git a/docs/source/ngclearn.commands.rst b/docs/source/ngclearn.commands.rst
deleted file mode 100644
index 7b0c40c1..00000000
--- a/docs/source/ngclearn.commands.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-ngclearn.commands package
-=========================
-
-Module contents
----------------
-
-.. automodule:: ngclearn.commands
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/ngclearn.components.lava.neurons.rst b/docs/source/ngclearn.components.lava.neurons.rst
deleted file mode 100644
index 9126f5e4..00000000
--- a/docs/source/ngclearn.components.lava.neurons.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-ngclearn.components.lava.neurons package
-========================================
-
-Submodules
-----------
-
-ngclearn.components.lava.neurons.LIFCell module
------------------------------------------------
-
-.. automodule:: ngclearn.components.lava.neurons.LIFCell
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: ngclearn.components.lava.neurons
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/ngclearn.components.lava.rst b/docs/source/ngclearn.components.lava.rst
deleted file mode 100644
index 6b8be426..00000000
--- a/docs/source/ngclearn.components.lava.rst
+++ /dev/null
@@ -1,31 +0,0 @@
-ngclearn.components.lava package
-================================
-
-Subpackages
------------
-
-.. toctree::
- :maxdepth: 4
-
- ngclearn.components.lava.neurons
- ngclearn.components.lava.synapses
- ngclearn.components.lava.traces
-
-Submodules
-----------
-
-ngclearn.components.lava.monitor module
----------------------------------------
-
-.. automodule:: ngclearn.components.lava.monitor
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: ngclearn.components.lava
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/ngclearn.components.lava.synapses.rst b/docs/source/ngclearn.components.lava.synapses.rst
deleted file mode 100644
index 2babbb40..00000000
--- a/docs/source/ngclearn.components.lava.synapses.rst
+++ /dev/null
@@ -1,37 +0,0 @@
-ngclearn.components.lava.synapses package
-=========================================
-
-Submodules
-----------
-
-ngclearn.components.lava.synapses.hebbianSynapse module
--------------------------------------------------------
-
-.. automodule:: ngclearn.components.lava.synapses.hebbianSynapse
- :members:
- :undoc-members:
- :show-inheritance:
-
-ngclearn.components.lava.synapses.staticSynapse module
-------------------------------------------------------
-
-.. automodule:: ngclearn.components.lava.synapses.staticSynapse
- :members:
- :undoc-members:
- :show-inheritance:
-
-ngclearn.components.lava.synapses.traceSTDPSynapse module
----------------------------------------------------------
-
-.. automodule:: ngclearn.components.lava.synapses.traceSTDPSynapse
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: ngclearn.components.lava.synapses
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/ngclearn.components.lava.traces.rst b/docs/source/ngclearn.components.lava.traces.rst
deleted file mode 100644
index e2dbe697..00000000
--- a/docs/source/ngclearn.components.lava.traces.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-ngclearn.components.lava.traces package
-=======================================
-
-Submodules
-----------
-
-ngclearn.components.lava.traces.gatedTrace module
--------------------------------------------------
-
-.. automodule:: ngclearn.components.lava.traces.gatedTrace
- :members:
- :undoc-members:
- :show-inheritance:
-
-Module contents
----------------
-
-.. automodule:: ngclearn.components.lava.traces
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/ngclearn.components.neurons.graded.rst b/docs/source/ngclearn.components.neurons.graded.rst
index d62a5b7e..9eb532c8 100644
--- a/docs/source/ngclearn.components.neurons.graded.rst
+++ b/docs/source/ngclearn.components.neurons.graded.rst
@@ -28,6 +28,14 @@ ngclearn.components.neurons.graded.laplacianErrorCell module
:undoc-members:
:show-inheritance:
+ngclearn.components.neurons.graded.leakyNoiseCell module
+--------------------------------------------------------
+
+.. automodule:: ngclearn.components.neurons.graded.leakyNoiseCell
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.components.neurons.graded.rateCell module
--------------------------------------------------
diff --git a/docs/source/ngclearn.components.rst b/docs/source/ngclearn.components.rst
index e3209782..c822f3bf 100644
--- a/docs/source/ngclearn.components.rst
+++ b/docs/source/ngclearn.components.rst
@@ -8,7 +8,6 @@ Subpackages
:maxdepth: 4
ngclearn.components.input_encoders
- ngclearn.components.lava
ngclearn.components.neurons
ngclearn.components.other
ngclearn.components.synapses
@@ -16,14 +15,6 @@ Subpackages
Submodules
----------
-ngclearn.components.base\_monitor module
-----------------------------------------
-
-.. automodule:: ngclearn.components.base_monitor
- :members:
- :undoc-members:
- :show-inheritance:
-
ngclearn.components.jaxComponent module
---------------------------------------
@@ -32,14 +23,6 @@ ngclearn.components.jaxComponent module
:undoc-members:
:show-inheritance:
-ngclearn.components.monitor module
-----------------------------------
-
-.. automodule:: ngclearn.components.monitor
- :members:
- :undoc-members:
- :show-inheritance:
-
Module contents
---------------
diff --git a/docs/source/ngclearn.rst b/docs/source/ngclearn.rst
index 814817bb..15e44840 100644
--- a/docs/source/ngclearn.rst
+++ b/docs/source/ngclearn.rst
@@ -7,7 +7,6 @@ Subpackages
.. toctree::
:maxdepth: 4
- ngclearn.commands
ngclearn.components
ngclearn.modules
ngclearn.operations
diff --git a/docs/source/ngclearn.utils.density.rst b/docs/source/ngclearn.utils.density.rst
index 40134e2a..09f8a461 100644
--- a/docs/source/ngclearn.utils.density.rst
+++ b/docs/source/ngclearn.utils.density.rst
@@ -4,10 +4,34 @@ ngclearn.utils.density package
Submodules
----------
-ngclearn.utils.density.gmm module
----------------------------------
+ngclearn.utils.density.bernoulliMixture module
+----------------------------------------------
-.. automodule:: ngclearn.utils.density.gmm
+.. automodule:: ngclearn.utils.density.bernoulliMixture
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+ngclearn.utils.density.exponentialMixture module
+------------------------------------------------
+
+.. automodule:: ngclearn.utils.density.exponentialMixture
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+ngclearn.utils.density.gaussianMixture module
+---------------------------------------------
+
+.. automodule:: ngclearn.utils.density.gaussianMixture
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+ngclearn.utils.density.mixture module
+-------------------------------------
+
+.. automodule:: ngclearn.utils.density.mixture
:members:
:undoc-members:
:show-inheritance:
diff --git a/docs/source/ngclearn.utils.feature_dictionaries.rst b/docs/source/ngclearn.utils.feature_dictionaries.rst
new file mode 100644
index 00000000..bc9daa64
--- /dev/null
+++ b/docs/source/ngclearn.utils.feature_dictionaries.rst
@@ -0,0 +1,21 @@
+ngclearn.utils.feature\_dictionaries package
+============================================
+
+Submodules
+----------
+
+ngclearn.utils.feature\_dictionaries.polynomialLibrary module
+-------------------------------------------------------------
+
+.. automodule:: ngclearn.utils.feature_dictionaries.polynomialLibrary
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: ngclearn.utils.feature_dictionaries
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/ngclearn.utils.masks.rst b/docs/source/ngclearn.utils.masks.rst
new file mode 100644
index 00000000..17721150
--- /dev/null
+++ b/docs/source/ngclearn.utils.masks.rst
@@ -0,0 +1,21 @@
+ngclearn.utils.masks package
+============================
+
+Submodules
+----------
+
+ngclearn.utils.masks.multiblock2d module
+----------------------------------------
+
+.. automodule:: ngclearn.utils.masks.multiblock2d
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: ngclearn.utils.masks
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/source/ngclearn.utils.optim.rst b/docs/source/ngclearn.utils.optim.rst
index 20145f98..547b6209 100644
--- a/docs/source/ngclearn.utils.optim.rst
+++ b/docs/source/ngclearn.utils.optim.rst
@@ -12,6 +12,14 @@ ngclearn.utils.optim.adam module
:undoc-members:
:show-inheritance:
+ngclearn.utils.optim.nag module
+-------------------------------
+
+.. automodule:: ngclearn.utils.optim.nag
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.utils.optim.optim\_utils module
----------------------------------------
diff --git a/docs/source/ngclearn.utils.rst b/docs/source/ngclearn.utils.rst
index b442e626..9f4e4809 100644
--- a/docs/source/ngclearn.utils.rst
+++ b/docs/source/ngclearn.utils.rst
@@ -10,12 +10,22 @@ Subpackages
ngclearn.utils.analysis
ngclearn.utils.density
ngclearn.utils.diffeq
+ ngclearn.utils.feature_dictionaries
+ ngclearn.utils.masks
ngclearn.utils.optim
ngclearn.utils.viz
Submodules
----------
+ngclearn.utils.JaxProcessesMixin module
+---------------------------------------
+
+.. automodule:: ngclearn.utils.JaxProcessesMixin
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.utils.data\_loader module
----------------------------------
@@ -24,18 +34,18 @@ ngclearn.utils.data\_loader module
:undoc-members:
:show-inheritance:
-ngclearn.utils.io\_utils module
--------------------------------
+ngclearn.utils.distribution\_generator module
+---------------------------------------------
-.. automodule:: ngclearn.utils.io_utils
+.. automodule:: ngclearn.utils.distribution_generator
:members:
:undoc-members:
:show-inheritance:
-ngclearn.utils.jaxProcess module
---------------------------------
+ngclearn.utils.io\_utils module
+-------------------------------
-.. automodule:: ngclearn.utils.jaxProcess
+.. automodule:: ngclearn.utils.io_utils
:members:
:undoc-members:
:show-inheritance:
@@ -56,6 +66,14 @@ ngclearn.utils.model\_utils module
:undoc-members:
:show-inheritance:
+ngclearn.utils.patch module
+---------------------------
+
+.. automodule:: ngclearn.utils.patch
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.utils.patch\_utils module
----------------------------------
@@ -72,14 +90,6 @@ ngclearn.utils.surrogate\_fx module
:undoc-members:
:show-inheritance:
-ngclearn.utils.weight\_distribution module
-------------------------------------------
-
-.. automodule:: ngclearn.utils.weight_distribution
- :members:
- :undoc-members:
- :show-inheritance:
-
Module contents
---------------
diff --git a/docs/source/ngclearn.utils.viz.rst b/docs/source/ngclearn.utils.viz.rst
index 0a48f7a8..4c926118 100644
--- a/docs/source/ngclearn.utils.viz.rst
+++ b/docs/source/ngclearn.utils.viz.rst
@@ -4,6 +4,22 @@ ngclearn.utils.viz package
Submodules
----------
+ngclearn.utils.viz.compartment\_plot module
+-------------------------------------------
+
+.. automodule:: ngclearn.utils.viz.compartment_plot
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+ngclearn.utils.viz.compartment\_raster module
+---------------------------------------------
+
+.. automodule:: ngclearn.utils.viz.compartment_raster
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
ngclearn.utils.viz.dim\_reduce module
-------------------------------------
diff --git a/docs/tutorials/configuration/compartments.md b/docs/tutorials/configuration/compartments.md
new file mode 100644
index 00000000..d2dfed25
--- /dev/null
+++ b/docs/tutorials/configuration/compartments.md
@@ -0,0 +1,44 @@
+# Compartments
+
+Within NGC-Sim-Lib, the global state serves as the backbone of any given model.
+This global state is essentially the culmination of all of the dynamic or changing parts of the model itself. Each
+value that builds this state is stored within a special "container" that helps track these changes over time -- this
+is referred to as a `Compartment`.
+
+## Practical Information
+
+Practically, when working with compartments, there are a few simple things to keep in mind despite the fact that most
+of NGC-Sim-Lib's primary operation is behind-the-scenes bookkeeping. The two main points to note are:
+1. Each compartment holds a value and, thus, setting a compartment with `myCompartment = newValue` will not function as
+ intended since this will overwrite the Python object, i.e., the compartment with `newValue`. Instead, it is
+ important to make use of the `.set()` method to update the value stored inside a compartment so
+ `myCompartment = newValue` becomes `myCompartment.set(newValue)`.
+2. In order to retrieve a value from a compartment, use `myCompartment.get()`. These methods of getting and setting
+ data inside a compartment are important to use when both working with and designing a multi-compartment component
+ (i.e., `Component`).
+
+## Technical Information
+
+The follow sections are devoted to explication of more technical information regarding how a compartment functions
+with in the broader scope of NGC-Sim-Lib and, furthermore, to explain how to leverage this information.
+
+### How Data is Stored (Within a Model Context)
+
+The data stored inside of a compartment is not actually physically stored within a compartment. Instead, it is stored
+inside of the global state and each compartment effectively holds the path or `key` to the right spot in the global
+state, allowing it to pull out a specific piece of information. As such, it is technically possible to manipulate the
+value of a compartment without actually touching the compartment object itself within any given component. By default,
+compartments have in-built safeguards in order to prevent this from happening accidentally; however, directly addressing
+the compartment within the global state directly has no such safeguards.
+
+### What is "Targeting"?
+
+As discussed in the model building section, there is notion of "wiring" together different compartments of different
+components -- this is at the core of NGC-Learn's and NGC-Sim-Lib's "nodes-and-cables system". These wires are created
+through the concept of "targeting,", which is, in essence, just the updating of the path stored within a compartment
+using the path of a different compartment. This means that, if the targeted compartment goes to retrieve the value
+stored within it, it will actually retrieve the value of a different compartment (as dictated by the target). When a
+compartment is in this state -- where it is targeting another compartment -- it is set to read-only, which only means that
+it cannot modify a different compartment.
+
+
diff --git a/docs/tutorials/configuration/compiling.md b/docs/tutorials/configuration/compiling.md
new file mode 100644
index 00000000..7f4aad06
--- /dev/null
+++ b/docs/tutorials/configuration/compiling.md
@@ -0,0 +1,98 @@
+# Compiling
+
+The term "compiling" for NGC-Sim-Lib refers to automatic step that happens
+inside of a context that produces a transformed method for all of its
+components. This step is the most complicated part of the library and, in
+general, does not need to be touched or interacted with. Nevertheless, this
+section will cover most of the steps that the NGC-Sim-Lib compilation process
+does at a high level. This section contains advanced technical/developer-level
+information: there is an expectation that the reader has an understanding of
+Python abstract syntax trees (ASTs), Python namespaces, and how to
+dynamically compile Python code and execute it.
+
+## The Decorator
+
+In NGC-Sim-Lib, there is a decorator marked as `@compilable` which is used to
+add a flag to methods that the user wants to compile. On its own, this will not
+do anything; however, this decorator lets the parser distinguish between methods
+that should be compiled and methods that should be ignored.
+
+## The Step-by-Step NGC-Sim-Lib Parsing Process
+
+The process starts by telling the parser to compile a specific object.
+
+### Step 1: Compile Children
+
+The first step to compile any object is to make sure that all of the
+"compilable" objects of the top level object are compiled. As a
+result, NGC-Sim-Lib will loop through all of the whole object and will compile
+each part that it finds that is flagged as compilable (via the decorators
+mentioned above) and is, furthermore, an instance of a class.
+
+### Step 2: Extract Methods to Compile
+
+While the parser is looping through all of the parts of the top-level object, it
+is also extracting the methods on/embedded to the object that are flagged as
+compilable (with the decorator above). NGC-Sim-Lib stores them for later;
+however, this lets the parser only loop over the object once.
+
+### Step 3: Parse Each Method
+
+As each method is its own entry-point into the transformer, this step will run
+for each method in the top-level object.
+
+### Step 3a: Set up a Transformer
+
+This step sets up a `ContextTransformer`, which further makes use of a
+`ast.NodeTransformer`, and will convert methods from class methods (with the use
+of `self`), as well as other methods that need to be removed / ignored, into
+their more context-friendly counterparts.
+
+### Step 3b: Transform the Function
+
+There are quite a few pieces of common Python that need to be transformed. This
+step happens with the overall goal of replacing all object-focused parts with a
+more global view. This means that a compartment's `.get` and `.set` calls are
+replaced with direct setting and getting from the global state, based on the
+compartment's target. This also means that all temporally constant values --
+such as `batch_size` -- are moved into the globals space for that specific file
+and ultimately replaced with the naming convention of `object_path_constant`.
+One more key step that is performed is to ensure that there is no branching in
+the code. Specifically, if there is a branch, i.e., an if-statement, NGC-Sim-Lib
+will evaluate it and only keep the branch it will traverse down. This means that
+there cannot be any branch logic based on inputs or computed values (this is a
+common restriction for just-in-time compiling).
+
+### Step 3c: Parse Sub-Methods
+
+Since it is possible to have other class methods that are not marked as
+entry-points for compilation but still need to be compiled, as step 3b happens,
+NGC-Sim-Lib tracks all of the sub-methods required. Notably, this step goes
+through and repeats steps 3a and 3b for each of the (sub-)methods with a naming
+convention similar to the temporally constant values for each method.
+
+### Step 3d: Compile the Abstract Syntax Tree (AST)
+
+Once we have all of the namespace and globals needed to execute the
+properly-transformed method, the method is compiled with Python and finally
+executed.
+
+### Step 3e: Binding
+
+The final step per method is to bind each to their original method; this
+replaces each method with an object which, when called, will act like the
+normal, uncompiled version but has the addition of the `.compiled` attribute.
+This attribute contains all of the compiled information to be used later (for
+model / system simulation). This crucially allows for the end user to
+call `myComponent.myMethod.compiled()` and have it run. The exact type for
+a `compiled` value can be found
+in `ngcsimlib._src.parser.utils:CompiledMethod`.
+
+### Step 4: Finishing Up / Final Processing
+
+Some objects, such as the processes, entail additional steps to modify
+themselves or their compiled methods in order to align themselves with needed
+functionality. However, this operation/functionality is found within each
+class's expanded `compile` method and should be referred to by looking at those
+methods specifically.
+
diff --git a/docs/tutorials/configuration/components.md b/docs/tutorials/configuration/components.md
new file mode 100644
index 00000000..c72f149d
--- /dev/null
+++ b/docs/tutorials/configuration/components.md
@@ -0,0 +1,32 @@
+# Components
+
+Living one step above compartments in the NGC-Learn dynamical systems hierachy rests the component.
+A component (`ngcsimlib.Component`) holds a collection of both temporally constant values as well as dynamic (time-evolving)
+compartments. In addition, they are the core place where logic governing the dynamics of a system are
+defined. Generally, components serve as the building blocks that are to be reused multiple times
+when constructing a complete model of a dynmical system.
+
+## Temporally Constant versus Dynamic Compartments
+
+One important distinction that needs to be highlighted within a component is the
+difference between a temporally constant value and a dynamic (time-varying) compartment.
+Compartments themselves house values that change over time and, generally, they will have the
+type `ngcsimlib.Compartment`; note that compartments are to be used to track the internal values
+of a component. These internal values can be ones such inputs, decaying values, counters, etc.
+The second kind of values found within a component are known as temporally constant values; these
+are values (e.g., hyper-parameters, structural parameters, etc.) that will remain fixed
+within constructed model dynamical system. These types of values tend to include common configuration
+and meta-parameter settings, such as matrix shapes and coefficients.
+
+## Defining Compilable Methods
+
+Inside of a component, it is expected that there will be methods defined that govern the
+temporal dynamics of the system component. These compilable methods are decorated
+with `@compilable` and are defined like any other regular (Python) method. Within a compilable
+method, there will be access to `self`, which means that, to reference a compartment's
+value, one must write out such a call as: `self.myCompartment.get()`. The only requirement is
+that any method that is decorated cannot have a return value; values should be stored
+inside their respective compartments (by making an appeal to their respective set routine, i.e.,
+`self.myCompartment.set(value)`). In an external (compilation) step, outside of the developer's
+definition of a component, an NGC-Sim-Lib transformer will change/convert all of these (decorated)
+methods into ones that function with the rest of the NGC-Sim-Lib back-end.
diff --git a/docs/tutorials/configuration/context.md b/docs/tutorials/configuration/context.md
new file mode 100644
index 00000000..1e8e06b2
--- /dev/null
+++ b/docs/tutorials/configuration/context.md
@@ -0,0 +1,67 @@
+# Contexts
+
+Contexts, in NGC-Sim-Lib, are the top-level containers that hold everything used to
+define a model / dynamical system. On their own, contexts have no runtime logic;
+they rely on their internal processes and components to build a complete, working model.
+
+## Defining a Context
+
+To define a context (`ngcsimlib.Context`), NGC-Sim-Lib leverages the `with` block; this
+means that to create a new context, simply start with the statement
+`with Context("myContext") as ctx:` and a new context will be created.
+(Important Note: names are unique; if a context is created with the same name,
+they will be the same context and, thus, there might be conflicts).
+A defined context does not do anything on its own.
+
+## Adding Components
+
+To add components to a context, simply initialize components while inside
+the `with` block of the context. Any component defined while inside this block
+will automatically be added and tacked-on to the context object.
+
+## Wiring Components
+
+Inside of a model / dynamical system, components will need to pass data to one
+another; this is configured within the context. To connect the compartments of
+two components, follow the pattern: `mySource.output >> myDestination.input`,
+where `output` and `input` are compartments inside their respective components.
+This format will ensure that, when processes are being run, the value will
+flow properly from component to component.
+
+### Operators
+
+There is a special type of wire called an operator; this performs a simple
+operation on the compartment values as the data flows from one component to
+another. Generally, these are use for simple mathematical operations, such as
+negation `Negate(mySource.output) >> myDestination.input` or the summation of
+multiple compartments into
+one `Summation(mySource1.output, mySource2.output, ...) >> myDestination.input`.
+Note that operators can be chained, so it would be possible to negate one or
+more of the inputs that flow into the summation.
+
+## Adding Processes
+
+To add processes to a context, simply initialize the process and add all of its
+steps while inside the `with`-block of the process.
+
+## Exiting the `with` block
+
+When the context exits the `with`-block, it will re-compile the entire model.
+Behind the scenes, this is calling `recompile` on the context
+itself; it is possible to manually trigger the recompile step, but doing so can
+break certain connections (between components/compartments), so use this
+functionality sparingly.
+
+## Saving and Loading
+
+The context's one unique job is the handling of the "saving" (serialization) and
+"loading" (de-serialization) of models to disk. By default, calling
+`save_to_json(...)` will create the correct file structure as well as the core files
+needed and load the context in the future. To load / de-serialize a model,
+calling `Context.load(...)` will load the context in from a directory; something
+important to note is that loading in a context entails effectively
+recreating the components with their initial values using their arguments as well as
+keywords arguments (excluding those that cannot be serialized). This means that,
+if you have a trained model, ensure that your components have a save method
+defined that will handle the saving and loading of all values within their compartments.
+
diff --git a/docs/tutorials/configuration/global_state.md b/docs/tutorials/configuration/global_state.md
new file mode 100644
index 00000000..1f78ac6e
--- /dev/null
+++ b/docs/tutorials/configuration/global_state.md
@@ -0,0 +1,52 @@
+# The Global State
+
+Since NGC-Sim-Lib is a simulation library focused on temporal models and dynamical
+systems, or models that change over time there, it is foundational that all models
+(and their respective elements) have some concept of a "state". These states
+might be comprised of a single value that changes/evolves or of a complex set of values
+that, when combined all together, make up the full dynamical system that underwrites the
+final model. In both cases, these sets of values are stored in what is known as the
+global state.
+
+## Interacting with the Global State
+
+Since the global state will contain a large amount of information describing a given
+model, there will be a need to facilitate interaction with and modification of the values
+contained within the global state. In most use-cases, this is not done directly. The
+most common way to interact with the global state is through the use of the state-manager.
+The state-manager exists to provide a set of helper methods for interacting with the
+global state itself. Note that, although the manager is there to assist you, it will not stop
+you from changing the state (or "breaking" the state). When changing the state -- beyond
+setting it through the specificaiton of processes -- be careful to not add or remove
+anything that is needed for your actual model.
+
+### Adding New Fields to the Global State
+
+If you are new to using NGC-Sim-Lib and looking for a way to add values to the
+global state directly and explicitly, stop for a moment and reconsider. Unless
+you know exactly what you are doing (i.e., doing core development), it is strongly
+advised to not manually add values to the global state; instead, work through the
+mechanisms afforded by compartments and/or components, as these are built to afford you the
+most common ways for adding fields to the global state itself. The dynamical systems
+semantics inherent to compartments and components is meant to ensure carefully-constrained
+design and simulation of flexible models.
+
+If you actually intend to manually and directly add values to the global state itself, it
+is done through the use of the `add_key` method. This will create the appropriate key in
+the global state for the given path and name; furthermore, its value can be retrieved
+with `from_key` calls. This value, however, is not linked to a compartment and, therefore,
+will be hard to get working properly in the compiled methods without some specific references.
+Please take extra care when working directly and explicitly with the global state.
+
+### Getting the Current State
+
+To get the current state, simply call `global_state_manager.state`; this will give
+you a (shallow) copy of the current state, which means that any modifications made to it will
+not be reflected in the global state.
+
+### Updating the Global State
+
+To manually update the global state after modifying a local copy; please write an overriding
+call command: `global_state_manager.state = new_state`. This will update the state with the
+`.update` call to its underlying dictionaries, which means that a partial state will still update correctly.
+
diff --git a/docs/tutorials/configuration/index.rst b/docs/tutorials/configuration/index.rst
new file mode 100644
index 00000000..e9ccc8ca
--- /dev/null
+++ b/docs/tutorials/configuration/index.rst
@@ -0,0 +1,28 @@
+.. ngc-learn documentation master file, created by
+ sphinx-quickstart on Wed Apr 20 02:52:17 2022.
+ Note - This file needs to at least contain a root `toctree` directive.
+
+Configuration Basics
+====================
+
+This set of guides provide information about the fundamental building blocks that characterize the NGC-Learn as well
+as the operation of NGC-Learn's back-end, NGC-Sim-Lib.
+For end-users (experimentalists, engineers), the sections under "Building Blocks" will be most informative. For
+developers, the sections under "Development Information" are recommended, particularly for advanced use-cases and
+low-level development.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Building Blocks
+
+ context
+ components
+ compartments
+ processes
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Development Information
+
+ global_state
+ compiling
\ No newline at end of file
diff --git a/docs/tutorials/configuration/processes.md b/docs/tutorials/configuration/processes.md
new file mode 100644
index 00000000..a7b9b8f5
--- /dev/null
+++ b/docs/tutorials/configuration/processes.md
@@ -0,0 +1,90 @@
+# Processes
+
+Processes in NGC-Sim-Lib offer a central way of defining a specific transition to be
+taken within a given model (this effectively sets up the behavior of the state-machine
+that defines the desired dynamical system one wants to simulate). In effect, processes
+take in as many compilable methods as possible across any number of
+components; they work to produce a single top-level method and a varying number of
+sub-methods needed to execute the entire chain of compilable methods in one (single) step.
+This is ultimately done to interface nicely with just-in-time (JIT) compilers, such as
+the one inherent to JAX, and to minimize the amount of read and write calls done across
+a chain of methods.
+
+## Building the (Command) Chain
+
+Building the chain that a process will use is done through an iterative process. Once
+the process object is created, steps are added using either `.then()` or `>>`.
+As an example:
+
+```
+myProcess.then(myCompA.forward).then(myCompB.forward).then(myCompA.evolve).then(myCompB.evolve)
+```
+
+or
+
+```
+myProcess >> myCompA.forward >> myCompB.forward >> myCompA.evolve >> myCompB.evolve
+```
+
+In both cases, this process will chain the four methods together into a single
+step, only updating the final state after all steps are complete.
+
+## Types of Processes
+
+There are two types of processes: the above example would be with what is
+referred to as a `MethodProcess` -- these are used to chain together any
+compilable methods from any number of different components. The other second
+type of process, called a `JointProcess` in NGC-Sim-Lib, is used to chain
+together entire processes.
+JointProcesses are especially useful if there are multiple method processes that
+need to be called but different orders of the processes are needed at different
+times. These allow for the specification of complex events / behaviors in a
+dynamical system that one will simulate.
+
+## Extra Elements
+
+There are a few extra methods that come standard with each process type which can
+be useful for both regular operation as well as debugging.
+
+### Viewing the Compiled Method
+
+Behind the scenes, a process is transforming and compiling down all of the steps
+used to build it; this means that the exact code it is running to do its
+set of calculations will ultimately not be what the user wrote. To allow for
+the end user to view and make sure that the two pieces of code -- theirs and
+the compiled version -- are equivalent (and yielding expected behavior), every
+process has a `view_compiled_method()` call which can be used after the (final) model
+is compiled. This call will return the code (block) that it will be running as a
+string. There will be some stark differences between the produced/generated code and
+the code in the (Python) components used to build the steps. Please refer to the
+compiling page for a more in-depth guide to comparing the outputs between these
+two stages of code.
+
+### Needed Keywords
+
+Since some methods will require external values such as `t` (for time) or `dt`
+(for integration time / the temporal delta) for a given execution, a process
+will also track all the keyword arguments that are needed to run their compiled
+process. To view which keywords a given process is expecting, one may use the
+command: `get_keywords()`.
+This is mostly used for debugging and/or as a sanity check.
+
+### Packing Keywords
+
+To add onto the needed keywords, the process also provides an interface to
+produce the keywords needed to run in the form of two methods. The first method
+is `pack_keywords(...)`; this method packs together a single row of values that
+are needed to run a single execution (step) of the process. The arguments are
+the `row_seed`, which is a seed that is to be passed to all of the keyword
+generators (only needed if generators are being used).
+The second set of arguments are keyword arguments that are either constant,
+such as `dt=0.1`, or generators, such as `lambda row_seed: 0.1 * row_seed`.
+The second method for generating the keywords for a process is with `pack_rows(...)`.
+This method will create many sets of keywords that are needed to run multiple
+iterations of the process. Note that the arguments are slightly different: first,
+it now utilizes a `length` argument to indicate the number of rows being produced and,
+second, it features a `seed_generator` that is used to generate the seed of each row
+(for instance, to have only even seed values: `seed_generator = lambda x: 2 * x`); if
+the generator is `None`, then `seed_generator = lamda x: x` is used.
+After this, the same keyword arguments to define the needed parameters are used as in `pack_keywords`.
+
diff --git a/docs/tutorials/foundations.md b/docs/tutorials/foundations.md
deleted file mode 100644
index 15887da3..00000000
--- a/docs/tutorials/foundations.md
+++ /dev/null
@@ -1,18 +0,0 @@
-# Foundational Elements
-
-In this set of tutorials/walkthroughs, we go through the some of the core elements
-and mechanisms underlying ngc-learn in order understand how its simulation
-scheme (and the nodes-and-cables system) works and to help in writing your
-own custom elements.
-
-The foundational walkthroughs are organized as follows:
-1. [Using Model Contexts](../tutorials/foundations/contexts.md): This lesson goes
- the fundamentals of the primary simulation construct you need to set up models, the
- (simulation) context.
-2. [Understanding Commands](../tutorials/foundations/commands.md): This lesson will
- walk you through the basics of a command -- an essential part of building a
- simulation controller in ngc-learn and ngcsimlib -- and offer some useful
- points for designing new ones.
-3. [Operations](../tutorials/foundations/operations.md): Here, the basics
- of bundle rules, a commonly use mechanism for crafting complex biophysical
- systems, will be presented.
diff --git a/docs/tutorials/foundations/commands.md b/docs/tutorials/foundations/commands.md
deleted file mode 100755
index e3fff8f8..00000000
--- a/docs/tutorials/foundations/commands.md
+++ /dev/null
@@ -1,137 +0,0 @@
-# Understanding Commands
-
-## Overview
-Commands are one of the central pillars of
-ngcsimlib, the dependency
-library that drives ngc-learn's simulation backend.
-In general, commands provide the instructions and logic for what each component
-should be doing at any given time. In addition, they are the normal way that an
-outside user would interact with ngc-learn models. Commands live inside a model's
-controller and are generally made with the `add_command` method.
-
-## Abstract Command
-Contained within ngcsimlib is an abstract class for every command included in
-ngcsimlib. It is strongly recommended that custom commands are built using this
-base class (but there is nothing enforcing this inside of ngcsimlib).
-
-At its base the abstract command forces two things: firstly, the constructor
-for the base class requires a list of components, and a list of attributes that
-each component should have. Secondly, all commands must implement their
-`__call__` command, taking in only `*args` and `**kwargs`.
-
-## Constructing Commands
-It is common that commands will need to have values passed into them to control
-their internal behavior, such as a value to clamp, or a flag for freezing
-synaptic weight values.
-To do this, we introduce the notion of binding keywords to commands.
-Specifically, commands will take strings in during their construction and then
-look for those strings when called inside the list of keyword arguments in order
-to get their arguments.
-
-## Calling Commands
-When commands are called, they will take in only `*args` and `**kwargs`.
-While custom commands can break this by adding in additional arguments
-without any problem, it is not recommended to do this as multiple instances
-of a command with different parameters will then use the same keyword for their
-call.
-
-## Creating Custom Commands
-It is recommended that all custom commands inherit from the base class
-provided within ngcsimlib. This provides a good starting point for designing a
-component that will seamlessly interact with ngcsimlib's internal simulation mechanics.
-These mechanics, which characterize the core operation of a simulation controller,
-entail that, for each command supplied to a controller, a command will call the
-same function with the same parameters on each component provided
-to that very command. It is also expected that there is error handling within the
-constructor to catch as many runtime errors as possible. Note that base
-command class provides a list to check required calls such as `reset` or `evolve`.
-
-It is important to note that, if commands are going to be constructed via a
-controller, they should have keyword arguments with default values that
-error out on bad input instead of positional arguments.
-
-## Example Command (reset)
-
-Below, we present the key bits of source code that characterize a reset command
--- a very commonly used, built-in command for models designed in ngc-learn -- and
-its internal operation:
-
-```python
-from ngcsimlib.commands import Command
-from ngcsimlib.utils import extract_args
-from ngcsimlib.logger import warn, error
-
-class Reset(Command):
- def __init__(self, components=None, reset_name=None, command_name=None,
- **kwargs):
- super().__init__(components=components, command_name=command_name,
- required_calls=['reset'])
- if reset_name is None:
- error(self.name, "requires a \'reset_name\' to bind to for construction")
- self.reset_name = reset_name
-
- def __call__(self, *args, **kwargs):
- try:
- vals = extract_args([self.reset_name], *args, **kwargs)
- except RuntimeError:
- warn(self.name, ",", self.reset_name,
- "is missing from keyword arguments and no positional arguments were provided")
- return
-
- if vals[self.reset_name]:
- for component in self.components:
- self.components[component].reset()
-```
-
-## Custom Command Template
-
-Here, we show the generic command template which shows how one would go about
-designing the key operational bits that make up a useful command.
-
-```python
-from ngcsimlib.commands.command import Command
-from ngcsimlib.utils import extract_args
-from ngcsimlib.logger import error
-
-
-class CustomCommand(Command):
- def __init__(self, components=None, BINDING_VALUE=None, ADDITIONAL_INPUT=None, command_name=None,
- **kwargs):
- super().__init__(components=components, command_name=None, required_calls=['CUSTOM_CALL'])
- # Make sure additional input is passed in
- if ADDITIONAL_INPUT is None:
- error(self.name, "requires a \'ADDITIONAL_INPUT\' for construction")
-
- # Make sure command is bound to a value
- if BINDING_VALUE is None:
- error(self.name, "requires a \'BINDING_VALUE\' to bind to for construction")
-
- self.BOUND_VALUE = BINDING_VALUE
- self.ADDITION_VALUE = ADDITIONAL_INPUT
-
- def __call__(self, *args, **kwargs):
- # Extract the bound value from the arguments
- try:
- vals = extract_args([self.BOUND_VALUE], *args, **kwargs)
- except RuntimeError:
- error(self.name, ",", str(self.BOUND_VALUE), "is missing from keyword arguments or a positional "
- "arguments can be provided")
-
- #Use extracted value to call a method on each component
- for component in self.components:
- self.components[component].CUSTOM_CALL(self.ADDITION_VALUE, vals[self.BOUND_VALUE])
-```
-
-## Notes
-All components added to commands must have a `name` attribute and the word
-`name` is automatically appended to any provided list of required attributes
-to the base class constructor.
-
-As all built-in commands use `extract_args` when called with a controller via
-`myController.COMMAND(ARGUMENT)`, there is no need to use keywords as it will
-use `args` if there are no keyword arguments. (Keywords will still work, however.)
-
-When commands are constructed via a controller, they are also provided with the
-keyword arguments `controller` and `command_name`. It is not recommended to
-use these for any core logic (just use them for error messages), unless
-it using them is absolutely essential in achieving the desired functionality.
diff --git a/docs/tutorials/foundations/contexts.md b/docs/tutorials/foundations/contexts.md
deleted file mode 100644
index b025cd8d..00000000
--- a/docs/tutorials/foundations/contexts.md
+++ /dev/null
@@ -1,70 +0,0 @@
-# What are Contexts
-
-A context in ngclearn is a container that holds all the information for your model and can be used as an access point to
-reference different models in a multi-model system. Some of the information that contexts hold is all the components
-defined in the context, all the wiring information for each of the components, as well as all the commands defined on
-the context through various means.
-
-## How to make a Context
-
-To make a context first import it from ngclearn with `from ngclearn import Context`. This will give you access to not
-only the constructor for new contexts but also the ability to get previously defined contexts. The general use case for
-this is
-
-```python
-from ngclearn import Context
-
-with Context("Model1") as model1:
- pass
-```
-
-This will make a context named "Model1" and also drops you into a with block where you can define the various parts of
-the model. The call `Context("Model1")` will always return the same context. So if there is already a model with that
-name defined earlier in the code this instance of `model1` will have all the same object defined previously.
-
-## Adding Components
-
-The best way to add components to a context is by using components that have implemented the `MetaComponent` metaclass.
-In ngclearn the base `Component` class does this. If using these components all that is needed to have them added to
-the context is calling their constructors inside a with block of the context. For example
-
-```python
-from ngclearn import Context
-from ngclearn.components import LIFCell
-
-with Context("Model1") as model1:
- z1 = LIFCell("z1", n_units=10, tau_m=100)
-```
-
-## Creating Cables
-
-To add connections between components and their compartments in a model we do that also in a context. Just like with
-components there are no special actions that need to be taken to add them beyond doing so in a with block. To connect
-to compartments the `<<` operator is used following the outline of `destination << source`. For example
-
-```python
-with model1:
- w1.inputs << z1.s
-```
-
-## Dynamic Commands
-
-When building models it can be desirable to use the same training and testing scrips while having commands do different
-actions. For example if two different models had different clamp procedures to set inputs and labels it is possible to
-dynamically add a generic clamp command to each model and call them the same way despite them doing different things.
-As an example
-```python
-with model1:
- @model1.dynamicCommand
- def clamp(inputs, labels):
- z0.inputs.clamp(inputs)
- z2.labels.clamp(labels)
-
-with model2:
- @model2.dynamicCommand
- def clamp(inputs, labels):
- z0.inputs.clamp(inputs)
- z0_p.inputs.clamp(inputs)
- z2.labels.clamp(labels)
-```
-In both these cases later we can just call clamp and each one will call their own version of the clamp command.
diff --git a/docs/tutorials/foundations/monitors.md b/docs/tutorials/foundations/monitors.md
deleted file mode 100644
index f835a59d..00000000
--- a/docs/tutorials/foundations/monitors.md
+++ /dev/null
@@ -1,72 +0,0 @@
-# NGC Monitors
-
-NGC-monitors are a way of storing a rolling window of compartment values
-automatically. Their intended purpose is not to be used "inside" of a model but
-just as an auxiliary way to view the internal state of the model even when it is
-compiled. A monitor will track the last `n` values it has observed within the
-compartment with the oldest value being at `index=0` and the newest being at
-`index=n-1`.
-
-## Building a Monitor
-
-Monitors are constructed exactly like regular components are for general models.
-To use one, simply import the monitor `from ngclearn.components import Monitor`. Now,
-inside of your model, build it like a regular component:
-
-```python
-with Context("model") as model:
- M = Monitor("M", default_window_length=100)
-```
-
-## Watching compartments
-
-There are then two key ways of watching compartments, the first way looks similar
-to the wiring paradigm found in connecting standard ngclearn components
-together. The primary difference is that connecting compartments to the monitor
-does not require a compartment, they are wired directly into the `Monitor`
-following the pattern below:
-
-```python
- M << z0.s
-```
-
-This will wire the spike output of `z0` into the monitor with a view window
-length of the `default_window_length`. In the event that you want a view window
-that is not the default viewing length, you can use the `watch()` method
-instead as in below:
-
-```python
- M.watch(z0.s, customWindowLength)
-```
-
-There is no limit to the number of compartments that a monitor can watch or the
-length of the window that it can store. However, as it is constantly shifting
-values, tracking large matrices, such as those containing synapse values
-over many timesteps, may get expensive.
-
-For the monitor to run during your `advance_state` and `reset` calls, make sure
-to add to the monitor to the list of components to compile. Currently,
-monitors do not work with non-compiled methods
-(This is a planned feature for future developments of ngc-learn).
-
-## Extracting Values
-
-To look at the currently stored window of any compartment being tracked, there
-are two methods available to you. The first method requires that you have
-access to the compartment that the monitor is watching. To read out the
-monitors values, you can call:
-
-```python
-M.view(z0.s)
-```
-
-In the event that you do not have access to the compartment, all of the stored
-values can be found via the path using the following:
-
-```python
-M.get_store("path/to/compartment")
-```
-
-The stored windows are kept in a tree of dictionaries, where each node is a part
-of the path and the leaves are compartment objects holding the
-tracked value windows.
\ No newline at end of file
diff --git a/docs/tutorials/foundations/operations.md b/docs/tutorials/foundations/operations.md
deleted file mode 100755
index 7584d2d2..00000000
--- a/docs/tutorials/foundations/operations.md
+++ /dev/null
@@ -1,60 +0,0 @@
-# The Basics of Operations
-
-## What are Operations?
-
-The underlying method for passing data/signals from component to component inside of
-contexts is through the use of cables. A large amount of the time compartments will have a single cable being passed
-connected to it that overwrites the previous value in that compartment. However, there are times when this is not the
-case and then cable operations must be used.
-
-## Built-in operations
-
-By default, ngclearn comes with four operations defined, `overwrite`, `negate`, `summation` and `add`. Of these four
-operations the default one used by all cables is the overwrite operation. This operations will take the value of its
-source compartment and place it into the destination compartment overwriting the value currently there. The negate
-operation has a similar effect as the overwrite operation with the added functionality of applying the `-` operation to
-the value being transmitted. The summation operation takes in any number of source compartments and sums together all
-their values and overwrites the previous value with the sum. Finally, the add operation does the same thing as the
-summation operation but instead adds the sum to the previous value instead of overwriting it.
-
-## Building Custom Operation
-
-At its core, an operation is a static method that does all the runtime logic of the operation with the source
-compartments, and a resolver that does clean up and assignment of the output of the operation to the destination
-compartment.
-
-> General Form of an Operation:
-> ```python
-> class operationName(BaseOp):
-> @staticmethod
-> def operation(*sources):
-> #Runtime Logic
-> return computed_value
-> ```
-
-> Example Operation (Summation)
-> ```python
-> class summation(BaseOp):
-> @staticmethod
-> def operation(*sources):
-> s = None
-> for source in sources:
-> if s is None:
-> s = source
-> else:
-> s += source
-> return s
-> ```
-
-## Notes
-
-- Every cable coming into or out of a compartment can have a different operation.
-
-- The order of these operations should be the order they are wired in, but this is not guaranteed.
-
-- Only the logic that exists in the static method `operation` is used for a compiled operation, all logic existing in an overwritten resolve method is not captured.
-
-- Some operations have a flag of `is_compilable` set to false. This is checked during compile to flag if the model can be compiled.
-
-- Operations can be nested so `summaion(negate(c1), c2)` would be a valid operation and will work while compiled
-
diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst
index f1640834..5f02878c 100644
--- a/docs/tutorials/index.rst
+++ b/docs/tutorials/index.rst
@@ -2,41 +2,15 @@
sphinx-quickstart on Wed Apr 20 02:52:17 2022.
Note - This file needs to at least contain a root `toctree` directive.
-Tutorial Contents
-=================
+Modeling Basics
+===============
-Lessons/tutorials go through the very basics of constructing a dynamical system in
-ngc-learn, core elements and tools of neurocognitive modeling using ngc-learn's
-in-built components and simulation tools, and finally providing foundational insights
-into how ngc-learn and its backend, ngc-sim-lib, work (particularly with respect
-to configuration).
+Lessons/tutorials go through the very basics of constructing a dynamical system in NGC-Learn, core elements and tools of neurocognitive modeling using NGC-Learn's in-built components and simulation tools, and finally providing foundational insights into how NGC-Learn and its backend, NGC-Sim-Lib, work (particularly with respect to configuration).
.. toctree::
- :maxdepth: 1
- :caption: I. Modeling Basics
+ :maxdepth: 2
+ :caption: Table of Contents
model_basics/configuration
- model_basics/json_modules
model_basics/model_building
model_basics/evolving_synapses
-
-.. toctree::
- :maxdepth: 1
- :caption: II. NGC-Learn/Sim-Lib Foundations
-
- foundations
- foundations/contexts
- foundations/commands
- foundations/operations
- foundations/monitors
-
-.. toctree::
- :maxdepth: 1
- :caption: III. NGC-Lava: Support for Loihi 2 Transfer
-
- lava/introduction
- lava/setup
- lava/lava_components
- lava/lava_context
- lava/hebbian_learning
- lava/monitors
diff --git a/docs/tutorials/intro.md b/docs/tutorials/intro.md
index 5651a453..e835819c 100755
--- a/docs/tutorials/intro.md
+++ b/docs/tutorials/intro.md
@@ -1,55 +1,27 @@
# Introduction
-ngc-learn is a general-purpose library for modeling biomimetic/neuro-mimetic
-complex systems. While the library is designed to provide flexibility on the
-experimenter/designer side -- allowing one to design their own dynamics and
-evolutionary processes -- at its foundation are a few standard components, the
-basic modeling nodes for simulating some common biophysical systems computationally,
-that useful to know in getting started and quickly building some classical/historical
-models. If you are interested in knowing some of the neurophysiological theory
-behind ngc-learn's design philosophy, [this section](../tutorials/theory) might
-be of interest.
+NGC-Learn is a general-purpose library for modeling complex dynamical systems, particularly those that are useful for
+computational neuroscience, neuroscience-motivated artificial intelligence (NeuroAI), and brain-inspired computing.
+
+While the library is designed to provide flexibility on the experimenter/designer side -- allowing one to develop their
+own dynamics and evolutionary processes -- at its foundation are a few standard components. These are basic modeling
+nodes for simulating some common biophysical systems computationally, which are useful to know when getting started and
+for quickly building some classical/historical models. If you are interested in knowing some of the neurophysiological
+theory behind NGC-Learn's design philosophy, [this section](../tutorials/theory) might be of interest.
-Specifically, to make best use of ngc-learn, it is important to get the
-hang of its "nodes-and-cables system" (as it was historically referred to) in
-order to build simulation objects. This set of tutorials will walk through,
-step-by-step, the key aspects of the library you need to know so you can build
-and run simulations of computational biophysical models. In addition, we
-provide walkthroughs of some of the central mechanisms underlying
-ngcsimlib, the simulation
-dependency library that drives ngc-learn; these are particularly useful for not
-only understanding why and how things are done by ngc-learn's simulation
-backend but also for those who want to design new, custom extensions of ngc-learn
-either for their own research or to contribute to the development of the main library.
+Specifically, to make best use of NGC-Learn, it is important to get the hang of its "nodes-and-cables system" (the
+historical name for its backend engine) in order to build simulation objects. This set of tutorials will walk you
+through, step-by-step, the key aspects of the library that you will need to know so that you can build
+and run simulations of computational biophysical models. In addition, we provide walkthroughs of some of the central
+mechanisms underlying NGC-Sim-Lib, the simulation dependency
+library that drives NGC-Learn; these lessons are particularly useful for not only understanding why and how things are
+done by NGC-Learn's simulation backend engine but also for those who want to design new, custom extensions of NGC-Learn
+either for their own research or to help contribute to the development of the main library.
## Organization of Tutorials
-The core tutorials and lessons for using ngc-learn can be found [here, in the
-tutorial table of contents](../tutorials/index.rst) and go through: the basic
-configuration and use of ngc-learn and ngc-sim-lib to construct simulations
-of dynamical systems, the essentials of neurocognitive modeling (such as
-building and analyzing neuronal dynamics and synaptic plasticity), as well
-as the coverage of some key foundational ideas/tools worth knowing about
-ngc-learn (and its backend, ngc-sim-lib) particularly to facilitate easier
-debugging, experimental configuration, and advanced model tools like `bundle rules`.
-
-
+The core tutorials and usage lessons for using NGC-Learn can be found [here, in the modeling basics table of contents](../tutorials/index.rst) which essentially go through: the basic configuration and use of NGC-Learn (and NGC-Sim-Lib) to
+construct simulations of basic dynamical systems.
+More advanced tutorials related to the essentials of neurocognitive modeling -- such as building and analyzing
+neuroscience models of neuronal dynamics and synaptic plasticity -- can be found [here, in the neurocognitive modeling
+table of contents](../tutorials/neurocog/index.rst).
diff --git a/docs/tutorials/lava/hebbian_learning.md b/docs/tutorials/lava/hebbian_learning.md
deleted file mode 100644
index c45a01b6..00000000
--- a/docs/tutorials/lava/hebbian_learning.md
+++ /dev/null
@@ -1,312 +0,0 @@
-# Training a Spiking Network On-chip
-
-In this tutorial we will build generate a simple dataset (consisting of binary
-patterns of X's, O's, and T's) and train a model in the Loihi simulator
-using Hebbian learning in the form of trace-based spike-timing-dependent
-plasticity.
-
-## Setting up ngc-learn
-
-The first step of this project consist of setting up the configuration file for
-ngc-learn. Create a folder in your project directory root called `json_files`
-and
-then create a `config.json` configuration inside of that folder.
-
-Now for this project we will not be loading anything dynamically, so we can
-simply add:
-
-```json
-{
- "modules": {
- "module_path": null
- }
-}
-```
-
-The above configuration will skip the dynamic loading of modules, which is
-important for Lava-based model transference and simulation.
-
-Next, in order to run code with the Loihi simulator, the base version of numpy
-needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts
-to by default). To change all of ngc-learn over to using the base version of
-numpy, simply add the following to your configuration:
-
-```json
-"packages": {
- "use_base_numpy": true
-}
-```
-
-Now your project is configured for ngc-lava and Lava usage and we can move on
-to data generation.
-
-## Generating Data
-
-For this project we will be using three different patterns to train a simple
-biophysical spiking neural network; the data will simply consist of binary
-image patterns of either an `X`, `O`, and a `T`. To create the file needed to
-generate these patterns, create a Python script named `data_generator.py` in
-your project root. Next, we will import `numpy` and `random` and
-define the following three generator methods:
-
-```python
-from ngclearn import numpy as np
-
-
-def make_X(size):
- X = np.zeros((size, size))
- for i in range(0, size):
- X[i, i] = np.random.uniform(0.75, 1)
- X[i, size - 1 - i] = np.random.uniform(0.75, 1)
- return X
-
-
-def make_O(size):
- O = np.zeros((size, size))
- for i in range(0, (size // 2) - 1):
- O[1 + i, (size // 2) - 1 - i] = np.random.uniform(0.75, 1)
- O[1 + i, (size // 2) + i] = np.random.uniform(0.75, 1)
- O[(size // 2) + i, 1 + i] = np.random.uniform(0.75, 1)
- O[(size // 2) + i, size - 2 - i] = np.random.uniform(0.75, 1)
- return O
-
-
-def make_T(size):
- T = np.zeros((size, size))
- T[1, 1:size - 1] = np.random.uniform(0.75, 1, (1, size - 2))
- for i in range(2, size - 1):
- T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1,
- (1, 2))
- return T
-```
-
-Each of these methods will create a pattern of the desired size and shape.
-
-## Building the Model
-
-Found below is all of the imports that will be needed to run the model we desire
-in Lava:
-
-```python
-from ngclava import LavaContext
-from ngclearn import numpy as np
-from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse, Monitor
-import ngclearn.utils.viz as viz_utils
-import ngclearn.utils.weight_distribution as dist
-from data_generator import make_X, make_O, make_T
-```
-
-To start off building this model, we will define all of the hyperparameters
-needed to create the necessary model components:
-
-```python
-# Training Params
-epochs = 35
-view_length = 200
-rest_length = 1000
-
-# Model Params
-n_in = 64 # Input layer size
-n_hid = 25 # Hidden layer size
-dt = 1. # ms # integration time constant
-np.random.seed(42) ## seed the internal numpy calls
-```
-
-After this we will create the lava context, the components, as well as the
-wiring:
-
-```python
-with LavaContext("Model") as model:
- z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt,
- tau_m=50., v_decay=0., tau_theta=500.,
- refract_T=0.) ## IF cell
- z1e = LIFCell("z1e", n_units=n_hid,
- thr_theta_init=dist.uniform(amin=-2, amax=2.),
- dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell
- z1i = LIFCell("z1i", n_units=n_hid,
- thr_theta_init=dist.uniform(amin=-2, amax=2.),
- dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45.,
- theta_plus=0.) ## inhibitory LIF cell
-
- tr0 = GatedTrace("tr0", n_units=n_in, dt=dt, tau_tr=20.)
- tr1 = GatedTrace("tr1", n_units=n_hid, dt=dt, tau_tr=20.)
-
- W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3),
- shape=(n_in, n_hid), dt=dt, Aplus=0.011,
- Aminus=0.0011,
- preTrace_target=0.055)
- W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.),
- shape=(n_hid, n_hid), dt=dt)
- W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5),
- shape=(n_hid, n_hid), dt=dt)
-
- M = Monitor("M", default_window_length=view_length)
-
- ## wire z0 to z1e via W1 and z1i to z1e via W1ie
- W1.inputs << z0.s
- W1ie.inputs << z1i.s
-
- z1e.j_exc << W1.outputs
- z1e.j_inh << W1ie.outputs
-
- # wire z1e to z1i via W1ie
- W1ei.inputs << z1e.s
- z1i.j_exc << W1ei.outputs
-
- # wire cells z0 and z1e to their respective traces
- tr0.inputs << z0.s
- tr1.inputs << z1e.s
-
- # wire relevant compartment statistics to synaptic cable W1 (for STDP update)
- W1.x_pre << tr0.trace
- W1.pre << z0.s
- W1.x_post << tr1.trace
- W1.post << z1e.s
-
- # set up monitoring of z1e's spike output
- M << z1e.s
-```
-
-After the components have been set up, we have to "lag out" the synapses that
-will cause recurrent (locking) problems when running on the Loihi2. This will
-cause each of these synapses to run one time-step behind and fixes many
-recurrency
-issues (as described [here](lava_context.md)).
-
-```python
- model.set_lag('W1')
- model.set_lag('W1ie')
- model.set_lag('W1ei')
-```
-
-Now that the model is all set up, we have to tell the Lava compiler to actually
-build all the Lava objects with the following:
-
-```python
- model.rebuild_lava()
-```
-
-This line will stop the automatic build of components when leaving this
-with-block and provides access to all of the Lava components inside of this
-with-block.
-
-Next, we set up two methods, a `clamp` method to set the input data and
-`viz` to visualize all of the different receptive fields of our model:
-
-```python
- lz0, lW1 = model.get_lava_components('z0', 'W1')
-
-
- @model.dynamicCommand
- def clamp(x):
- model.pause()
- lz0.j_exc.set(x)
-
-
- @model.dynamicCommand
- def viz():
- viz_utils.synapse_plot.visualize([lW1.weights.get()], [(8, 8)], "lava_fields")
-```
-
-## Running The Model
-
-Now that everything is set up to build the runtime and start training the model
-inside of the Loihi simulator. To set up the runtime we call the following:
-
-```python
- model.set_up_runtime("z0", rest_image=np.zeros((1, 64)))
-```
-
-This will set up a runtime with `z0` as the root node and also uses a resting
-image of all zeros to allow the system to return to its resting state.
-
-Now the training loop will be as follows:
-
-```python
-with model.runtime:
- for i in range(epochs):
- print(f"\rStarting Epoch: {i}", end="")
- X = np.reshape(make_X(8), (1, 64))
- O = np.reshape(make_O(8), (1, 64))
- T = np.reshape(make_T(8), (1, 64))
-
- model.view(X, view_length)
- model.rest(rest_length)
-
- model.view(O, view_length)
- model.rest(rest_length)
-
- model.view(T, view_length)
- model.rest(rest_length)
- print("\nDone Training")
-
-```
-
-## Evaluating the On-Chip Trained Model
-
-The code above will work to train the model on a Loihi neuromorphic chip, but,
-currently, we do not have a way of viewing how effective the model learned
-really is. To set up this evaluation, we can call
-the `viz` method defined above to view the receptive fields that our spiking
-model has acquired:
-
-```python
- model.viz()
-```
-
-Running this should produce a set of receptive fields that look like the
-following:
-
-
-
-While viewing the receptive fields qualitatively tells us that our spiking
-model has trained, we may also want to view the
-[raster plots](ngclearn.utils.viz.raster) -- visual depictions of the
-underlying spike patterns acquired in the hidden layer of our model -- for each
-of our three image patterns (as they are fed into our trained model). To do
-this, we will make use of the monitor we defined above in the following manner:
-
-```python
- ## Turning off learning
- lW1.eta.set(np.array([0]))
-
- model.view(np.reshape(make_T(8), (1, 64)), view_length)
- model.write_to_ngc()
- spikes = M.view(z1e.s)
- viz_utils.raster.create_raster_plot(spikes, tag="T", plot_fname="raster_T")
- model.rest(rest_length)
- print("Done T")
-
- model.view(np.reshape(make_X(8), (1, 64)), view_length)
- model.write_to_ngc()
- spikes = M.view(z1e.s)
- viz_utils.raster.create_raster_plot(spikes, tag="X", plot_fname="raster_X")
- model.rest(rest_length)
- print("Done X")
-
- model.view(np.reshape(make_O(8), (1, 64)), view_length)
- model.write_to_ngc()
- spikes = M.view(z1e.s)
- viz_utils.raster.create_raster_plot(spikes, tag="O", plot_fname="raster_O")
- model.rest(rest_length)
- print("Done O")
-```
-
-The above should result in raster plots where the spikes correspond to the
-receptive fields of each trained letter pattern. Specifically, you should see
-that the top left field is `N0` and the bottom right is `N24`. Your raster plots
-should look like the ones below:
-
-
-
-
-
-Finally to save the model to disk, you can call the following:
-
-```python
- model.save_to_json(".", model_name="trained")
-```
-
-which will save your on-chip trained Loihi model to disk for later use.
-
-
diff --git a/docs/tutorials/lava/introduction.md b/docs/tutorials/lava/introduction.md
deleted file mode 100644
index 994e3809..00000000
--- a/docs/tutorials/lava/introduction.md
+++ /dev/null
@@ -1,37 +0,0 @@
-# Blending ngc-learn and lava-nc
-
-The subpackage of ngclearn known as ngc-lava is an interfacing layer between
-ngclearn's components and contexts and lava-nc's models and processes. In this
-package, there is the introduction of the `LavaContext`, a subclass of the default
-ngclearn `Context`. This context has all the same functionality as the base
-ngclearn context but adds the ability to convert lava compatible components into
-their Lava process and model automatically and on-the-fly. This allows for the
-development and testing of models inside ngclearn prior to their deployment onto
-a Loihi neuromorphic chip without needing to translate between the two models
-written across the two different Python libraries.
-
-## Some Cautionary Notes
-
-- For the best experience in training models in ngclearn, Python version `>=3.10`
- should be used. However, much of lava is written to be used in Python `3.8` and,
- because of this, there are some flags and functionality that cannot be used in Lava
- components directly. It is for this reason that ngc-learn has several
- in-built "lava components", i.e., those in `ngclearn.components.lava` that
- are meant to directly interact with ngc-lava; other components (such as those
- (`ngclearn.components.neurons` or `ngclearn.components.synapses`) are not likely
- to work and, when writing your own custom ngc-lava components, we recommend
- that you use those in the `ngclearn.components.lava` subpackage as starting
- points to see what design patterns will work with Lava.
-- As of right now, all of ngc-lava is built using the Loihi2 configuration and
- Loihi1 is not actively supported. Loihi1 might still work but nothing is
- guaranteed nor has been tested by the ngc-learn dev team.
-
-## Table of Contents
-1. [Setting up ngc-lava](setup.md): A brief overview of how to set up
- ngc-lava
-2. [Lava components](lava_components.md): An overview of lava components in ngclearn and
- how to make custom ones
-3. [Lava Context](lava_context.md): An overview of the Lava context and building
- models for Lava
-4. [On-Chip Hebbian Learning](hebbian_learning.md): A walkthrough for getting a simple
- hebbian learning model setup
diff --git a/docs/tutorials/lava/lava_components.md b/docs/tutorials/lava/lava_components.md
deleted file mode 100644
index 79411db1..00000000
--- a/docs/tutorials/lava/lava_components.md
+++ /dev/null
@@ -1,29 +0,0 @@
-# Lava Components
-
-Inside ngc-learn, there is a wide variety of components with which biophysical
-models can be built. Unfortunately, many of those components are not compatible
-with Lava and the loihi2. Therefore, ngc-learn supports several in-built
-components that are Lava-compliant; many of the components that are compatible
-to you can be found in `ngclearn.components.lava`.
-
-## What Makes an ngc-learn Component Compatible
-
-For components to be compatible with Lava, there are a few key rules that must
-be followed:
-- Lava Components can not make use of JAX's random or JAX's `nn` libraries
-- Lava Components must import numpy from ngclearn and not JAX (there is a flag
- in the configuration file to control JAX's numpy versus base numpy)
-- Lava Components cannot take in any runtime arguments to their `advance_state` method
-- Lava Components cannot take in any runtime arguments or compartments to their
- `reset` method(s)
-
-## Mapping Methods -- Going from ngc-learn to Lava
-
-There are two methods that are mapped to their lava processes; these include the
-`reset` method and the `advance_state` method. The reset method is just mapped
-to the lava components and can be called on them without any issue. The
-`advance_state` method is mapped to the `run_spk` method and is called during
-the runtime loops in Lava. It is important to note that the methods that are
-actually mapped are the pure methods passed into the resolvers that
-decorate the ngc-learn `reset` and `advance_state` methods, not the
-`reset` and `advance_state` methods themselves.
\ No newline at end of file
diff --git a/docs/tutorials/lava/lava_context.md b/docs/tutorials/lava/lava_context.md
deleted file mode 100644
index cb7888ff..00000000
--- a/docs/tutorials/lava/lava_context.md
+++ /dev/null
@@ -1,103 +0,0 @@
-# The Lava Context
-
-The lava context, i.e, the `LavaContext`, serves as the core to ngc-lava as well
-as the main workhorse of all of its features. Since it is a subclass of the
-default ngc-learn context, we will only be covering the new Lava-specific
-features here.
-
-## Building Lava Components
-
-The Lava context generally keeps track of two sets of components -- the ngclearn
-components and the Lava components. However, due to the nature of the lava
-components themselves, they must be built once the model is fixed and cannot be
-built on-the-fly. Due to this fact, the building of the lava components must
-be triggered before they can be used. Nevertheless, there are a few ways to trigger the
-building of the Lava components. It is important to note that only the latest set
-of components can be used for methods like clamping and running. This will
-affect all dynamically compiled methods.
-
-### Events that Trigger a Rebuild
-
-- When a `LavaContext` is first constructed via: `with LavaContext("model") as model:`
- leaving the context block will trigger a rebuild
-- Calling `with model.updater:` will rebuild the lava components upon leaving the
- with-block
-- Calling `model.rebuild_lava()` will rebuild the lava components even if it is
- still inside a with-block. However, by default, it will stop the with-block
- from recompiling upon exiting as doing so would overwrite the previously built
- model components.
-
-### Events That Will Not Trigger a Rebuild
-
-Simply calling `with model` will not trigger a rebuild upon exiting since this is
-where additional dynamic method can be defined as well as reference sub-models
-while not triggering a complete rebuild of the Lava components each time.
-
-## The Runtime
-
-Inside of Lava, there is an internal runtime that is controlling the simulator
-for the loihi2. This runtime must be started in order to act upon Lava components,
-such as clamping values to their compartments as well as probing information
-about the model. To help simplify this, the `LavaContext` comes with a built-in
-runtime manager. To gain access to the ngc-lava runtime manager, first call
-`model.set_up_runtime()`. Note that the `set_up_runtime` method takes two
-arguments. The first is the root Lava component name to be used to start the
-runtime -- this is how Lava knows what component it will need to simulate. The
-second argument is the "rest" image -- the "rest" image is used to allow the
-dynamical system that is your model to return to its reset state while
-receiving no input (this is akin to allowing a biophysical neural system to relax
-to its resting potential state). This can be left as `None` and doing so will
-skip this functionality. Note that this method does not actually start
-the runtime, it just configures everything. It is important to observe that a
-clamp method fitting the signature `clamp(x) -> None` needs to be defined in
-order to use certain runtime methods as defined below.
-
-### Runtime Methods
-
-- `with model.runtime`: The lava runtime will exist for the duration of this
- with block.
-- `model.start_runtime()`: This starts the runtime without the management of
- automatically stopping it later.
-- `model.pause()`: Pauses the runtime, allowing for values to be read and set.
-- `model.stop()`: Stops the runtime, runtimes can not be restarted once they are
- stopped.
-- `model.run(t)`: Runs the runtime, for `t` time steps. Will automatically pause
- upon completion.
-- `model.view(x, t)`: First calls `model.clamp(x)` and then runs the runtime for
- `t` steps. Will automatically pause upon completion.
-- `model.rest(t)`: First calls `model.clamp(rest_image)` and then runs the
- runtime for `t` steps. Will automatically pause upon completion. If a reset
- image was not supplied, this runtime method will not be available.
-
-## Additional Utility Methods
-
-### Using Lags with: `set_lag(component_name, status=True)`
-
-In Lava, it is easy to lock your system if there is recurrence in your model.
-The Lava context allows for you to temporally "lag" the values emitted by
-specific components, delaying their executation with respect to the previous
-time-step.
-
-By default, the process pattern for a mapped Lava component is:
-`Receive values -> Process values -> Emit values`
-
-A lagged Lava component will follow the pattern:
-`Emit values -> Receive values -> Process Values`
-
-Example:
-> There is a model that has the wiring pattern of `Z0 -> W1 -> Z1 -> W1`
-> Here we can see that in order for Z1 to emit values it relies on the values
-> emitted by W1. But W1 also relies on values emitted from Z1. So if we lag
-> W1 it will emit last timesteps value at the start of the loop and then wait
-> for the new values meaning that the value emitted by W1 will be delayed by a
-> timestep, but it will no longer lock Z1 from running.
-
-### `write_to_ngc()`
-
-This method is designed to copy the current state of the Lava model into the
-ngc-learn model. This will do a one-to-one mapping of all of thecomponents and
-their values from Lava to ngclearn. It is important to point out that this must
-be done inside of a runtime. This is critical for saving since, in order to save
-an on-chip-trained model, it must first be written back to ngc-lava/learn
-and then to disk. By default, this is called by `model.save_to_json` if called
-inside a runtime.
diff --git a/docs/tutorials/lava/monitors.md b/docs/tutorials/lava/monitors.md
deleted file mode 100644
index b5d82243..00000000
--- a/docs/tutorials/lava/monitors.md
+++ /dev/null
@@ -1,17 +0,0 @@
-# Monitors
-
-While lava does have its own version of monitors, ngclearn offers an
-in-built version for convenience. It is
-recommended that you use the ngclearn monitors as they have expanded
-functionality and are designed to interact with the Lava components well. For an
-overview of/details on how monitors work please see
-[this](../foundations/monitors.md). The
-only difference is that Lava has its own monitor found
-in `ngclearn.components.lava`.
-
-## Sharp Edges and Bits
-
-- Due to the fact that a Lava component of the monitor must be built, it has to
- be defined inside the `LavaContext`.
-- To view the values found within the monitor via the `view()` and `get_path()`
- methods, `model.write_to_ngc()` must be called to refresh the values.
\ No newline at end of file
diff --git a/docs/tutorials/lava/setup.md b/docs/tutorials/lava/setup.md
deleted file mode 100644
index 542b479e..00000000
--- a/docs/tutorials/lava/setup.md
+++ /dev/null
@@ -1,26 +0,0 @@
-# Setting Up ngc-lava
-
-Setting up ngc-lava is fairly straightforward. The only part that takes some
-time is the setting up of the lava environment itself.
-
-## Installation and Setup Steps
-
-1. To set up and use ngc-lava, first download lava-nc
- found [here](https://lava-nc.org/lava/notebooks/in_depth/tutorial01_installing_lava.html).
-2. Install ngc-learn via pip `pip install ngc-learn`
-3. Clone ngc-lava and add it as a project source
- ```bash
- git clone https://github.com/NACLab/ngc-lava.git
- pip install -e ngc-lava
- ```
-4. To compile for lava, Jax must be turned off; to do this, set the flag
- `packages/use_base_numpy` to `true` in the ngc-learn
- `config.json` file. If you do not have a `config.json` file written, the
- script below will make one for you and add the needed Lava configuration flag:
-
- ```bash
- mkdir json_files
- touch json_files/config.json
- echo "{\n \"packages\": {\n \"use_base_numpy\": true \n }\n}" > json_files/config.json
- ```
-5. You are now set up and ready to use ngc-lava.
\ No newline at end of file
diff --git a/docs/tutorials/model_basics/configuration.md b/docs/tutorials/model_basics/configuration.md
index 480fea98..37de90b0 100644
--- a/docs/tutorials/model_basics/configuration.md
+++ b/docs/tutorials/model_basics/configuration.md
@@ -1,39 +1,29 @@
-# Lesson 1: Configuring ngcsimlib
+# Lesson 1: Configuring NGC-Sim-Lib
## Basics
-There are various global configurations that can be made to ngcsimlib, the systems simulation backend for ngc-learn. These include the ability to point to custom locations for the `json_modules` files as well as setting up the logger. In both of these cases, the configuration will generally persist between different models that might be loaded and, thus, it will need to exist outside of the scope of the model's files. To solve this problem, ngcsimlib provides `config.json` as well as the `--config` flag mechanisms.
+There are various global configurations that can be made to NGC-Sim-Lib, the systems simulation backend for NGC-Learn.
+The primary use-case for a configuration file is to modify the library's built-in logger. Generally to control the
+configuration, run any script (that uses NGC-Learn) with the flag `--config="path/to/your/config.json`.
-The `config.json` file contains one large json object with sections set up for different parts of the configuration, broke up into sub-objects. There is no limit to the size or the number of these objects, meaning that the user is free to define and use them as they so choose. However, there are some general design principals that govern ngcsimlib that are worth knowing about. Specifically, this mechanism will not configure any parts of individual models. `config.json` configurations should be used to select/generally set up experiments and control global level flags and not to set hyperparameters for models.
-
-## Built-in Configurations
-
-There are a couple configurations that ngcsimlib will look for while it is initializing. Specifically `modules` and `logging`. While neither of these is needed to get up and running some aspects of ngcsimlib, useful debugging tools such as logging to files and more verbosity are locked behind flags set up here.
-
-### Modules
-
-The modules configuration only contains one value, `module_path`. This value is the location of the `modules.json`, the model-level/experiments-level configuration file one should be setting up when building their experiments. For additional information for configuring this file please
-see modules.json.
-
-> Example Modules
->
-> ```json
-> {
-> "modules": {
-> "module_path": "custom/path/to/json/files/modules.json"
-> }
-> }
-> ```
+The `config.json` file contains one large JSON object with sections set up for different parts of the configuration,
+broken up into sub-objects. There is no limit to the size or to the number of these objects; this means that the user
+is free to define and use them as they so choose.
### Logging
-The logging configuration mechanism sets up and controls the instance of the python logger built into ngcsimlib. This mechanism (or JSON section) has three values found within it. Specifically, `logging_level`, `logging_file`, and `hide_console`. The logging levels are the same ones built into the python logger and the value words used are either the standard Python string representation of the level or the numeric equivalent. The `logging file`, if defined, is a file that the logger will append all logging messages to for a more permanent history of all messages. Finally, `hide console`, if set to true, will hide all logging output to the console.
+The logging configuration mechanism sets up and controls the instance of the Python logger built into ngcsimlib. This
+mechanism (or JSON section) has three values found within it. Specifically, `logging_level`, `logging_file`, and
+`hide_console`. The logging levels are the same ones built into the Python logger and the value words used are either
+the standard Python string representation of the level (or the numeric equivalent). The `logging file`, if defined, is
+a file that the logger will append all logging messages to in order to facilitate a more permanent history of all
+messages. Finally, `hide console`, if set to true, will hide all the logging output to the console.
> Default Config
> ```json
> {
> "logging": {
-> "logging_level": "WARNING",
+> "logging_level": "ERROR",
> "hide_console": false
> }
> }
@@ -52,21 +42,27 @@ The logging configuration mechanism sets up and controls the instance of the pyt
## Using a Configuration
-To use a configuration, there are a few options. The first option is to simply use the configuration as a python dictionary. This is done by importing the `get_config` method from `ngcsimlib.configManager` and providing the name of the configuration section to the method.
+To use a configuration, there are a few options. The first option is to simply use the configuration as a Python
+dictionary. This is done by importing the `get_config` method from `ngclearn` and providing the name of the
+configuration section to the method.
> Example get_config
>```python
->from ngcsimlib.configManager import get_config
+>from ngclearn import get_config
>
>loggerConfig = get_config("logger")
>level = loggerConfig['logging_level']
>```
-The other way you can access a configuration is through a provided namespace. This makes use of python's `SimpleNamespace` to map all the dictionary's key values to properties of an object to be used. One important note about namespaces is that, unlike a python dictionary where the `get` method can be provided a default value for missing keys, namespaces do not have this functionality. Therefore, if keys are missing it has the potential to cause errors. Below is an example of how one could use the namespace for logging configuration.
+The other way you can access a configuration is through a provided namespace. This makes use of Python's
+`SimpleNamespace` to map all the dictionary's key values to properties of an object that is to be used. One
+important note about namespaces is that, unlike a Python dictionary where the `get` method can be provided a default
+value for missing keys, namespaces do not have this functionality. Therefore, missing keys create the potential
+to cause errors. Below is an example of how one could use the namespace for a logging configuration.
> Example provide_namespace
> ```python
-> from ngcsimlib.configManager import provide_namespace
+> from ngclearn import provide_namespace
>
> loggerConfig = provide_namespace("logger")
> level = loggerConfig.logging_level
diff --git a/docs/tutorials/model_basics/evolving_synapses.md b/docs/tutorials/model_basics/evolving_synapses.md
index 8cf92f8b..e68ef1de 100755
--- a/docs/tutorials/model_basics/evolving_synapses.md
+++ b/docs/tutorials/model_basics/evolving_synapses.md
@@ -1,27 +1,24 @@
-# Lesson 4: Evolving Synaptic Efficacies
+# Lesson 3: Evolving Synaptic Efficacies
-In this tutorial, we will extend a controller with three components,
-two cell components connected with a synaptic cable component, to incorporate a
-basic a two-factor Hebbian adjustment process.
+In this tutorial, we will extend a model context/controller with three components, two cell components connected with a
+synaptic cable component, to incorporate a basic a two-factor Hebbian adjustment process.
## Adding a Learnable Synapse to a Multi-Component System
-Let us start by building a controller similar to previous lessons with the one
-exception that now we will trigger the synaptic connection between `a` and `b`
-to adapt via a simple 2-factor Hebbian rule. This Hebbian rule will require us
-to wire the output compartment of `a` to the pre-synaptic compartment of the
-synapse `Wab` and the output compartment of `b` to the post-synaptic
-compartment of `Wab`. This will wire in the two relevant factors needed to
+Create a Python script/file named `run_lesson3.py` to place/write your Python code below into.
+Let us start by building a controller/model-context similar to previous lessons with the one exception that now we will
+trigger the synaptic connection between `a` and `b` to adapt via a simple 2-factor Hebbian rule. This Hebbian rule will
+require us to wire the output compartment of `a` to the pre-synaptic compartment of the synapse `Wab` and the output
+compartment of `b` to the post-synaptic compartment of `Wab`. This will wire in the two relevant factors needed to
compute a simple Hebbian adjustment.
We do this specifically as follows:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
from ngclearn.components import HebbianSynapse, RateCell
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
## create seeding keys
dkey = random.PRNGKey(1234)
@@ -29,62 +26,57 @@ dkey, *subkeys = random.split(dkey, 6)
## create simple system with only one F-N cell
with Context("Circuit") as circuit:
- a = RateCell(name="a", n_units=1, tau_m=0., act_fx="identity", key=subkeys[0])
- b = RateCell(name="b", n_units=1, tau_m=0., act_fx="identity", key=subkeys[1])
-
- Wab = HebbianSynapse(
- name="Wab", shape=(1, 1), eta=1., sign_value=-1., weight_init=dist.constant(value=1.),
- w_bound=0., key=subkeys[3]
- )
-
- # wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
- Wab.inputs << a.zF
- # wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
- b.j << Wab.outputs
-
- # wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab`
- Wab.pre << a.zF
- # wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab`
- Wab.post << b.zF
-
- ## create and compile core simulation commands
- evolve_process = (JaxProcess()
- >> Wab.evolve)
- circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
-
- advance_process = (JaxProcess()
- >> a.advance_state)
- circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
-
- reset_process = (JaxProcess()
- >> a.reset)
- circuit.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
+ a = RateCell(name="a", n_units=1, tau_m=0., act_fx="identity", key=subkeys[0])
+ b = RateCell(name="b", n_units=1, tau_m=0., act_fx="identity", key=subkeys[1])
+
+ Wab = HebbianSynapse(
+ name="Wab", shape=(1, 1), eta=1., sign_value=-1., weight_init=dist.constant(value=1.),
+ w_bound=0., key=subkeys[3]
+ )
+
+ # wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
+ a.zF >> Wab.inputs
+ # wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
+ Wab.outputs >> b.j
+
+ # wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab`
+ a.zF >> Wab.pre
+ # wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab`
+ b.zF >> Wab.post
+
+ ## create and compile core simulation commands
+ evolve = (MethodProcess("evolve")
+ >> a.evolve)
+
+ advance = (MethodProcess("advance")
+ >> a.advance_state)
+
+ reset = (MethodProcess("reset")
+ >> a.reset)
+
+## set up non-compiled utility commands
+def clamp(x):
a.j.set(x)
```
-Now with our simple system above created, we will now run a simple sequence
-of one-dimensional "spike" data through it and evolve the synapse every time
-step like so:
+Now with our simple system above created, we will now run a simple sequence of one-dimensional "spike" data through it
+and evolve the synapse every time step like so:
```python
## run some data through the dynamical system
x_seq = jnp.asarray([[1, 1, 0, 0, 1]], dtype=jnp.float32)
-circuit.reset()
+reset.run()
print("{}: Wab = {}".format(-1, Wab.weights.value))
for ts in range(x_seq.shape[1]):
x_t = jnp.expand_dims(x_seq[0,ts], axis=0) ## get data at time t
- circuit.clamp(x_t)
- circuit.advance(t=ts*1., dt=1.)
- circuit.evolve(t=ts*1., dt=1.)
- print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.value))
+ clamp(x_t)
+ advance.run(t=ts*1., dt=1.)
+ evolve.run(t=ts*1., dt=1.)
+ print(" {}: input = {} ~> Wab = {}".format(ts, x_t, Wab.weights.get()))
```
-Your code should produce the same output (towards the bottom):
+After running `run_lesson3.py`, your code should produce (printed to I/O) the same output as below:
```console
-1: Wab = [[1.]]
@@ -95,14 +87,11 @@ Your code should produce the same output (towards the bottom):
4: input = [1.] ~> Wab = [[8.]]
```
-Notice that for every non-spike (a value of `0`), the synaptic value remains
-the same (because the product of a pre-synaptic value of `0` with a post-synaptic
-value of anything -- in this case, also a `0` -- is simply `0`, meaning no
-change will be applied to the synapse). For every spike (a value of `1`), we
-get a synaptic change equal to `dW = input * (Wab * input)`; so for the
-first time-step, the weight will change according to
-`W = W + eta * dW = W + dW` and `dW = 1 * (1 * 1) = 1`, whereas, for the
-second time-step, `W` will be increased by `dW = 1 * (2 * 1) = 2` (yielding a
- new synaptic strength of `W = 4`).
-
-You have now created your first plastic, evolving neuronal system.
+Notice that for every non-spike (a value of `0`), the synaptic value remains the same (because the product of a
+pre-synaptic value of `0` with a post-synaptic value of anything -- in this case, also a `0` -- is simply `0`, meaning
+that no change will be applied to the synapse). For every spike (a value of `1`), we get a synaptic change equal to
+`dW = input * (Wab * input)`; so for the first time-step, the weight will change according to
+`W = W + eta * dW = W + dW` and `dW = 1 * (1 * 1) = 1`, whereas, for the second time-step, `W` will be increased by
+`dW = 1 * (2 * 1) = 2` (yielding a new synaptic strength of `W = 4`).
+
+As per the above, you have now created your first plastic, evolving neuronal system!
diff --git a/docs/tutorials/model_basics/json_modules.md b/docs/tutorials/model_basics/json_modules.md
deleted file mode 100644
index 54f8a81f..00000000
--- a/docs/tutorials/model_basics/json_modules.md
+++ /dev/null
@@ -1,154 +0,0 @@
-# Lesson 2: Configuring with the modules.json File
-
-## Basic Usage:
-
-The basic usage for the `modules.json` file is to provide ngclearn with a list of modules to import and associated
-classes that are needed to build the models it will be loading. If there is a need to use the imported
-modules outside of these cases, use `ngcsimlib.utils.load_attribute` and the loaded
-attribute will be returned.
-
-By default, ngcsimlib, the backend
-dependency of ngc-learn, looks for `json_files/modules.json` in your project path.
-However, this can be changed inside the
-configuration file. In
-the event that this
-file is missing, ngcsimlib will not break but its ability to load saved models will be limited.
-
-## Motivation
-
-The motivation behind the use of `modules.json` versus the registering all the
-various parts of the model at the top of the file is reusability. When all the
-parts have to be registered/imported at the top of every test file, or be placed into specific locations can be limiting
-and slows down development. With a single project wide modules file all loaded models can look there to load components.
-This also allows for components to be saved in humanreadable formats not as a pickled object as we can save and load all
-the relevant class information from the class name and the modules file.
-
-## Structure
-
-A complete schema for the modules file can be found in `modules.schema`
-
-The general structure of the modules file can be thought of as a transformation
-of python import statements to JSON objects. Take the following example:
-
-```python
-from ngclearn.commands import AdvanceState as advance
-```
-
-In this statement we are importing a command from ngcsimlib and aliasing it to the
-word "advance". Now we will transform this into JSON for the modules file. First,
-we take the top level module that we are importing from, in this case
-`ngcsimlib.commands`; this the absolute path to the location of this module. Next,
-we look at the name of what we are importing here: `AdvanceState`. Finally, we
-look at the keyword since this import is being assigned to `advance`. We then
-take these three parts and combine them into the following JSON object:
-
-```json
- {
- "absolute_path": "ngclearn.commands",
- "attributes": [
- {
- "name": "AdvanceState",
- "keywords": [
- "advance"
- ]
- }
- ]
-}
-```
-
-Now there are a few additional things that this JSON formulation of an import
-allows us to do. Primarily, it allows for multiple keywords for a single import
-to be defined. This if we wanted to use `advance` and `adv` all we would do is
-change the keyword line to `"keywords": ["advance", "adv"]`. In addition, we are able
-to specify more than one attribute to import from a single top level module
-such as also importing the evolve command.
-
-```json
- {
- "absolute_path": "ngcsimlib.commands",
- "attributes": [
- {
- "name": "AdvanceState",
- "keywords": [
- "advance",
- "adv"
- ]
- },
- {
- "name": "Evolve"
- }
- ]
-}
-```
-
-Now you might notice above that, when importing the evolve attribute, no
-keywords were given. This means that, in order to add an evolve command to
-the controller, the whole name will need to be given. There is one caveat to
-this scheme though; it is case-insensitive by default, meaning that both
-`Evolve` and `evolve` are valid ways to using this import.
-
-## Example Transformations
-
-Below are some additional examples to help with transitioning from python
-header import statements to JSON configuration.
-
-> Case 1
-> Python:
-> ```python
-> from ngcsimlib.commands import AdvanceState as advance, Evolve, Multiclamp as mClamp
-> ```
-> Json:
-> ```json
-> [
-> {
-> "absolute_path": "ngcsimlib.commands",
-> "attributes": [
-> {
-> "name": "AdvanceState",
-> "keywords": ["advance"]
-> },
-> {
-> "name": "Evolve"
-> },
-> {
-> "name": "Multiclamp",
-> "keywords": "mClamp"
-> }
-> ]
-> }
-> ]
-> ```
-
-> Case 2
-> Python
-> ```python
-> from ngclearn.commands import AdvanceState as advance
-> from ngclearn.operations import summation as summ, overwrite
-> ```
->
-> Json
-> ```json
-> [
-> {
-> "absolute_path": "ngclearn.commands",
-> "attributes": [
-> {
-> "name": "AdvanceState",
-> "keywords": ["advance"]
-> }
-> ]
-> },
-> {
-> "absolute_path": "ngclearn.operations",
-> "attributes": [
-> {
-> "name": "summation",
-> "keywords": ["summ"]
-> },
-> {
-> "name": "overwrite"
-> }
-> ]
-> }
-> ]
-> ```
diff --git a/docs/tutorials/model_basics/model_building.md b/docs/tutorials/model_basics/model_building.md
index e5cff8e6..34c1be48 100755
--- a/docs/tutorials/model_basics/model_building.md
+++ b/docs/tutorials/model_basics/model_building.md
@@ -1,19 +1,19 @@
-# Lesson 3: Building a Model
+# Lesson 2: Building a Model
-In this tutorial, we will build a simple model made up of three components:
-two simple graded cells that are connected by one synaptic cable.
+In this tutorial, we will build a simple model made up of three components: two simple graded cells that are connected
+by a single synaptic cable.
## Instantiating the Dynamical System as a Context
-While building our dynamical system we will set up a Context and then add the three different components to it.
+Create a file named `run_lesson2.py` to place/write your Python code below into.
+While building our dynamical system we will set up a `Context` and then add the three different components to it,
+like so:
```python
from jax import numpy as jnp, random
-from ngclearn import Context
-from ngclearn.utils import JaxProcess
-from ngcsimlib.compilers.process import Process
+from ngclearn import Context, MethodProcess
from ngclearn.components import RateCell, HebbianSynapse
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
## create seeding keys
dkey = random.PRNGKey(1234)
@@ -23,94 +23,67 @@ dkey, *subkeys = random.split(dkey, 4)
with Context("model") as model:
a = RateCell(name="a", n_units=1, tau_m=0., act_fx="identity", key=subkeys[0])
b = RateCell(name="b", n_units=1, tau_m=20., act_fx="identity", key=subkeys[1])
- Wab = HebbianSynapse(
- name="Wab", shape=(1, 1), weight_init=dist.constant(value=1.), key=subkeys[2]
- )
+ Wab = HebbianSynapse(name="Wab", shape=(1, 1), weight_init=dist.constant(value=1.), key=subkeys[2])
```
-Next, we will want to wire together the three components we have embedded into
-our model, connecting `a` to node `b` through synaptic cable `Wab`. In
-other words, this means that the output compartment of `a` must be wired to the
-input compartment of transformation `Wab` and the output compartment of `Wab`
-must be wired to the input compartment of `b`. In code, this is done as follows:
+Next, we will want to wire together the three components we have embedded into our model, connecting `a` to node `b`
+through synaptic cable `Wab`. In other words, this means that the output compartment of `a` (which, if one checks
+the documentation for `a`, turns out to be `.zF`) must be wired to the input compartment of transformation `Wab`
+(i.e., `.inputs`) and the output compartment of `Wab` (i.e., `.outputs`) must be wired to the input compartment
+of `b` (i.e., `.j`). In code, this is done (within the `Context`-block) as follows:
```python
- ## wire a to w_ab and wire w_ab to b
- Wab.inputs << a.zF
- b.j << Wab.outputs
+ ## wire a to w_ab and wire w_ab to b (a -> Wab -> b)
+ a.zF >> Wab.inputs
+ Wab.outputs >> b.j
```
-Finally, to make our dynamical system do something for each step of simulated
-time, we must append a few basic commands
-(see [Understanding Commands](../foundations/commands.md) to the context.
-The commands we will want, as implied by our JSON configuration that we put
-together at the start of this tutorial, include a `reset` (which will
-initialize the compartments within each node to their resting values,
-i.e., generally zero, if they have them -- this will only end up affecting
-nodes `a` and `b` since a basic synapse component like `Wab` does not have a
-base/resting value), an `advance` (which moves all the nodes one step
-forward in time according to their compartments' ODEs), and `clamp` (which will
-allow us to insert data into particular nodes).
-This is simply done with the use of the following convenience function calls:
-
-
-
+Finally, to make our dynamical system do something for each step of simulated time, we must append a few basic
+processes (see [Understanding Processes](../configuration/processes.md)) to the context.
+The commands that we will (in general) want will include a `reset` (which will initialize the compartments within
+each node to their "resting" values, i.e., generally zero, if they have them), an `advance` (which moves all the
+nodes one step forward in time according to their compartments' differential equations/internal dynamics), and
+`clamp` (which will allow us to insert data into particular nodes).
+This is simply done by writing the following next (within the `Context`-block):
```python
## configure desired commands for simulation object
- reset_process = (JaxProcess()
- >> a.reset
- >> Wab.reset
- >> b.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ reset = (MethodProcess("reset")
+ >> a.reset
+ >> Wab.reset
+ >> b.reset)
- advance_process = (JaxProcess()
- >> a.advance_state
- >> Wab.advance_state
- >> b.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
+ advance = (MethodProcess("advance")
+ >> a.advance_state
+ >> Wab.advance_state
+ >> b.advance_state)
- ## set up clamp as a non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+## set up clamp as a non-compiled utility commands (outside the context-block)
+def clamp(x):
+ a.j.set(x) ## injects value/tensor x into compartment .j of component a
```
-## Running the Dynamical System's Controller
+## Running the Dynamical System
-With our simple 3-component dynamical system built, we may now run it on a
-simple sequence of one-dimensional real-valued numbers:
+With our simple 3-component dynamical system built, we may now apply and run it on a simple sequence of
+one-dimensional real-valued numbers:
```python
## run some data through our simple dynamical system
x_seq = jnp.asarray([[1., 2., 3., 4., 5.]], dtype=jnp.float32)
-model.reset()
+reset.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.expand_dims(x_seq[0, ts], axis=0) ## get data at time ts
- model.clamp_data(x_t)
- model.advance(t=ts * 1., dt=1.)
+ clamp(x_t)
+ advance.run(t=ts * 1., dt=1.)
## naively extract simple statistics at time ts and print them to I/O
- a_out = a.zF
- b_out = b.zF
+ a_out = a.zF.get()
+ b_out = b.zF.get()
print(" {}: a.zF = {} ~> b.zF = {}".format(ts, a_out, b_out))
```
-and, assuming you place your code above in a Python script
-(e.g., `run_lesson2.py`), we should obtain output in your terminal as below:
+and, when running your Python script (i.e., `run_lesson2.py`), we should obtain output in your terminal as below:
```console
$ python run_lesson2.py
@@ -121,24 +94,17 @@ $ python run_lesson2.py
4: a.zF = [5.] ~> b.zF = [[0.75]]
```
-The simple 3-component system simulated above merely transforms the input
-sequence into another time-evolving series. For the curious, in your code above,
-you modeled a very simple non-leaky integration of cell `b` injected with some
-value produced by `a` (since `Wab = 1`, the synapses had no effect and merely
-copies the value along). While node `a` is always clamped to a value as per the
-clamp command call we constructed and call above (even though its time constant
-was `tau_m = 0` ms, meaning that it reduces to a stateless "feedforward" cell),
-b had a time constant you set to `tau_m = 20` ms. This means, as can be confirmed
-by inspecting the API for `RateCell`, with your integration time constant
-`dt = 1` ms:
-
-1. at time step `ts = 0`, the value clamped to `a`, i.e., `1`, was multiplied by
- `1/20 = 0.05` and then added `b`'s internal state (which started at the value
- of `0` through the reset command called before the for-loop);
-2. at step `ts = 1`, the value clamped to `a`, i.e., `2`, was multiplied by
- `0.05` (yielding `0.1`) and then added to `b`'s current state -- meaning that
- the new state becomes `0.05 + 0.1 = 0.15`;
-3. at `ts = 2`, a value `3` is clamped to `a`, which is then multiplied by `0.05`
- to yield `0.15` and then added to `b`'s current state -- meaning that the new
- state is `0.15 + 0.15 = 0.3`
- and so on and so forth (`b` acts like a non-decaying recurrently additive state).
+The simple 3-component system simulated above merely transforms the input sequence into another time-evolving series.
+For the curious, in your code above, you modeled a very simple non-leaky integration of cell `b` injected with some
+value produced by `a` (since `Wab = 1`, the synapses had no effect and merely copies the value along). While node
+`a` is always clamped to a value as per the clamp command call we constructed and call above (even though its
+time constant was `tau_m = 0` ms, meaning that it reduces to a stateless "feedforward" cell), `b` had a time constant
+you set to `tau_m = 20` ms. This means, as can be confirmed by inspecting the API for `RateCell`, with your integration time constant `dt = 1` ms:
+
+1. at time step `ts = 0`, the value clamped to `a`, i.e., `1`, was multiplied by `1/20 = 0.05` and then added
+ `b`'s internal state (which started at the value of `0` through the reset command called before the for-loop);
+2. at step `ts = 1`, the value clamped to `a`, i.e., `2`, was multiplied by `0.05` (yielding `0.1`) and then added
+ to `b`'s current state -- meaning that the new state becomes `0.05 + 0.1 = 0.15`;
+3. at `ts = 2`, a value `3` is clamped to `a`, which is then multiplied by `0.05` to yield `0.15` and then added to
+ `b`'s current state -- meaning that the new state is `0.15 + 0.15 = 0.3` and so on and so forth (`b` acts like a
+ non-decaying recurrently additive state).
diff --git a/docs/tutorials/neurocog/adex_cell.md b/docs/tutorials/neurocog/adex_cell.md
index 4a685488..f4ddc79a 100755
--- a/docs/tutorials/neurocog/adex_cell.md
+++ b/docs/tutorials/neurocog/adex_cell.md
@@ -22,9 +22,7 @@ AdEx cell amounts to the following:
from jax import numpy as jnp, random, jit
import numpy as np
-from ngclearn.utils.model_utils import scanner
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.adExCell import AdExCell
@@ -46,20 +44,15 @@ with Context("Model") as model:
intrinsic_mem_thr=-55., v_thr=5., v_rest=-72., v_reset=-75., a=0.1,
b=0.75, v0=v0, w0=w0, integration_type="euler", key=subkeys[0]
)
-
## create and compile core simulation commands
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance_proc")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
-
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset_proc")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
In effect, the AdEx two-dimensional differential equation system [1]-[2] offers
@@ -109,19 +102,19 @@ i_app = 19. ## electrical current to inject into AdEx cell
data = jnp.asarray([[i_app]], dtype=jnp.float32)
time_span = []
-model.reset()
+reset_process.run()
t = 0.
for ts in range(T):
x_t = data
## pass in t and dt and run step forward of simulation
- model.clamp(x_t)
- model.advance(t=t, dt=dt)
+ clamp(x_t)
+ advance_process.run(t=t, dt=dt) # run one step of dynamics
t = t + dt
## naively extract simple statistics at time ts and print them to I/O
- v = cell.v.value
- w = cell.w.value
- s = cell.s.value
+ v = cell.v.get()
+ w = cell.w.get()
+ s = cell.s.get()
curr_in.append(data)
mem_rec.append(v)
recov_rec.append(w)
@@ -150,26 +143,27 @@ recov_rec = np.squeeze(np.asarray(recov_rec))
spk_rec = np.squeeze(np.asarray(spk_rec))
# Plot the AdEx cell trajectory
-cell_tag = "RS"
n_plots = 1
fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5))
ax_ptr = ax
-ax_ptr.set(xlabel='Time', ylabel='Voltage (v)',
- title="AdEx ({}) Voltage Dynamics".format(cell_tag))
+ax_ptr.set(
+ xlabel='Time', ylabel='Voltage (v)', title="AdEx Voltage Dynamics"
+)
v = ax_ptr.plot(time_span, mem_rec, color='C0')
ax_ptr.legend([v[0]],['v'])
plt.tight_layout()
-plt.savefig("{0}".format("adex_v_plot.jpg".format(cell_tag.lower())))
+plt.savefig("{0}".format("adex_v_plot.jpg"))
fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5))
ax_ptr = ax
-ax_ptr.set(xlabel='Time', ylabel='Recovery (w)',
- title="AdEx ({}) Recovery Dynamics".format(cell_tag))
+ax_ptr.set(
+ xlabel='Time', ylabel='Recovery (w)', title="AdEx Recovery Dynamics"
+)
w = ax_ptr.plot(time_span, recov_rec, color='C1', alpha=.5)
ax_ptr.legend([w[0]],['w'])
plt.tight_layout()
-plt.savefig("{0}".format("adex_w_plot.jpg".format(cell_tag.lower())))
+plt.savefig("{0}".format("adex_w_plot.jpg"))
plt.close()
```
@@ -194,27 +188,6 @@ however, one could configure it to use the midpoint method for integration
by setting its argument `integration_type = rk2` in cases where more
accuracy in the dynamics is needed (at the cost of additional computational time).
-## Optional: Setting Up The Components with a JSON Configuration
-
-While you are not required to create a JSON configuration file for ngc-learn,
-to get rid of the warning that ngc-learn will throw at the start of your
-program's execution (indicating that you do not have a configuration set up yet),
-all you need to do is create a sub-directory for your JSON configuration
-inside of your project code's directory, i.e., `json_files/modules.json`.
-Inside the JSON file, you would write the following:
-
-```json
-[
- {"absolute_path": "ngclearn.components",
- "attributes": [
- {"name": "AdExCell"}]
- },
- {"absolute_path": "ngcsimlib.operations",
- "attributes": [
- {"name": "overwrite"}]
- }
-]
-```
## References
diff --git a/docs/tutorials/neurocog/density_modeling.md b/docs/tutorials/neurocog/density_modeling.md
new file mode 100644
index 00000000..4c930167
--- /dev/null
+++ b/docs/tutorials/neurocog/density_modeling.md
@@ -0,0 +1,164 @@
+# Density Modeling and Analysis
+
+NGC-Learn offers some support for density modeling/estimation, which can be particularly useful in analyzing how internal properties of neuronal models' self-organized cell populations (e.g., how the distributed representations of a model might cluster into distinct groups/categories) or to draw samples from the underlying generative model implied by a particular neuronal structure (e.g., sampling a trained predictive coding generative model).
+Particularly, within `ngclearn.utils.density`, one can find implementations of mixture models -- such as mixtures of Bernoullis, Gaussians, and exponentials -- which might be employed to carry out such tasks. In this small lesson, we will demonstrate how to set up a Gaussian mixture model (GMM), fit it to some synthetic latent code data, and plot out the distribution it learns overlaid over the data samples as well as examine the kinds of patterns one may sample from the learnt GMM.
+
+## Setting Up a Gaussian Mixture Model (GMM)
+
+Let's say you have a two-dimensional dataset of neural code vectors collected from another model you have simulated -- here, we will artificially synthesize this kind of data in this lesson from an "unobserved" trio of multivariate Gaussians (as was done in the t-SNE tutorial) and pretend that this is a set of collected vector measurements. Furthermore, you decide that, after consideration that your data might follow a multi-modal distribution (and reasonably asssuming that multivariate Gaussians might capture most of the inherent structure/shape), you want to fit a GMM to these codes to later on sample from their underlying multi-modal distribution.
+
+The following Python code will employ an NGC-Learn-in-built GMM density estimator for you (including setting up the data generator):
+
+```python
+from jax import numpy as jnp, random
+from ngclearn.utils.density.gaussianMixture import GaussianMixture as GMM ## pull out density estimator
+
+def gen_data(dkey, n_samp_per_mode): ## data generator (or proxy stochastic data generating process)
+ scale = 0.3
+ mu1 = jnp.asarray([[2.1, 3.2]]) * scale
+ cov1 = jnp.eye(mu1.shape[1]) * 0.78 * scale * 0.5
+ mu2 = jnp.asarray([[-2.8, 2.0]]) * scale
+ cov2 = jnp.eye(mu2.shape[1]) * 0.52 * scale * 0.5
+ mu3 = jnp.asarray([[1.2, -2.7]]) * scale
+ cov3 = jnp.eye(mu3.shape[1]) * 1.2 * scale * 0.5
+ params = (mu1,cov1 ,mu2,cov2,mu3,cov3)
+
+ dkey, *subkeys = random.split(dkey, 7)
+ samp1 = random.multivariate_normal(subkeys[0], mu1, cov1, shape=(n_samp_per_mode,))
+ samp2 = random.multivariate_normal(subkeys[1], mu2, cov2, shape=(n_samp_per_mode,))
+ samp3 = random.multivariate_normal(subkeys[2], mu3, cov3, shape=(n_samp_per_mode,))
+ X = jnp.concatenate((samp1, samp2, samp3), axis=0)
+ y1 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[1., 0., 0.]])
+ y2 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 1., 0.]])
+ y3 = jnp.ones((n_samp_per_mode, 3)) * jnp.asarray([[0., 0., 1.]])
+ lab = jnp.concatenate((y1, y2, y3), axis=0) ## one-hot codes
+
+ ## shuffle the data
+ ptrs = random.permutation(subkeys[3], X.shape[0])
+ X = X[ptrs, :]
+ lab = lab[ptrs, :]
+
+ return X, lab, params
+
+## set up the GMM density estimator
+key = random.PRNGKey(69)
+dkey, *skey = random.split(key, 3)
+X, y, params = gen_data(key, n_samp_per_mode=200) ## X is your "vector dataset"
+
+n_iter = 100 ## maximum number of iterations to fit GMM to data
+n_components = 3 ## number of mixture components w/in GMM
+model = GMM(K=n_components, max_iter=n_iter, key=skey[0])
+model.init(X) ## initailize the GMM to dataset X
+```
+
+The above will construct a GMM with three components (or latent variables of its own) and be configured to use a maximum of `100` iterations to fit itself to data. Note that the call to `init()` will "shape" the GMM according to the dimensionality of the data and pre-initialize its parameters (i.e., choosing random data vectors to initialize its means).
+
+To fit the GMM itself to your dataset `X`, you will then write the following:
+
+```python
+## estimate GMM parameters over dataset via E-M
+model.fit(X, tol=1e-3, verbose=True) ## set verbose to `False` to silence the fitting process
+```
+
+which should print to I/O something akin to:
+
+```console
+0: Mean-diff = 1.4147894382476807 log(p(X)) = -1706.0753173828125 nats
+1: Mean-diff = 0.14663299918174744 log(p(X)) = -1386.569091796875 nats
+2: Mean-diff = 0.18331432342529297 log(p(X)) = -1359.6962890625 nats
+3: Mean-diff = 0.17693905532360077 log(p(X)) = -1309.736083984375 nats
+4: Mean-diff = 0.1494818776845932 log(p(X)) = -1250.130615234375 nats
+5: Mean-diff = 0.11344392597675323 log(p(X)) = -1221.0008544921875 nats
+6: Mean-diff = 0.07362686842679977 log(p(X)) = -1204.680419921875 nats
+7: Mean-diff = 0.03828870505094528 log(p(X)) = -1192.706298828125 nats
+8: Mean-diff = 0.025705577805638313 log(p(X)) = -1188.51123046875 nats
+9: Mean-diff = 0.021316207945346832 log(p(X)) = -1187.055908203125 nats
+10: Mean-diff = 0.019372563809156418 log(p(X)) = -1186.157470703125 nats
+11: Mean-diff = 0.018868334591388702 log(p(X)) = -1185.443115234375 nats
+...
+
+...
+46: Mean-diff = 0.017377303913235664 log(p(X)) = -1062.2596435546875 nats
+47: Mean-diff = 0.007906327955424786 log(p(X)) = -1060.440185546875 nats
+48: Mean-diff = 0.003615213558077812 log(p(X)) = -1060.09130859375 nats
+49: Mean-diff = 0.0016773870447650552 log(p(X)) = -1060.0233154296875 nats
+50: Mean-diff = 0.0007852672133594751 log(p(X)) = -1060.0093994140625 nats
+Converged after 51 iterations.
+```
+
+In the above instance, notice that our GMM converged early, reaching a good, stable log likelihood in `51` iterations. We can further calculate our final model's log likelihood over the dataset `X` with the following in-built function:
+
+```python
+# Calculate the GMM log likelihood
+_, logPX = model.calc_log_likelihood(X) ## 1st output is log-likelihood per data pattern
+print(f"log[p(X)] = {logPX} nats")
+```
+
+which will print out the following:
+
+```console
+log[p(X)] = -1060.006591796875 nats
+```
+
+(If you add a log-likelihood measurement before you call `.fit()`, you will see that your original log-likelihood is around `-1060.01 nats`.)
+Now, to visualize if our GMM actually capture the underlying multi-modal distribution of our dataset, we may visualize the final GMM with the following plotting code:
+
+```python
+import matplotlib.pyplot as plt
+x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
+y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
+xx, yy = jnp.meshgrid(jnp.linspace(x_min, x_max, 100), jnp.linspace(y_min, y_max, 100))
+Xspace = jnp.c_[xx.ravel(), yy.ravel()]
+Z, _ = model.calc_log_likelihood(Xspace) # Get log likelihood (LL)
+Z = -Z ## flip sign of LL (to get negative LL)
+Z = Z.reshape(xx.shape)
+
+plt.figure(figsize=(8, 6))
+plt.scatter(X[:, 0], X[:, 1], c="blue", s=10, alpha=0.7, label='Latent Codes')
+plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8)
+plt.colorbar(label='Negative Log Likelihood')
+
+plt.title('GMM Distribution Plot')
+plt.xlabel('Latent Dimension 1')
+plt.ylabel('Latent Dimension 2')
+plt.legend()
+plt.grid(True)
+plt.savefig("gmm_fit.jpg") #plt.show()
+
+plt.close()
+```
+
+which should produce a plot similar to the one below:
+
+
+
+
+To draw samples from our fitted/learnt GMM, we may next call its in-built synthesizing routine as follows:
+
+```python
+## Examine GMM samples
+Xs = model.sample(n_samples=200 * 3) ## draw 600 samples from fitted GMM
+```
+
+and then visualize the collected batch of samples with the following plotting code:
+
+```python
+
+plt.figure(figsize=(8, 6))
+plt.scatter(Xs[:, 0], Xs[:, 1], c="green", s=10, alpha=0.7, label='Sample Points')
+plt.contour(xx, yy, Z, levels=jnp.logspace(0, 2, 12), cmap='viridis', alpha=0.8)
+plt.colorbar(label='Negative Log-Likelihood')
+plt.title('GMM Samples')
+plt.xlabel('Latent Dimension 1')
+plt.ylabel('Latent Dimension 2')
+plt.grid(True) #plt.show()
+plt.savefig("gmm_samples.jpg")
+
+plt.close()
+```
+
+which will produce a plot similar to the one below:
+
+
+
+Notice that the green-colored data points roughly adhere to the contours of the GMM distribution and look much like the original (blue-colored) dataset we collected. In this example scenario, we see that we can successfully learn the density of our latent code dataset, facilitating some level of downstream distributional analysis and generative model sampling.
diff --git a/docs/tutorials/neurocog/dynamic_synapses.md b/docs/tutorials/neurocog/dynamic_synapses.md
index bc708264..d4e9b902 100644
--- a/docs/tutorials/neurocog/dynamic_synapses.md
+++ b/docs/tutorials/neurocog/dynamic_synapses.md
@@ -3,7 +3,7 @@
In this lesson, we will study dynamic synapses, or synaptic cable components in
ngc-learn that evolve on fast time-scales in response to their pre-synaptic inputs.
These types of chemical synapse components are useful for modeling time-varying
-conductance which ultimately drives eletrical current input into neuronal units
+conductance which ultimately drives electrical current input into neuronal units
(such as spiking cells). Here, we will learn how to build three important types of dynamic synapses in
ngc-learn -- the exponential, the alpha, and the double-exponential synapse -- and visualize
the time-course of their resulting conductances. In addition, we will then
@@ -22,19 +22,16 @@ value matrices we might initially employ (as in synapse components such as the
[DenseSynapse](ngclearn.components.synapses.denseSynapse)).
Building a dynamic synapse can be done by importing the [exponential synapse](ngclearn.components.synapses.exponentialSynapse),
-the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_synapses.py` to place
+the [double-exponential synapse](ngclearn.components.synapses.doubleExpSynapse), or the [alpha synapse](ngclearn.components.synapses.alphaSynapse) from ngc-learn's in-built components and setting them up within a model context for easy analysis. Go ahead and create a Python script named `probe_dynamic_synapses.py` to place
the code you will write within.
-For the first part of this lesson, we will import all three dynamic synpapse models and compare their behavior.
+For the first part of this lesson, we will import all three dynamic synapse models and compare their behavior.
This can be done as follows (using the meta-parameters we provide in the code block below to ensure reasonable dynamics):
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoupleExpSynapse
-
-from ngcsimlib.compilers.process import Process
-from ngcsimlib.context import Context
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import Context, MethodProcess
+from ngclearn.components import ExponentialSynapse, AlphaSynapse, DoubleExpSynapse
+from ngclearn.utils.distribution_generator import DistributionGenerator
dkey = random.PRNGKey(1234) ## creating seeding keys for synapses
@@ -46,29 +43,27 @@ T = 8. # ms ## total duration time
with Context("dual_syn_system") as ctx:
Wexp = ExponentialSynapse( ## exponential dynamic synapse
name="Wexp", shape=(1, 1), tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
- weight_init=dist.constant(value=1.), key=subkeys[0]
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
)
Walpha = AlphaSynapse( ## alpha dynamic synapse
name="Walpha", shape=(1, 1), tau_decay=1., g_syn_bar=1., syn_rest=0., resist_scale=1.,
- weight_init=dist.constant(value=1.), key=subkeys[0]
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
)
- Wexp2 = DoupleExpSynapse(
+ Wexp2 = DoubleExpSynapse(
name="Wexp2", shape=(1, 1), tau_rise=1., tau_decay=3., g_syn_bar=1., syn_rest=0., resist_scale=1.,
- weight_init=dist.constant(value=1.), key=subkeys[0]
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
)
## set up basic simulation process calls
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> Wexp.advance_state
>> Walpha.advance_state
>> Wexp2.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> Wexp.reset
>> Walpha.reset
>> Wexp2.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
```
where we notice in the above we have instantiated three different kinds of chemical synapse components
@@ -90,7 +85,7 @@ $$
$$
where the conductance (for a post-synaptic unit) output of this synapse is driven by a sum over all of its incoming
-pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an expoential kernel (i.e., a low-pass filter).
+pre-synaptic spikes; this ODE means that pre-synaptic spikes are filtered via an exponential kernel (i.e., a low-pass filter).
On the other hand, for the alpha synapse, the dynamics adhere to the following coupled set of ODEs:
$$
@@ -100,7 +95,7 @@ $$
where $h_{\text{syn}}(t)$ is an intermediate variable that operates in service of driving the conductance variable $g_{\text{syn}}(t)$ itself.
The double-exponential (or difference of exponentials) synapse model looks similar to the alpha synapse except that the
-rise and fall/decay of its condutance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$,
+rise and fall/decay of its conductance dynamics are set separately using two different time constants, i.e., $\tau_{\text{rise}}$ and $\tau_{\text{decay}}$,
as follows:
$$
@@ -128,7 +123,7 @@ time_span = []
g = []
ga = []
gexp2 = []
-ctx.reset()
+reset_process.run()
Tsteps = int(T/dt) + 1
for t in range(Tsteps):
s_t = jnp.zeros((1, 1))
@@ -136,21 +131,23 @@ for t in range(Tsteps):
s_t = jnp.ones((1, 1))
Wexp.inputs.set(s_t)
Walpha.inputs.set(s_t)
- Wexp.v.set(Wexp.v.value * 0)
+ Wexp.v.set(Wexp.v.get() * 0)
Wexp2.inputs.set(s_t)
- Walpha.v.set(Walpha.v.value * 0)
- Wexp2.v.set(Wexp2.v.value * 0)
- ctx.run(t=t * dt, dt=dt)
-
- print(f"\r g = {Wexp.g_syn.value} ga = {Walpha.g_syn.value} gexp2 = {Wexp2.g_syn.value}", end="")
- g.append(Wexp.g_syn.value)
- ga.append(Walpha.g_syn.value)
+ Walpha.v.set(Walpha.v.get() * 0)
+ Wexp2.v.set(Wexp2.v.get() * 0)
+ advance_process.run(t=t * dt, dt=dt)
+
+ print(f"\r g = {Wexp.g_syn.get()} ga = {Walpha.g_syn.get()} gexp2 = {Wexp2.g_syn.get()}", end="")
+ g.append(Wexp.g_syn.get())
+ ga.append(Walpha.g_syn.get())
+ gexp2.append(Wexp2.g_syn.get())
time_span.append(t) #* dt)
print()
g = jnp.squeeze(jnp.concatenate(g, axis=1))
g = g/jnp.amax(g)
ga = jnp.squeeze(jnp.concatenate(ga, axis=1))
ga = ga/jnp.amax(ga)
+gexp2 = jnp.squeeze(jnp.concatenate(gexp2, axis=1))
gexp2 = gexp2/jnp.amax(gexp2)
```
@@ -195,6 +192,9 @@ ax.grid(which="major")
fig.savefig("alpha_syn.jpg")
plt.close()
+## ---- plot the double-exponential synapse conductance time-course ----
+fig, ax = plt.subplots()
+
gvals = ax.plot(time_span, gexp2, '-', color='tab:blue')
#plt.xticks(time_span, time_labs)
ax.set_xticks(time_ticks, time_labs)
@@ -207,7 +207,7 @@ plt.close()
```
which should produce and save three plots to disk. You can then compare and contrast the plots of the
-expoential, alpha synapse, and double-exponential conductance trajectories:
+exponential, alpha synapse, and double-exponential conductance trajectories:
```{eval-rst}
.. table::
@@ -222,7 +222,7 @@ expoential, alpha synapse, and double-exponential conductance trajectories:
Note that the alpha synapse (right figure) would produce a more realistic fit to recorded synaptic currents (as it attempts to model
the rise and fall of current in a less simplified manner) at the cost of extra compute, given it uses two ODEs to
-emulate condutance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure).
+emulate conductance, as opposed to the faster yet less-biophysically-realistic exponential synapse (left figure).
## Excitatory-Inhibitory Driven Dynamics
@@ -243,13 +243,10 @@ We will specifically model the excitatory and inhibitory conductance changes usi
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+from ngclearn.operations import Summation
from ngclearn.components import ExponentialSynapse, PoissonCell, LIFCell
-from ngclearn.operations import summation
-
-from ngcsimlib.compilers.process import Process
-from ngcsimlib.context import Context
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
## create seeding keys
dkey = random.PRNGKey(1234)
@@ -287,39 +284,36 @@ with Context("ei_snn") as ctx:
pre_inh = PoissonCell("pre_inh", n_units=n_inh, target_freq=inh_freq, key=subkeys[1]) ## pre-syn inhibitory group
Wexc = ExponentialSynapse( ## dynamic synapse between excitatory group and LIF
name="Wexc", shape=(n_exc,1), tau_decay=tau_syn_exc, g_syn_bar=g_e_bar, syn_rest=E_rest_exc, resist_scale=1./g_L,
- weight_init=dist.constant(value=1.), key=subkeys[2]
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2]
)
Winh = ExponentialSynapse( ## dynamic synapse between inhibitory group and LIF
name="Winh", shape=(n_inh, 1), tau_decay=tau_syn_inh, g_syn_bar=g_i_bar, syn_rest=E_rest_inh, resist_scale=1./g_L,
- weight_init=dist.constant(value=1.), key=subkeys[2]
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[2]
)
post_exc = LIFCell( ## post-syn LIF cell
"post_exc", n_units=1, tau_m=tau_m, resist_m=1., thr=v_thr, v_rest=v_rest, conduct_leak=1., v_reset=-75.,
tau_theta=0., theta_plus=0., refract_time=2., key=subkeys[3]
)
- Wexc.inputs << pre_exc.outputs
- Winh.inputs << pre_inh.outputs
- Wexc.v << post_exc.v ## couple voltage to exc synapse
- Winh.v << post_exc.v ## couple voltage to inh synapse
- post_exc.j << summation(Wexc.i_syn, Winh.i_syn) ## sum together excitatory & inhibitory pressures
+ pre_exc.outputs >> Wexc.inputs
+ pre_inh.outputs >> Winh.inputs
+ post_exc.v >> Wexc.v ## couple voltage to exc synapse
+ post_exc.v >> Winh.v ## couple voltage to inh synapse
+ Summation(Wexc.i_syn, Winh.i_syn) >> post_exc.j ## sum together excitatory & inhibitory pressures
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> pre_exc.advance_state
>> pre_inh.advance_state
>> Wexc.advance_state
>> Winh.advance_state
>> post_exc.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> pre_exc.reset
>> pre_inh.reset
>> Wexc.reset
>> Winh.reset
>> post_exc.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
```
### Examining the Simple Spiking Circuit's Behavior
@@ -331,18 +325,18 @@ volts = []
time_span = []
spikes = []
-ctx.reset()
+reset_process.run()
pre_exc.inputs.set(jnp.ones((1, n_exc)))
pre_inh.inputs.set(jnp.ones((1, n_inh)))
-post_exc.v.set(post_exc.v.value * 0 - 65.) ## initial condition for LIF is -65 mV
-volts.append(post_exc.v.value)
+post_exc.v.set(post_exc.v.get() * 0 - 65.) ## initial condition for LIF is -65 mV
+volts.append(post_exc.v.get())
time_span.append(0.)
Tsteps = int(T/dt) + 1
for t in range(1, Tsteps):
- ctx.run(t=t * dt, dt=dt)
- print(f"\r v {post_exc.v.value}", end="")
- volts.append(post_exc.v.value)
- spikes.append(post_exc.s.value)
+ advance_process.run(t=t * dt, dt=dt)
+ print(f"\r v {post_exc.v.get()}", end="")
+ volts.append(post_exc.v.get())
+ spikes.append(post_exc.s.get())
time_span.append(t) #* dt)
print()
volts = jnp.squeeze(jnp.concatenate(volts, axis=1))
@@ -384,9 +378,7 @@ ax.grid()
fig.savefig("ei_circuit_dynamics.jpg")
```
-which should produce a figure depicting dynamics similar to the one below. Black tick
-marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's
-voltage threshold.
+which should produce a figure depicting dynamics similar to the one below. Black tick marks indicate post-synaptic pulses whereas the horizontal dashed blue shows the LIF unit's voltage threshold.
```{eval-rst}
diff --git a/docs/tutorials/neurocog/error_cell.md b/docs/tutorials/neurocog/error_cell.md
index 04368d5d..b7fbdb11 100644
--- a/docs/tutorials/neurocog/error_cell.md
+++ b/docs/tutorials/neurocog/error_cell.md
@@ -60,8 +60,8 @@ The code you would write amounts to the below:
```python
from jax import numpy as jnp, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
@@ -71,32 +71,29 @@ T = 5 ## number time steps to simulate
with Context("Model") as model:
cell = GaussianErrorCell("z0", n_units=3)
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance_proc")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
-
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset_proc")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- @Context.dynamicCommand
- def clamp(x, y):
- ## error cells have two key input compartments; a "mu" and a "target"
- cell.mu.set(x)
- cell.target.set(y)
+## set up non-compiled utility commands
+def clamp(x, y):
+ ## error cells have two key input compartments; a "mu" and a "target"
+ cell.mu.set(x)
+ cell.target.set(y)
+
guess = jnp.asarray([[-1., 1., 1.]], jnp.float32) ## the produced guess or prediction
answer = jnp.asarray([[1., -1., 1.]], jnp.float32) ## what we wish the guess had been
-model.reset()
+reset_process.run()
for ts in range(T):
- model.clamp(guess, answer)
- model.advance(t=ts * 1., dt=dt)
+ clamp(guess, answer)
+ advance_process.run(t=ts * 1., dt=dt)
## extract compartment values of interest
- dmu = cell.dmu.value
- dtarget = cell.dtarget.value
- loss = cell.L.value
+ dmu = cell.dmu.get()
+ dtarget = cell.dtarget.get()
+ loss = cell.L.get()
## print compartment values to I/O
print("{} | dmu: {} dtarget: {} loss: {} ".format(ts, dmu, dtarget, loss))
```
diff --git a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md
index 6575fc06..5faed24b 100644
--- a/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md
+++ b/docs/tutorials/neurocog/fitzhugh_nagumo_cell.md
@@ -17,8 +17,7 @@ single component system made up of the Fitzhugh-Nagumo (`F-N`) cell.
from jax import numpy as jnp, random, jit
import numpy as np
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
@@ -36,46 +35,26 @@ w0 = -0.16983366 ## initial recovery value (for reset condition)
## create simple system with only one F-N cell
with Context("Model") as model:
- cell = FitzhughNagumoCell("z0", n_units=1, tau_w=tau_w, alpha=alpha, beta=beta,
- gamma=gamma, v0=v0, w0=w0, integration_type="euler")
+ cell = FitzhughNagumoCell(
+ "z0", n_units=1, tau_w=tau_w, alpha=alpha, beta=beta, gamma=gamma, v0=v0, w0=w0, integration_type="euler"
+ )
## create and compile core simulation commands
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
-In effect, the FitzHugh–Nagumo `F-N` two-dimensional differential
-equation system (developed by [1] and [2]) is
-a useful simplification of the more intricate Hodgkin–Huxley (H-H) squid axon
-model, attempting to extract some of the benefits of its more detailed modeling
-of the spiking cellular activation and deactivation dynamics (specifically
-attempting to isolate the properties related to sodium/potassium ion flow
-from cellular properties of excitation and propagation). Notably, the `F-N`
-cell models membrane potential `v` with a cubic function (which facilitates
-self-excitation through positive feedback) in tandem with a recovery variable `w`
-that provides a slower form of negative feedback. The linear dynamics that govern
-`w` are controlled by (dimensionless) coefficients `alpha` and `beta`, which
-control its shift and scale, respectively (another factor `gamma` is introduced
-in our implementation, which divides the cubic term in the voltage dynamics, but
-generally this can usually be set to either a value of `1` or `3` as in [1]).
-The value `tau_w` controls the time constant for the recovery variable (and,
-technically, ngc-learn implements `tau_m` to control the membrane potential,
-but this is default set to `1` since [1] and [2] typically only use a time
-constant for the recovery variable).
-
-The initial conditions for the voltage (i.e., `v0`) and the recovery (i.e., `w0`)
-have been set to particular interesting values above for the demonstration
-purposes of this tutorial but, by default, are `0` in the `F-N` cell component.
+In effect, the FitzHugh–Nagumo `F-N` two-dimensional differential equation system (developed by [1] and [2]) is a useful simplification of the more intricate Hodgkin–Huxley (H-H) squid axon model, attempting to extract some of the benefits of its more detailed modeling of the spiking cellular activation and deactivation dynamics (specifically attempting to isolate the properties related to sodium/potassium ion flow from cellular properties of excitation and propagation). Notably, the `F-N` cell models membrane potential `v` with a cubic function (which facilitates self-excitation through positive feedback) in tandem with a recovery variable `w` that provides a slower form of negative feedback. The linear dynamics that govern `w` are controlled by (dimensionless) coefficients `alpha` and `beta`, which control its shift and scale, respectively (another factor `gamma` is introduced in our implementation, which divides the cubic term in the voltage dynamics, but generally this can usually be set to either a value of `1` or `3` as in [1]).
+The value `tau_w` controls the time constant for the recovery variable (and, technically, ngc-learn implements `tau_m` to control the membrane potential, but this is default set to `1` since [1] and [2] typically only use a time constant for the recovery variable).
+
+The initial conditions for the voltage (i.e., `v0`) and the recovery (i.e., `w0`) have been set to particular interesting values above for the demonstration purposes of this tutorial but, by default, are `0` in the `F-N` cell component.
Formally, the core dynamics of the `F-N` can be written out as follows:
@@ -84,24 +63,13 @@ $$
\tau_w \frac{\partial \mathbf{w}_t}{\partial t} &= \mathbf{v}_t + a - b\mathbf{w}_t
$$
-where $a$ and $b$ are factors that drive the recovery variable's dynamics
-(shift and scaling, respectively), $R$ is the membrane resistance, $\tau_m$ is the
-membrane time constant, and $\tau_w$ is the recovery time constant ($g$ is a
-dividing constant meant to dampen the effects of the cubic term, but is generally
-set to $g = 1$ to adhere to [1] and [2])
+where $a$ and $b$ are factors that drive the recovery variable's dynamics (shift and scaling, respectively), $R$ is the membrane resistance, $\tau_m$ is the membrane time constant, and $\tau_w$ is the recovery time constant ($g$ is a dividing constant meant to dampen the effects of the cubic term, but is generally set to $g = 1$ to adhere to [1] and [2])
### Simulating a FitzHugh–Nagumo Neuronal Cell
-Given that we have a single-cell dynamical system set up as above, we can next
-write some code for visualizing how the `F-N` node's membrane potential and
-coupled recovery variable evolve with time (specifically over a period of about
-`200` milliseconds). We will, much as we did with the leaky integrators in
-prior tutorials, inject an electrical current `j` into the `F-N` cell (this time
-just a constant current value of `0.23` amperes) and observe how the cell
-produces action potentials.
-Specifically, we can plot the neuron's voltage `v` and recovery variable `w`
-as follows:
+Given that we have a single-cell dynamical system set up as above, we can next write some code for visualizing how the `F-N` node's membrane potential and coupled recovery variable evolve with time (specifically over a period of about `200` milliseconds). We will, much as we did with the leaky integrators in prior tutorials, inject an electrical current `j` into the `F-N` cell (this time just a constant current value of `0.23` amperes) and observe how the cell produces action potentials.
+Specifically, we can plot the neuron's voltage `v` and recovery variable `w` as follows:
```python
curr_in = []
@@ -119,26 +87,26 @@ time_span = np.linspace(0, 200, num=T)
dt = time_span[1] - time_span[0] # ~ 0.13342228152101404 ms
time_span = []
-model.reset()
+reset_process.run()
t = 0.
for ts in range(T):
x_t = data
## pass in t and dt and run step forward of simulation
- model.clamp(x_t)
- model.advance(t=t, dt=dt)
+ clamp(x_t)
+ advance_process.run(t=t, dt=dt)
t = t + dt
## naively extract simple statistics at time ts and print them to I/O
- v = cell.v.value
- w = cell.w.value
- s = cell.s.value
+ v = cell.v.get()
+ w = cell.w.get()
+ s = cell.s.get()
curr_in.append(data)
mem_rec.append(v)
recov_rec.append(w)
spk_rec.append(s)
## print stats to I/O (overriding previous print-outs to reduce clutter)
print("\r {}: s {} ; v {} ; w {}".format(ts, s, v, w), end="")
- time_span.append((ts)*dt)
+ time_span.append(ts * dt)
print()
import matplotlib #.pyplot as plt
@@ -169,38 +137,12 @@ plt.tight_layout()
plt.savefig("{0}".format("fncell_plot.jpg"))
```
-You should get a plot that depicts the evolution of the voltage and recovery,
-i.e., saved as `fncell_plot.jpg` locally to disk, like the one below:
+You should get a plot that depicts the evolution of the voltage and recovery, i.e., saved as `fncell_plot.jpg` locally to disk, like the one below:
-A useful note is that the `F-N` above used Euler integration to step through its
-dynamics (this is the default/base routine for all cell components in ngc-learn);
-however, one could configure it to use the midpoint method for integration
-by setting its argument `integration_type = rk2` in cases where more
-accuracy in the dynamics is needed (at the cost of additional computational time).
-
-## Optional: Setting Up The Components with a JSON Configuration
-
-While you are not required to create a JSON configuration file for ngc-learn,
-to get rid of the warning that ngc-learn will throw at the start of your
-program's execution (indicating that you do not have a configuration set up yet),
-all you need to do is create a sub-directory for your JSON configuration
-inside of your project code's directory, i.e., `json_files/modules.json`.
-Inside the JSON file, you would write the following:
-
-```json
-[
- {"absolute_path": "ngclearn.components",
- "attributes": [
- {"name": "FitzHughNagumoCell"}]
- },
- {"absolute_path": "ngcsimlib.operations",
- "attributes": [
- {"name": "overwrite"}]
- }
-]
-```
+A useful note is that the `F-N` above used Euler integration to step through its dynamics (this is the default/base routine for all cell components in ngc-learn); however, one could configure it to use the midpoint method for integration by setting its argument `integration_type = rk2` in cases where more accuracy in the dynamics is needed (at the cost of additional computational time).
+
## References
diff --git a/docs/tutorials/neurocog/hebbian.md b/docs/tutorials/neurocog/hebbian.md
index 8e67754c..6c6afbed 100644
--- a/docs/tutorials/neurocog/hebbian.md
+++ b/docs/tutorials/neurocog/hebbian.md
@@ -21,30 +21,29 @@ Specifically, we will zoom in on two particular code snippets from
below:
```python
-Wab = HebbianSynapse(name="Wab", shape=(1, 1), eta=1., signVal=-1.,
- wInit=("constant", 1., None), w_bound=0., key=subkeys[3])
+Wab = HebbianSynapse(
+ name="Wab", shape=(1, 1), eta=1., signVal=-1., wInit=("constant", 1., None), w_bound=0., key=subkeys[3]
+)
# wire output compartment (rate-coded output zF) of RateCell `a` to input compartment of HebbianSynapse `Wab`
-Wab.inputs << a.zF
+a.zF >> Wab.inputs
# wire output compartment of HebbianSynapse `Wab` to input compartment (electrical current j) RateCell `b`
-b.j << Wab.outputs
+Wab.outputs >> b.j
# wire output compartment (rate-coded output zF) of RateCell `a` to presynaptic compartment of HebbianSynapse `Wab`
-Wab.pre << a.zF
+a.zF >> Wab.pre
# wire output compartment (rate-coded output zF) of RateCell `b` to postsynaptic compartment of HebbianSynapse `Wab`
-Wab.post << b.zF
+b.zF >> Wab.post
```
as well as (a bit later in the model construction code):
```python
-evolve_process = (JaxProcess()
+evolve_process = (MethodProcess()
>> a.evolve)
-circuit.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
-advance_process = (JaxProcess()
+advance_process = (MethodProcess()
>> a.advance_state)
-circuit.wrap_and_add_command(jit(advance_process.pure), name="advance")
```
Notice that beyond wiring component `a`'s values into the synapse `Wab`'s input compartment
@@ -54,7 +53,7 @@ post-synaptic compartment `Wab.post`. These compartments are specifically
used in `Wab`'s `evolve` call and are not strictly required to be exactly
the same as its input and output compartments. Note that, if one wanted `pre`
and `post` to be exactly identical to `inputs` and `outputs`, one would simply need
-to write `Wab.pre << Wab.inputs` and `Wab.post << Wab.outputs` in place
+to write `Wab.inputs >> Wab.pre` and `Wab.outputs >> Wab.post` in place
of the pre- and post-synaptic compartment calls above.
The above snippets highlight two key aspects of functionality that a synapse
diff --git a/docs/tutorials/neurocog/hodgkin_huxley_cell.md b/docs/tutorials/neurocog/hodgkin_huxley_cell.md
index e055b5c5..56a17cba 100755
--- a/docs/tutorials/neurocog/hodgkin_huxley_cell.md
+++ b/docs/tutorials/neurocog/hodgkin_huxley_cell.md
@@ -1,16 +1,12 @@
# Lecture 2E: The Hodgkin-Huxley Cell
-In this tutorial, we will study/setup one of the most important biophysical
-neuronal models in computational neuroscience -- the Hodgkin-Huxley (H-H) spiking
-cell model.
+In this tutorial, we will study/setup one of the most important and sophisticated biophysical neuronal models in computational neuroscience -- the Hodgkin-Huxley (H-H) spiking cell model.
## Using and Probing the H-H Cell
-Go ahead and make a new folder for this study and create a Python script,
-i.e., `run_hhcell.py`, to write your code for this part of the tutorial.
+Go ahead and make a new folder for this study and create a Python script, i.e., `run_hhcell.py`, to write your code for this part of the tutorial.
-Now let's set up the controller for this lesson's simulation and construct a
-single component system made up of an H-H cell.
+Now let's set up the controller for this lesson's simulation and construct a single component system made up of an H-H cell.
### Instantiating the H-H Neuronal Cell
@@ -22,9 +18,7 @@ H-H cell amounts to the following:
from jax import numpy as jnp, random, jit
import numpy as np
-from ngclearn.utils.model_utils import scanner
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
@@ -52,18 +46,15 @@ with Context("Model") as model:
)
## create and compile core simulation commands
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
Notably, the H-H model is a four-dimensional differential equation system, invented in 1952
@@ -86,15 +77,12 @@ $$
\frac{\partial \mathbf{h}_t}{\partial t} &= \alpha_h(\mathbf{v}_t) * (1 - \mathbf{h}_t) - \beta_h(\mathbf{v}_t) * \mathbf{h}_t
$$
-where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biopphysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to.
+where we observe that the above four-dimensional set of dynamics is composed of nonlinear ODEs. Notice that, in each gate or channel probability ODE, there are two generator functions (each of which is a function of the membrane potential $\mathbf{v}_t$) that produces the necessary dynamic coefficients at time $t$; $\alpha_x(\mathbf{v}_t)$ and $\beta_x(\mathbf{v}_t)$ produce different biophysical weighting values depending on which channel $x = \{n, m, h\}$ they are related to.
Note that, in ngc-learn's implementation of the H-H cell model, most of the core coefficients have been generally set according to Hodgkin and Huxley's 1952 work but can be configured by the experimenter to obtain different kinds of behavior/dynamics.
### Simulating the H-H Neuronal Cell
-To see how the H-H cell works, we next write some code for visualizing how
-the node's membrane potential and core related gates/channels evolve with time
-(over a period of about `200` milliseconds). We will inject a square input pulse current
-into our H-H cell (specifically into its `j` compartment) and observe how the cell behaves in response.
+To see how the H-H cell works, we next write some code for visualizing how the node's membrane potential and core related gates/channels evolve with time (over a period of about `200` milliseconds). We will inject a square input pulse current into our H-H cell (specifically into its `j` compartment) and observe how the cell behaves in response.
Specifically, we simulate the injection of this kind of current via the code below:
```python
@@ -110,17 +98,17 @@ v = []
n = []
m = []
h = []
-model.reset()
+reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- model.clamp(x_t)
- model.run(t=ts * dt, dt=dt)
- outs.append(a.s.value)
- n.append(cell.n.value[0, 0])
- m.append(cell.m.value[0, 0])
- h.append(cell.h.value[0, 0])
- v.append(cell.v.value[0, 0])
- print(f"\r {ts} v = {cell.v.value}", end="")
+ clamp(x_t)
+ advance_process.run(t=ts * dt, dt=dt)
+ outs.append(cell.s.get())
+ n.append(cell.n.get()[0, 0])
+ m.append(cell.m.get()[0, 0])
+ h.append(cell.h.get()[0, 0])
+ v.append(cell.v.get()[0, 0])
+ print(f"\r {ts} v = {cell.v.get()}", end="")
time_span.append(ts*dt)
outs = jnp.concatenate(outs, axis=1)
v = jnp.array(v)
@@ -128,8 +116,7 @@ time_span = jnp.array(time_span)
outs = jnp.squeeze(outs)
```
-and we can plot the dynamics of the neuron's voltage `v` and its three gate/channel
-variables, `h`, `m`, and `n`, with the following:
+and we can plot the dynamics of the neuron's voltage `v` and its three gate/channel variables, `h`, `m`, and `n`, with the following:
```python
import matplotlib.pyplot as plt
@@ -159,9 +146,7 @@ plt.savefig("{0}".format("hh_plot.jpg"))
plt.close()
```
-You should get a compound plot that depict the evolution of the H-H cell's voltage
-and channel/gate variables, i.e., saved as `hh_plot.jpg` locally to
-disk, like the one below:
+You should get a compound plot that depict the evolution of the H-H cell's voltage and channel/gate variables, i.e., saved as `hh_plot.jpg` locally to disk, like the one below:
```{eval-rst}
.. table::
@@ -174,38 +159,11 @@ disk, like the one below:
+--------------------------------------------------------+
```
-A useful note is that the H-H cell above used Euler integration to step through its
-dynamics (this is the default/base routine for all cell components in ngc-learn).
-However, one could configure the cell to use the midpoint method for integration
-by setting its argument `integration_type = rk2` or the Runge-Kutta fourth-order
-routine via `integration_type=rk4` for cases where, at the cost of increased
-compute time, more accurate dynamics are possible.
-
-## Optional: Setting Up The Components with a JSON Configuration
-
-While you are not required to create a JSON configuration file for ngc-learn,
-to get rid of the warning that ngc-learn will throw at the start of your
-program's execution (indicating that you do not have a configuration set up yet),
-all you need to do is create a sub-directory for your JSON configuration
-inside of your project code's directory, i.e., `json_files/modules.json`.
-Inside the JSON file, you would write the following:
-
-```json
-[
- {"absolute_path": "ngclearn.components",
- "attributes": [
- {"name": "HodgkinHuxleyCell"}]
- },
- {"absolute_path": "ngcsimlib.operations",
- "attributes": [
- {"name": "overwrite"}]
- }
-]
-```
+A useful note is that the H-H cell above used Euler integration to step through its dynamics (this is the default/base routine for all cell components in ngc-learn).
+However, one could configure the cell to use the midpoint method for integration by setting its argument `integration_type = rk2` or the Runge-Kutta fourth-order routine via `integration_type=rk4` for cases where, at the cost of increased compute time, more accurate dynamics are possible.
+
## References
-[1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description
-of membrane current and its application to conduction and excitation in nerve."
-The Journal of physiology 117.4 (1952): 500.
+[1] Hodgkin, Alan L., and Andrew F. Huxley. "A quantitative description of membrane current and its application to conduction and excitation in nerve." The Journal of physiology 117.4 (1952): 500.
diff --git a/docs/tutorials/neurocog/index.rst b/docs/tutorials/neurocog/index.rst
index 326591c2..b702a535 100644
--- a/docs/tutorials/neurocog/index.rst
+++ b/docs/tutorials/neurocog/index.rst
@@ -5,31 +5,20 @@
Neurocognitive Modeling Lessons
===============================
-A central motivation for using ngc-learn is to flexibly build computational
-models of neuronal information processing, dynamics, and credit
-assignment (as well as design one's own custom instantiations of their
-mathematical formulations and ideas). In this set of tutorials, we will go
-through the central basics of using ngc-learn's in-built biophysical components,
-also called "cells" and "synapses", to craft and simulate adaptive neural systems
-and biophysical computational models.
+A central motivation for using ngc-learn is to flexibly build computational models of neuronal information processing,
+dynamics, and credit assignment (as well as design custom instantiations of one's own mathematical formulations and
+ideas). In this set of tutorials, we will go through the central basics of using ngc-learn's in-built biophysical
+components, also called "cells" and "synapses", to craft and simulate adaptive neural systems and biophysical
+computational models.
-Usefully, ngc-learn starts with a collection of cells -- those that are partitioned
-into those that are graded / real-valued (`ngclearn.components.neurons.graded`)
-and those that spike (`ngclearn.components.neurons.spiking`). In addition,
-ngc-learn supports another collection called synapses -- generally, those that
-adapt (or "learn") with biological credit assignment building blocks
-(such as those in `ngclearn.components.synapses.hebbian`) such as
-spike-timing-dependent plasticity and multi-factor rules. With the in-built,
-standard cells and synapses in these two
-core collections, you can readily construct a wide variety of models, recovering
-many classical ones previously proposed in computational neuroscience
-and brain-inspired computing researach (many of these kinds of models are available
-for external download in the `Model Museum `_).
-
-While the reader is free to jump into any one self-contained tutorial in any
-order based on their needs, we organize, within each topic, the lessons starting
-from more basic, foundational modeling modules and library tools and sequentially
-work towards more advanced concepts.
+Usefully, ngc-learn starts with a collection of cells -- those that are partitioned into those that are graded /
+real-valued (`ngclearn.components.neurons.graded`) and those that spike (`ngclearn.components.neurons.spiking`). In
+addition, ngc-learn supports another collection called synapses -- generally, those that adapt (or "learn") with
+biological credit assignment building blocks (such as those in `ngclearn.components.synapses.hebbian`) such as
+spike-timing-dependent plasticity and multi-factor rules. With the in-built, standard cells and synapses in these two
+core collections, you can readily construct a wide variety of models, recovering many classical ones previously
+proposed in computational neuroscience and brain-inspired computing research (many of these kinds of models are
+available for external download in the `Model Museum `_).
.. toctree::
:maxdepth: 1
@@ -73,3 +62,4 @@ work towards more advanced concepts.
plotting
metrics
integration
+ density_modeling
diff --git a/docs/tutorials/neurocog/input_cells.md b/docs/tutorials/neurocog/input_cells.md
index 1a58adac..c39c6ca7 100644
--- a/docs/tutorials/neurocog/input_cells.md
+++ b/docs/tutorials/neurocog/input_cells.md
@@ -39,8 +39,7 @@ spike train over $100$ steps in time as follows:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
from ngclearn.utils.viz.raster import create_raster_plot
## import model-specific mechanisms
@@ -56,27 +55,24 @@ T = 100 ## number time steps to simulate
with Context("Model") as model:
cell = BernoulliCell("z0", n_units=10, key=subkeys[0])
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance_proc")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset_proc")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- @Context.dynamicCommand
- def clamp(x):
- cell.inputs.set(x)
+def clamp(x):
+ cell.inputs.set(x)
+
probs = jnp.asarray([[0.8, 0.2, 0., 0.55, 0.9, 0, 0.15, 0., 0.6, 0.77]], dtype=jnp.float32)
spikes = []
-model.reset()
+reset_process.run()
for ts in range(T):
- model.clamp(probs)
- model.advance(t=ts * 1., dt=dt)
+ clamp(probs)
+ advance_process.run(t=ts * 1., dt=dt)
- s_t = cell.outputs.value
+ s_t = cell.outputs.get()
spikes.append(s_t)
spikes = jnp.concatenate(spikes, axis=0)
create_raster_plot(spikes, plot_fname="input_cell_raster.jpg")
@@ -121,7 +117,7 @@ and by replacing the line that has the `BernoulliCell` call with the
following line instead:
```python
-cell = PoissonCell("z0", n_units=10, max_freq=63.75, key=subkeys[0])
+cell = PoissonCell("z0", n_units=10, target_freq=63.75, key=subkeys[0])
```
Running the code with the two above small modifications will
@@ -149,12 +145,12 @@ mu = 0.
probs = jnp.asarray([[1.]],dtype=jnp.float32)
for _ in range(n_trials):
spikes = []
- model.reset()
+ reset_process.run()
for ts in range(T):
- model.clamp(probs)
- model.advance(t=ts*1., dt=dt)
+ clamp(probs)
+ advance_process.run(t=ts * 1., dt=dt)
- s_t = cell.outputs.value
+ s_t = cell.outputs.get()
spikes.append(s_t)
count = jnp.sum(jnp.concatenate(spikes, axis=0))
mu += count
diff --git a/docs/tutorials/neurocog/integration.md b/docs/tutorials/neurocog/integration.md
index db3fa1ca..95320c6f 100644
--- a/docs/tutorials/neurocog/integration.md
+++ b/docs/tutorials/neurocog/integration.md
@@ -1,32 +1,14 @@
# Numerical Integration
-In constructing one's own biophysical models, particularly those of phenomena
-that change with time, ngc-learn offers useful flexible tools for numerical
-integration that facilitate an easier time in constructing your own components
-that play well with the library's simulation backend. Knowing how things work
-beyond Euler integration -- the base/default form of integration often employed
-by ngc-learn -- might be useful for constructing and simulating dynamics more
-accurately (often at the cost of additional computational time).
+In constructing one's own biophysical models, particularly those of phenomena that change with time, ngc-learn offers useful flexible tools for numerical integration that facilitate an easier time in constructing your own components that play well with the library's simulation backend. Knowing how things work beyond Euler integration -- the base/default form of integration often employed by ngc-learn -- might be useful for constructing and simulating dynamics more accurately (often at the cost of additional computational time).
## Euler Integration
-Euler integration is very simple (and fast) way of using the ordinary differential
-equations you typically define for the cellular dynamics of various components
-in ngc-learn (which typically get called in any component's `AdvanceState()`
-command).
+Euler integration is very simple (and fast) way of using the ordinary differential equations you typically define for the cellular dynamics of various components in ngc-learn (which typically get called in any component's `advance_state()` command).
-While utilizing the numerical integrator will depend on your component's design
-and the (biophysical) elements you wish to model, let's observe ngc-learn's
-base backend utilities (its integration backend `ngclearn.utils.diffeq`) in
-the context of numerically integrating a simple
-differential equation; specifically the autonomous (linear) ordinary differential equation (ODE):
-$\frac{\partial y(t)}{\partial t} = y(t)$. The analytic
-solution to this equation is also simple -- it is $y(t) = e^{t}$.
+While utilizing the numerical integrator will depend on your component's design and the (biophysical) elements you wish to model, let's observe ngc-learn's base backend utilities (its integration backend `ngclearn.utils.diffeq`) in the context of numerically integrating a simple differential equation; specifically the autonomous (linear) ordinary differential equation (ODE): $\frac{\partial y(t)}{\partial t} = y(t)$. The analytic solution to this equation is also simple -- it is $y(t) = e^{t}$.
-If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$
-in a rather simple format[^1], you can write the following code to examine how
-Euler integration approximates the analytical solution (in this example, we
-examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)
+If you have defined your differential equation $\frac{\partial y(t)}{\partial t}$ in a rather simple format[^1], you can write the following code to examine how Euler integration approximates the analytical solution (in this example, we examine just two different step sizes, i.e., `dt = 0.1` and `dt = 0.09`)
```python
from jax import numpy as jnp, random, jit, nn
@@ -89,41 +71,13 @@ which should yield you a plot like the one below:
-Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's
-Euler integrator and typically, when constructing your biophysical models, you
-will need to think about this constant in the context of your simulation time-scale
-and what you intend to model. Note that, in many biophysical component cells,
-you will have an integration time constant of some form, i.e., a $\tau$, that you
-can control, allowing you to fix your `dt` to your simulated time-scale
-(say to a value like `dt = 1` millisecond) while tuning/altering your
-time constant $\tau$ (since the differential equation will be weighted
-by $\frac{\Delta t}{\tau}$).
+Notice how the integration constant `dt` (or $\Delta t$) chosen affects the approximation of ngc-learn's Euler integrator and typically, when constructing your biophysical models, you will need to think about this constant in the context of your simulation time-scale and what you intend to model. Note that, in many biophysical component cells, you will have an integration time constant of some form, i.e., a $\tau$, that you can control, allowing you to fix your `dt` to your simulated time-scale (say to a value like `dt = 1` millisecond) while tuning/altering your time constant $\tau$ (since the differential equation will be weighted by $\frac{\Delta t}{\tau}$).
## Higher-Order Forms of (Explicit) Integration
-Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond
-the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as
-the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method
-(also known as the trapezoid method). These forms of integration might be useful particularly
-if a cell or plastic synaptic component you might be writing follows dynamics
-that are more nonlinear or biophysically complex (requiring a higher degree
-of simulation accuracy). For instance, ngc-learn's in-built cell components,
-particularly those of higher biophysical complexity -- like the
-[Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the
-[FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) --
-contain argument flags for switching their simulation steps to use RK-2.
-
-To illustrate the value of higher-order numerical integration methods, let us
-examine a simple polynomial equation (thus nonlinear) that is further
-non-autonomous, i.e., it is a function of the time variable $t$ itself. A
-possible set of dynamics in this case might be:
-$\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which
-has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (
-where we will set $C = 1$). You can write code like below, importing from
-`ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`),
-the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare
-how these methods approximate the nonlinear dynamics inherent to our
-constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
+Notably, ngc-learn has built-in several forms of (explicit) numerical integration beyond the Euler method, such as a second order Runge-Kutta (RK-2) method (also known as the midpoint method) and 4th-order Runge-Kutta (RK-4) method or an error-predictor method such as Heun's method (also known as the trapezoid method). These forms of integration might be useful particularly if a cell or plastic synaptic component you might be writing follows dynamics that are more nonlinear or biophysically complex (requiring a higher degree of simulation accuracy). For instance, ngc-learn's in-built cell components, particularly those of higher biophysical complexity -- like the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell) or the [FitzhughNagumo cell](ngclearn.components.neurons.spiking.fitzhughNagumoCell) -- contain argument flags for switching their simulation steps to use RK-2.
+
+To illustrate the value of higher-order numerical integration methods, let us examine a simple polynomial equation (thus nonlinear) that is further non-autonomous, i.e., it is a function of the time variable $t$ itself. A possible set of dynamics in this case might be: $\frac{\partial y(t)}{\partial t} = -2 t^3 + 12 t^2 - 20 t + 8.5$ which has the analytic solution $y(t) = -(1/2) t^4 + 4 t^3 - 10 t^2 + 8.5 t + C$ (where we will set $C = 1$). You can write code like below, importing from `ngclearn.utils.diffeq.ode_utils` the Euler routine (`step_euler`), the RK-2 routine (`step_rk2`), the RK-4 routine (`step_rk4`), and Heun's method (`step_heun`), and compare how these methods approximate the nonlinear dynamics inherent to our constructed $\frac{\partial y(t)}{\partial t}$ ODE below:
```python
from jax import numpy as jnp, random, jit, nn
@@ -194,12 +148,7 @@ which should yield you a plot like the one below:
-As you might observe, RK-4 gives the best approximation of the solution. In addition,
-when the integration step size is held constant, Euler integration
-does quite poorly over just a few steps while RK-2 and Heun's method do much better
-at approximating the analytical equation. In the end, the type of numerical integration method employed can
-matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy
-for more nonlinear dynamics like in our example above.
+As you might observe, RK-4 give the best approximation of the solution. In addition, when the integration step size is held constant, Euler integration does quite poorly over just a few steps while RK-2 and Heun's method do much better at approximating the analytical equation. In the end, the type of numerical integration method employed can matter depending on the ODE(s) you use in modeling, particularly if you seek higher accuracy for more nonlinear dynamics like in our example above.
[^1]: The format expected by ngc-learn's backend is that the differential equation
provides a functional API/form like so: for instance `dy/dt = diff_eqn(t, y(t), params)`,
diff --git a/docs/tutorials/neurocog/izhikevich_cell.md b/docs/tutorials/neurocog/izhikevich_cell.md
index 6d1449a6..bdbdc742 100644
--- a/docs/tutorials/neurocog/izhikevich_cell.md
+++ b/docs/tutorials/neurocog/izhikevich_cell.md
@@ -19,8 +19,7 @@ single component system made up of the Izhikevich (`IZH`) cell.
from jax import numpy as jnp, random, jit
import numpy as np
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell
@@ -39,43 +38,25 @@ coupling_factor = 0.2
## create simple system with only one Izh Cell
with Context("Model") as model:
- cell = IzhikevichCell("z0", n_units=1, tau_w=tau_w, v_reset=v_reset,
- w_reset=w_reset, coupling_factor=coupling_factor,
- integration_type="euler", v0=v0, w0=w0, key=subkeys[0])
+ cell = IzhikevichCell(
+ "z0", n_units=1, tau_w=tau_w, v_reset=v_reset, w_reset=w_reset, coupling_factor=coupling_factor,
+ integration_type="euler", v0=v0, w0=w0, key=subkeys[0]
+ )
## create and compile core simulation commands
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
-The Izhikevich `IZH`, much like the FitzHugh–Nagumo cell covered in
-[a different lesson](../neurocog/fitzhugh_nagumo_cell.md), is a two-dimensional
-differential equation system (developed in [1]) that attempts to (approximately)
-model spiking cellular activation and deactivation dynamics. Notably, the `IZH`
-cell models membrane potential `v` (using a squared term) jointly with a
-recovery variable `w` (which is meant to provide a slower form of negative feedback).
-In his model, Izhikevich introduced four important control factors/coefficients,
-the choices of values for each changes the behavior of the neuronal model and
-thus recovering dynamics of different classes of neurons found in the brain.
-Several of these control factors have been renamed and/or mapped to more
-explicit descriptors in ngc-learn (for example, Izhikevich's original factor
-`a` has been mapped to `a = 1/tau_w` allowing the user to define the time
-constant for the recovery variable much in the same manner as the
-FitzHugh–Nagumo cell). Also like the FitzHugh–Nagumo cell, the Izhikevich model
-contains configurable initial conditions for its voltage (i.e., `v0`) and
-recovery values (i.e., `w0`), which we see have been set to interesting values
-for the purposes of this lesson (these are actually the default values of
-the Izhikevich component, i.e., `v0=-65` and `w0=-14`).
+The Izhikevich `IZH`, much like the FitzHugh–Nagumo cell covered in [a different lesson](../neurocog/fitzhugh_nagumo_cell.md), is a two-dimensional differential equation system (developed in [1]) that attempts to (approximately) model spiking cellular activation and deactivation dynamics. Notably, the `IZH` cell models membrane potential `v` (using a squared term) jointly with a recovery variable `w` (which is meant to provide a slower form of negative feedback).
+In his model, Izhikevich introduced four important control factors/coefficients, the choices of values for each will change the behavior of the neuronal model and thus recovering dynamics of different classes of neurons found in the brain. Several of these control factors have been renamed and/or mapped to more explicit descriptors in ngc-learn (for example, Izhikevich's original factor `a` has been mapped to `a = 1/tau_w` allowing the user to define the time constant for the recovery variable much in the same manner as the FitzHugh–Nagumo cell). Also like the FitzHugh–Nagumo cell, the Izhikevich model contains configurable initial conditions for its voltage (i.e., `v0`) and recovery values (i.e., `w0`), which we see have been set to interesting values for the purposes of this lesson (these are actually the default values of the Izhikevich component, i.e., `v0=-65` and `w0=-14`).
Formally, the core dynamics of the `IZH` can be written out as follows:
@@ -84,22 +65,12 @@ $$
\tau_w \frac{\partial \mathbf{w}_t}{\partial t} &= b \mathbf{v}_t - \mathbf{w}_t
$$
-where $b$ is the coupling factor, $R$ is the membrane resistance, $\tau_m$ is the
-membrane time constant, and $\tau_w$ is the recovery time constant (technically,
-$\tau_m = 1$, $R = 1$, and $\tau_w = 1/a$ to get to the perspective originally
-put forth in [1]).
+where $b$ is the coupling factor, $R$ is the membrane resistance, $\tau_m$ is the membrane time constant, and $\tau_w$ is the recovery time constant (technically, $\tau_m = 1$, $R = 1$, and $\tau_w = 1/a$ to get to the perspective originally put forth in [1]).
### Simulating a Izhikevich Neuronal Cell
-Given the single-cell dynamical system we set up above, we finally write
-some code that uses and visualizes the flow of the `IZH` cell's membrane
-potential and coupled recovery variable (specifically over a period of about
-`200` milliseconds). We will, much as we did with the leaky integrators in
-prior tutorials, inject an electrical current `j` into the `IZH` cell -- this
-time with a constant current value of `10` amperes -- and observe how the cell
-produces action potentials.
-Specifically, we can plot the `IZH` neuron's voltage `v` and recovery variable `w`
-in the following manner:
+Given the single-cell dynamical system we set up above, we finally write some code that uses and visualizes the flow of the `IZH` cell's membrane potential and coupled recovery variable (specifically over a period of about `200` milliseconds). We will, much as we did with the leaky integrators in prior tutorials, inject an electrical current `j` into the `IZH` cell -- this time with a constant current value of `10` amperes -- and observe how the cell produces action potentials.
+Specifically, we can plot the `IZH` neuron's voltage `v` and recovery variable `w` in the following manner:
```python
curr_in = []
@@ -114,19 +85,19 @@ i_app = 10. # 0.23 ## electrical current to inject into F-N cell
data = jnp.asarray([[i_app]], dtype=jnp.float32)
time_span = []
-model.reset()
+reset_process.run()
t = 0.
for ts in range(T):
x_t = data
## pass in t and dt and run step forward of simulation
- model.clamp(x_t)
- model.advance(t=t, dt=dt)
+ clamp(x_t)
+ advance_process.run(t=t, dt=dt)
t = t + dt
## naively extract simple statistics at time ts and print them to I/O
- v = cell.v.value
- w = cell.w.value
- s = cell.s.value
+ v = cell.v.get()
+ w = cell.w.get()
+ s = cell.s.get()
curr_in.append(data)
mem_rec.append(v)
recov_rec.append(w)
@@ -153,8 +124,9 @@ n_plots = 1
fig, ax = plt.subplots(1, n_plots, figsize=(5*n_plots,5))
ax_ptr = ax
-ax_ptr.set(xlabel='Time', ylabel='Voltage (v), Recovery (w)',
- title="Izhikevich (RS) Voltage/Recovery Dynamics")
+ax_ptr.set(
+ xlabel='Time', ylabel='Voltage (v), Recovery (w)', title=f"Izhikevich ({cell_tag}) Voltage/Recovery Dynamics"
+)
v = ax_ptr.plot(time_span, mem_rec, color='C0')
w = ax_ptr.plot(time_span, recov_rec, color='C1', alpha=.5)
@@ -164,22 +136,12 @@ plt.tight_layout()
plt.savefig("{0}".format("izhcell_plot.jpg"))
```
-You should get a plot that depicts the evolution of the voltage and recovery of
-the Izhikevich cell, i.e., saved as `izhcell_plot.jpg` locally to disk, much
-like the one below:
+You should get a plot that depicts the evolution of the voltage and recovery of the Izhikevich cell, i.e., saved as `izhcell_plot.jpg` locally to disk, much like the one below:
-The plot above, which you can modify slightly yourself to include the neuronal
-type tag "RS" like we do, actually depicts the dynamics for a specific type of spiking
-neuron called the "regular spiking" (RS) neuron (also the default configuration
-for ngc-learn's neuronal cell implementation), which is only one of several
-kinds of neurons you can emulate with Izhikevich's dynamics implemented in
-ngc-learn. Try modifying the exposed Izhikevich cell hyper-parameters above
-and setting them to particular values (such as those noted in the
-component's documentation) to recreate other possible neuron types. For
-example, to obtain a "fast spiking" (FS) neuronal cell, all you would need to
-do is modify the recovery variable's time constant like so:
+The plot above, which you can modify slightly yourself to include the neuronal type tag "RS" like we do, actually depicts the dynamics for a specific type of spiking neuron called the "regular spiking" (RS) neuron (also the default configuration for ngc-learn's neuronal cell implementation), which is only one of several kinds of neurons you can emulate with Izhikevich's dynamics implemented in
+ngc-learn. Try modifying the exposed Izhikevich cell hyper-parameters above and setting them to particular values (such as those noted in the component's documentation) to recreate other possible neuron types. For example, to obtain a "fast spiking" (FS) neuronal cell, all you would need to do is modify the recovery variable's time constant like so:
```python
## FS cell configuration values
@@ -189,15 +151,11 @@ w_reset = 8. ## ngc-learn default
coupling_factor = 0.2 ## ngc-learn default
```
-to obtain a voltage/recovery dynamics plot like so (if you also modify the
-plot title of the plotting code accordingly):
+to obtain a voltage/recovery dynamics plot like so (if you also modify the plot title of the plotting code accordingly):
-Three other well-known classes of neural behaviors are possible to easily simulate
-under the following hyper-parameter configurations (which produce the array
-of three plots similar to those shown near the bottom of this lesson),
-by simplifying modifying hyper-parameters according to the following:
+Three other well-known classes of neural behaviors are possible to easily simulate under the following hyper-parameter configurations (which produce the array of three plots similar to those shown near the bottom of this lesson), by simplifying modifying hyper-parameters according to the following:
1. Chattering (CH) neurons:
```python
@@ -222,8 +180,7 @@ w_reset = 2.
coupling_factor = 0.25
```
-The above three hyper-parameter settings produce, from top-to-bottom, the
-plots shown below (from left-to-right):
+The above three hyper-parameter settings produce, from top-to-bottom, the plots shown below (from left-to-right):
```{eval-rst}
@@ -237,27 +194,6 @@ plots shown below (from left-to-right):
+-------------------------------------------------------+-------------------------------------------------------+--------------------------------------------------------+
```
-## Optional: Setting Up The Components with a JSON Configuration
-
-While you are not required to create a JSON configuration file for ngc-learn,
-to get rid of the warning that ngc-learn will throw at the start of your
-program's execution (indicating that you do not have a configuration set up yet),
-all you need to do is create a sub-directory for your JSON configuration
-inside of your project code's directory, i.e., `json_files/modules.json`.
-Inside the JSON file, you would write the following:
-
-```json
-[
- {"absolute_path": "ngclearn.components",
- "attributes": [
- {"name": "IzhikevichCell"}]
- },
- {"absolute_path": "ngcsimlib.operations",
- "attributes": [
- {"name": "overwrite"}]
- }
-]
-```
## References
diff --git a/docs/tutorials/neurocog/lif.md b/docs/tutorials/neurocog/lif.md
index 48485da9..82a08030 100755
--- a/docs/tutorials/neurocog/lif.md
+++ b/docs/tutorials/neurocog/lif.md
@@ -1,31 +1,15 @@
# Lecture 2B: The Leaky Integrate-and-Fire Cell
-The leaky integrate-and-fire (LIF) cell component in ngc-learn is a stepping
-stone towards working with more biophysical intricate cell components when crafting
-your neuronal circuit models. This
-[cell](ngclearn.components.neurons.spiking.LIFCell) is markedly different from the
-[simplified LIF](ngclearn.components.neurons.spiking.sLIFCell) in both its
-implemented dynamics as well as what modeling routines that it offers, including
-the fact that it does not offer implicit fixed lateral inhibition like the
-`SLIF` does (one would need to explicitly model the lateral inhibition as a
-separate population of `LIF` cells, as we do in the
-[Diehl and Cook model museum spiking network](../../museum/snn_dc.md)). Furthermore,
-using this neuronal cell is a useful transition to using the more complicated and
-biophysically more accurate neuronal models such as the
-[adaptive exponential integrator cell](ngclearn.components.neurons.spiking.adExCell)
-or the
-[Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell).
+The leaky integrate-and-fire (LIF) cell component in ngc-learn is a stepping stone towards working with more biophysical intricate cell components when crafting your neuronal circuit models. This [cell](ngclearn.components.neurons.spiking.LIFCell) is markedly different from the [simplified LIF](ngclearn.components.neurons.spiking.sLIFCell) in both its implemented dynamics as well as what modeling routines that it offers, including the fact that it does not offer implicit fixed lateral inhibition like the `SLIF` does (one would need to explicitly model the lateral inhibition as a separate population of `LIF` cells, as we do in the [Diehl and Cook model museum spiking network](../../museum/snn_dc.md)). Furthermore, using this neuronal cell is a useful transition to using the more complicated and biophysically more accurate neuronal models such as the [adaptive exponential integrator cell](ngclearn.components.neurons.spiking.adExCell) or the [Izhikevich cell](ngclearn.components.neurons.spiking.izhikevichCell).
## Instantiating the LIF Neuronal Cell
-To implement a single-component dynamical system made up of a single LIF
-cell, you would write code akin to the following:
+To implement a single-component dynamical system made up of a single LIF cell, you would write code akin to the following:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
from ngclearn.utils.viz.spike_plot import plot_spiking_neuron
@@ -42,35 +26,27 @@ tau_m = 100.
## create simple system with only one AdEx
with Context("Model") as model:
- cell = LIFCell("z0", n_units=1, tau_m=tau_m, resist_m=tau_m/dt, thr=V_thr,
- v_rest=V_rest, v_reset=-60., tau_theta=300., theta_plus=0.05,
- refract_time=2., key=subkeys[0])
+ cell = LIFCell(
+ "z0", n_units=1, tau_m=tau_m, resist_m=tau_m/dt, thr=V_thr, v_rest=V_rest, v_reset=-60., tau_theta=300.,
+ theta_plus=0.05, refract_time=2., key=subkeys[0]
+ )
## create and compile core simulation commands
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
## Simulating the LIF on Stepped Constant Electrical Current
-Given our single-LIF dynamical system above, let us write some code to use
-our `LIF` node and visualize the resultant spiking pattern super-imposed
-over its membrane (voltage) potential by feeding
-into it a step current, where the electrical current `j` starts at $0$ then
-switches to $0.3$ at $t = 10$ ms (much as we did for the `SLIF` component
-in the previous lesson). We craft the simulation portion of our code like so:
-
+Given our single-LIF dynamical system above, let us write some code to use our `LIF` node and visualize the resultant spiking pattern super-imposed over its membrane (voltage) potential by feeding into it a step current, where the electrical current `j` starts at $0$ then switches to $0.3$ at $t = 10$ ms (much as we did for the `SLIF` component in the previous lesson). We craft the simulation portion of our code like so:
```python
# create a synthetic electrical step current
@@ -80,14 +56,14 @@ curr_in = []
mem_rec = []
spk_rec = []
-model.reset()
+reset_process.run()
for ts in range(current.shape[1]):
j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts
- model.clamp(j_t)
- model.advance(t=ts*1., dt=dt)
+ clamp(j_t)
+ advance_process.run(t=ts*1., dt=dt)
## naively extract simple statistics at time ts and print them to I/O
- v = cell.v.value
- s = cell.s.value
+ v = cell.v.get()
+ s = cell.s.get()
curr_in.append(j_t)
mem_rec.append(v)
spk_rec.append(s)
@@ -95,43 +71,24 @@ for ts in range(current.shape[1]):
print()
```
-Then, we can plot the input current, the neuron's voltage `v`, and its output
-spikes as follows:
+Then, we can plot the input current, the neuron's voltage `v`, and its output spikes as follows:
```python
import numpy as np
curr_in = np.squeeze(np.asarray(curr_in))
mem_rec = np.squeeze(np.asarray(mem_rec))
spk_rec = np.squeeze(np.asarray(spk_rec))
-plot_spiking_neuron(curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=V_rest-1.,
- max_mem_val=V_thr+2., spike_loc=V_thr, spike_spr=0.5, title="LIF-Node: Constant Electrical Input", fname="lif_plot.jpg")
+plot_spiking_neuron(
+ curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=V_rest-1., max_mem_val=V_thr+2., spike_loc=V_thr,
+ spike_spr=0.5, title="LIF-Node: Constant Electrical Input", fname="lif_plot.jpg"
+)
```
which should produce the following plot (saved to disk):
-As we might observe, the LIF operates very differently from the SLIF, notably
-that its dynamics live in the different space of values (one aspect of the
-SLIF is that its dynamics are effectively normalized/configured to live
-a non-negative membrane potential number space), specifically values that
-are a bit better aligned with those observed in experimental neuroscience.
-While more biophysically more accurate, the `LIF` typically involves consideration
-of multiple additional hyper-parameters/simulation coefficients, including
-the resting membrane potential value `v_rest` and the reset membrane value
-`v_reset` (upon occurrence of a spike/emitted action potential); the `SLIF`,
-in contrast, assumed a `v_reset = v_reset = 0.`. Note that the `LIF`'s
-`tau_theta` and `theta_plus` coefficients govern its particular adaptive threshold,
-which is a particular increment variable (one per cell in the `LIF` component)
-that gets adjusted according to its own dynamics and added to the fixed constant
-threshold `thr`, i.e., the threshold that a cell's membrane potential must
-exceed for a spike to be emitted.
-
-The `LIF` cell component is particularly useful when more flexibility is required/
-desired in setting up neuronal dynamics, particularly when attempting to match
-various mathematical models that have been proposed in computational neuroscience.
-This benefit comes at the greater cost of additional tuning and experimental planning,
-whereas the `SLIF` can be a useful go-to initial spiking cell for building certain spiking
-models such as those proposed in machine intelligence research (we demonstrate
-one such use-case in the context of the
-[feedback alignment-trained spiking network](../../museum/snn_bfa.md) that we offer in the model museum).
+As we might observe, the LIF operates very differently from the SLIF, notably that its dynamics live in the different space of values (one aspect of the SLIF is that its dynamics are effectively normalized/configured to live a non-negative membrane potential number space), specifically values that are a bit better aligned with those observed in experimental neuroscience. While more biophysically more accurate, the `LIF` typically involves consideration of multiple additional hyper-parameters/simulation coefficients, including
+the resting membrane potential value `v_rest` and the reset membrane value `v_reset` (upon occurrence of a spike/emitted action potential); the `SLIF`, in contrast, assumed a `v_reset = v_reset = 0.`. Note that the `LIF`'s `tau_theta` and `theta_plus` coefficients govern its particular adaptive threshold, which is a particular increment variable (one per cell in the `LIF` component) that gets adjusted according to its own dynamics and added to the fixed constant threshold `thr`, i.e., the threshold that a cell's membrane potential must exceed for a spike to be emitted.
+
+The `LIF` cell component is particularly useful when more flexibility is required/desired in setting up neuronal dynamics, particularly when attempting to match various mathematical models that have been proposed in computational neuroscience. This benefit comes at the greater cost of additional tuning and experimental planning, whereas the `SLIF` can be a useful go-to initial spiking cell for building certain spiking models such as those proposed in machine intelligence research (we demonstrate one such use-case in the context of the [feedback alignment-trained spiking network](../../museum/snn_bfa.md) that we offer in the model museum).
diff --git a/docs/tutorials/neurocog/metrics.md b/docs/tutorials/neurocog/metrics.md
index aea77da6..e872e23c 100644
--- a/docs/tutorials/neurocog/metrics.md
+++ b/docs/tutorials/neurocog/metrics.md
@@ -1,26 +1,11 @@
# Metrics and Measurement Functions
-Inside of `ngclearn.utils.metric_utils`, ngc-learn offers metrics and measurement
-utility functions that can be quite useful when building neurocognitive models using
-ngc-learn's node-and-cables system for specific tasks. While this utilities
-sub-module will not always contain every possible function you might need,
-given that measurements are often dependent on the task the experimenter wants
-to conduct, there are several commonly-used ones drawn from machine intelligence
-and computational neuroscience that are (jit-i-fied) in-built to ngc-learn you
-can readily use.
-In this small lesson, we will briefly examine two examples of importing such
-functions and examine what they do.
+Inside of `ngclearn.utils.metric_utils`, ngc-learn offers metrics and measurement utility functions that can be quite useful when building neurocognitive models using ngc-learn's node-and-cables system for specific tasks. While this utilities sub-module will not always contain every possible function you might need, given that measurements are often dependent on the task the experimenter wants to conduct, there are several commonly-used ones drawn from machine intelligence and computational neuroscience that are (jit-i-fied) in-built to ngc-learn you can readily use.
+In this small lesson, we will briefly examine two examples of importing such functions and examine what they do.
## Measuring Task-Level Quantities
-For many tasks that you might be interested in, a useful measurement
-is the performance of the model in some supervised learning context. For example,
-you might want to measure a model's accuracy on a classification task. To do so,
-assuming we have some model outputs extracted from a model that you have constructed
-elsewhere -- say a matrix of scores `Y_scores` -- and a target set of predictions
-that you are testing against -- such as `Y_labels` (in one-hot binary encoded form )
--- then you can write some code to compute the accuracy, mean squared error (MSE),
-and categorical log likelihood (Cat-NLL), like so:
+For many tasks that you might be interested in, a useful measurement is the performance of the model in some supervised learning context. For example, you might want to measure a model's accuracy on a classification task. To do so, assuming we have some model outputs extracted from a model that you have constructed elsewhere -- say a matrix of scores `Y_scores` -- and a target set of predictions that you are testing against -- such as `Y_labels` (in one-hot binary encoded form ) -- then you can write some code to compute the accuracy, mean squared error (MSE), and categorical log likelihood (Cat-NLL), like so:
```python
from jax import numpy as jnp
@@ -55,24 +40,18 @@ and you should obtain the following in I/O like so:
> Cat-NLL = 4.003
```
-Notice that we imported the utility function `softmax` from
-`ngclearn.utils.model_utils` to convert our raw theoretical model scores to
-probability values so that using `measure_CatNLL()` makes sense (as this
-assumes the model scores are normalized probability values).
+Notice that we imported the utility function `softmax` from `ngclearn.utils.model_utils` to convert our raw theoretical model scores to
+probability values so that using `measure_CatNLL()` makes sense (as this assumes the model scores are normalized probability values).
## Measuring Some Model Statistics
-In some cases, you might be interested in measuring certain statistics
-related to aspects of a model that you construct. For example, you might have
-collected a (binary) spike train produced by one of the internal neuronal layers
-of your ngc-learn-simulated spiking neural network and want to compute the
-firing rates and Fano factors associated with each neuron. Doing so with
-ngc-learn utility functions would entail writing something like:
+In some cases, you might be interested in measuring certain statistics related to properties of a model that you construct. For example, you might have collected a (binary) spike train produced by one of the internal neuronal layers of your ngc-learn-simulated spiking neural network and want to compute the firing rates and Fano factors associated with each neuron. Doing so with ngc-learn utility functions would entail writing something like:
```python
from jax import numpy as jnp
from ngclearn.utils.metric_utils import measure_fanoFactor, measure_firingRate
+## let's create a fake synthetic spike train for 3 neurons (one per column)
spikes = jnp.asarray([[0., 0., 0.],
[0., 0., 1.],
[0., 1., 0.],
@@ -92,6 +71,7 @@ spikes = jnp.asarray([[0., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=jnp.float32)
+## measure the firing rates and Fano factors of the 3 neurons
fr = measure_firingRate(spikes, preserve_batch=True)
fano = measure_fanoFactor(spikes, preserve_batch=True)
@@ -106,8 +86,4 @@ which should result in the following to be printed to I/O:
> Fano Factor = [[0.8888888 0.77777773 0.55555546]]
```
-The Fano factor is a useful secondary statistic for characterizing the
-variable of a neuronal spike train -- as we see in the measurement above,
-the first and second neurons have a higher Fano factor (given they are
-more irregular in their spiking patterns) whereas the third neuron is far more
-regular in its spiking pattern and thus has a lower Fano factor.
+The Fano factor is a useful secondary statistic for characterizing the variability of a neuronal spike train -- as we see in the measurement above, the first and second neurons have a higher Fano factor (given they are more irregular in their spiking patterns) whereas the third neuron is far more regular in its spiking pattern and thus has a lower Fano factor.
diff --git a/docs/tutorials/neurocog/mod_stdp.md b/docs/tutorials/neurocog/mod_stdp.md
index 3a76de37..f705f935 100755
--- a/docs/tutorials/neurocog/mod_stdp.md
+++ b/docs/tutorials/neurocog/mod_stdp.md
@@ -1,52 +1,28 @@
# Lecture 4D: Reward-Modulated Spike-Timing-Dependent Plasticity
-In this lesson, we will build on the notions of spike-timing-dependent
-plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct
-an important form of biological credit assignment in spiking neural networks
-known as reward-modulated STDP (sometimes abbreviated to R-STDP). Specifically,
-we will simulate and plot the underlying plasticity dynamics associated with
-this form of change in synaptic efficacy, specifically studying two in-built
-schemes of STDP: modulated STDP (MSTDP) and modulated STDP with eligibility
-traces (MSTDP-ET).
+In this lesson, we will build on the notions of spike-timing-dependent plasticity (STDP), covered [earlier here](../neurocog/stdp.md), to construct an important form of biological credit assignment in spiking neural networks known as reward-modulated STDP (sometimes abbreviated to R-STDP). Specifically, we will simulate and plot the underlying plasticity dynamics associated with this form of change in synaptic efficacy, specifically studying two in-built schemes of STDP: modulated STDP (MSTDP) and modulated STDP with eligibility traces (MSTDP-ET).
## Probing Modulated STDP and Eligibility Traces
-Go ahead and make a new folder for this study and create a Python script,
-i.e., `run_reward_stdp.py`, to write your code for this part of the tutorial.
+Go ahead and make a new folder for this study and create a Python script, i.e., `run_reward_stdp.py`, to write your code for this part of the tutorial.
-Much as we did in the STDP lesson, we will build a 3-component dynamical system
--- two spiking neurons (represented by traces) that are connected with a single
-synapse -- but, this time, we will simulate three variations of this system in
-parallel. Each one of these variants will evolve its single synapse according
-to a different condition of STDP:
+Much as we did in the STDP lesson, we will build a 3-component dynamical system -- two spiking neurons (represented by traces) that are connected with a single synapse -- but, this time, we will simulate three variations of this system in parallel. Each one of these variants will evolve its single synapse according to a different condition of STDP:
1. the first one will change its synapse's strength in accordance with trace-based STDP;
2. the second one will change its synapse's strength via modulated STDP (MSTDP); and,
3. the third and final one will change its synapse's strength via modulated STDP
equipped with an eligibility trace (MSTDP-ET).
-The second and third model above will make use of ngc-learn's in-built
-[MSTDPETSynapse](ngclearn.components.synapses.modulated.MSTDPETSynapse), which
-is an STDP cable component that sub-classes the `TraceSTDPSynapse` cable component
-and will offer the additional machinery we will need to carry out modulated
-forms of STDP.
-All three of these variant STDP-evolved systems will make use of the same set
-of variable traces (the `VarTrace` object introduced in the previous STDP lesson),
-and we will control the spike trains by providing a specific set of pre-synaptic
-spike times and a corresponding set of post-synaptic spike times ( both in
-milliseconds). Furthermore, we will insert a convenience cell in-built
-to ngc-learn called the `RewardErrorCell`, which is generally use to produce
-what is known in neuroscience literature as "reward prediction error" (RPE).
-
-Writing the above three parallel single synapse systems, including meta-parameters
-and the required compiled simulation and dynamic commands, can be done as follows:
+The second and third model above will make use of ngc-learn's in-built [MSTDPETSynapse](ngclearn.components.synapses.modulated.MSTDPETSynapse), which is an STDP cable component that sub-classes the `TraceSTDPSynapse` cable component and will offer the additional machinery we will need to carry out modulated forms of STDP.
+All three of these variant STDP-evolved systems will make use of the same set of variable traces (the `VarTrace` object introduced in the previous STDP lesson), and we will control the spike trains by providing a specific set of pre-synaptic spike times and a corresponding set of post-synaptic spike times (both in milliseconds). Furthermore, we will insert a convenience cell in-built to ngc-learn called the `RewardErrorCell`, which is generally use to produce what is known in neuroscience literature as "reward prediction error" (RPE).
+
+Writing the above three parallel single synapse systems, including meta-parameters and the required compiled simulation and dynamic commands, can be done as follows:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
-from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse,
- RewardErrorCell, VarTrace)
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.components import (TraceSTDPSynapse, MSTDPETSynapse, RewardErrorCell, VarTrace)
+from ngclearn.utils.distribution_generator import DistributionGenerator
## create seeding keys (JAX-style)
dkey = random.PRNGKey(231)
@@ -55,142 +31,95 @@ dkey, *subkeys = random.split(dkey, 2)
dt = 1. # ms # integration time constant
T_max = 200 ## number time steps to simulate
tau_pre = tau_post = 20. # ms
-tau_elg = 25.
+tau_elg = 25. # ms ## eligibility trace time constant
Aplus = Aminus = 1. ## in ngc-learn, Aplus/Aminus are magnitudes (signs are handled internally)
gamma = 0.2
gamma_0 = 0.2/tau_elg
with Context("Model") as model:
W_stdp = TraceSTDPSynapse( ## reward-STDP (RSTDP)
- "W1_stdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus,
- weight_init=dist.constant(value=0.2), key=subkeys[0])
+ "W1_stdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus,
+ weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0]
+ )
W_mstdp = MSTDPETSynapse( ## reward-STDP (RSTDP)
- "W1_rstdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus,
- tau_elg=0., weight_init=dist.constant(value=0.2), key=subkeys[0])
+ "W1_rstdp", shape=(1, 1), eta=gamma, A_plus=Aplus, A_minus=Aminus, tau_elg=0.,
+ weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0]
+ )
W_mstdpet = MSTDPETSynapse( ## reward-STDP w/ eligibility traces
- "W_mstdpet", shape=(1, 1), eta=gamma_0, A_plus=Aplus, A_minus=Aminus,
- tau_elg=tau_elg, weight_init=dist.constant(value=0.2), key=subkeys[0])
+ "W_mstdpet", shape=(1, 1), eta=gamma_0, A_plus=Aplus, A_minus=Aminus, tau_elg=tau_elg,
+ weight_init=DistributionGenerator.constant(value=0.2), key=subkeys[0]
+ )
## set up pre- and -post synaptic trace variables
tr0 = VarTrace("tr0", n_units=1, tau_tr=tau_pre, a_delta=Aplus)
tr1 = VarTrace("tr1", n_units=1, tau_tr=tau_post, a_delta=Aminus)
rpe = RewardErrorCell("r", n_units=1, alpha=0.)
- evolve_process = (JaxProcess()
+ evolve_process = (MethodProcess("evolve")
>> W_stdp.evolve
>> W_mstdp.evolve
>> W_mstdpet.evolve)
- model.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> tr0.advance_state
>> tr1.advance_state
>> rpe.advance_state
>> W_stdp.advance_state
>> W_mstdp.advance_state
>> W_mstdpet.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> W_stdp.reset
>> W_mstdp.reset
>> W_mstdpet.reset
>> rpe.reset
>> tr0.reset
- >> tr1.reset
- )
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- @Context.dynamicCommand
- def clamp_spikes(f_j, f_i):
- tr0.inputs.set(f_j)
- tr1.inputs.set(f_i)
-
- @Context.dynamicCommand
- def clamp_stdp_stats(f_j, f_i, trace_j, trace_i):
- W_stdp.preSpike.set(f_j)
- W_stdp.postSpike.set(f_i)
- W_stdp.preTrace.set(trace_j)
- W_stdp.postTrace.set(trace_i)
-
- @Context.dynamicCommand
- def clamp_mstdp_stats(f_j, f_i, trace_j, trace_i, reward):
- W_mstdp.preSpike.set(f_j)
- W_mstdp.postSpike.set(f_i)
- W_mstdp.preTrace.set(trace_j)
- W_mstdp.postTrace.set(trace_i)
- W_mstdp.modulator.set(reward)
-
- @Context.dynamicCommand
- def clamp_mstdpet_stats(f_j, f_i, trace_j, trace_i, reward):
- W_mstdpet.preSpike.set(f_j)
- W_mstdpet.postSpike.set(f_i)
- W_mstdpet.preTrace.set(trace_j)
- W_mstdpet.postTrace.set(trace_i)
- W_mstdpet.modulator.set(reward)
+ >> tr1.reset)
+
+## set up some utility functions for the model context
+def clamp_spikes(f_j, f_i):
+ tr0.inputs.set(f_j)
+ tr1.inputs.set(f_i)
+
+def clamp_stdp_stats(f_j, f_i, trace_j, trace_i):
+ W_stdp.preSpike.set(f_j)
+ W_stdp.postSpike.set(f_i)
+ W_stdp.preTrace.set(trace_j)
+ W_stdp.postTrace.set(trace_i)
+
+def clamp_mstdp_stats(f_j, f_i, trace_j, trace_i, reward):
+ W_mstdp.preSpike.set(f_j)
+ W_mstdp.postSpike.set(f_i)
+ W_mstdp.preTrace.set(trace_j)
+ W_mstdp.postTrace.set(trace_i)
+ W_mstdp.modulator.set(reward)
+
+def clamp_mstdpet_stats(f_j, f_i, trace_j, trace_i, reward):
+ W_mstdpet.preSpike.set(f_j)
+ W_mstdpet.postSpike.set(f_i)
+ W_mstdpet.preTrace.set(trace_j)
+ W_mstdpet.postTrace.set(trace_i)
+ W_mstdpet.modulator.set(reward)
```
-Given our three parallel models constructed above, we ready to write some code
-to use our simulation setup. Before we do, however, notice that we have
-configured the simulation
-time `T_max` to be `200` milliseconds (ms), the integration time constant
-`dt` to be `1` ms, and the time constant for both our pre-synaptic and
-post-synaptic spiking neuron traces to be `20` ms. Two final points to notice
-about the models we have constructed above are:
-1. the RPE cell `rpe` has been configured to only output a given clamped
- reward signal via `alpha = 0`; this, according to the internal design of the
- `RewardErrorCell`, just effectively shuts off the moving average prediciton
- of reward signals that the cell encounters over time (in most practical
- cases, you will not want this set to zero as we are often interested in
- the difference between a reward prediction and a target reward value);
-2. for the third model, the MSTDP-ET model, we have configured an eligibility
- trace to be used by setting the eligibility time constant `tau_elg` to be
- be non-zero, i.e., it was set to `25` ms. An eligibility trace, in the context
- STDP/Hebbian synaptic updates, simply another set of dynamics (i.e., another
- ordinary differential equation) that we maintain as STDP synaptic updates
- are computed.
-
-With respect to the second point made about eligibility traces, formally, we note
-that under MSTDP-ET, instead of computing a trace-based STDP update at
-each and every single time step `t` and updating the synapses immediately,
-we first aggregate each STDP into another variable (the eligibility) according
-to the following ODE:
+Given our three parallel models constructed above, we ready to write some code to use our simulation setup. Before we do, however, notice that we have configured the simulation time `T_max` to be `200` milliseconds (ms), the integration time constant `dt` to be `1` ms, and the time constant for both our pre-synaptic and post-synaptic spiking neuron traces to be `20` ms. Two final points to notice about the models that we have constructed above are:
+1. the RPE cell `rpe` has been configured to only output a given clamped reward signal via `alpha = 0`; this, according to the internal design of the `RewardErrorCell`, just effectively shuts off the moving average prediction of reward signals that the cell encounters over time (in most practical cases, you will not want this set to zero as we are often interested in the difference between a reward prediction and a target reward value);
+2. for the third model, the MSTDP-ET model, we have configured an eligibility trace to be used by setting the eligibility time constant `tau_elg` to be non-zero, i.e., it was set to `25` ms. An eligibility trace, in the context STDP/Hebbian synaptic updates, simply another set of dynamics (i.e., another ordinary differential equation) that we maintain as STDP synaptic updates are computed.
+
+With respect to the second point made above about eligibility traces, formally, we note that: under MSTDP-ET, instead of computing a trace-based STDP update at each and every single time step `t` and updating the synapses immediately, we first aggregate each STDP update into another variable (the "eligibility") according to the following ODE:
$$
-\tau_{elg} \frac{\partial \mathbf{E}_{ij}}{\partial t} = -\mathbf{E}_{ij} +
-\beta \frac{\partial \mathbf{W}_{ij}}{\partial t}
+\tau_{elg} \frac{\partial \mathbf{E}_{ij}}{\partial t} = -\mathbf{E}_{ij} + \beta \frac{\partial \mathbf{W}_{ij}}{\partial t}
$$
-where $i$ denotes the index of the post-synpatic spiking neuron (which emits
-a spike we label as $f_i$) and $j$ denotes the index of the pre-synaptic
-spiking neuron (which emits a spike we label as $f_j$), $\mathbf{W}_{ij}$ is
-the synapse that connects neuron $j$ to $i$, $\mathbf{E}_{ij}$ is the eligibility
-trace we maintain for synapse $\mathbf{W}_{ij}$, and $\beta$ is control factor
-(typically set to one) for scaling the magnitude of the STDP update's effect.
-Finally, note that $\frac{\partial \mathbf{W}_{ij}}{\partial t}$ is the actual
-synaptic update produced by our trace-based STDP at time $t$.
-
-Given the idea of the eligibility trace above, and how our RPE cell has been
-configured, we can write down simply what kind of synaptic update ]
-$\Delta \mathbf{W}_{ij}(t)$ that each of
-our three dynamical systems will yield once we simulate them.
-1. Trace-based STDP will produce an update to the synapse according to the combined
- products of a paired pre-synaptic trace and post-synaptic spike (long-term
- potentiation) and a paired pre-synaptic spike and post-synaptic trace
- (long-term depression), i.e,
- $\Delta \mathbf{W}_{ij}(t) = \gamma \frac{\partial \mathbf{W}_{ij}}{\partial t}$;
-2. MSTDP -- the second/middle model with the `MSTDPETSynapse` with its
- `tau_elg = 0` -- will produce a modulated update to the synapse at each
- time step as follows:
- $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \frac{\partial \mathbf{W}_{ij}}{\partial t}$;
-3. MSTDP-ET -- the third and final model that uses an eligiblity trace -- will
- produce a modulated update to the synapse at each time step via:
- $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \mathbf{E}_{ij}(t)$.
-Note that $r(t)$ is the reward administered at each time step `t` and $\gamma$
-is just an additional dampening factor to control how much of the STDP update
-is applied at each time step (i.e., a global learning rate).
-
-Armed with our knowledge of the plasticity dynamics above, we next write down
-what we want our model simulations to do:
+where $i$ denotes the index of the post-synaptic spiking neuron (which emits a spike we label as $f_i$) and $j$ denotes the index of the pre-synaptic spiking neuron (which emits a spike we label as $f_j$), $\mathbf{W}_{ij}$ is the synapse that connects neuron $j$ to $i$, $\mathbf{E}_{ij}$ is the eligibility trace we maintain for synapse $\mathbf{W}_{ij}$, and $\beta$ is control factor (typically set to one) for scaling the magnitude of the STDP update's effect.
+Finally, note that $\frac{\partial \mathbf{W}_{ij}}{\partial t}$ is the actual synaptic update produced by our trace-based STDP at time $t$.
+
+Given the idea of the eligibility trace explained above, as well as how our RPE cell has been configured, we can write down simply what kind of synaptic update $\Delta \mathbf{W}_{ij}(t)$ that each of our three dynamical systems will yield once we simulate them.
+1. Trace-based STDP will produce an update to the synapse according to the combined products of a paired pre-synaptic trace and post-synaptic spike (long-term potentiation) and a paired pre-synaptic spike and post-synaptic trace (long-term depression), i.e, $\Delta \mathbf{W}_{ij}(t) = \gamma \frac{\partial \mathbf{W}_{ij}}{\partial t}$;
+2. MSTDP -- the second/middle model with the `MSTDPETSynapse` with its `tau_elg = 0` -- will produce a modulated update to the synapse at each time step as follows: $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \frac{\partial \mathbf{W}_{ij}}{\partial t}$;
+3. MSTDP-ET -- the third and final model that uses an eligibility trace -- will produce a modulated update to the synapse at each time step via: $\Delta \mathbf{W}_{ij}(t) = \gamma r(t) \mathbf{E}_{ij}(t)$. Note that $r(t)$ is the reward administered at each time step `t` and $\gamma$ is just an additional dampening factor to control how much of the STDP update is applied at each time step (i.e., a global learning rate).
+
+Armed with our knowledge of the plasticity dynamics above, we next write down what we want our model simulations to do:
```python
# synthetic spike times of pre and post synaptic neurons
@@ -208,7 +137,7 @@ elg_vals = []
W_stdp_vals = []
W_mstdp_vals = []
W_mstdpet_vals = []
-model.reset()
+reset_process.run()
for i in range(T_max):
f_j = jnp.zeros((1, 1)) ## pre-syn spike
if (i * dt) in spike_times_pre:
@@ -222,33 +151,29 @@ for i in range(T_max):
reward = -reward
rpe.reward.set(reward)
- model.clamp_spikes(f_j, f_i) ## clamp pre/post spikes to traces
- model.advance(t=i * dt, dt=dt)
+ clamp_spikes(f_j, f_i) ## clamp pre/post spikes to traces
+ advance_process.run(t=i * dt, dt=dt)
- model.clamp_stdp_stats(f_j, f_i, tr0.trace.value, tr1.trace.value)
- model.clamp_mstdp_stats(
- f_j, f_i, tr0.trace.value, tr1.trace.value, rpe.reward.value)
- model.clamp_mstdpet_stats(
- f_j, f_i, tr0.trace.value, tr1.trace.value, rpe.reward.value)
- model.evolve(t=i * dt, dt=dt)
+ clamp_stdp_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get())
+ clamp_mstdp_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get(), rpe.reward.get())
+ clamp_mstdpet_stats(f_j, f_i, tr0.trace.get(), tr1.trace.get(), rpe.reward.get())
+ evolve_process.run(t=i * dt, dt=dt)
## record statistics for plotting
pre_spikes.append(jnp.squeeze(f_j))
post_spikes.append(jnp.squeeze(f_i))
r_vals.append(jnp.squeeze(reward))
- tr0_vals.append(jnp.squeeze(tr0.trace.value))
- tr1_vals.append(jnp.squeeze(-tr1.trace.value))
- dWstdp_vals.append(jnp.squeeze(W_stdp.dWeights.value))
- elg_vals.append(jnp.squeeze(W_mstdpet.eligibility.value))
- W_stdp_vals.append(jnp.squeeze(W_stdp.weights.value))
- W_mstdp_vals.append(jnp.squeeze(W_mstdp.weights.value))
- W_mstdpet_vals.append(jnp.squeeze(W_mstdpet.weights.value))
+ tr0_vals.append(jnp.squeeze(tr0.trace.get()))
+ tr1_vals.append(jnp.squeeze(-tr1.trace.get()))
+ dWstdp_vals.append(jnp.squeeze(W_stdp.dWeights.get()))
+ elg_vals.append(jnp.squeeze(W_mstdpet.eligibility.get()))
+ W_stdp_vals.append(jnp.squeeze(W_stdp.weights.get()))
+ W_mstdp_vals.append(jnp.squeeze(W_mstdp.weights.get()))
+ W_mstdpet_vals.append(jnp.squeeze(W_mstdpet.weights.get()))
t_vals.append(i * dt)
```
-which will run all three models simultaneously for `200` simulated milliseconds
-and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode
-(reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so:
+which will run all three of our models simultaneously for `200` simulated milliseconds and collect statistics of interest. We may then finally make several plots of what happens under each STDP mode (reproducing some key results in [1]. First, we will plot the resulting synaptic magnitude over time, like so:
```python
import matplotlib.pyplot as plt
@@ -282,23 +207,13 @@ ax3.grid()
fig1.savefig("modstdp_syn_dynamics.jpg")
```
-which should produce a plot like the one below:
+which should produce a plot like the one below:
-Notice, first, that the middle plot for MSTDP (the red curve in the middle plot)
-essentially mimics the update produced by STDP for the first `100` ms and then
-flips (becomes a mirror image) of the STDP trajectory; this is due to the fact
-that, as you can see in the code you wrote earlier for the spike train simulation,
-the reward signal changes sign after `100` ms and since MSTDP is effectively
-the product of the reward and an STDP synaptic update the sign of the synaptic
-change will flip as well. Finally, notice that the MSTDP-ET yields a
-smoothened change in synaptic efficacy (the blue curve in the bottom plot);
-this is due to the eligibility trace leakily integrating the STDP updates
-over time (and ultimately multiplying the trace by the reward at time `t`).
+Notice, first, that the middle plot for MSTDP (the red curve in the middle plot) essentially mimics the update produced by STDP for the first `100` ms and then flips (becomes a mirror image) of the STDP trajectory; this is due to the fact that, as you can see in the code you wrote earlier for the spike train simulation, the reward signal changes sign after `100` ms and since MSTDP is effectively the product of the reward and an STDP synaptic update the sign of the synaptic change will flip as well. Finally, notice that the MSTDP-ET yields a smoothened change in synaptic efficacy (the blue curve in the bottom plot); this is due to the eligibility trace leakily integrating the STDP updates over time (and ultimately multiplying the trace by the reward at time `t`).
-We will then plot the dynamics of important compartments the drive the operation
-of the various STDP models with the following code block:
+We will then plot the dynamics of important compartments the drive the operation of the various STDP models with the following code block:
```python
## create STDP synaptic dynamics plots (figure 1)
@@ -352,27 +267,8 @@ which should yield the following component dynamics plot:
-This plot usefully breaks down the plasticity dynamics of all three STDP models
-into the core component dynamics. The top two plots illustrate the emissions
-of spikes over time (the pre-synaptic spike plot followed by the post-synaptic
-spike plot) while underneath these -- the third plot -- is a visualization
-of these spikes respective traces multiplied by their corresponding sign
-that is used in STDP, i.e., the blue pre-synaptic curve is positive as it
-represents synaptic potentiation over time (pre occurs before post) while the
-orange post-synaptic curve is negative as it represents synaptic depression
-over time (post occurs after pre). The teal curve in the fourth plot
-illustrates what kind of updates that typical trace-based STDP would produce,
-in the absence of a reward signal, whereas the yellow-ish/goldenrod curve
-underneath shows the eligibilty trace that smoothens out the more pulse-like
-adjustments that STDP yields. In the very bottom plot, we see the red piecewise
-function that characterizes our reward signal -- for the first `100` ms it
-is simply one whereas for the last `200` ms it is negative one. In general,
-one will not likely have access to a clean dense reward in most control
-problems, i.e., the reward signal is typically sparse, which will mean that
-modulated STDP updates will only occur when the signal is non-zero; this is
-the advantage that MSTDP-ET offers over MSTDP as the synaptic change
-dynamics persist (yet decay) in between reward presentation times and thus
-MSTDP-ET will be more effective in cases when the reward signal is delayed.
+This plot usefully breaks down the plasticity dynamics of all three STDP models into the core component dynamics. The top two plots illustrate the emissions of spikes over time (the pre-synaptic spike plot followed by the post-synaptic spike plot) while underneath these -- the third plot -- is a visualization of these spikes respective traces multiplied by their corresponding sign
+that is used in STDP, i.e., the blue pre-synaptic curve is positive as it represents synaptic potentiation over time (pre occurs before post) while the orange post-synaptic curve is negative as it represents synaptic depression over time (post occurs after pre). The teal curve in the fourth plot illustrates what kind of updates that typical trace-based STDP would produce, in the absence of a reward signal, whereas the yellow-ish/goldenrod curve underneath shows the eligibility trace that smoothens out the more pulse-like adjustments that STDP yields. In the very bottom plot, we see the red piecewise function that characterizes our reward signal -- for the first `100` ms it is simply one whereas for the last `200` ms it is negative one. In general, one will not likely have access to a clean dense reward in most control problems, i.e., the reward signal is typically sparse, which will mean that modulated STDP updates will only occur when the signal is non-zero; this is the advantage that MSTDP-ET offers over MSTDP as the synaptic change dynamics persist (yet decay) in between reward presentation times and, thus, MSTDP-ET will be more effective in cases when the reward signal is delayed.
## References
diff --git a/docs/tutorials/neurocog/plotting.md b/docs/tutorials/neurocog/plotting.md
index b105b06f..f48845b2 100644
--- a/docs/tutorials/neurocog/plotting.md
+++ b/docs/tutorials/neurocog/plotting.md
@@ -1,24 +1,12 @@
# Plotting and Visualization
-While writing one's own custom task-specific matplotlib visualization code
-might be needed for specific experimental setups, there are several useful tools
-already in-built to ngc-learn, organized under the package sub-directory
-`ngclearn.utils.viz`, including utilities for generating raster plots and
-synaptic receptive field views (useful for biophysical models such as spiking
-neural networks) as well as t-SNE plots of model latent codes. While the other
-lesson/tutorials demonstrate some of these useful routines (e.g., raster plots
-for spiking neuronal cells), in this small lesson, we will demonstrate how to
-produce a t-SNE plot using ngc-learn's in-built tool.
+While writing one's own custom task-specific matplotlib visualization code might be needed for specific experimental setups, there are several useful tools already in-built to ngc-learn, organized under the package sub-directory `ngclearn.utils.viz`, including utilities for generating raster plots and synaptic receptive field views (useful for biophysical models such as spiking neural networks) as well as t-SNE plots of model latent codes. While the other lesson/tutorials demonstrate some of these useful routines (e.g., raster plots for spiking neuronal cells), in this small lesson, we will demonstrate how to produce a t-SNE plot using ngc-learn's in-built tool.
## Generating a t-SNE Plot
-Let's say you have a labeled five-dimensional (5D) dataset -- which we will
-synthesize artificially in this lesson from an "unobserved" trio of multivariate
-Gaussians -- and wanted to visualize these "model outputs" and their
-corresponding labels in 2D via ngc-learn's in-built t-SNE.
+Let's say you have a labeled five-dimensional (5D) dataset -- which we will artificially synthesize in this lesson from an "unobserved" trio of multivariate Gaussians -- and that you wanted to visualize these "model outputs" and their corresponding labels in 2D via ngc-learn's in-built t-SNE.
-The following bit of Python code will do this for you (including the artificial
-data generator):
+The following bit of Python code will do this for you (including setting up the data generator):
```python
from jax import numpy as jnp, random
@@ -26,7 +14,7 @@ from ngclearn.utils.viz.dim_reduce import extract_tsne_latents, plot_latents
dkey = random.PRNGKey(1234)
-def gen_data(dkey, N): ## artificial data generator (or proxy model)
+def gen_data(dkey, N): ## data generator (or proxy stochastic data generating process)
mu1 = jnp.asarray([[2.1, 3.2, 0.6, -4., -2.]])
cov1 = jnp.eye(5) * 0.78
mu2 = jnp.asarray([[-1.8, 0.2, -0.1, 1.99, 1.56]])
@@ -59,6 +47,4 @@ which should produce a plot, i.e., `codes.jpg`, similar to the one below:
-In this example scenario, we see that we can successfully map the 5D model output
-data to a plottable 2D space, facilitating some level of downstream qualitative
-interpretation of the model.
+In this example scenario, we see that we can successfully map the 5D model output data to a plottable 2D space, facilitating some level of downstream qualitative interpretation of the model.
diff --git a/docs/tutorials/neurocog/rate_cell.md b/docs/tutorials/neurocog/rate_cell.md
index 56afb789..f6554116 100644
--- a/docs/tutorials/neurocog/rate_cell.md
+++ b/docs/tutorials/neurocog/rate_cell.md
@@ -1,7 +1,6 @@
# Lecture 3A: The Rate Cell Model
-Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical
-current differential equations in spiking networks.
+Graded neurons are one of the main classes/collections of cell components in ngc-learn. These specifically offer cell models that operate under real-valued dynamics -- in other words, they do not spike or use discrete pulse-like values in their operation. These are useful for building biophysical systems that evolve under continuous, time-varying dynamics, e.g., continuous-time recurrent neural networks, various kinds of predictive coding circuit models, as well as for continuous components in discrete systems, e.g. electrical current differential equations in spiking networks.
In this tutorial, we will study one of ngc-learn's workhorse in-built graded cell components, the rate cell ([RateCell](ngclearn.components.neurons.graded.rateCell)).
@@ -9,15 +8,12 @@ In this tutorial, we will study one of ngc-learn's workhorse in-built graded cel
### Instantiating the Rate Cell
-Let's go ahead and set up the controller for this lesson's simulation,
-where we will a dynamical system with only a single component,
-specifically the rate-cell (RateCell). Let's start with the file's header
-(or import statements):
+Let's go ahead and set up the controller for this lesson's simulation, where we will a dynamical system with only a single component, specifically the rate-cell (RateCell). Let's start with the file's header (or import statements):
```python
from jax import numpy as jnp, random, jit
-from ngclearn.utils import JaxProcess
-from ngcsimlib.context import Context
+
+from ngclearn import Context, MethodProcess
## import model-specific elements
from ngclearn.components.neurons.graded.rateCell import RateCell
```
@@ -36,91 +32,67 @@ gamma = 1.
with Context("Model") as model: ## model/simulation definition
## instantiate components (like cells)
- cell = RateCell("z0", n_units=1, tau_m=tau_m, act_fx=act_fx,
- prior=("gaussian", gamma), integration_type="euler", key=subkeys[0])
+ cell = RateCell(
+ "z0", n_units=1, tau_m=tau_m, act_fx=act_fx, prior=("gaussian", gamma), integration_type="euler",
+ key=subkeys[0]
+ )
## instantiate desired core commands that drive the simulation
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
-
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## instantiate some non-jitted dynamic utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## instantiate utility commands
+def clamp(x):
+ cell.j.set(x)
```
-A notable argument to the rate-cell, beyond some of its differential equation
-constants (`tau_m` and `gamma`), is its activation function choice (default is
-the `identity`), which we have chosen to be a discrete pulse emitting function
-known as the `unit_threshold` (which outputs a value of one for any input that
-exceeds the threshold of one and zero for anything else).
+A notable argument to the rate-cell, beyond some of its differential equation constants (`tau_m` and `gamma`), is its activation function choice (default is the `identity`), which we have chosen to be a discrete pulse emitting function known as the `unit_threshold` (which outputs a value of one for any input that exceeds the threshold of one and zero for anything else).
-Mathematically, under the hood, a rate-cell evolves according to the
-ordinary differential equation (ODE):
+Mathematically, under the hood, a rate-cell evolves according to the ordinary differential equation (ODE):
$$
\tau_m \frac{\partial \mathbf{z}}{\partial t} =
-\gamma \text{prior}\big(\mathbf{z}\big) + (\mathbf{x} + \mathbf{x}_{td})
$$
-where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default
-value is zero) is an optional additional input pressure signal (`td` stands for "top-down",
-its name motivated by predictive coding literature).
-A good way to understand this equation is in the context of two examples:
-1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the
-total electrical input into the cell from multiple injections produced
-by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$
-is set to `gaussian` ($\gamma = 1$), yielding the equation
-$\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for
-a simple model of synaptic conductance, and
-2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections
-(or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$
-is set to be the sum of (top-down) pressures produced by an "upper" layer/group
-such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In
-this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one
-of any kind of kurtotic distribution to induce a soft form of sparsity in
-the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
+where $\mathbf{x}$ is external input signal and $\mathbf{x}_{td}$ (default value is zero) is an optional additional input pressure signal (`td` stands for "top-down", its name motivated by predictive coding literature).
+A good way to understand this equation is in the context of two examples:
+1. in a biophysically more realistic spiking network, $\mathbf{x}$ is the total electrical input into the cell from multiple injections produced by transmission across synapses ($\mathbf{x}_{td} = 0$)) and the $\text{prior}$ is set to `gaussian` ($\gamma = 1$), yielding the equation $\tau_m \frac{\partial \mathbf{z}}{\partial t} = -\mathbf{z} + \mathbf{x}$ for a simple model of synaptic conductance, and
+2. in a predictive coding circuit, $\mathbf{x}$ is the sum of input projections (or messages) passed from a "lower" layer/group of neurons while $\mathbf{x}_{td}$ is set to be the sum of (top-down) pressures produced by an "upper" layer/group such as the value of a pair of nearby error neurons multiplied by $-1$.[^1] In this example, $0 \leq \gamma \leq 1$ and $\text{prior}$ could be set to one of any kind of kurtotic distribution to induce a soft form of sparsity in the dynamics, e.g., such as "cauchy" for the Cauchy distribution.
### Simulating a Rate Cell
-Given our single rate-cell dynamical system above, let us write some code to use
-our `Rate` node and visualize its dynamics by feeding
-into it a pulse current (a piecewise input function that is an alternating
-sequence of intervals of where nothing is input and others where a non-zero
-value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically,
-we can plot the input current, the neuron's linear rate activity `z` and its
-nonlinear activity `phi(z)` as follows:
+Given our single rate-cell dynamical system above, let us write some code to use our `Rate` node and visualize its dynamics by feeding into it a pulse current (a piecewise input function that is an alternating sequence of intervals of where nothing is input and others where a non-zero value is input) for a small period of time (`dt * T = 1 * 210` ms). Specifically, we can plot the input current, the neuron's linear rate activity `z` and its nonlinear activity `phi(z)` as follows:
```python
# create a synthetic electrical pulse current
-current = jnp.concatenate((jnp.zeros((1,10)),
- jnp.ones((1,50)) * 1.006,
- jnp.zeros((1,50)),
- jnp.ones((1,50)) * 1.006,
- jnp.zeros((1,50))), axis=1)
+current = jnp.concatenate(
+ (jnp.zeros((1,10)),
+ jnp.ones((1,50)) * 1.006,
+ jnp.zeros((1,50)),
+ jnp.ones((1,50)) * 1.006,
+ jnp.zeros((1,50))), axis=1
+)
lin_out = []
nonlin_out = []
t_values = []
-model.reset()
+reset_process.run()
t = 0.
for ts in range(current.shape[1]):
j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts
- model.clamp(j_t)
- model.advance(t=ts*1., dt=dt)
+ clamp(j_t)
+ advance_process.run(t=ts*1., dt=dt)
t_values.append(t)
- t += dt
+ t += dt ## advance time forward by dt milliseconds
## naively extract simple statistics at time ts and print them to I/O
- linear_z = cell.z.value
- nonlinear_z = cell.zF.value
+ linear_z = cell.z.get()
+ nonlinear_z = cell.zF.get()
lin_out.append(linear_z)
nonlin_out.append(nonlinear_z)
print("\r {}: s {} ; v {}".format(ts, linear_z, nonlinear_z), end="")
@@ -148,10 +120,11 @@ ax.grid()
fig.savefig("rate_cell_integration.jpg")
```
-which should yield a dynamics plot similar to the one below:
+which should yield a dynamics plot similar to the one below:
+
[^1]: [Error neurons](ngclearn.components.neurons.graded.gaussianErrorCell)
produce this kind of "top-down" value, which is technically the first derivative
diff --git a/docs/tutorials/neurocog/short_term_plasticity.md b/docs/tutorials/neurocog/short_term_plasticity.md
index b225f3c5..c669c5bb 100755
--- a/docs/tutorials/neurocog/short_term_plasticity.md
+++ b/docs/tutorials/neurocog/short_term_plasticity.md
@@ -1,69 +1,28 @@
# Lecture 4E: Short-Term Plasticity
-In this lesson, we will study how short-term plasticity (STP) [1] dynamics
--- where synaptic efficacy is cast in terms of the history of presynaptic activity --
-using ngc-learn's in-built `STPDenseSynapse`.
-Specifically, we will study how a dynamic synapse may be constructed and
-examine what short-term depression (STD) and short-term facilitation
-(STF) dominated configurations of an STP synapse look like.
+In this lesson, we will study how short-term plasticity (STP) [1] dynamics -- where synaptic efficacy is cast in terms of the history of presynaptic activity -- using ngc-learn's in-built `STPDenseSynapse`. Specifically, we will study how a dynamic synapse may be constructed and examine what short-term depression (STD) and short-term facilitation (STF) dominated configurations of an STP synapse look like.
## Probing Short-Term Plasticity
-Go ahead and make a new folder for this study and create a Python script,
-i.e., `run_shortterm_plasticity.py`, to write your code for this part of the
-tutorial.
+Go ahead and make a new folder for this study and create a Python script, i.e., `run_shortterm_plasticity.py`, to write your code for this part of the tutorial.
-We will write a 3-component dynamical system that connects a Poisson input
-encoding cell to a leaky integrate-and-fire (LIF) cell via a single dynamic
-synapse that evolves according to STP. We will first write our
-simulation of this dynamic synapse from the perspective of STF-dominated
-dynamics, plotting out the results under two different Poisson spike trains
-with different spiking frequencies. Then, we will modify our simulation
-to emulate dynamics from an STD-dominated perspective.
+We will write a 3-component dynamical system that connects a Poisson input encoding cell to a leaky integrate-and-fire (LIF) cell via a single dynamic synapse that evolves according to STP. We will first write our simulation of this dynamic synapse from the perspective of STF-dominated dynamics, plotting out the results under two different Poisson spike trains with different spiking frequencies. Then, we will modify our simulation to emulate dynamics from an STD-dominated perspective.
### Starting with Facilitation-Dominated Dynamics
-One experimental goal with using a "dynamic synapse" [1] is often to computationally
-model the fact that synaptic efficacy (strength/conductance magnitude) is
-not a fixed quantity -- even in cases where long-term adaptation/learning is
-absent -- and instead a time-varying property that depends on a fixed
-quantity of biophysical resources. Specifically, biological neuronal networks,
-synaptic signaling (or communication of information across synaptic connection
-pathways) consumes some quantity of neurotransmitters -- STF results from an
-influx of calcium into an axon terminal of a pre-synaptic neuron (after
-emission of a spike pulse) whereas STD occurs after a depletion of
-neurotransmitters that is consumed by the act of synaptic signaling at the axon
-terminal of a pre-synaptic neuron. Studies of cortical neuronal regions have
-empirically found that some areas are STD-dominated, STF-dominated, or exhibit
-some mixture of the two.
-
-Ultimately, the above means that, in the context of spiking cells, when a
-pre-synaptic neuron emits a pulse, this act will affect the relative magnitude
-of the synapse's efficacy. In some cases, this will result in an increase
-(facilitation) and, in others, this will result in a decrease (depression)
-that lasts over a short period of time (several hundreds to thousands of
-milliseconds in many instances).
-As a result of considering synapses to have a dynamic nature to them, both over
-short and long time-scales, plasticity can now be thought of as a stimulus and
-resource-dependent quantity, reflecting an important biophysical aspect that
-affects how neuronal systems adapt and generalize given different kinds of
-sensory stimuli.
-
-Writing our STP dynamic synapse can be done by importing
-[STPDenseSynapse](ngclearn.components.synapses.STPDenseSynapse)
-from ngc-learn's in-built components and using it to wire the output
-spike compartment of the `PoissonCell` to the input electrical current
-compartment of the `LIFCell`. This can be done as follows (using the
-meta-parameters we provide in the code block below to ensure
-STF-dominated dynamics):
+One experimental goal with using a "dynamic synapse" [1] is often to computationally model the fact that synaptic efficacy (strength/conductance magnitude) is not a fixed quantity -- even in cases where long-term adaptation/learning is absent -- and instead a time-varying property that depends on a fixed quantity of biophysical resources. Specifically, biological neuronal networks, synaptic signaling (or communication of information across synaptic connection pathways) consumes some quantity of neurotransmitters -- STF results from an influx of calcium into an axon terminal of a pre-synaptic neuron (after emission of a spike pulse) whereas STD occurs after a depletion of neurotransmitters that is consumed by the act of synaptic signaling at the axon terminal of a pre-synaptic neuron. Studies of cortical neuronal regions have empirically found that some areas are STD-dominated, STF-dominated, or exhibit some mixture of the two.
+
+Ultimately, the above means that, in the context of spiking cells, when a pre-synaptic neuron emits a pulse, this act will affect the relative magnitude of the synapse's efficacy. In some cases, this will result in an increase (facilitation) and, in others, this will result in a decrease (depression) that lasts over a short period of time (several hundreds to thousands of milliseconds in many instances). As a result of considering synapses to have a dynamic nature to them, both over short and long time-scales, plasticity can now be thought of as a stimulus and resource-dependent quantity, reflecting an important biophysical aspect that affects how neuronal systems adapt and generalize given different kinds of sensory stimuli.
+
+Writing our STP dynamic synapse can be done by importing [STPDenseSynapse](ngclearn.components.synapses.STPDenseSynapse) from ngc-learn's in-built components and using it to wire the output spike compartment of the `PoissonCell` to the input electrical current compartment of the `LIFCell`. This can be done as follows (using the meta-parameters we provide in the code block below to ensure STF-dominated dynamics):
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components import PoissonCell, STPDenseSynapse, LIFCell
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
## create seeding keys (JAX-style)
dkey = random.PRNGKey(231)
@@ -88,49 +47,40 @@ plot_fname = "{}Hz_stp_{}.jpg".format(firing_rate_e, tag)
with Context("Model") as model:
W = STPDenseSynapse(
- "W", shape=(1, 1), weight_init=dist.constant(value=2.5),
- resources_init=dist.constant(value=Rval), tau_f=tau_f, tau_d=tau_d,
- key=subkeys[0]
+ "W", shape=(1, 1), weight_init=DistributionGenerator.constant(value=2.5),
+ resources_init=DistributionGenerator.constant(value=Rval), tau_f=tau_f, tau_d=tau_d, key=subkeys[0]
)
z0 = PoissonCell("z0", n_units=1, target_freq=firing_rate_e, key=subkeys[0])
z1 = LIFCell(
- "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60.,
- v_reset=-70., thr=-50., tau_theta=0., theta_plus=0., refract_time=0.
+ "z1", n_units=1, tau_m=tau_m, resist_m=(tau_m / dt) * R_m, v_rest=-60., v_reset=-70., thr=-50., tau_theta=0.,
+ theta_plus=0., refract_time=0.
)
- W.inputs << z0.outputs ## z0 -> W
- z1.j << W.outputs ## W -> z1
+ z0.outputs >> W.inputs ## z0 -> W
+ W.outputs >> z1.j ## W -> z1
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> z0.advance_state
>> W.advance_state
>> z1.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> z0.reset
>> z1.reset
>> W.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- @Context.dynamicCommand
- def clamp(obs):
- z0.inputs.set(obs)
+## set up some utility functions for the model context
+def clamp(obs):
+ z0.inputs.set(obs)
```
-Notice that the `STPDenseSynapse` has two important time constants to configure;
-`tau_f` ($\tau_f$), the facilitation time constant, and `tau_d` ($\tau_d$), the
-depression time constant. In effect, it is these two constants that you will
-want to set to obtain different desired behavior from this in-built dynamic
-synapse:
+Notice that the `STPDenseSynapse` has two important time constants to configure; `tau_f` ($\tau_f$), the facilitation time constant, and `tau_d` ($\tau_d$), the depression time constant. In effect, it is these two constants that you will want to set to obtain different desired behavior from this in-built dynamic synapse:
1. setting $\tau_f > \tau_d$ will result in STF-dominated behavior; whereas
2. setting $\tau_f < \tau_d$ will produce STD-dominated behavior.
-Note that setting $\tau_d = 0$ will result in short-term depression being turned off
-completely (and $\tau_f = 0$ disables STF).
+Note that setting $\tau_d = 0$ will result in short-term depression being turned off completely (and $\tau_f = 0$ disables STF).
-Formally, given the time constants above the dynamics of the `STPDenseSynapse`
-operate according to the following coupled ordinary differential equations (ODEs):
+Formally, given the time constants above the dynamics of the `STPDenseSynapse` operate according to the following coupled ordinary differential equations (ODEs):
$$
\tau_f \frac{\partial u_j(t)}{\partial t} &= -u_j(t) + N_R \big(1 - u_j(t)\big) s_j(t) \\
@@ -144,25 +94,11 @@ W^{dyn}_{ij}(t + \Delta t) = \Big( W^{max}_{ij} u_j(t + \Delta t) x_j(t) s_j(t)
+ W^{dyn}_{ij} (1 - s_j(t))
$$
-where $N_R$ represents an increment produced by a pre-synaptic spike $\mathbf{s}_j(t)$
-(and in essence, the neurotransmitter resources available to yield facilitation),
-$W^{max}_{ij}$ denotes the absolute synaptic efficacy (or maximum response
-amplitude of this synapse in the case of a complete release of all
-neurotransmitters; $x_j(t) = u_j(t) = 1$) of the connection between pre-synaptic
-neuron $j$ and post-synaptic neuron $i$, and $W^{dyn}_{ij}(t)$ is the value
-of the dynamic synapse's efficacy at time `t`.
-$\mathbf{x}_j$ is a variable (which lies in the range of $[0,1]$) that indicates
-the fraction of (neurotransmitter) resources available after a depletion of the
-neurotransmitter resource pool. $\mathbf{u}_j$, on the hand,
-represents the neurotransmitter "release probability", or the fraction of available
-resources ready for the dynamic synapse's use.
+where $N_R$ represents an increment produced by a pre-synaptic spike $\mathbf{s}_j(t)$ (and in essence, the neurotransmitter resources available to yield facilitation), $W^{max}_{ij}$ denotes the absolute synaptic efficacy (or maximum response amplitude of this synapse in the case of a complete release of all neurotransmitters; $x_j(t) = u_j(t) = 1$) of the connection between pre-synaptic neuron $j$ and post-synaptic neuron $i$, and $W^{dyn}_{ij}(t)$ is the value of the dynamic synapse's efficacy at time `t`. $\mathbf{x}_j$ is a variable (which lies in the range of $[0,1]$) that indicates the fraction of (neurotransmitter) resources available after a depletion of the neurotransmitter resource pool. $\mathbf{u}_j$, on the hand, represents the neurotransmitter "release probability", or the fraction of available resources ready for the dynamic synapse's use.
### Simulating and Visualizing STF
-Now that we understand the basics of how an ngc-learn STP synapse works, we can next
-try it out on a simple pre-synaptic Poisson spike train. Writing out the
-simulated input Poisson spike train and our STP model's processing of this
-data can be done as follows:
+Now that we understand the basics of how an ngc-learn STP synapse works, we can next try it out on a simple pre-synaptic Poisson spike train. Writing out the simulated input Poisson spike train and our STP model's processing of this data can be done as follows:
```python
t_vals = []
@@ -170,26 +106,27 @@ u_vals = []
x_vals = []
W_vals = []
num_z1_spikes = 0.
-model.reset()
+reset_process.run()
obs = jnp.asarray([[1.]])
ts = 1.
ptr = 0 # spike time pointer
for i in range(T_max):
- model.clamp(obs)
- model.advance(t=dt * ts, dt=dt)
- u = jnp.squeeze(W.u.value)
- x = jnp.squeeze(W.x.value)
- Wexc = jnp.squeeze(W.Wdyn.value)
- s0 = jnp.squeeze(W.inputs.value)
- s1 = jnp.squeeze(z1.s.value)
+ clamp(obs)
+ advance_process.run(t=dt * ts, dt=dt)
+ u = jnp.squeeze(W.u.get())
+ x = jnp.squeeze(W.x.get())
+ Wexc = jnp.squeeze(W.Wdyn.get())
+ s0 = jnp.squeeze(W.inputs.get())
+ s1 = jnp.squeeze(z1.s.get())
num_z1_spikes = s1 + num_z1_spikes
u_vals.append(u)
x_vals.append(x)
W_vals.append(Wexc)
t_vals.append(ts)
- print("{}| u: {} x: {} W: {} pre: {} post {}".format(ts, u, x, Wexc, s0, s1))
+ print("\r{}| u: {} x: {} W: {} pre: {} post {}".format(ts, u, x, Wexc, s0, s1), end="")
ts += dt
ptr += 1
+print()
print("Number of z1 spikes = ",num_z1_spikes)
u_vals = jnp.squeeze(jnp.asarray(u_vals))
@@ -197,8 +134,7 @@ x_vals = jnp.squeeze(jnp.asarray(x_vals))
t_vals = jnp.squeeze(jnp.asarray(t_vals))
```
-We may then plot out the result of the STF-dominated dynamics we
-simulate above with the following code:
+We may then plot out the result of the STF-dominated dynamics we simulate above with the following code:
```python
import matplotlib.pyplot as plt
@@ -235,13 +171,13 @@ ax2.grid()
fig1.savefig(plot_fname)
```
-Under the `2` Hertz Poisson spike train set up above, the plotting
-code should produce (and save to disk) the following:
+Under the `2` Hertz Poisson spike train set up above, the plotting code should produce (and save to disk) the following:
-Note that, if you change the frequency of the input Poisson spike train to `20`
-Hertz instead, like so:
+where we also observe that about `3` spikes/pulses were emitted by the post-synaptic neuron over the course of this
+simulation.
+Note that, if you change the frequency of the input Poisson spike train to `20` Hertz instead, like so:
```python
firing_rate_e = 20 ## Hz (of Poisson input train)
@@ -251,16 +187,13 @@ and re-run your simulation script, you should obtain the following:
-Notice that increasing the frequency in which the pre-synaptic spikes occur
-results in more volatile dynamics with respect to the effective synaptic
-efficacy over time.
+where we further observe that about `68` spikes/pulses were emitted by the post-synaptic neuron over the course of this
+simulation.
+In general, notice that increasing the frequency in which the pre-synaptic spikes occur results in more volatile dynamics with respect to the effective synaptic efficacy over time.
### Depression-Dominated Dynamics
-With your code above, it's simple to reconfigure the model to emulate
-the opposite of STF dominated dynamics, i.e., short-term depression (STD)
-dominated dynamics.
-Modify your meta-parameter values like so:
+With the code you have written code above, it's simple to reconfigure the model to emulate the opposite of STF dominated dynamics, i.e., short-term depression (STD) dominated dynamics. To do so, you will need to modify your meta-parameter values like so:
```python
firing_rate_e = 2 ## Hz (of Poisson input train)
@@ -274,17 +207,13 @@ and re-run your script to obtain an output akin to the following:
-Now, modify your meta-parameters one last time to use a higher-frequency
-input spike train, i.e., `firing_rate_e = 20 ## Hz`, to obtain a plot similar
-to the one below:
+which, after running the script, will print out that the post-synaptic neuron spiked about `3` times. Now, modify your meta-parameters one last time to use a higher-frequency input spike train, i.e., `firing_rate_e = 20 ## Hz`, to obtain a plot similar to the one below:
-You have now successfully simulated a dynamic synapse in ngc-learn across
-several different Poisson input train frequencies under both STF and
-STD-dominated regimes. In more complex biophysical models, it could prove useful
-to consider combining STP with other forms of long-term experience-dependent
-forms of synaptic adaptation, such as spike-timing-dependent plasticity.
+where the script will further print out that the post-synaptic neuron spiked only a single time.
+
+You have now successfully simulated a dynamic synapse in ngc-learn across several different Poisson input train frequencies under both STF- and STD-dominated regimes. In more complex biophysical models, it could prove useful to consider combining STP with other forms of long-term experience-dependent forms of synaptic adaptation, such as [spike-timing-dependent plasticity](stdp.md).
## References
diff --git a/docs/tutorials/neurocog/simple_leaky_integrator.md b/docs/tutorials/neurocog/simple_leaky_integrator.md
index ec8d485e..1aa6643e 100644
--- a/docs/tutorials/neurocog/simple_leaky_integrator.md
+++ b/docs/tutorials/neurocog/simple_leaky_integrator.md
@@ -1,24 +1,18 @@
# Lecture 2A: The Simplified Leaky Integrator Cell
-In this tutorial, we will study one of ngc-learn's (simplest) in-built leaky
-integrator components, the simplified leaky integrate-and-fire (SLIF).
+In this tutorial, we will study one of ngc-learn's (simplest) in-built neuronal cell components, the simplified leaky integrate-and-fire (SLIF).
## Creating and Using a Leaky Integrator
### Instantiating the Leaky Integrate-and-Fire Cell
-With our JSON configuration in place, go ahead and create a Python script,
-i.e., `run_slif.py`, to write your code for this part of the tutorial.
-
-Now let's go ahead and set up the controller/context for this lesson's simulation,
-where we will a dynamical system with only a single component,
-specifically the simplified LIF (sLIF), like so:
+Start by creating a Python script, i.e., `run_slif.py`, to write your code for this part of the tutorial.
+Now let's go ahead and set up the controller/model-context for this lesson's simulation, where we will a dynamical system with only a single component, specifically the simplified LIF (sLIF). Write code to do this like so:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.neurons.spiking.sLIFCell import SLIFCell
from ngclearn.utils.viz.spike_plot import plot_spiking_neuron
@@ -36,57 +30,28 @@ tau_m = R_m * C ## membrane time constant
## create simple system with only one sLIF
with Context("Model") as model:
- cell = SLIFCell("z0", n_units=1, tau_m=tau_m, resist_m=R_m, thr=V_thr,
- refract_time=ref_T, key=subkeys[0])
+ cell = SLIFCell("z0", n_units=1, tau_m=tau_m, resist_m=R_m, thr=V_thr, refract_time=ref_T, key=subkeys[0])
## set up core commands that drive the simulation
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- cell.j.set(x)
+## set up non-compiled utility commands
+def clamp(x):
+ cell.j.set(x)
```
-This node has quite a few compartments and constants but only a handful are important
-for understanding how this model governs spiking/firing rates within its simulation window.
-Specifically, in this lesson, we will focus on
-its electrical current `j` (formally labeled here as $\mathbf{j}_t$),
-its voltage `v` (formally labeled: $\mathbf{v}_t$), its spike emission
-(or action potential) `s` (formally $\mathbf{s}_t$), and its refractory
-variable/marker (formally $\mathbf{r}_t$). The subscript $t$ indicates
-that this compartment variable takes on a certain value at a certain time step
-$t$ and we will refer to the ngc-learn context's integration time constant,
-the amount of time we move forward by, as $\Delta t$. The constants or
-hyper-parameters we will be most interested in are the cell's membrane resistance
-`R_m` (formally $R$ with its capacitance $C$ implied), its membrane time
-constant `tau_m` (formally $\tau_m$), its refractory period time
-`refract_T` (formally $T_{ref}$), and its voltage threshold `v_thr`
-(formally $V_thr$). (There are other constants inherent to the
-sLIF, but these are sufficient for this exercise.)
-
-Later on, towards the end of this tutorial, we provide some theoretical
-exposition and explanation of the above constants/compartments
-(see `On the Dynamics of Leaky Integrators`); for practical
-purposes we will now move on to using your `sLIF` node in a simple simulation
-to illustrate some of its dynamics.
+This node has quite a few compartments and constants but only a handful are important for understanding how this model governs spiking/firing rates during a controller's simulation window. Specifically, in this lesson, we will focus on its electrical current `j` (formally labeled here as $\mathbf{j}_t$), its voltage `v` (formally labeled: $\mathbf{v}_t$), its spike emission (or action potential) `s` (formally $\mathbf{s}_t$), and its refractory variable/marker (formally $\mathbf{r}_t$). The subscript $t$ indicates that this compartment variable takes on a certain value at a certain time step $t$ and we will refer to the ngc-learn controller's integration time constant, the amount of time we move forward by, as $\Delta t$. The constants or hyper-parameters we will be most interested in are the cell's membrane resistance `R_m` (formally $R$ with its capacitance $C$ implied), its membrane time constant `tau_m` (formally $\tau_m$), its refractory period time `refract_T` (formally $T_{ref}$), and its voltage threshold `v_thr` (formally $V_thr$). (There are other constants inherent to the sLIF, but these are sufficient for this exercise.)
+
+Later on, towards the end of this tutorial, we provide some theoretical exposition and explanation of the above constants/compartments
+(see `On the Dynamics of Leaky Integrators`); for practical purposes, we will now move on to using your `sLIF` node in a simple simulation to illustrate some of its dynamics.
### Simulating a Leaky Integrator
-
-Given our single-cell dynamical system above, let us write some code to use
-our `sLIF` node and visualize its spiking pattern by feeding
-into it a step current, where the electrical current `j` starts at $0$ then
-switches to $0.3$ at $t = 10$ (ms). Specifically, we can plot the input current,
-the neuron's voltage `v`, and its output spikes as follows:
+Given our single-cell dynamical system above, let us write some code to use our `sLIF` node and visualize its spiking pattern by feeding into it a step current, where the electrical current `j` starts at $0$ then switches to $0.3$ at $t = 10$ (ms). Specifically, we can plot the input current, the neuron's voltage `v`, and its output spikes as follows:
```python
# create a synthetic electrical step current
@@ -96,81 +61,65 @@ curr_in = []
mem_rec = []
spk_rec = []
-model.reset()
+reset_process.run()
for ts in range(current.shape[1]):
j_t = jnp.expand_dims(current[0,ts], axis=0) ## get data at time ts
- model.clamp(j_t)
- model.advance(t=ts*1., dt=dt)
+ clamp(j_t)
+ advance_process.run(t=ts*1., dt=dt)
## naively extract simple statistics at time ts and print them to I/O
- v = cell.v.value
- s = cell.s.value
+ v = cell.v.get()
+ s = cell.s.get()
curr_in.append(j_t)
mem_rec.append(v)
spk_rec.append(s)
- print(" {}: s {} ; v {}".format(ts, s, v), end="")
+ print(f"\r{ts}: s {s} ; v {v}", end="")
print()
import numpy as np
curr_in = np.squeeze(np.asarray(curr_in))
mem_rec = np.squeeze(np.asarray(mem_rec))
spk_rec = np.squeeze(np.asarray(spk_rec))
-plot_spiking_neuron(curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=0.,
- max_mem_val=1.3, title="SLIF-Node: Constant Electrical Input",
- fname="lif_plot.jpg")
+plot_spiking_neuron(
+ curr_in, mem_rec, spk_rec, None, dt, thr_line=V_thr, min_mem_val=0., max_mem_val=1.3,
+ title="SLIF-Node: Constant Electrical Input", fname="lif_plot.jpg"
+)
```
-which produces the following plot (saved as `lif_plot.jpg` locally to disk):
+which produces the following plot (saved as `lif_plot.jpg` locally to disk):
-where we see that, given a build-up over time in the neuron's membrane potential
-(since the current is constant and non-zero after $10$ ms), a spike is emitted
-once the value of the membrane potential exceeds the threshold (indicated by
-the dashed horizontal line in the middle plot) $V_{thr} = 1$.
-Notice that if we play with the value of `ref_T` (the refactory period $T_{ref}$)
-and change it to something like `ref_T = 10 * dt` (ten times the integration time
-constant), we get the following neuronal dynamics plot:
+where we see that, given a build-up over time in the neuron's membrane potential (since the current is constant and non-zero after $10$ ms), a spike is emitted once the value of the membrane potential exceeds the threshold (indicated by the dashed horizontal line in the middle plot) $V_{thr} = 1$. Notice that if we play with the value of `ref_T` (the refactory period $T_{ref}$) and change it to something like `ref_T = 10 * dt` (ten times the integration time constant), we get the following neuronal dynamics plot:
-where we see that after the LIF neuron fires, it remains stuck at its resting
-potential for a period of $0.01$ ms (the short flat periods in the red curve
-starting after the first spike).
+where we see that after the LIF neuron fires, it remains stuck at its resting potential for a period of $0.01$ ms (the short flat periods in the red curve starting after the first spike).
## On the Dynamics of Leaky Integrators
-Now let us unpack this component by first defining the relevant compartments:
+Now let us unpack this component by first defining the relevant compartments:
-+ $\mathbf{j}_t$: the current electrical current of the neurons within this node
- (note that this current could be the summation of multiple step/pointwise
- current sources or be the current sample of an electrical current, itself
- modeled by a differential equation);
++ $\mathbf{j}_t$: the current electrical current of the neurons within this node (note that this current could be the summation of multiple step/pointwise current sources or be the current sample of an electrical current, itself modeled by a differential equation);
+ $\mathbf{v}_t$: the current membrane potential of the neurons within this node;
-+ $\mathbf{s}_t$: the current recording/reading of any spikes produced by this
- node's neurons;
-+ $\mathbf{r}_t$: the current value of the absolute refractory variables - this
- accumulates with time (and forces neurons to rest)
++ $\mathbf{s}_t$: the current recording/reading of any spikes produced by this node's neurons;
++ $\mathbf{r}_t$: the current value of the absolute refractory variables - this accumulates with time (and forces neurons to rest)
and finally the constants:
-+ $V_{thr}$: threshold that a neuron's membrane potential must overcome before
- a spike is transmitted;
++ $V_{thr}$: threshold that a neuron's membrane potential must overcome before a spike is transmitted;
+ $\Delta t$: the integration time constant, on the order of milliseconds (ms);
+ $R$: the neural (cell) membrane resistance, on the order of mega Ohms ($M \Omega$);
+ $C$: the neural (cell) membrane capacitance, on the order of picofarads ($pF$);
-+ $\tau_{m}$: membrane potential time constant (also $\tau_{m} = R * C$ -
- resistance times capacitance);
++ $\tau_{m}$: membrane potential time constant (also $\tau_{m} = R * C$, or resistance times capacitance);
+ $T_{ref}$: the length of a neuron's absolute refractory period.
-With above defined, we can now explicitly lay out the underlying (linear) ordinary
-differential equation that the `sLIF` evolves according to:
+With above defined, we can now explicitly lay out the underlying (linear) ordinary differential equation that the `sLIF` evolves according to:
$$
\tau_m \frac{\partial \mathbf{v}_t}{\partial t} = (-\mathbf{v}_t + R \mathbf{j}_t), \; \mbox{where, } \tau_m = R C
$$
-and with some simple mathematical manipulations (leveraging the method of finite differences),
-we can derive the Euler integrator employed by the `sLIF` as follows:
+and with some simple mathematical manipulations (leveraging the method of finite differences), we can derive the Euler integrator employed by the `sLIF` as follows:
$$
\tau_m \frac{\partial \mathbf{v}_t}{\partial t} &= (-\mathbf{v}_t + R \mathbf{j}_t) \\
@@ -178,32 +127,24 @@ $$
\mathbf{v}_{t + \Delta t} &= \mathbf{v}_t + (-\mathbf{v}_t + R \mathbf{j}_t) \frac{\Delta t}{\tau_m }
$$
-where we see that above integration tells us that the membrane potential of this node varies
-over time as a function of the sum of its input electrical current $\mathbf{j}_t$
-(multiplied by the cell membrane resistance) and a leak (or decay) $-\mathbf{v}_t$
-modulated by the integration time constant divided by the membrane time constant.
-The `sLIF` allows you to control the value of $\tau_m$ directly (hence why we
-calculated $\tau_m$ externally via our chosen $R$ and $C$; other neuronal cells
-allow you to change $\tau_m$ via $R$ and $C$).
-
-
-
+in this walkthrough.)
+-->
-In effect, given the above, every time the `sLIF`'s `.advanceState()` function is
-called within a simulation controller context (`Context()`), the above Euler integration of
-the membrane potential differential equation is happening each time step. Knowing this,
-the last item required to understand ngc-learn's `sLIF` node's computation is
-related to its spike $\mathbf{s}_t$. The spike reading is computed simply by
-comparing the current membrane potential $\mathbf{v}_t$ to the constant threshold
-defined by $V_{thr}$ according to the following piecewise function:
+In effect, given the above, every time the `sLIF`'s `.advanceState()` function is called within a simulation controller (`Controller()`), the above Euler integration of the membrane potential differential equation is happening each time step. Knowing this, the last item required to understand ngc-learn's `sLIF` node's computation is related to its spike $\mathbf{s}_t$. The spike reading is computed simply by comparing the current membrane potential $\mathbf{v}_t$ to the constant threshold defined by $V_{thr}$ according to the following piecewise function:
$$
\mathbf{s}_{t, i} = \begin{cases}
@@ -212,52 +153,8 @@ $$
\end{cases}
$$
-where we see that if the $i$th neuron's membrane potential exceeds the threshold
-$V_{thr}$, then a voltage spike is emitted. After a spike is emitted, the $i$th
-neuron within the node needs to be reset to its resting potential and this is done
-with the final compartment that we mentioned, i.e., the refractory
-variable $\mathbf{r}_t$.
-The refractory variable $\mathbf{r}_t$ is important for hyperpolarizing the
-$i$th neuron back to its resting potential (establishing a critical reset mechanism
--- otherwise, the neuron would fire out of control after overcoming its
-threshold) and reducing the amount of spikes generated over time. This reduction
-is one of the key factors behind the power efficiency of biological neuronal systems.
-Another aspect of ngc-learn's refractory variable is the temporal length of the reset itself,
-which is controlled by the $T_{ref}$ (`T_ref`) constant -- this yields what is known as the
-absolute refractory period, or the interval of time at which a second action potential
-absolutely cannot be initiated. If $T_{ref}$ is set to be greater than
-zero, then the $i$th neuron that fires will be forced to remain at its resting
-potential of zero for the duration of this refractory period.
-
-Note that the reason the `sLIF` contains simplified in its name is that its
-internal dynamics and parameterization have been drastically simplified in
-comparison to ngc-learn's more standard `LIF` component. Furthermore, the
-`sLIF` operates assuming a resting membrane potential of `0` (milliVolts) whereas,
-for more intricate leaky integrator models, the resting potential is often
-negative, requiring a different and more careful setting of hyper-parameters
-(such as the voltage threshold). Nevertheless, although `sLIF` is a simpler
-model, it can be used as a rational first step for crafting very useful spiking
-neural networks and offers other aspects of functionality not used in this tutorial,
-such as adaptive threshold functionality and fast approximate lateral inhibition/recurrence.
-
-## Optional: Setting Up The Components with a JSON Configuration
-
-While you are not required to create a JSON configuration file for ngc-learn,
-to get rid of the warning that ngc-learn will throw at the start of your
-program's execution (indicating that you do not have a configuration set up yet),
-all you need to do is create a sub-directory for your JSON configuration
-inside of your project code's directory, i.e., `json_files/modules.json`.
-Inside the JSON file, you would write the following:
-
-```json
-[
- {"absolute_path": "ngclearn.components",
- "attributes": [
- {"name": "SLIFCell"}]
- },
- {"absolute_path": "ngcsimlib.operations",
- "attributes": [
- {"name": "overwrite"}]
- }
-]
-```
+where we see that if the $i$th neuron's membrane potential exceeds the threshold $V_{thr}$, then a voltage spike is emitted. After a spike is emitted, the $i$th neuron within the node needs to be reset to its resting potential and this is done with the final compartment that we mentioned, i.e., the refractory variable $\mathbf{r}_t$.
+The refractory variable $\mathbf{r}_t$ is important for hyperpolarizing the $i$th neuron back to its resting potential (establishing a critical reset mechanism -- otherwise, the neuron would fire out of control after overcoming its threshold) and reducing the amount of spikes generated over time. This reduction is one of the key factors behind the power efficiency of biological neuronal systems. Another aspect of ngc-learn's refractory variable is the temporal length of the reset itself, which is controlled by the $T_{ref}$ (`T_ref`) constant -- this yields what is known as the absolute refractory period, or the interval of time at which a second action potential absolutely cannot be initiated. If $T_{ref}$ is set to be greater than zero, then the $i$th neuron that fires will be forced to remain at its resting potential of zero for the duration of this refractory period.
+
+Note that the reason the `sLIF` contains simplified in its name is that its internal dynamics and parameterization have been drastically simplified in comparison to ngc-learn's more standard `LIF` component. Furthermore, the `sLIF` operates assuming a resting membrane potential of `0` (milliVolts) whereas, for more intricate leaky integrator models, the resting potential is often negative, requiring a different and more careful setting of hyper-parameters (such as the voltage threshold). Nevertheless, although `sLIF` is a simpler model, it can be used as a rational first step for crafting very useful spiking neural networks and offers other aspects of functionality not used in this tutorial, such as adaptive threshold functionality and fast approximate lateral inhibition/recurrence.
+
diff --git a/docs/tutorials/neurocog/stdp.md b/docs/tutorials/neurocog/stdp.md
index b8e889a0..ea65cc41 100755
--- a/docs/tutorials/neurocog/stdp.md
+++ b/docs/tutorials/neurocog/stdp.md
@@ -1,40 +1,23 @@
# Lecture 4C: Spike-Timing-Dependent Plasticity
-In the context of spiking neuronal networks, one of the most important forms
-of adaptation that is often simulated is that of spike-timing-dependent
-plasticity (STDP). In this lesson, we will setup and use one
-of ngc-learn's standard in-built STDP-based components, visualizing the
-changes in synaptic efficacy that it produces in the context of
-pre-synaptic and post-synaptic variable traces.
+In the context of spiking neuronal networks, one of the most important forms of adaptation that is often simulated is that of spike-timing-dependent plasticity (STDP). In this lesson, we will setup and use one of ngc-learn's standard in-built STDP-based components, visualizing the changes in synaptic efficacy that it produces in the context of pre-synaptic and post-synaptic variable traces.
## Probing Spike-Timing-Dependent Plasticity
-Go ahead and make a new folder for this study and create a Python script,
-i.e., `run_trstdp.py`, to write your code for this part of the tutorial.
+Go ahead and make a new folder for this study and create a Python script, i.e., `run_trstdp.py`, to write your code for this part of the tutorial.
-Now let's set up the model for this lesson's simulation and construct a
-3-component system made up of two variable traces (`VarTrace`) connected by
-one single synapse that is capable of producing changes in connection strength
-in accordance with STDP, specifically with a form of the update rule known
-as [trace-based STDP](ngclearn.components.synapses.hebbian.traceSTDPSynapse).
-Note that the trace components do not really do
-anything meaningful unless they receive some input and we will provide
-carefully controlled input spike values in order to control their behavior
-so as to see how STDP responds to the relative temporal ordering of a pre- and
-post-synaptic spike, where the time of spikes is approximated by the
-corresponding pre- and post-synaptic traces (which decay exponentially with time
-in the absence of input).
+Now let's set up the model for this lesson's simulation and construct a 3-component system made up of two variable traces (`VarTrace`) connected by one single synapse that is capable of producing changes in connection strength in accordance with STDP, specifically with a form of the update rule known as [trace-based STDP](ngclearn.components.synapses.hebbian.traceSTDPSynapse). Note that the trace components do not really do anything meaningful unless they receive some input. Therefore, we will provide carefully controlled input spike values in order to control their behavior in order to see how STDP responds to the relative temporal ordering of a pre- and post-synaptic spike, where the timing of the spikes is approximated by the corresponding pre- and post-synaptic traces (which decay exponentially with time in the absence of input).
-Writing the above 3-component system can be in the following manner:
+Writing the above 3-component system can be done in the following manner:
```python
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
-from ngclearn.utils import JaxProcess
+
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.other.varTrace import VarTrace
from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
## create seeding keys (JAX-style)
dkey = random.PRNGKey(231)
@@ -46,55 +29,43 @@ T_max = 100 ## number time steps to simulate
with Context("Model") as model:
tr0 = VarTrace("tr0", n_units=1, tau_tr=8., a_delta=1.)
tr1 = VarTrace("tr1", n_units=1, tau_tr=8., a_delta=1.)
- W = TraceSTDPSynapse("W1", shape=(1, 1), eta=0., A_plus=1., A_minus=0.8,
- weight_init=dist.uniform(0.0, 0.3), key=subkeys[0])
+ W = TraceSTDPSynapse(
+ "W1", shape=(1, 1), eta=0., A_plus=1., A_minus=0.8,
+ weight_init=DistributionGenerator.uniform(low=0.0, high=0.3), key=subkeys[0]
+ )
# wire only relevant compartments to synaptic cable W for demo purposes
- W.preTrace << tr0.trace
- # self.W1.preSpike << self.z0.outputs ## we disable this as we will manually
- ## insert a binary value (for a spike)
- W.postTrace << tr1.trace
- # self.W1.postSpike << self.z1e.s ## we disable this as we will manually
- ## insert a binary value (for a spike)
-
- evolve_process = (JaxProcess()
+ tr0.trace >> W.preTrace
+ # self.z0.outputs >> self.W1.preSpike ## we disable this as we will manually
+ ## insert a binary value (for a spike) in this tutorial
+ tr1.trace >> W.postTrace
+ # self.z1e.s >> self.W1.postSpike ## we disable this as we will manually
+ ## insert a binary value (for a spike) in this tutorial
+
+ evolve_synapse = (MethodProcess("evolve")
>> W.evolve)
- model.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
-
- advance_process = (JaxProcess()
- >> tr0.advance_state
- >> tr1.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance_traces")
- reset_process = (JaxProcess()
- >> tr0.reset
- >> tr1.reset
- >> W.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ advance_traces = (MethodProcess("advance")
+ >> tr0.advance_state
+ >> tr1.advance_state
+ >> W.advance_state)
+ reset = (MethodProcess("reset")
+ >> tr0.reset
+ >> tr1.reset
+ >> W.reset)
- @Context.dynamicCommand
- def clamp_synapse(pre_spk, post_spk):
- W.preSpike.set(pre_spk)
- W.postSpike.set(post_spk)
+## set up some utility functions for the model context
+def clamp_synapse(pre_spk, post_spk):
+ W.preSpike.set(pre_spk)
+ W.postSpike.set(post_spk)
-
- @Context.dynamicCommand
- def clamp_traces(pre_spk, post_spk):
- tr0.inputs.set(pre_spk)
- tr1.inputs.set(post_spk)
+def clamp_traces(pre_spk, post_spk):
+ tr0.inputs.set(pre_spk)
+ tr1.inputs.set(post_spk)
```
-With our carefully constructed STDP-adapted model above, we can then simulate
-the changes to synaptic efficacy that it would produce as a function of
-the distance between and order between a pre- and a post-synaptic binary spike.
-Notice that in the above model, we have set the global learning rate `eta` to
-zero, which will prevent the `TraceSTDPSynapse` from actually adjusting
-its internal matrix of synaptic weight values using the updates produced by
-STDP -- this means our synapses are held fixed throughout this particular
-demonstration. Our goal is to produce an approximation of the theoretical synaptic
-strength adjustment curve dictated by STDP; this can be done using the
-code below:
+With our carefully constructed STDP-adapted model above, we can then simulate the changes to synaptic efficacy that it would produce as a function of the distance between and order between a pre- and a post-synaptic binary spike. Notice that in the above model, we have set the global learning rate `eta` to zero, which will prevent the `TraceSTDPSynapse` from actually adjusting its internal matrix of synaptic weight values using the updates produced by STDP -- this means our synapses are held fixed throughout this particular demonstration. Our goal is to produce an approximation of the theoretical synaptic strength adjustment curve dictated by STDP; this can be done using the code below:
```python
t_values = []
@@ -118,16 +89,13 @@ for i in range(T_max+1):
_pre_trig = jnp.ones((1,1))
_post_trig = jnp.zeros((1,1))
ts = 0.
- model.clamp_traces(pre_spk, post_spk)
- model.advance_traces(t=dt * i, dt=dt)
+ clamp_traces(pre_spk, post_spk)
+ advance_traces.run(t=dt * i, dt=dt)
## get STDP update
- W.preSpike.set(_pre_trig)
- W.postSpike.set(_post_trig)
- W.preTrace.set(tr0.trace.value)
- W.postTrace.set(tr1.trace.value)
- model.evolve(t=dt * i, dt=dt)
- dW = W.dWeights.value
+ clamp_synapse(_pre_trig, _post_trig)
+ evolve_synapse.run(t=dt * i, dt=dt)
+ dW = W.dWeights.get()
dW_vals.append(dW)
if i >= int(T_max/2):
t_values.append(ts)
@@ -165,40 +133,12 @@ which should produce a plot similar to the one in the left-hand side below:
+------------------------------------------------------------+----------------------------------------------------------------+
```
-where we have provided a marked-up image of the STDP experimental data produced
-and visualized in the classical work done by Bi and Poo in 1998 [1].
-We remark that our approximate STDP synaptic change curve does not perfectly
-match/fit that of [1] perfectly by any means but does capture the
-general trend and form of the long-term potentiation arc (the roughly
-negative exponential curve to the right-hand side of zero) and the long-term
-depression curve (the flipped exponential-like function to the left-hand
-side of zero). Ultimately, a synaptic component like the `TraceSTDPSynapse`
-can be quite useful for constructing spiking neural network architectures
-that learn in a biologically-plausible fashion since this rule, as seen by the
-above simulation usage, solely depends on information that is locally
-available at the pre-synaptic neuron (its spike and a single trace that
-tracks its temporal spiking history) and the post-synaptic neuron
-(its own spike as well as a trace that tracks its spike history). Notably,
-traced-based STDP can be an effective way of adapting the synapses of
-biophysically more accurate computational models, such as those that balance
-excitatory and inhibitory pressures produced by laterally-wired populations of
-leaky integrator neurons, e.g., the
-[Diehl and Cook spiking architecture](../../museum/snn_dc) we study in the model
-museum in more detail.
-
-### Other Forms of Spike-Timing-Dependent Plasticity
-Finally, beyond trace-based STDP, there are other types of STDP in-built to
-ngc-learn, such as event-driven post-synaptic STDP
-([eventSTDPSynapse](ngclearn.components.synapses.hebbian.eventSTDPSynapse)), that
-you can experiment with and use in your model building and simulation projects.
-You can learn more about these in the ngc-learn
-[modeling API](../../modeling/components.md).
-Beyond this, the ngc-learn dev team is always busy behind the scenes
-constructing more standard computational neuroscience building blocks and
-synaptic plasticity rules; so keep an eye out for future incoming developments!
+Notice that, for the above visual, we have also provided a marked-up image of the STDP experimental data produced and visualized in the classical work done by Bi and Poo in 1998 [1]. We remark that our approximate STDP synaptic change curve does not perfectly match/fit that of [1] perfectly by any means; however, it does capture the general trend and form of the long-term potentiation arc (the roughly negative exponential curve to the right-hand side of zero) and the long-term depression curve (the flipped exponential-like function to the left-hand side of zero). Ultimately, a synaptic component like the `TraceSTDPSynapse` can be quite useful for constructing spiking neural network architectures that learn in a biologically-plausible fashion given that this rule, as seen by the above simulation usage, solely depends on information that is locally available at the pre-synaptic neuron (its spike and a single trace that tracks its temporal spiking history) and the post-synaptic neuron (its own spike as well as a trace that tracks its spike history). Notably, traced-based STDP can be an effective way of adapting the synapses of biophysically more accurate computational models, such as those that balance excitatory and inhibitory pressures produced by laterally-wired populations of leaky integrator neurons, e.g., the [Diehl and Cook spiking architecture](../../museum/snn_dc) that we study in more detail within the context of a model museum exhibit.
+
+### Other Forms of Spike-Timing-Dependent Plasticity
+Finally, beyond trace-based STDP, there are other types of STDP in-built to ngc-learn, such as event-driven post-synaptic STDP ([eventSTDPSynapse](ngclearn.components.synapses.hebbian.eventSTDPSynapse)), which you can experiment with and use in your model building and simulation projects. You can learn more about these and other related biologically-plausible learning rules in the ngc-learn [modeling API](../../modeling/components.md) (specifically in the "Synapses" subsection page).
+Beyond this, the ngc-learn dev team is always busy behind the scenes constructing more standard computational neuroscience building blocks and synaptic plasticity rules; so keep an eye out for future incoming developments!
## References
-[1] Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modifications in cultured
-hippocampal neurons: dependence on spike timing, synaptic strength, and
-postsynaptic cell type." Journal of neuroscience 18.24 (1998).
+[1] Bi, Guo-qiang, and Mu-ming Poo. "Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type." Journal of neuroscience 18.24 (1998).
diff --git a/docs/tutorials/neurocog/traces.md b/docs/tutorials/neurocog/traces.md
index 4d338f65..da038db7 100755
--- a/docs/tutorials/neurocog/traces.md
+++ b/docs/tutorials/neurocog/traces.md
@@ -1,29 +1,17 @@
# Lecture 1B: Trace Variables and Filtering
-Traces represent one very important component tool in ngc-learn as these are
-often, in biophysical model simulations, used to produce real-valued
-representations of often discrete-valued patterns, e.g., spike vectors within
-a spike train, that can facilitate mechanisms such as online biological credit
-assignment. In this lesson, we will observe how one of ngc-learn's core
-trace components -- the `VarTrace` -- operates.
+Traces represent one very important component tool in ngc-learn as these are often, in biophysical model simulations, used to produce real-valued representations of often discrete-valued patterns, e.g., spike vectors within a spike train, that can facilitate mechanisms such as online biological credit assignment. In this lesson, we will observe how one of ngc-learn's core trace components -- the `VarTrace` -- operates.
## Setting Up a Variable Trace for a Poisson Spike Train
-To observe the value of a variable trace, we will pair it to another in-built
-ngc-component; the `PoissonCell`, which will be configured to emit spikes
-approximately at `63.75` Hertz (yielding a fairly sparse spike train). This means
-we will construct a two-component dynamical system, where the input
-compartment `outputs` of the `PoissonCell` will be wired directly into the
-`inputs` compartment of the `VarTrace`. Note that a `VarTrace` has an `inputs`
-compartment -- which is where raw signals typically go into -- and a `trace`
-output compartment -- which is where filtered signal values/by-products are emitted from.
+To observe the value of a variable trace, we will pair it to another in-built ngc-component; the `PoissonCell`, which will be configured to emit spikes approximately at `63.75` Hertz (yielding a fairly sparse spike train). This means we will construct a two-component dynamical system, where the input compartment `outputs` of the `PoissonCell` will be wired directly into the `inputs` compartment of the `VarTrace`. Note that a `VarTrace` has an `inputs` compartment -- which is where raw signals typically go into -- and a `trace` output compartment -- which is where filtered signal values/by-products are emitted from.
The code below will instantiate the paired Poisson cell and corresponding variable trace:
```python
from jax import numpy as jnp, random, jit
-from ngclearn.utils import JaxProcess
-from ngcsimlib.context import Context
+
+from ngclearn import Context, MethodProcess
## import model-specific mechanisms
from ngclearn.components.input_encoders.poissonCell import PoissonCell
from ngclearn.components.other.varTrace import VarTrace
@@ -37,30 +25,24 @@ with Context("Model") as model:
trace = VarTrace("tr0", n_units=1, tau_tr=30., a_delta=0.5)
## wire up cell z0 to trace tr0
- trace.inputs << cell.outputs
+ cell.outputs >> trace.inputs
- advance_process = (JaxProcess()
+ advance_process = (MethodProcess("advance")
>> cell.advance_state
>> trace.advance_state)
- model.wrap_and_add_command(jit(advance_process.pure), name="advance")
- reset_process = (JaxProcess()
+ reset_process = (MethodProcess("reset")
>> cell.reset
>> trace.reset)
- model.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- @Context.dynamicCommand
- def clamp(x):
- cell.inputs.set(x)
+## set up some utility functions for the model context
+def clamp(x):
+ cell.inputs.set(x)
```
## Running the Paired Cell-Trace System
-We can then run the above two-component dynamical system by injecting a fixed
-(valid) probability value into the Poisson input encoder and then record the
-resulting spikes and trace values. We will do this for `T = 200` milliseconds (ms)
-with the code below:
+We can then run the above two-component dynamical system by injecting a fixed (valid) probability value into the Poisson input encoder and then record the resulting spikes and trace values. We will do this for `T = 200` milliseconds (ms) with the code below:
```python
dt = 1. # ms # integration time constant
@@ -70,22 +52,21 @@ probs = jnp.asarray([[0.35]],dtype=jnp.float32)
time_span = []
spikes = []
traceVals = []
-model.reset()
+reset_process.run()
for ts in range(T):
- model.clamp(probs)
- model.advance(t=ts*1., dt=dt)
+ clamp(probs)
+ advance_process.run(t=ts*1., dt=dt)
- print("{} {}".format(cell.outputs.value, trace.trace.value), end="")
- spikes.append( cell.outputs.value )
- traceVals.append( trace.trace.value )
+ print(f"\r{cell.outputs.get()} {trace.trace.get()}", end="")
+ spikes.append( cell.outputs.get() )
+ traceVals.append( trace.trace.get() )
time_span.append(ts * dt)
print()
spikes = jnp.concatenate(spikes,axis=0)
traceVals = jnp.concatenate(traceVals,axis=0)
```
-We can plot the above simulation's trace outputs with the discrete spikes
-super-imposed at their times of occurrence with the code below:
+We can plot the above simulation's trace outputs with the discrete spikes super-imposed at their times of occurrence with the code below:
```python
import matplotlib #.pyplot as plt
@@ -100,8 +81,7 @@ stat = jnp.where(spikes > 0.)
indx = (stat[0] * 1. - 1.).tolist()
spk = ax.vlines(x=indx, ymin=0.985, ymax=1.05, colors='black', ls='-', lw=5)
-ax.set(xlabel='Time (ms)', ylabel='Trace Output',
- title='Variable Trace of Poisson Spikes')
+ax.set(xlabel='Time (ms)', ylabel='Trace Output', title='Variable Trace of Poisson Spikes')
#ax.legend([zTr[0],spk[0]],['z','phi(z)'])
ax.grid()
fig.savefig("poisson_trace.jpg")
@@ -111,29 +91,16 @@ to get the following output saved to disk:
-Notice that every time a spike is produced by the Poisson encoding cell, the trace
-increments by `0.5` -- the result of the `a_delta` hyper-parameter we set when
-crafting the model and simulation object -- and then exponentially decays in
-the absence of a spike (with the time constant of `tau_tr = 30` milliseconds).
+Notice that every time a spike is produced by the Poisson encoding cell, the trace increments by `0.5` -- the result of the `a_delta` hyper-parameter we set when crafting the model and simulation object -- and then exponentially decays in the absence of a spike (with the time constant of `tau_tr = 30` milliseconds).
-The variable trace can be further configured to filter signals in different ways
-if desired; specifically by manipulating its `decay_type` and `a_delta` arguments.
-Notably, if a piecewise-gated variable trace is desired (a very common choice
-in some neuronal circuit models), then all one would have to do is set `a_delta = 0`,
-yielding the following line in the model creation code earlier in this tutorial:
+The variable trace can be further configured to filter signals in different ways if desired; specifically by manipulating its `decay_type` and `a_delta` arguments. Notably, if a piecewise-gated variable trace is desired (a very common choice in some neuronal circuit models), then all one would have to do is set `a_delta = 0`, yielding the following line in the model creation code earlier in this tutorial:
```python
trace = VarTrace("tr0", n_units=1, tau_tr=30., a_delta=0., decay_type="exp")
```
-Running the same code from before but with the above alteration would yield the
-plot below:
+Running the same code from before but with the above alteration would yield the plot below:
-Notice that, this time, when a spike is emitted from the Poisson cell, the trace
-is "clamped" to the value of one and then exponentially decays. Such a trace
-configuration is useful if one requires the maintained trace to never increase
-beyond a value of one, preventing divergence or run-away values if a spike train
-is particularly dense and yielding friendlier values for biological learning
-rules.
+Notice that, this time, when a spike is emitted from the Poisson cell, the trace is "clamped" to the value of one and then exponentially decays. Such a trace configuration is useful if one requires the maintained trace to never increase beyond a value of one, preventing divergence or run-away values if a spike train is particularly dense and yielding friendlier values for biological learning rules.
diff --git a/docs/tutorials/theory.md b/docs/tutorials/theory.md
index 49bd9a6f..1622a6bf 100755
--- a/docs/tutorials/theory.md
+++ b/docs/tutorials/theory.md
@@ -1,48 +1,15 @@
# Theory and Design Motivation
## Cable Theory and Neural Compartments
-At its core, part of ngc-learn's internal design is inspired by (neural)
-cable theory ,
-where neuronal units, which are arranged in complex connectivity structures, are viewed
-as performing dendritic calculations (of varying complexity). In essence, a particular
-neuron integrates information from different input signal sources (for example,
-signals produced by other neurons), in often highly nonlinear ways through a
-complex dendritic tree.
+At its core, part of NGC-Learn's internal design is inspired by (neural) cable theory and neuronal compartment models [1], where neuronal units, which are arranged in complex connectivity structures, are viewed as performing dendritic calculations (of varying complexity). In essence, a particular neuron integrates information from different input signal sources (for example, signals produced by other neurons), in often highly nonlinear ways through a complex dendritic tree.
-Although modeling a complete neuronal system through the lens of cable theory is
-complex and intricate in of itself, ngc-learn is built with this direction in
-mind. ngc-learn starts with with the idea that a neuron (or a cluster of them)
-can be viewed as a node or nodal component -- specifically a type of "cell"
-component (in ngc-learn, many of these are component classes that end with the
-suffix `Cell`) -- and each bundle of synapses that connects pairs of nodes can
-be viewed as a cable -- specifically a "synapse" component (these component
-classes usually end with the suffix `Synapse` or `SynapticCable`)-- that performs
-some sort of transformation of its pre-synaptic signal (also treated as another
-component in terms of abstract simulation) and often differentiated by its form
-of plasticity. See the [Neurons](../modeling/neurons) specification for the base available
-neuronal cells and the [Synapses](../modeling/synapses) specification for the base available
-synaptic cables. Note that these two types of nodal components can be combined
-with other types such as [Input Encoders](../modeling/input_encoders) and [Operations](../modeling/other_ops) to build
-gradually more complex dynamical biomimetic/neuro-mimetic systems.
+Although modeling a complete neuronal system through the lens of cable theory and compartmental structures is complex and intricate in of itself, NGC-Learn is built with this direction in mind. NGC-Learn starts with the idea that a neuron (or a cluster of them) can be viewed as a node or nodal component -- specifically a type of "cell" component (in NGC-Learn, many of these are component classes that end with the suffix `Cell`). Each bundle of synapses that connects pairs of nodes can
+be viewed as a cable -- specifically a "synapse" component (these component classes usually end with the suffix `Synapse` or `SynapticCable`) -- which performs some sort of transformation of its pre-synaptic signal (also treated as another component in terms of abstract simulation); a synaptic bundle in NGC-Learn is often differentiated by its form of plasticity. See the [Neurons](../modeling/neurons) specification for the base available neuronal cells and the [Synapses](../modeling/synapses) specification for the base available synaptic cables. Note that these two types of nodal components can be combined with other types such as [Input Encoders](../modeling/input_encoders) and [Operations](../modeling/other_ops) to build gradually more complex dynamical biomimetic/neuro-mimetic/NeuroAI systems.
-Each neuronal cell component/node has multiple, different (named) "compartments",
-which are regions or slots within the node that other nodes can deposit
-information/signals into. These compartments allow a node to collect information
-from many different connected/related nodes and then decide how to combine these
-different signals in order calculate its own output activity (either in the form
-of a rate-coded firing rate or binary spikes) using the integration logic defined
-within its own specific `advanceState()` function. When a biomimetic system,
-composed of many of these nodes/components, is simulated over a period of time
-(processing some form of sensory input), its underlying simulation object
-(the `Controller`) calls the `advanceState()` routine of each constituent node,
-shifting that nodes internal time by one discrete step. The order in which the
-node `advanceState()` routines are called is governed by "run cycles", which are
-defined by the experimenter at the object initialization of the controller. For
-example, a user might want one set of nodes to first execute their internal step
-logic before another set is able to -- this could be done by specifying two
-distinct cycles in the order desired.
+Each neuronal cell component/node has multiple, different (named) "compartments", which are regions or slots within the node that other nodes can deposit information/signals into. These compartments allow a node to collect information from many different connected/related nodes and then decide how to combine these different signals in order calculate its own output activity (either in the form of a rate-coded firing rate or binary spikes) using the integration logic defined within its own specific `advance_state()` function. When a biomimetic system, composed of many of these nodes/components, is simulated over a period of time (processing some form of sensory input), its underlying simulation object (the `Context` controller) calls the `advance_state()` routine of each constituent node, shifting that nodes internal time by one discrete step. The order in which the node `advance_state()` routines are called is governed by "run cycles", which are defined by the experimenter at the object initialization of the controller. For example, a user might want one set of nodes to first execute their internal step logic before another set is able to -- this could be done by specifying two distinct cycles in the order desired.
-As a result, many nodes, and the synaptic cables that connect them, result in
-a simulated biomimetic system where each node is itself, in general, treated as
-a stateful computation even if we are processing inherently non-temporal data
-such as static images.
+As a result, many nodes, and the synaptic cables that connect them, result in a simulated biomimetic system where each node is itself, in general, treated as a stateful computation even if we are processing inherently non-temporal data such as static images.
+
+## References
+
+[1] Talevi, Alan, and Carolina Leticia Bellera. "Compartmental pharmacokinetic models." In ADME Processes in Pharmaceutical Sciences: Dosage, Design, and Pharmacotherapy, pp. 173-192. Cham: Springer Nature Switzerland, 2024.
diff --git a/history.txt b/history.txt
index 4ee30278..e03221af 100644
--- a/history.txt
+++ b/history.txt
@@ -19,8 +19,7 @@ History
* NGCGraph .compile() further tweaked to use an injection/clamping look-up
system to allow for dynamic changes to occur w/in a static graph compiled
simulated NGC system
- * Cable API slightly modified to increase flexiblity (demonstrations and
- tests modified to reflect updated API)
+ * Cable API slightly modified to increase flexibility (demonstrations and tests modified to reflect updated API)
* Demonstration 6 released showcasing how to use ngc-learn to construct/fit a
restricted Boltzmann machine
@@ -81,7 +80,15 @@ History
* basic unit-tests (pytest framework) integrated to support dev
* includes support for Intel's lava-nc emulator (several spiking/stdp components that play with ngc-lava)
-2.0.3
-— — — — — — — — -
- * Minor patch to point / depend on minor-patched ngcsimlib 1.0.1 (nudge to minor patched release)
- * Added wrapper `inverse_sigmoid` for original `inverse_logistic` routine in model_utils (for convenience)
+ 3.0.0
+ — — — — — — — — -
+ * revisions made / upgrades applied to framework/simulation back-end to integrate major version v2 of ngc-sim-lib
+ * new harmonium/RBM model-museum exhibit written and tutorial integrated
+ * clean-up of utils and new integration of mixture models for utils.density (Gaussian, Bernoulli, & exponential mixtures)
+ * addition of new BernoulliErrorCell (binary cross-entropy node); added leakyNoiseCell to support contus-time RNNs
+ * model museum (ngc-museum) and tutorials updated to reflect newest ngc-sim-lib format
+ * clean-up/upgrade of docs to reflect new v3 version (and patches)
+ * all model-museum (standard/main) exhibits revised/updated to operate with new v3 ngclearn / v2 ngcsimlib
+ * integration/addition of RL-SNN model in model-museum
+ * integration of full dynamics synapses -- alpha, exponential, and double-exponential synaptic cables
+ * new metrics/clean-up of metrics in utils.metric_utils (e.g., KL divs, etc.)
diff --git a/ngclearn/__init__.py b/ngclearn/__init__.py
index e457cb52..210f5121 100644
--- a/ngclearn/__init__.py
+++ b/ngclearn/__init__.py
@@ -1,16 +1,13 @@
import sys
-import subprocess
import pkg_resources
from pkg_resources import get_distribution
-#from pathlib import Path
-#from sys import argv
__version__ = get_distribution('ngclearn').version
if sys.version_info.minor < 10:
import warnings
warnings.warn(
- "Running ngclearn and jax in a python version prior to 3.10 may have unintended consequences. Compatability "
+ "Running ngclearn and jax in a python version prior to 3.10 may have unintended consequences. Compatibility "
"with python 3.8 is maintained to allow for lava-nc components and should only be used with those")
#required = {'ngcsimlib', 'jax', 'jaxlib'} ## list of core ngclearn dependencies
@@ -31,33 +28,17 @@
import ngcsimlib
-from ngcsimlib.context import Context
-from ngcsimlib.component import Component
+from ngclearn.utils import JointProcess, MethodProcess
+from ngcsimlib.context import Context, ContextObjectTypes
+from ngcsimlib import Component
from ngcsimlib.compartment import Compartment
-from ngcsimlib.resolver import resolver
-from ngcsimlib import utils as sim_utils
+from ngcsimlib import logger, get_config, provide_namespace
+from ngcsimlib.parser import compilable
+from ngcsimlib.operations import Summation, Product
-from ngclearn.utils.jaxProcess import JaxProcess
-from ngcsimlib.compilers.process import transition, Process
-
-
-from ngcsimlib import configure, preload_modules
-from ngcsimlib import logger
if not Path(argv[0]).name == "sphinx-build" or Path(argv[0]).name == "build.py":
if "readthedocs" not in argv[0]: ## prevent readthedocs execution of preload
+ from ngcsimlib import configure
configure()
logger.init_logging()
- from ngcsimlib.configManager import get_config
- pkg_config = get_config("packages")
- if pkg_config is not None:
- use_base_numpy = pkg_config.get("use_base_numpy", False)
- if use_base_numpy:
- import numpy as numpy
- else:
- from jax import numpy
- else:
- from jax import numpy
-
-
- preload_modules()
diff --git a/ngclearn/commands/__init__.py b/ngclearn/commands/__init__.py
deleted file mode 100644
index 74eb06b3..00000000
--- a/ngclearn/commands/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from ngcsimlib.commands import *
diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py
index af856c1a..d8c4dc67 100644
--- a/ngclearn/components/__init__.py
+++ b/ngclearn/components/__init__.py
@@ -2,6 +2,7 @@
## point to rate-coded cell component types
from .neurons.graded.rateCell import RateCell
+from .neurons.graded.leakyNoiseCell import LeakyNoiseCell
from .neurons.graded.gaussianErrorCell import GaussianErrorCell
from .neurons.graded.laplacianErrorCell import LaplacianErrorCell
from .neurons.graded.bernoulliErrorCell import BernoulliErrorCell
@@ -39,7 +40,7 @@
from .synapses.hebbian.BCMSynapse import BCMSynapse
from .synapses.STPDenseSynapse import STPDenseSynapse
from .synapses.exponentialSynapse import ExponentialSynapse
-from .synapses.doubleExpSynapse import DoupleExpSynapse
+from .synapses.doubleExpSynapse import DoubleExpSynapse
from .synapses.alphaSynapse import AlphaSynapse
## point to convolutional component types
@@ -55,10 +56,7 @@
from .synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
from .synapses.modulated.REINFORCESynapse import REINFORCESynapse
-## point to monitors
-from .monitor import Monitor
-
-## point to patched component types
+## point to patched component types
from .synapses.patched.patchedSynapse import PatchedSynapse
from .synapses.patched.staticPatchedSynapse import StaticPatchedSynapse
from .synapses.patched.hebbianPatchedSynapse import HebbianPatchedSynapse
diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py
deleted file mode 100644
index 8d7c71d0..00000000
--- a/ngclearn/components/base_monitor.py
+++ /dev/null
@@ -1,330 +0,0 @@
-import json
-
-from ngclearn import Component, Compartment, transition
-from ngclearn import numpy as np
-from ngcsimlib.utils import get_current_path
-from ngcsimlib.logger import warn, critical
-
-import matplotlib.pyplot as plt
-
-
-class Base_Monitor(Component):
- """
- An abstract base for monitors for both ngclearn and ngclava. Compartments
- wired directly into this component will have their value tracked during
- `advance_state` loops automatically.
-
- Note the monitor only works for compiled methods currently
-
-
- Using default window length:
- myMonitor << myComponent.myCompartment
-
- Using custom window length:
- myMonitor.watch(myComponent.myCompartment, customWindowLength)
-
- To get values out of the monitor either path to the stored value
- directly, or pass in a compartment directly. All
- paths are the same as their local path variable.
-
- Using a compartment:
- myMonitor.view(myComponent.myCompartment)
-
- Using a path:
- myMonitor.get_store(myComponent.myCompartment.path).value
-
- There can only be one monitor in existence at a time due to the way it
- interacts with resolvers and the compilers
- for ngclearn.
-
- Args:
- name: The name of the component.
-
- default_window_length: The default window length.
- """
- auto_resolve = False
-
- @staticmethod
- def build_reset(component):
- return Base_Monitor.reset(component)
-
- @staticmethod
- def build_advance_state(component):
- return Base_Monitor.record(component)
-
- @staticmethod
- def _record_internal(compartments):
- """
- A method to build the method to advance the stored values.
-
- Args:
- compartments: A list of compartments to store values
-
- Returns: The method to advance the stored values.
-
- """
- critical(
- "build_advance() is not defined on this monitor, use either the "
- "monitor found in ngclearn.components or "
- "ngclearn.components.lava (If using lava)")
-
- @transition(None, True)
- @staticmethod
- def reset(component):
- """
- A method to build the method to reset the stored values.
- Args:
- component: The component to resolve
-
- Returns: the reset resolver
- """
- output_compartments = []
- compartments = []
- for comp in component.compartments:
- output_compartments.append(comp.split("/")[-1] + "*store")
- compartments.append(comp.split("/")[-1])
-
- @staticmethod
- def _reset(**kwargs):
- return_vals = []
- for comp in compartments:
- current_store = kwargs[comp + "*store"]
- return_vals.append(np.zeros(current_store.shape))
- return return_vals if len(compartments) > 1 else return_vals[0]
-
- # pure func, output compartments, args, params, input compartments
- return _reset, output_compartments, [], [], output_compartments
-
- @transition(None, True)
- @staticmethod
- def record(component):
- output_compartments = []
- compartments = []
- for comp in component.compartments:
- output_compartments.append(comp.split("/")[-1] + "*store")
- compartments.append(comp.split("/")[-1])
-
- _advance = component._record_internal(compartments)
-
- return _advance, output_compartments, [], [], compartments + output_compartments
-
- def __init__(self, name, default_window_length=100, **kwargs):
- super().__init__(name, **kwargs)
- self.store = {}
- self.compartments = []
- self._sources = []
- self.default_window_length = default_window_length
-
- def __lshift__(self, other):
- if isinstance(other, Compartment):
- self.watch(other, self.default_window_length)
- else:
- warn("Only Compartments can be monitored not", type(other))
-
- def watch(self, compartment, window_length):
- """
- Sets the monitor to watch a specific compartment, for a specified
- window length.
-
- Args:
- compartment: the compartment object to monitor
-
- window_length: the window length
- """
- cs, end = self._add_path(compartment.path)
-
- if hasattr(compartment.value, "dtype"):
- dtype = compartment.value.dtype
- else:
- dtype = type(compartment.value)
-
- if hasattr(compartment.value, "shape"):
- shape = compartment.value.shape
- else:
- shape = (1,)
- new_comp = Compartment(np.zeros(shape, dtype=dtype))
- new_comp_store = Compartment(np.zeros((window_length, *shape), dtype=dtype))
-
- comp_key = "*".join(compartment.path.split("/"))
- store_comp_key = comp_key + "*store"
-
- new_comp._setup(self, comp_key)
- new_comp_store._setup(self, store_comp_key)
-
- new_comp << compartment
-
- cs[end] = new_comp_store
- setattr(self, comp_key, new_comp)
- setattr(self, store_comp_key, new_comp_store)
- self.compartments.append(new_comp.path)
- self._sources.append(compartment)
- # self._update_resolver()
-
- def halt(self, compartment):
- """
- Stops the monitor from watching a specific compartment. It is important
- to note that it does not stop previously compiled methods. It does not
- remove it from the stored values, so it can still be viewed.
- Args:
- compartment: The compartment object to stop watching
- """
- if compartment not in self._sources:
- return
-
- comp_key = "*".join(compartment.path.split("/"))
- store_comp_key = comp_key + "*store"
-
- self.compartments.remove(getattr(self, comp_key).path)
- self._sources.remove(compartment)
-
- delattr(self, comp_key)
- delattr(self, store_comp_key)
- self._update_resolver()
-
- def halt_all(self):
- """
- Stops the monitor from watching all compartments.
- """
- for compartment in self._sources:
- self.halt(compartment)
-
- # def _update_resolver(self):
- # output_compartments = []
- # compartments = []
- # for comp in self.compartments:
- # output_compartments.append(comp.split("/")[-1] + "*store")
- # compartments.append(comp.split("/")[-1])
- #
- # args = []
- # parameters = []
- #
- # add_component_resolver(self.__class__.__name__, "advance_state",
- # (self.build_advance(compartments),
- # output_compartments))
- # add_resolver_meta(self.__class__.__name__, "advance_state",
- # (args, parameters,
- # compartments + [o for o in output_compartments],
- # False))
-
- # add_component_resolver(self.__class__.__name__, "reset", (
- # self.build_reset(compartments), output_compartments))
- # add_resolver_meta(self.__class__.__name__, "reset",
- # (args, parameters, [o for o in output_compartments],
- # False))
-
- def _add_path(self, path):
- _path = path.split("/")[1:]
- end = _path.pop(-1)
-
- current_store = self.store
- for p in _path:
- if p not in current_store.keys():
- current_store[p] = {}
- current_store = current_store[p]
-
- return current_store, end
-
- def view(self, compartment):
- """
- Gets the value associated with the specified compartment
-
- Args:
- compartment: The compartment to extract the stored value of
-
- Returns: The stored value, None if not monitoring that compartment
-
- """
- _path = compartment.path.split("/")[1:]
- store = self.get_store(_path)
- return store.value if store is not None else store
-
- def get_store(self, path):
- current_store = self.store
- for p in path:
- if p not in current_store.keys():
- return None
- current_store = current_store[p]
- return current_store
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".json"
- _dict = {"sources": {}, "stores": {}}
- for key in self.compartments:
- n = key.split("/")[-1]
- _dict["sources"][key] = self.__dict__[n].value.shape
- _dict["stores"][key + "*store"] = self.__dict__[
- n + "*store"].value.shape
-
- with open(file_name, "w") as f:
- json.dump(_dict, f)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".json"
- with open(file_name, "r") as f:
- vals = json.load(f)
-
- for comp_path, shape in vals["stores"].items():
- compartment_path = comp_path.split("/")[-1]
- new_path = get_current_path() + "/" + "/".join(
- compartment_path.split("*")[-3:-1])
-
- cs, end = self._add_path(new_path)
-
- new_comp = Compartment(np.zeros(shape))
- new_comp._setup(self, compartment_path)
-
- cs[end] = new_comp
- setattr(self, compartment_path, new_comp)
-
- for comp_path, shape in vals['sources'].items():
- compartment_path = comp_path.split("/")[-1]
- new_comp = Compartment(np.zeros(shape))
- new_comp._setup(self, compartment_path)
-
- setattr(self, compartment_path, new_comp)
- self.compartments.append(new_comp.path)
-
- # self._update_resolver()
-
- def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
- vals = self.view(compartment)
-
- if n is None:
- n = vals.shape[2]
- if title is None:
- title = compartment.name.split("/")[0] + " " + compartment.display_name
-
- if ylabel is None:
- _ylabel = compartment.units
- elif ylabel:
- _ylabel = ylabel
- else:
- _ylabel = None
-
- if xlabel is None:
- _xlabel = "Time Steps"
- elif xlabel:
- _xlabel = xlabel
- else:
- _xlabel = None
-
- if ax is None:
- _ax = plt
- _ax.title(title)
- if _ylabel:
- _ax.ylabel(_ylabel)
- if _xlabel:
- _ax.xlabel(_xlabel)
- else:
- _ax = ax
- _ax.set_title(title)
- if _ylabel:
- _ax.set_ylabel(_ylabel)
- if _xlabel:
- _ax.set_xlabel(_xlabel)
-
- if plot_func is None:
- for k in range(n):
- _ax.plot(vals[:, 0, k])
- else:
- plot_func(vals[:, :, 0:n], ax=_ax)
diff --git a/ngclearn/components/input_encoders/__init__.py b/ngclearn/components/input_encoders/__init__.py
index b779226e..5d14d2ec 100644
--- a/ngclearn/components/input_encoders/__init__.py
+++ b/ngclearn/components/input_encoders/__init__.py
@@ -2,3 +2,4 @@
from .poissonCell import PoissonCell
from .latencyCell import LatencyCell
from .phasorCell import PhasorCell
+
diff --git a/ngclearn/components/input_encoders/bernoulliCell.py b/ngclearn/components/input_encoders/bernoulliCell.py
index d240de64..87b965fc 100755
--- a/ngclearn/components/input_encoders/bernoulliCell.py
+++ b/ngclearn/components/input_encoders/bernoulliCell.py
@@ -1,12 +1,9 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+import jax
+from typing import Union
class BernoulliCell(JaxComponent):
"""
@@ -29,51 +26,33 @@ class BernoulliCell(JaxComponent):
batch_size: batch size dimension of this cell (Default: 1)
"""
- def __init__(self, name, n_units, batch_size=1, **kwargs):
- super().__init__(name, **kwargs)
- #super(BernoulliCell, self).__init__(name, **kwargs)
+ def __init__(self, name: str, n_units: int, batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs):
+ super().__init__(name=name, key=key)
## Layer Size Setup
- self.batch_size = batch_size
- self.n_units = n_units
+ self.batch_size = Compartment(batch_size)
+ self.n_units = Compartment(n_units)
- # Compartments (state of the cell, parameters, will be updated through stateless calls)
- restVals = jnp.zeros((self.batch_size, self.n_units))
+ restVals = jnp.zeros((batch_size, n_units))
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
- @transition(output_compartments=["outputs", "tols", "key"])
- @staticmethod
- def advance_state(t, key, inputs, tols):
- ## NOTE: should `inputs` be checked if bounded to [0,1]?
- # print(key)
- # print(t)
- # print(inputs.shape)
- # print(tols.shape)
- # print("-----")
- key, *subkeys = random.split(key, 3)
- outputs = random.bernoulli(subkeys[0], p=inputs).astype(jnp.float32)
- # Updates time-of-last-spike (tols) variable:
- # output = s = binary spike vector
- # tols = current time-of-last-spike variable
- tols = (1. - outputs) * tols + (outputs * t)
- return outputs, tols, key
-
- @transition(output_compartments=["inputs", "outputs", "tols"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
- return restVals, restVals, restVals
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, key=self.key.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.key.set(data['key'])
+ @compilable
+ def advance_state(self, t):
+ key, subkey = random.split(self.key.get(), 2)
+ self.outputs.set(random.bernoulli(subkey, p=self.inputs.get()).astype(jnp.float32))
+ self.tols.set((1. - self.outputs.get()) * self.tols.get() + (self.outputs.get() * t))
+ self.key.set(key)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.outputs.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -101,22 +80,9 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = BernoulliCell("X", 9)
- print(X)
+
+ X.batch_size.set(10)
diff --git a/ngclearn/components/input_encoders/latencyCell.py b/ngclearn/components/input_encoders/latencyCell.py
index c7343cfa..c0708e3d 100755
--- a/ngclearn/components/input_encoders/latencyCell.py
+++ b/ngclearn/components/input_encoders/latencyCell.py
@@ -1,16 +1,13 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+import jax
+from typing import Union
from ngclearn.utils.model_utils import clamp_min, clamp_max
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@partial(jit, static_argnums=[5])
def _calc_spike_times_linear(data, tau, thr, first_spk_t, num_steps=1.,
@@ -146,90 +143,79 @@ class LatencyCell(JaxComponent):
batch_size: batch size dimension of this cell (Default: 1)
"""
- # Define Functions
def __init__(
- self, name, n_units, tau=1., threshold=0.01, first_spike_time=0., linearize=False, normalize=False,
- clip_spikes=False, num_steps=1., batch_size=1, **kwargs
+ self, name: str, n_units: int, tau: float = 1., threshold: float = 0.01, first_spike_time: float = 0.,
+ linearize: bool = False, normalize: bool = False, clip_spikes: bool = False, num_steps: float = 1.,
+ batch_size: int = 1, key: Union[jax.Array, None] = None, **kwargs
):
- super().__init__(name, **kwargs)
+ super().__init__(name=name, key=key)
## latency meta-parameters
- self.first_spike_time = first_spike_time
- self.tau = tau
- self.threshold = threshold
- self.linearize = linearize
- self.clip_spikes = clip_spikes
+ self.first_spike_time = Compartment(first_spike_time)
+ self.tau = Compartment(tau)
+ self.threshold = Compartment(threshold)
+ self.linearize = Compartment(linearize)
+ self.clip_spikes = Compartment(clip_spikes)
## normalize latency code s.t. final spike(s) occur w/in num_steps
- self.normalize = normalize
- self.num_steps = num_steps
+ self.normalize = Compartment(normalize)
+ self.num_steps = Compartment(num_steps)
## Layer Size Setup
- self.batch_size = batch_size
- self.n_units = n_units
+ self.batch_size = Compartment(batch_size)
+ self.n_units = Compartment(n_units)
## Compartment setup
- restVals = jnp.zeros((self.batch_size, self.n_units))
+ restVals = jnp.zeros((batch_size, n_units))
self.inputs = Compartment(restVals, display_name="Input Stimulus") # input compartment
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.mask = Compartment(restVals, display_name="Spike Time Mask")
self.clip_mask = Compartment(restVals, display_name="Clip Mask")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
self.targ_sp_times = Compartment(restVals, display_name="Target Spike Time", units="ms")
- #self.reset()
- @transition(output_compartments=["targ_sp_times", "clip_mask"])
- @staticmethod
- def calc_spike_times(
- linearize, tau, threshold, first_spike_time, num_steps, normalize, clip_spikes, inputs
- ):
- ## would call this function before processing a spike train (at start)
- data = inputs
- if clip_spikes:
- clip_mask = (data < threshold) * 1. ## find values under threshold
+ @compilable
+ def calc_spike_times(self):
+ if self.clip_spikes.get():
+ self.clip_mask.set((self.inputs.get() < self.threshold) * 1.)
else:
- clip_mask = data * 0. ## all values allowed to fire spikes
- if linearize: ## linearize spike time calculation
- stimes = _calc_spike_times_linear(data, tau, threshold,
- first_spike_time,
- num_steps, normalize)
- targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
- else: ## standard nonlinear spike time calculation
- stimes = _calc_spike_times_nonlinear(data, tau, threshold,
- first_spike_time,
- num_steps=num_steps,
- normalize=normalize)
- targ_sp_times = stimes #* calcEvent + targ_sp_times * (1. - calcEvent)
- return targ_sp_times, clip_mask
-
- @transition(output_compartments=["outputs", "tols", "mask", "targ_sp_times", "key"])
- @staticmethod
- def advance_state(t, dt, key, inputs, mask, clip_mask, targ_sp_times, tols):
- key, *subkeys = random.split(key, 2)
- data = inputs ## get sensory pattern data / features
- spikes, spk_mask = _extract_spike(targ_sp_times, t, mask) ## get spikes at t
-
- # Updates time-of-last-spike (tols) variable:
- # output = s = binary spike vector
- # tols = current time-of-last-spike variable
- tols = (1. - spikes) * tols + (spikes * t)
-
- spikes = spikes * (1. - clip_mask)
- return spikes, tols, spk_mask, targ_sp_times, key
-
- @transition(output_compartments=["inputs", "outputs", "tols", "mask", "clip_mask", "targ_sp_times"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
- return (restVals, restVals, restVals, restVals, restVals, restVals)
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, key=self.key.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.key.set(data['key'])
+ self.clip_mask.set(self.inputs.get() * 0.)
+
+ if self.linearize.get():
+ self.targ_sp_times.set(
+ _calc_spike_times_linear(self.inputs.get(),
+ self.tau.get(),
+ self.threshold.get(),
+ self.first_spike_time.get(),
+ self.num_steps.get(),
+ self.normalize.get()))
+ else:
+ self.targ_sp_times.set(
+ _calc_spike_times_nonlinear(self.inputs.get(),
+ self.tau.get(),
+ self.threshold.get(),
+ self.first_spike_time.get(),
+ self.num_steps.get(),
+ self.normalize.get()))
+
+
+ @compilable
+ def advance_state(self, t):
+ spikes, spike_mask = _extract_spike(self.targ_sp_times.get(), t, self.mask.get())
+ self.tols.set((1. - spikes) * self.tols.get() + (spikes * t))
+ self.outputs.set(spikes * (1. - self.clip_mask.get()))
+ self.mask.set(spike_mask)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size.get(), self.n_units.get()))
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.outputs.set(restVals)
+ self.tols.set(restVals)
+ self.mask.set(restVals)
+ self.clip_mask.set(restVals)
+ self.targ_sp_times.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -266,22 +252,10 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
X = LatencyCell("X", 9)
print(X)
+ print(X.calc_spike_times.compiled.code)
+ print(X.advance_state.compiled.code)
diff --git a/ngclearn/components/input_encoders/phasorCell.py b/ngclearn/components/input_encoders/phasorCell.py
index 9eaa16a7..ada0ddc8 100755
--- a/ngclearn/components/input_encoders/phasorCell.py
+++ b/ngclearn/components/input_encoders/phasorCell.py
@@ -1,13 +1,11 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+import jax
+from typing import Union
+from ngcsimlib.logger import info, warn
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class PhasorCell(JaxComponent):
"""
@@ -33,9 +31,9 @@ class PhasorCell(JaxComponent):
batch_size: batch size dimension of this cell (Default: 1)
"""
- # Define Functions
def __init__(
- self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs):
+ self, name, n_units, target_freq=63.75, batch_size=1, disable_phasor=False, **kwargs
+ ):
super().__init__(name, **kwargs)
## Phasor meta-parameters
@@ -44,7 +42,7 @@ def __init__(
## Layer Size Setup
self.batch_size = batch_size
self.n_units = n_units
- _key, *subkey = random.split(self.key.value, 3)
+ _key, *subkey = random.split(self.key.get(), 3)
self.key.set(_key)
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
@@ -62,7 +60,7 @@ def __init__(
# alpha = ((random.normal(subkey, self.angles.value.shape) * (jnp.sqrt(target_freq) / target_freq)) + 1)
# beta = random.poisson(subkey, lam=target_freq, shape=self.angles.value.shape) / target_freq
- self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.value.shape) / target_freq
+ self.base_scale = random.poisson(subkey[0], lam=target_freq, shape=self.angles.get().shape) / target_freq
self.disable_phasor = disable_phasor
def validate(self, dt=None, **validation_kwargs):
@@ -86,21 +84,27 @@ def validate(self, dt=None, **validation_kwargs):
)
return valid
- @transition(output_compartments=["outputs", "tols", "key", "angles"])
- @staticmethod
- def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale, disable_phasor):
+ # @transition(output_compartments=["outputs", "tols", "key", "angles"])
+ # @staticmethod
+ @compilable
+ def advance_state(self, t, dt, ):
+
+ inputs = self.inputs.get()
+ angles = self.angles.get()
+ tols = self.tols.get()
+
ms_per_second = 1000 # ms/s
- events_per_ms = target_freq / ms_per_second # e/s s/ms -> e/ms
+ events_per_ms = self.target_freq / ms_per_second # e/s s/ms -> e/ms
ms_per_event = 1 / events_per_ms # ms/e
time_step_per_event = ms_per_event / dt # ms/e * ts/ms -> ts / e
angle_per_event = 2 * jnp.pi # rad / e
angle_per_timestep = angle_per_event / time_step_per_event # rad / e
# * e/ts -> rad / ts
- key, *subkey = random.split(key, 3)
+ key, *subkey = random.split(self.key.get(), 3)
# scatter = random.uniform(subkey, angles.shape, minval=0.5,
# maxval=1.5) * base_scale
- scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * base_scale
+ scatter = ((random.normal(subkey[0], angles.shape) * 0.2) + 1) * self.base_scale
scattered_update = angle_per_timestep * scatter
scaled_scattered_update = scattered_update * inputs
@@ -109,27 +113,30 @@ def advance_state(t, dt, target_freq, key, inputs, angles, tols, base_scale, dis
updated_angles = jnp.where(updated_angles > angle_per_event,
updated_angles - angle_per_event,
updated_angles)
- if disable_phasor:
+ if self.disable_phasor:
outputs = inputs + 0
tols = tols * (1. - outputs) + t * outputs
- return outputs, tols, key, updated_angles
+ self.outputs.set(outputs)
+ self.tols.set(tols)
+ self.key.set(key)
+ self.angles.set(updated_angles)
- @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"])
- @staticmethod
- def reset(batch_size, n_units, key, target_freq):
- restVals = jnp.zeros((batch_size, n_units))
- key, *subkey = random.split(key, 3)
- return restVals, restVals, restVals, restVals, key
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, key=self.key.value)
+ # @transition(output_compartments=["inputs", "outputs", "tols", "angles", "key"])
+ # @staticmethod
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ key, *subkey = random.split(self.key.get(), 3)
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.key.set(data['key'])
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.outputs.set(restVals)
+ self.tols.set(restVals)
+ self.angles.set(restVals)
+ self.key.set(key)
@classmethod
def help(cls): ## component help function
@@ -157,19 +164,4 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if
- Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
diff --git a/ngclearn/components/input_encoders/poissonCell.py b/ngclearn/components/input_encoders/poissonCell.py
index 5f385951..810776ab 100644
--- a/ngclearn/components/input_encoders/poissonCell.py
+++ b/ngclearn/components/input_encoders/poissonCell.py
@@ -1,12 +1,11 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
+import jax
+from typing import Union
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngcsimlib import deprecate_args
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class PoissonCell(JaxComponent):
"""
@@ -32,8 +31,11 @@ class PoissonCell(JaxComponent):
"""
@deprecate_args(max_freq="target_freq")
- def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
- super().__init__(name, **kwargs)
+ def __init__(
+ self, name: str, n_units: int, target_freq: float = 63.75, batch_size: int = 1,
+ key: Union[jax.Array, None] = None, **kwargs
+ ):
+ super().__init__(name=name, key=key)
## Constrained Bernoulli meta-parameters
self.target_freq = target_freq ## maximum frequency (in Hertz/Hz)
@@ -48,55 +50,25 @@ def __init__(self, name, n_units, target_freq=63.75, batch_size=1, **kwargs):
self.outputs = Compartment(restVals, display_name="Spikes") # output compartment
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") # time of last spike
- def validate(self, dt=None, **validation_kwargs):
- valid = super().validate(**validation_kwargs)
- if dt is None:
- warn(f"{self.name} requires a validation kwarg of `dt`")
- return False
- ## check for unstable combinations of dt and target-frequency meta-params
- events_per_timestep = (dt/1000.) * self.target_freq ## compute scaled probability
- if events_per_timestep > 1.:
- valid = False
- warn(
- f"{self.name} will be unable to make as many temporal events as "
- f"requested! ({events_per_timestep} events/timestep) Unstable "
- f"combination of dt = {dt} and target_freq = {self.target_freq} "
- f"being used!"
- )
- return valid
-
- @transition(output_compartments=["outputs", "tols", "key"])
- @staticmethod
- def advance_state(t, dt, target_freq, key, inputs, tols):
- key, *subkeys = random.split(key, 2)
- pspike = inputs * (dt / 1000.) * target_freq
- eps = random.uniform(subkeys[0], inputs.shape, minval=0., maxval=1.,
+ @compilable
+ def advance_state(self, t, dt):
+ key, subkey = random.split(self.key.get(), 2)
+ pspike = self.inputs.get() * (dt / 1000.) * self.target_freq
+ eps = random.uniform(subkey, self.inputs.get().shape, minval=0., maxval=1.,
dtype=jnp.float32)
- outputs = (eps < pspike).astype(jnp.float32)
-
- # Updates time-of-last-spike (tols) variable:
- # output = s = binary spike vector
- # tols = current time-of-last-spike variable
- tols = (1. - outputs) * tols + (outputs * t)
- return outputs, tols, key
-
- @transition(output_compartments=["inputs", "outputs", "tols"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
- return restVals, restVals, restVals
-
- def save(self, directory, **kwargs):
- target_freq = (self.target_freq if isinstance(self.target_freq, float)
- else jnp.ones([[self.target_freq]]))
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, key=self.key.value, target_freq=target_freq)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.key.set(data['key'])
- self.target_freq = data['target_freq']
+
+ self.outputs.set((eps < pspike).astype(jnp.float32))
+ self.tols.set((1. - self.outputs.get()) * self.tols.get() + (self.outputs.get() * t))
+ self.key.set(key)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.outputs.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -126,22 +98,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if
- Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-
if __name__ == '__main__':
from ngcsimlib.context import Context
diff --git a/ngclearn/components/jaxComponent.py b/ngclearn/components/jaxComponent.py
index 0488c47c..8cffa49a 100755
--- a/ngclearn/components/jaxComponent.py
+++ b/ngclearn/components/jaxComponent.py
@@ -1,8 +1,13 @@
import time
+
+from typing import Union, Dict, Any
+import jax
+from jax import numpy as jnp
from jax import random
-#from ngclearn import resolver, Component, Compartment
-from ngcsimlib.component import Component
from ngcsimlib.compartment import Compartment
+from ngcsimlib import Component
+from ngclearn.utils import tensorstats
+
class JaxComponent(Component):
"""
@@ -14,12 +19,56 @@ class JaxComponent(Component):
key: PRNG key to control determinism of any underlying random values
associated with this cell
- directory: string indicating directory on disk to save component parameter
- values to
"""
- def __init__(self, name, key=None, directory=None, **kwargs):
- super().__init__(name, **kwargs)
- self.directory = directory
+ def __init__(self, name: str, key: Union[jax.Array, None] = None):
+ super().__init__(name)
self.key = Compartment(
random.PRNGKey(time.time_ns()) if key is None else key)
+
+ def save(self, directory: str):
+ """
+ The default save method for JaxComponents, it stores the values of all
+ non-targeted (non-wired) compartments into a .npz file.
+
+ Args:
+ directory: The directory to save the .npz file.
+ """
+ file_name = directory + "/" + self.name + ".npz"
+ data = {}
+ for comp_name, comp in self.compartments:
+ if not comp.targeted and comp.auto_save:
+ data[comp_name] = comp.get()
+ jnp.savez(file_name, **data)
+
+
+ def load(self, directory: str):
+ """
+ The default load method for JaxComponents, it is expected to work with
+ the default save. If the save method is modified this one will need to
+ be modified too.
+
+ Args:
+ directory: The directory to load the .npz file.
+ """
+ file_name = directory + "/" + self.name + ".npz"
+ data = jnp.load(file_name)
+ for comp_name, comp in self.compartments:
+ d = data.get(comp_name, None)
+ if d is not None:
+ comp.set(d)
+
+ def __repr__(self):
+ comps = [varname for varname in dir(self) if isinstance(getattr(self, varname), Compartment)]
+ maxlen = max(len(c) for c in comps) + 5
+ lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
+ for c in comps:
+ stats = tensorstats(getattr(self, c).get())
+ if stats is not None:
+ line = [f"{k}: {v}" for k, v in stats.items()]
+ line = ", ".join(line)
+ else:
+ line = "None"
+ lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
+ return lines
+
diff --git a/ngclearn/components/lava/__init__.py b/ngclearn/components/lava/__init__.py
deleted file mode 100644
index 962f843a..00000000
--- a/ngclearn/components/lava/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-## lava-compliant neuronal cells
-from .neurons.LIFCell import LIFCell
-## lava-compliant synapses
-from .synapses.staticSynapse import StaticSynapse
-from .synapses.traceSTDPSynapse import TraceSTDPSynapse
-from .synapses.hebbianSynapse import HebbianSynapse
-## Lava-compliant encoders/traces
-from .traces.gatedTrace import GatedTrace
-
-#monitor
-from .monitor import Monitor
\ No newline at end of file
diff --git a/ngclearn/components/lava/monitor.py b/ngclearn/components/lava/monitor.py
deleted file mode 100644
index aaabf8f8..00000000
--- a/ngclearn/components/lava/monitor.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from ngclearn.components.base_monitor import Base_Monitor
-
-
-class Monitor(Base_Monitor):
- """
- A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components
- """
- auto_resolve = False
-
-
- @staticmethod
- def build_advance(compartments):
- @staticmethod
- def _advance(**kwargs):
- return_vals = []
- for comp in compartments:
- new_val = kwargs[comp]
- current_store = kwargs[comp + "*store"]
- current_store[:-1] = current_store[1:]
- current_store[-1] = new_val
- return_vals.append(current_store)
- return return_vals if len(compartments) > 1 else return_vals[0]
-
- return _advance
-
- @staticmethod
- def build_advance_state(component):
- return super().build_advance_state(component)
-
- @staticmethod
- def build_reset(component):
- return super().build_reset(component)
diff --git a/ngclearn/components/lava/neurons/LIFCell.py b/ngclearn/components/lava/neurons/LIFCell.py
deleted file mode 100644
index e0ba3641..00000000
--- a/ngclearn/components/lava/neurons/LIFCell.py
+++ /dev/null
@@ -1,177 +0,0 @@
-from ngclearn import numpy as jnp
-from ngcsimlib.logger import info, warn
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-
-class LIFCell(Component): ## Lava-compliant leaky integrate-and-fire cell
- """
- A spiking cell based on (leaky) integrate-and-fire (LIF) neuronal dynamics.
- Note that this cell can be readily configured to pure integrate-and-fire
- dynamics as needed. Note that dynamics in this Lava-compliant cell are
- hard-coded to move according to Euler integration.
-
- The specific differential equation that characterize this cell
- is (for adjusting v, given current j, over time) is:
-
- | tau_m * dv/dt = gamma_d * (v_rest - v) + j * R
- | where R is the membrane resistance and v_rest is the resting potential
- | gamma_d is voltage decay -- 1 recovers LIF dynamics and 0 recovers IF dynamics
-
- | --- Cell Input Compartments: (Takes wired-in signals) ---
- | j_exc - excitatory electrical input
- | j_inh - inhibitory electrical input
- | --- Cell Output Compartments: (These signals are generated) ---
- | v - membrane potential/voltage state
- | s - emitted binary spikes/action potentials
- | rfr - (relative) refractory variable state
- | thr_theta - homeostatic/adaptive threshold increment state
-
- Args:
- name: the string name of this cell
-
- n_units: number of cellular entities (neural population size)
-
- dt: integration time constant (ms)
-
- tau_m: cell membrane time constant
-
- thr_theta_init: initialization kernel for threshold increment variable
-
- resist_m: membrane resistance value (Default: 1)
-
- thr: base value for adaptive thresholds that govern short-term
- plasticity (in milliVolts, or mV)
-
- v_rest: membrane resting potential (in mV)
-
- v_reset: membrane reset potential (in mV) -- upon occurrence of a spike,
- a neuronal cell's membrane potential will be set to this value
-
- v_decay: decay factor applied to voltage leak (Default: 1.); setting this
- to 0 mV results in pure integrate-and-fire (IF) dynamics
-
- tau_theta: homeostatic threshold time constant
-
- theta_plus: physical increment to be applied to any threshold value if
- a spike was emitted
-
- refract_time: relative refractory period time (ms; Default: 1 ms)
-
- thr_theta0: (DEPRECATED) initial conditions for voltage threshold
- """
-
- # Define Functions
- def __init__(self, name, n_units, dt, tau_m, thr_theta_init=None, resist_m=1.,
- thr=-52., v_rest=-65., v_reset=-60., v_decay=1., tau_theta=1e7,
- theta_plus=0.05, refract_time=5., thr_theta0=None, **kwargs):
- super().__init__(name, **kwargs)
-
- ## Cell dynamics setup
- self.dt = dt
- self.tau_m = tau_m ## membrane time constant
- self.R_m = resist_m ## resistance value
- if kwargs.get("R_m") is not None:
- warn("The argument `R_m` being used is deprecated.")
- self.Rscale = kwargs.get("R_m")
- self.v_rest = v_rest # mV
- self.v_reset = v_reset # mV (milli-volts)
- self.v_decay = v_decay
- ## basic asserts to prevent neuronal dynamics breaking...
- assert (self.v_decay * self.dt / self.tau_m) <= 1.
- assert self.R_m > 0.
- self.tau_theta = tau_theta ## threshold time constant # ms (0 turns off)
- self.theta_plus = theta_plus ## threshold increment
- self.refract_T = refract_time ## refractory period # ms
- self.thr = thr ## (fixed) base value for threshold # mV
- self.thr_theta_init = thr_theta_init
- self.thr_theta0 = thr_theta0 ## initial jittered adaptive threshold values
-
- ## Component size setup
- self.batch_size = 1
- self.n_units = n_units
-
- ## Compartment setup
- restVals = jnp.zeros((self.batch_size, self.n_units))
- self.j_exc = Compartment(restVals)
- self.j_inh = Compartment(restVals)
- self.v = Compartment(restVals + self.v_rest)
- self.s = Compartment(restVals)
- self.rfr = Compartment(restVals + self.refract_T)
- self.thr_theta = Compartment(None)
-
- if thr_theta0 is not None:
- warn("The argument `thr_theta0` being used is deprecated.")
- self._init(thr_theta0)
- else:
- if self.thr_theta_init is None:
- info(self.name, "is using default threshold variable initializer!")
- self.thr_theta_init = {"dist": "constant", "value": 0.}
- thr_theta0 = initialize_params(None, self.thr_theta_init, (1, self.n_units))
- self._init(thr_theta0)
-
- def _init(self, thr_theta0):
- self.thr_theta.set(thr_theta0)
-
- @transition(output_compartments=["v", "s", "rfr", "thr_theta"])
- @staticmethod
- def advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta,
- theta_plus, j_exc, j_inh, v, s, rfr, thr_theta):
- #j = j * (tau_m/dt) ## scale electrical current
- j = j_exc - j_inh ## sum the excitatory and inhibitory input channels
- mask = (rfr >= refract_T) * 1. #numpy.greater_equal(rfr, refract_T) * 1.
- ## update voltage / membrane potential
- ### note: the ODE is a bit differently formulated here than usual
- dv_dt = (v_rest - v) * v_decay * (dt/tau_m) + ((j * R_m) * mask)
- v = v + dv_dt ### hard-coded Euler integration
- ## obtain action potentials/spikes
- s = (v > (thr + thr_theta)) * 1. #numpy.greater_equal(v, thr + thr_theta) * 1.
- ## update refractory variables
- rfr = (rfr + dt) * (1. - s)
- ## perform hyper-polarization of neuronal cells
- v = v * (1. - s) + s * v_reset
- ## update adaptive threshold variables
- theta_decay = jnp.exp(-dt/tau_theta)
- thr_theta = thr_theta * theta_decay + s * theta_plus
- ## update time-of-last-spike
- #tols = (1. - s) * tols + (s * t)
- return v, s, rfr, thr_theta #, tols
-
- @transition(output_compartments=["j_exc", "j_inh", "v", "s", "rfr"])
- @staticmethod
- def reset(batch_size, n_units, v_rest, refract_T):
- restVals = jnp.zeros((batch_size, n_units))
- j_exc = restVals #+ 0
- j_inh = restVals #+ 0
- v = restVals + v_rest
- s = restVals #+ 0
- rfr = restVals + refract_T
- return j_exc, j_inh, v, s, rfr #, tols
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- threshold_theta=self.thr_theta.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self._init( data['threshold_theta'] )
-
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/lava/neurons/__init__.py b/ngclearn/components/lava/neurons/__init__.py
deleted file mode 100644
index e28ed0f8..00000000
--- a/ngclearn/components/lava/neurons/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .LIFCell import LIFCell
diff --git a/ngclearn/components/lava/synapses/__init__.py b/ngclearn/components/lava/synapses/__init__.py
deleted file mode 100644
index bd7f9ea3..00000000
--- a/ngclearn/components/lava/synapses/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .staticSynapse import StaticSynapse
-from .hebbianSynapse import HebbianSynapse
-from .traceSTDPSynapse import TraceSTDPSynapse
diff --git a/ngclearn/components/lava/synapses/hebbianSynapse.py b/ngclearn/components/lava/synapses/hebbianSynapse.py
deleted file mode 100644
index c06a3792..00000000
--- a/ngclearn/components/lava/synapses/hebbianSynapse.py
+++ /dev/null
@@ -1,159 +0,0 @@
-from ngclearn import numpy as jnp
-from ngcsimlib.logger import info, warn
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-
-class HebbianSynapse(Component): ## Lava-compliant Hebbian synapse
- """
- A synaptic cable that adjusts its efficacies via a two-factor Hebbian adjustment rule. This is a Lava-compliant
- synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent.
-
- | --- Synapse Input Compartments: (Takes wired-in signals) ---
- | inputs - input (pre-synaptic) stimulus
- | --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) ---
- | pre - pre-synaptic signal to drive first term of Hebbian update
- | post - post-synaptic signal to drive 2nd term of Hebbian update
- | eta - global learning rate (unidimensional/scalar value)
- | --- Synapse Output Compartments: (These signals are generated) ---
- | outputs - transformed (post-synaptic) signal
- | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)
-
- Args:
- name: the string name of this cell
-
- dt: integration time constant (ms)
-
- resist_scale: a fixed scaling factor to apply to synaptic transform
- (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
-
- weight_init: a kernel to drive initialization of this synaptic cable's values;
- typically a tuple with 1st element as a string calling the name of
- initialization to use
-
- shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
- with number of inputs by number of outputs)
-
- eta: global learning rate
-
- w_decay: degree to which (L2) synaptic weight decay is applied to the
- computed Hebbian adjustment (Default: 0); note that decay is not
- applied to any configured biases
-
- w_bound: maximum weight to softly bound this cable's value matrix to; if
- set to 0, then no synaptic value bounding will be applied
-
- weights: matrix of synaptic weight values to initialize this synapse
- component to
-
- Rscale: DEPRECATED argument (maps to resist_scale)
- """
-
- # Define Functions
- def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None,
- eta=0., w_decay=0., w_bound=1., weights=None, **kwargs):
- super().__init__(name, **kwargs)
-
- ## synaptic plasticity properties and characteristics
- self.weight_init = weight_init
- self.shape = shape
- self.batch_size = 1
-
- self.dt = dt
- self.Rscale = resist_scale
- if kwargs.get("Rscale") is not None:
- warn("The argument `Rscale` being used is deprecated.")
- self.Rscale = kwargs.get("Rscale")
- self.w_bounds = w_bound
- self.w_decay = w_decay ## synaptic decay
- self.eta0 = eta
-
- self.inputs = Compartment(None)
- self.outputs = Compartment(None)
- self.pre = Compartment(None)
- self.post = Compartment(None)
- self.weights = Compartment(None)
- self.eta = Compartment(jnp.ones((1, 1)) * eta)
-
- if weights is not None:
- warn("The argument `weights` being used is deprecated.")
- self._init(weights)
- else:
- assert self.shape is not None ## if using an init, MUST have shape
- if self.weight_init is None:
- info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "uniform", "amin": 0.025,
- "amax": 0.8}
- weights = initialize_params(None, self.weight_init, self.shape)
- self._init(weights)
-
- def _init(self, weights):
- self.rows = weights.shape[0]
- self.cols = weights.shape[1]
-
- ## pre-computed empty zero pads
- preVals = jnp.zeros((self.batch_size, self.rows))
- postVals = jnp.zeros((self.batch_size, self.cols))
- ## Compartments
- self.inputs.set(preVals)
- self.outputs.set(postVals)
- self.pre.set(preVals)
- self.post.set(postVals)
- self.weights.set(weights)
-
- @transition(output_compartments=["outputs", "weights"])
- @staticmethod
- def advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights,
- pre, post, eta):
- outputs = jnp.matmul(inputs, weights) * Rscale
- ########################################################################
- ## Run one step of 2-factor Hebbian adaptation online
- dW = jnp.matmul(pre.T, post)
- #db = jnp.sum(_post, axis=0, keepdims=True)
- ## reformulated bounding flag to be linear algebraic
- flag = (w_bounds > 0.) * 1.
- dW = (dW * (w_bounds - jnp.abs(weights))) * flag + (dW) * (1. - flag)
- ## add small amount of synaptic decay
- weights = weights + (dW - weights * w_decay) * eta
- weights = jnp.clip(weights, 0., w_bounds)
- ########################################################################
- return outputs, weights
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "eta"])
- @staticmethod
- def reset(batch_size, rows, cols, eta0):
- preVals = jnp.zeros((batch_size, rows))
- postVals = jnp.zeros((batch_size, cols))
- return (
- preVals, # inputs
- postVals, # outputs
- preVals, # pre
- postVals, # post
- jnp.ones((1,1)) * eta0
- )
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self._init( data['weights'] )
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/lava/synapses/staticSynapse.py b/ngclearn/components/lava/synapses/staticSynapse.py
deleted file mode 100755
index 20f39ebe..00000000
--- a/ngclearn/components/lava/synapses/staticSynapse.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from ngclearn import numpy as jnp
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info, warn
-from ngclearn.components.synapses.hebbian import TraceSTDPSynapse
-from ngclearn.utils import tensorstats
-
-class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse
- """
- A static (dense) synaptic cable; no form of synaptic evolution/adaptation is in-built to this component. This a
- Lava-compliant version of the static synapse component from the synapses sub-package of components.
-
- | --- Synapse Input Compartments: (Takes wired-in signals) ---
- | inputs - input (pre-synaptic) stimulus
- | --- Synapse Output Compartments: .set()ese signals are generated) ---
- | outputs - transformed (post-synaptic) signal
- | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)
-
- Args:
- name: the string name of this cell
-
- dt: integration time constant (ms)
-
- weight_init: a kernel to drive initialization of this synaptic cable's values;
- typically a tuple with 1st element as a string calling the name of
- initialization to use
-
- shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
- with number of inputs by number of outputs)
-
- resist_scale: a fixed scaling factor to apply to synaptic transform
- (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
-
- Rscale: DEPRECATED argument (maps to resist_scale)
-
- weights: a provided, externally created weight value matrix that will
- be used instead of an auto-init call
- """
-
- # Define Functions
- def __init__(self, name, dt, weight_init=None, shape=None, resist_scale=1.,
- weights=None, **kwargs):
- super().__init__(name, **kwargs)
-
- ## synaptic plasticity properties and characteristics
- self.batch_size = 1
- self.dt = dt
- self.Rscale = resist_scale
- if kwargs.get("Rscale") is not None:
- warn("The argument `Rscale` being used is deprecated.")
- self.Rscale = kwargs.get("Rscale")
- self.shape = shape
- self.weight_init = weight_init
-
- self.inputs = Compartment(None)
- self.outputs = Compartment(None)
- self.weights = Compartment(None)
-
- if weights is not None:
- warn("The argument `weights` being used is deprecated.")
- self._init(weights)
- else:
- assert self.shape is not None ## if using an init, MUST have shape
- if self.weight_init is None:
- info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "uniform", "amin": 0.025,
- "amax": 0.8}
- weights = initialize_params(None, self.weight_init, self.shape)
- self._init(weights)
-
- def _init(self, weights):
- self.rows = weights.shape[0]
- self.cols = weights.shape[1]
- ## pre-computed empty zero pads
- preVals = jnp.zeros((self.batch_size, self.rows))
- postVals = jnp.zeros((self.batch_size, self.cols))
- ## Compartments
- self.inputs.set(preVals)
- self.outputs.set(postVals)
- self.weights.set(weights)
-
- @transition(output_compartments=["outputs"])
- @staticmethod
- def advance_state(dt, Rscale, inputs, weights):
- outputs = jnp.matmul(inputs, weights) * Rscale
- return outputs
-
- @transition(output_compartments=["inputs", "outputs"])
- @staticmethod
- def reset(batch_size, rows, cols):
- preVals = jnp.zeros((batch_size, rows))
- postVals = jnp.zeros((batch_size, cols))
- return (
- preVals, # inputs
- postVals, # outputs
- )
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self._init( data['weights'] )
-
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/lava/synapses/traceSTDPSynapse.py b/ngclearn/components/lava/synapses/traceSTDPSynapse.py
deleted file mode 100755
index 23a3287d..00000000
--- a/ngclearn/components/lava/synapses/traceSTDPSynapse.py
+++ /dev/null
@@ -1,181 +0,0 @@
-from ngclearn import numpy as jnp
-from ngcsimlib.logger import info, warn
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-
-class TraceSTDPSynapse(Component): ## Lava-compliant trace-STDP synapse
- """
- A synaptic cable that adjusts its efficacies via trace-based form of spike-timing-dependent plasticity (STDP).
- This is a Lava-compliant synaptic cable that adjusts with a hard-coded form of (stochastic) gradient ascent.
-
- | --- Synapse Input Compartments: (Takes wired-in signals) ---
- | inputs - input (pre-synaptic) stimulus
- | --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) ---
- | pre - pre-synaptic spike(s) to drive STDP update
- | x_pre - pre-synaptic trace value(s) to drive STDP update
- | post - post-synaptic spike(s) to drive STDP update
- | x_post - post-synaptic trace value(s) to drive STDP update
- | eta - global learning rate (unidimensional/scalar value)
- | --- Synapse Output Compartments: (These signals are generated) ---
- | outputs - transformed (post-synaptic) signal
- | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)
-
- Args:
- name: the string name of this cell
-
- dt: integration time constant (ms)
-
- resist_scale: a fixed scaling factor to apply to synaptic transform
- (Default: 1.), i.e., yields: out = ((W * Rscale) * in) + b
-
- weight_init: a kernel to drive initialization of this synaptic cable's values;
- typically a tuple with 1st element as a string calling the name of
- initialization to use
-
- shape: tuple specifying shape of this synaptic cable (usually a 2-tuple
- with number of inputs by number of outputs)
-
- Aplus: strength of long-term potentiation (LTP)
-
- Aminus: strength of long-term depression (LTD)
-
- eta: global learning rate (default: 1)
-
- w_decay: degree to which (L2) synaptic weight decay is applied to the
- computed Hebbian adjustment (Default: 0); note that decay is not
- applied to any configured biases
-
- w_bound: maximum weight to softly bound this cable's value matrix to; if
- set to 0, then no synaptic value bounding will be applied
-
- preTrace_target: controls degree of pre-synaptic disconnect, i.e., amount of decay
- (higher -> lower synaptic values)
-
- weights: matrix of synaptic weight values to initialize this synapse
- component to
-
- Rscale: DEPRECATED argument (maps to resist_scale)
- """
-
- # Define Functions
- def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None,
- Aplus=0.01, Aminus=0.001, eta=1., w_decay=0., w_bound=1.,
- preTrace_target=0., weights=None, **kwargs):
- super().__init__(name, **kwargs)
-
- ## synaptic plasticity properties and characteristics
- self.weight_init = weight_init
- self.shape = shape
- self.dt = dt
- self.Rscale = resist_scale
- if kwargs.get("Rscale") is not None:
- warn("The argument `Rscale` being used is deprecated.")
- self.Rscale = kwargs.get("Rscale")
- self.w_bounds = w_bound
- self.w_decay = w_decay ## synaptic decay
- self.eta0 = eta
- self.Aplus = Aplus
- self.Aminus = Aminus
- self.x_tar = preTrace_target
-
- ## Component size setup
- self.batch_size = 1
-
- self.eta = Compartment(jnp.ones((1, 1)) * eta)
-
- self.inputs = Compartment(None)
- self.outputs = Compartment(None)
- self.pre = Compartment(None) ## pre-synaptic spike
- self.x_pre = Compartment(None) ## pre-synaptic trace
- self.post = Compartment(None) ## post-synaptic spike
- self.x_post = Compartment(None) ## post-synaptic trace
- self.weights = Compartment(None)
-
- if weights is not None:
- warn("The argument `weights` being used is deprecated.")
- self._init(weights)
- else:
- assert self.shape is not None ## if using an init, MUST have shape
- if self.weight_init is None:
- info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "uniform", "amin": 0.025,
- "amax": 0.8}
- weights = initialize_params(None, self.weight_init, self.shape)
- self._init(weights)
-
- def _init(self, weights):
- self.rows = weights.shape[0]
- self.cols = weights.shape[1]
- ## pre-computed empty zero pads
- preVals = jnp.zeros((self.batch_size, self.rows))
- postVals = jnp.zeros((self.batch_size, self.cols))
- ## Compartments
- self.inputs.set(preVals)
- self.outputs.set(postVals)
- self.pre.set(preVals) ## pre-synaptic spike
- self.x_pre.set(preVals) ## pre-synaptic trace
- self.post.set(postVals) ## post-synaptic spike
- self.x_post.set(postVals) ## post-synaptic trace
- self.weights.set(weights)
-
- @transition(output_compartments=["outputs", "weights"])
- @staticmethod
- def advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar,
- inputs, weights, pre, x_pre, post, x_post, eta):
- outputs = jnp.matmul(inputs, weights) * Rscale
- ########################################################################
- ## Run one step of STDP online
- dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus)
- dWpre = -jnp.matmul(pre.T, x_post * Aminus)
- dW = dWpost + dWpre
- ## reformulated bounding flag to be linear algebraic
- flag = (w_bounds > 0.) * 1.
- dW = (dW * (w_bounds - jnp.abs(weights))) * flag + (dW) * (1. - flag)
- ## physically adjust synapses
- weights = weights + (dW - weights * w_decay) * eta
- #weights = weights + (dW - weights * w_decay) * dt/tau_w ## ODE format
- weights = jnp.clip(weights, 0., w_bounds)
- ########################################################################
- return outputs, weights
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "x_pre", "x_post", "eta"])
- @staticmethod
- def reset(batch_size, rows, cols, eta0):
- preVals = jnp.zeros((batch_size, rows))
- postVals = jnp.zeros((batch_size, cols))
- return (
- preVals, # inputs
- postVals, # outputs
- preVals, # pre
- postVals, # post
- preVals, # x_pre
- postVals, # x_post
- jnp.ones((1, 1)) * eta0
- )
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self._init( data['weights'] )
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/lava/traces/__init__.py b/ngclearn/components/lava/traces/__init__.py
deleted file mode 100755
index 5dc901bf..00000000
--- a/ngclearn/components/lava/traces/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .gatedTrace import GatedTrace
diff --git a/ngclearn/components/lava/traces/gatedTrace.py b/ngclearn/components/lava/traces/gatedTrace.py
deleted file mode 100755
index 941fe061..00000000
--- a/ngclearn/components/lava/traces/gatedTrace.py
+++ /dev/null
@@ -1,69 +0,0 @@
-from ngclearn import numpy as jnp
-from ngcsimlib.logger import info, warn
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-
-class GatedTrace(Component): ## gated/piecewise low-pass filter
- """
- A gated/piecewise variable trace (filter). This is a Lava-compliant trace component.
-
- | --- Cell Input Compartments: (Takes wired-in signals) ---
- | inputs - input (takes wired-in external signals)
- | --- Cell Output Compartments: (These signals are generated) ---
- | trace - traced value signal
-
- Args:
- name: the string name of this operator
-
- n_units: number of calculating entities or units
-
- dt: integration time constant (ms)
-
- tau_tr: trace time constant (in milliseconds, or ms)
- """
-
- # Define Functions
- def __init__(self, name, n_units, dt, tau_tr, **kwargs):
- super().__init__(name, **kwargs)
-
- ## trace control coefficients
- self.dt = dt
- self.tau_tr = tau_tr ## trace time constant
-
- ## Layer size setup
- self.batch_size = 1
- self.n_units = n_units
-
- restVals = jnp.zeros((self.batch_size, self.n_units))
- self.inputs = Compartment(restVals) # input compartment
- self.trace = Compartment(restVals)
-
- @transition(output_compartments=["trace"])
- @staticmethod
- def advance_state(dt, tau_tr, inputs, trace):
- trace = (trace * (1. - dt/tau_tr)) * (1. - inputs) + inputs
- return trace
-
- @transition(output_compartments=["inputs", "trace"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
- return restVals, restVals
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py
deleted file mode 100644
index 3b373cf3..00000000
--- a/ngclearn/components/monitor.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from ngclearn.components.base_monitor import Base_Monitor
-from ngclearn import transition
-
-class Monitor(Base_Monitor):
- """
- A jax implementation of `Base_Monitor`. Designed to be used with all
- non-lava ngclearn components
- """
- auto_resolve = False
-
- @staticmethod
- def _record_internal(compartments):
- @staticmethod
- def _record(**kwargs):
- return_vals = []
- for comp in compartments:
- new_val = kwargs[comp]
- current_store = kwargs[comp + "*store"]
- current_store = current_store.at[:-1].set(current_store[1:])
- current_store = current_store.at[-1].set(new_val)
- return_vals.append(current_store)
- return return_vals if len(compartments) > 1 else return_vals[0]
- return _record
-
- @staticmethod
- def build_advance_state(component):
- return super().build_advance_state(component)
-
- @staticmethod
- def build_reset(component):
- return super().build_reset(component)
diff --git a/ngclearn/components/neurons/__init__.py b/ngclearn/components/neurons/__init__.py
index e7165d7e..564577cd 100644
--- a/ngclearn/components/neurons/__init__.py
+++ b/ngclearn/components/neurons/__init__.py
@@ -1,5 +1,6 @@
## point to rate-coded cell componet types
from .graded.rateCell import RateCell
+from .graded.leakyNoiseCell import LeakyNoiseCell
from .graded.gaussianErrorCell import GaussianErrorCell
from .graded.laplacianErrorCell import LaplacianErrorCell
from .graded.bernoulliErrorCell import BernoulliErrorCell
@@ -15,3 +16,4 @@
from .spiking.izhikevichCell import IzhikevichCell
from .spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
from .spiking.RAFCell import RAFCell
+
diff --git a/ngclearn/components/neurons/graded/__init__.py b/ngclearn/components/neurons/graded/__init__.py
index bde64b39..2974d91f 100644
--- a/ngclearn/components/neurons/graded/__init__.py
+++ b/ngclearn/components/neurons/graded/__init__.py
@@ -1,6 +1,8 @@
-## point to rate-coded cell componet types
+## point to rate-coded cell component types
from .rateCell import RateCell
+from .leakyNoiseCell import LeakyNoiseCell
from .gaussianErrorCell import GaussianErrorCell
from .laplacianErrorCell import LaplacianErrorCell
from .bernoulliErrorCell import BernoulliErrorCell
from .rewardErrorCell import RewardErrorCell
+
diff --git a/ngclearn/components/neurons/graded/bernoulliErrorCell.py b/ngclearn/components/neurons/graded/bernoulliErrorCell.py
index 6bf0ebe6..978aa1ce 100755
--- a/ngclearn/components/neurons/graded/bernoulliErrorCell.py
+++ b/ngclearn/components/neurons/graded/bernoulliErrorCell.py
@@ -1,14 +1,16 @@
-from ngclearn import resolver, Component, Compartment
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
-from ngclearn.utils import tensorstats
from ngclearn.utils.model_utils import sigmoid, d_sigmoid
-from ngcsimlib.compilers.process import transition
+
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class BernoulliErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
"""
A simple (non-spiking) Bernoulli error cell - this is a fixed-point solution
- of a mismatch signal. Specifically, this cell operates as a factorized multivariate
+ of a mismatch signal. Specifically, this cell operates as a factorized multivariate
Bernoulli distribution.
| --- Cell Input Compartments: ---
@@ -59,14 +61,20 @@ def __init__(self, name, n_units, batch_size=1, input_logits=False, shape=None,
self.modulator = Compartment(restVals + 1.0) # to be set/consumed
self.mask = Compartment(restVals + 1.0)
- @transition(output_compartments=["dp", "dtarget", "L", "mask"])
- @staticmethod
- def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bernoulli error cell output
+ # @transition(output_compartments=["dp", "dtarget", "L", "mask"])
+ @compilable
+ def advance_state(self, dt): ## compute Bernoulli error cell output
+ # Get the variables
+ p = self.p.get()
+ target = self.target.get()
+ modulator = self.modulator.get()
+ mask = self.mask.get()
+
# Moves Bernoulli error cell dynamics one step forward. Specifically, this routine emulates the error unit
# behavior of the local cost functional
eps = 0.0001
_p = p
- if input_logits: ## convert from "logits" to probs via sigmoidal link function
+ if self.input_logits: ## convert from "logits" to probs via sigmoidal link function
_p = sigmoid(p)
_p = jnp.clip(_p, eps, 1. - eps) ## post-process to prevent div by 0
x = target
@@ -78,7 +86,7 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
log_p = jnp.log(_p) ## ln(p)
log_one_min_p = jnp.log(one_min_p) ## ln(1 - p)
L = jnp.sum(log_p * x + log_one_min_p * one_min_x) ## Bern LL
- if input_logits:
+ if self.input_logits:
dL_dp = x - _p ## d(Bern LL)/dp where _p = sigmoid(p)
else:
dL_dp = x/(_p) - one_min_x/one_min_p ## d(Bern LL)/dp
@@ -89,14 +97,21 @@ def advance_state(dt, p, target, modulator, mask, input_logits): ## compute Bern
dp = dp * modulator * mask ## NOTE: how does mask apply to a multivariate Bernoulli?
dtarget = dL_dx * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
- return dp, dtarget, jnp.squeeze(L), mask
-
- @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
- @staticmethod
- def reset(batch_size, shape): ## reset core components/statistics
- _shape = (batch_size, shape[0])
- if len(shape) > 1:
- _shape = (batch_size, shape[0], shape[1], shape[2])
+
+ # Set state
+ # dp, dtarget, jnp.squeeze(L), mask
+ self.dp.set(dp)
+ self.dtarget.set(dtarget)
+ self.L.set(jnp.squeeze(L))
+ self.mask.set(mask)
+
+
+ # @transition(output_compartments=["dp", "dtarget", "target", "p", "modulator", "L", "mask"])
+ @compilable
+ def reset(self): ## reset core components/statistics
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
restVals = jnp.zeros(_shape) ## "rest"/reset values
dp = restVals
dtarget = restVals
@@ -105,7 +120,16 @@ def reset(batch_size, shape): ## reset core components/statistics
modulator = restVals + 1. ## reset modulator signal
L = 0. #jnp.zeros((1, 1)) ## rest loss
mask = jnp.ones(_shape) ## reset mask
- return dp, dtarget, target, p, modulator, L, mask
+
+ # Set compartment
+ self.dp.set(dp)
+ self.dtarget.set(dtarget)
+ self.target.set(target)
+ self.p.set(p)
+ self.modulator.set(modulator)
+ self.L.set(L)
+ self.mask.set(mask)
+
@classmethod
def help(cls): ## component help function
@@ -135,20 +159,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py
index 29b5f267..776dad46 100755
--- a/ngclearn/components/neurons/graded/gaussianErrorCell.py
+++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py
@@ -1,8 +1,9 @@
-from ngclearn import resolver, Component, Compartment
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
-from ngclearn.utils import tensorstats
-from ngcsimlib.compilers.process import transition
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class GaussianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
"""
@@ -71,9 +72,15 @@ def eval_log_density(target, mu, Sigma):
log_density = -jnp.sum(jnp.square(_dmu)) * (0.5 / Sigma)
return log_density
- @transition(output_compartments=["dmu", "dtarget", "dSigma", "L", "mask"])
- @staticmethod
- def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian error cell output
+ @compilable
+ def advance_state(self, dt): ## compute Gaussian error cell output
+ # Get the variables
+ mu = self.mu.get()
+ target = self.target.get()
+ Sigma = self.Sigma.get()
+ modulator = self.modulator.get()
+ mask = self.mask.get()
+
# Moves Gaussian cell dynamics one step forward. Specifically, this routine emulates the error unit
# behavior of the local cost functional:
# FIXME: Currently, below does: L(targ, mu) = -(1/(2*sigma)) * ||targ - mu||^2_2
@@ -90,24 +97,39 @@ def advance_state(dt, mu, target, Sigma, modulator, mask): ## compute Gaussian e
dmu = dmu * modulator * mask ## not sure how mask will apply to a full covariance...
dtarget = dtarget * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
- return dmu, dtarget, dSigma, jnp.squeeze(L), mask
- @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
- @staticmethod
- def reset(batch_size, shape, sigma_shape): ## reset core components/statistics
- _shape = (batch_size, shape[0])
- if len(shape) > 1:
- _shape = (batch_size, shape[0], shape[1], shape[2])
+ # Update compartments
+ self.dmu.set(dmu)
+ self.dtarget.set(dtarget)
+ self.dSigma.set(dSigma)
+ self.L.set(jnp.squeeze(L))
+ self.mask.set(mask)
+
+ # @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
+ # @staticmethod
+ @compilable
+ def reset(self): ## reset core components/statistics
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
restVals = jnp.zeros(_shape)
dmu = restVals
dtarget = restVals
- dSigma = jnp.zeros(sigma_shape)
+ dSigma = jnp.zeros(self.sigma_shape)
target = restVals
mu = restVals
modulator = mu + 1.
L = 0. #jnp.zeros((1, 1))
mask = jnp.ones(_shape)
- return dmu, dtarget, dSigma, target, mu, modulator, L, mask
+
+ self.dmu.set(dmu)
+ self.dtarget.set(dtarget)
+ self.dSigma.set(dSigma)
+ self.target.set(target)
+ self.mu.set(mu)
+ self.modulator.set(modulator)
+ self.L.set(L)
+ self.mask.set(mask)
@classmethod
def help(cls): ## component help function
@@ -139,20 +161,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/graded/laplacianErrorCell.py b/ngclearn/components/neurons/graded/laplacianErrorCell.py
index 6d825fe0..c881372b 100755
--- a/ngclearn/components/neurons/graded/laplacianErrorCell.py
+++ b/ngclearn/components/neurons/graded/laplacianErrorCell.py
@@ -1,8 +1,9 @@
-from ngclearn import resolver, Component, Compartment
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
-from ngclearn.utils import tensorstats
-from ngcsimlib.compilers.process import transition
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cell
"""
@@ -33,7 +34,6 @@ class LaplacianErrorCell(JaxComponent): ## Rate-coded/real-valued error unit/cel
to a constant/fixed `scale`
"""
- # Define Functions
def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
super().__init__(name, **kwargs)
@@ -44,7 +44,7 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
else:
_shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
scale_shape = (1, 1)
- if not isinstance(scale, float) and not isinstance(sigma, int):
+ if not isinstance(scale, float) and not isinstance(scale, int):
scale_shape = jnp.array(scale).shape
self.scale_shape = scale_shape
## Layer Size setup
@@ -67,9 +67,15 @@ def __init__(self, name, n_units, batch_size=1, scale=1., shape=None, **kwargs):
self.modulator = Compartment(restVals + 1.0) ## to be set/consumed
self.mask = Compartment(restVals + 1.0)
- @transition(output_compartments=["dshift", "dtarget", "dScale", "L", "mask"])
- @staticmethod
- def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplacian error cell output
+ @compilable
+ def advance_state(self, dt): ## compute Laplacian error cell output
+ # Get the variables
+ shift = self.shift.get()
+ target = self.target.get()
+ Scale = self.Scale.get()
+ modulator = self.modulator.get()
+ mask = self.mask.get()
+
# Moves Laplacian cell dynamics one step forward. Specifically, this routine emulates the error unit
# behavior of the local cost functional:
# FIXME: Currently, below does: L(targ, shift) = -||targ - shift||_1/scale
@@ -85,21 +91,34 @@ def advance_state(dt, shift, target, Scale, modulator, mask): ## compute Laplaci
dshift = dshift * modulator * mask
dtarget = dtarget * modulator * mask
mask = mask * 0. + 1. ## "eat" the mask as it should only apply at time t
- return dshift, dtarget, dScale, jnp.squeeze(L), mask
- @transition(output_compartments=["dshift", "dtarget", "dScale", "target", "shift", "modulator", "L", "mask"])
- @staticmethod
- def reset(batch_size, n_units, scale_shape):
- restVals = jnp.zeros((batch_size, n_units))
+ # Update compartments
+ self.dshift.set(dshift)
+ self.dtarget.set(dtarget)
+ self.dScale.set(dScale)
+ self.L.set(jnp.squeeze(L))
+ self.mask.set(mask)
+
+ @compilable
+ def reset(self): ## reset core components/statistics
+ restVals = jnp.zeros((self.batch_size, self.n_units))
dshift = restVals
dtarget = restVals
- dScale = jnp.zeros(scale_shape)
+ dScale = jnp.zeros(self.scale_shape)
target = restVals
shift = restVals
modulator = shift + 1.
L = 0.
- mask = jnp.ones((batch_size, n_units))
- return dshift, dtarget, dScale, target, shift, modulator, L, mask
+ mask = jnp.ones((self.batch_size, self.n_units))
+
+ self.dshift.set(dshift)
+ self.dtarget.set(dtarget)
+ self.dScale.set(dScale)
+ self.target.set(target)
+ self.shift.set(shift)
+ self.modulator.set(modulator)
+ self.L.set(L)
+ self.mask.set(mask)
@classmethod
def help(cls): ## component help function
@@ -130,20 +149,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/graded/leakyNoiseCell.py b/ngclearn/components/neurons/graded/leakyNoiseCell.py
new file mode 100755
index 00000000..85c4cd03
--- /dev/null
+++ b/ngclearn/components/neurons/graded/leakyNoiseCell.py
@@ -0,0 +1,157 @@
+from jax import numpy as jnp, random, jit
+from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.jaxComponent import JaxComponent
+from ngclearn.utils.model_utils import create_function
+from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4
+
+def _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale): ## raw dynamics ODE
+ dz_dt = -(z * leak_scale) + (j_recurrent + j_input) + jnp.sqrt(2. * tau_x * jnp.square(sigma_rec)) * eps
+ return dz_dt * (1. / tau_x)
+
+def _dfz(t, z, params): ## raw dynamics ODE wrapper
+ j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale = params
+ return _dfz_fn(z, j_input, j_recurrent, eps, tau_x, sigma_rec, leak_scale)
+
+class LeakyNoiseCell(JaxComponent): ## Real-valued, leaky noise cell
+ """
+ A non-spiking cell driven by the gradient dynamics entailed by a continuous-time noisy, leaky recurrent state.
+
+ Reference: https://pmc.ncbi.nlm.nih.gov/articles/PMC4771709/
+
+ The specific differential equation that characterizes this cell is (for adjusting x) is:
+
+ | tau_x * dx/dt = -x + j_rec + j_in + sqrt(2 alpha (sigma_rec)^2) * eps
+ | where j_in is the set of incoming input signals
+ | and j_rec is the set of recurrent input signals
+ | and eps is a sample of unit Gaussian noise, i.e., eps ~ N(0, 1)
+
+ | --- Cell Input Compartments: ---
+ | j_input - input (bottom-up) electrical/stimulus current (takes in external signals)
+ | j_recurrent - recurrent electrical/stimulus pressure
+ | --- Cell State Compartments ---
+ | x - noisy rate activity / current value of state
+ | --- Cell Output Compartments: ---
+ | r - post-rectified activity, i.e., fx(x) = relu(x)
+
+ Args:
+ name: the string name of this cell
+
+ n_units: number of cellular entities (neural population size)
+
+ tau_x: state membrane time constant (milliseconds)
+
+ act_fx: rectification function (Default: "relu")
+
+ output_scale: factor to multiply output of nonlinearity of this cell by (Default: 1.)
+
+ integration_type: type of integration to use for this cell's dynamics;
+ current supported forms include "euler" (Euler/RK-1 integration) and "midpoint" or "rk2"
+ (midpoint method/RK-2 integration) (Default: "euler")
+
+ :Note: setting the integration type to the midpoint method will increase the accuracy of the estimate of
+ the cell's evolution at an increase in computational cost (and simulation time)
+
+ sigma_rec: noise scaling factor / standard deviation (Default: 1)
+ """
+
+ # Define Functions
+ def __init__(
+ self, name, n_units, tau_x, act_fx="relu", integration_type="euler", batch_size=1, sigma_rec=1.,
+ leak_scale=1., shape=None, **kwargs
+ ):
+ super().__init__(name, **kwargs)
+
+
+ self.tau_x = tau_x
+ self.sigma_rec = sigma_rec ## a "resistance" scaling factor
+ self.leak_scale = leak_scale ## the leak scaling factor (most appropriate default is 1)
+
+ ## integration properties
+ self.integrationType = integration_type
+ self.intgFlag = get_integrator_code(self.integrationType)
+
+ ## Layer size setup
+ _shape = (batch_size, n_units) ## default shape is 2D/matrix
+ if shape is None:
+ shape = (n_units,) ## we set shape to be equal to n_units if nothing provided
+ else:
+ _shape = (batch_size, shape[0], shape[1], shape[2]) ## shape is 4D tensor
+ self.shape = shape
+ self.n_units = n_units
+ self.batch_size = batch_size
+
+ self.fx, self.dfx = create_function(fun_name=act_fx)
+
+ # compartments (state of the cell & parameters will be updated through stateless calls)
+ restVals = jnp.zeros(_shape)
+ self.j_input = Compartment(restVals, display_name="Input Stimulus Current", units="mA") # electrical current
+ self.j_recurrent = Compartment(restVals, display_name="Recurrent Stimulus Current", units="mA") # electrical current
+ self.x = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
+ self.r = Compartment(restVals, display_name="Rectified Rate Activity") # rectified output
+
+ @compilable
+ def advance_state(self, t, dt):
+ ### run a step of integration over neuronal dynamics
+ key, skey = random.split(self.key.get(), 2)
+ eps = random.normal(skey, shape=self.x.get().shape) ## sample of unit distributional noise
+
+ #x = _run_cell(dt, self.j_input.get(), self.j_recurrent.get(), self.x.get(), eps, self.tau_x, self.sigma_rec, integType=self.intgFlag)
+ _step_fns = {
+ 0: step_euler,
+ 1: step_rk2,
+ 2: step_rk4,
+ }
+ _step_fn = _step_fns[self.intgFlag] #_step_fns.get(self.intgFlag, step_euler)
+ params = (self.j_input.get(), self.j_recurrent.get(), eps, self.tau_x, self.sigma_rec, self.leak_scale)
+ _, x = _step_fn(0., self.x.get(), _dfz, dt, params) ## update state activation dynamics
+ r = self.fx(x) ## calculate rectified / post-activation function value(s)
+
+ ## set compartments to next state values in accordance with dynamics
+ self.key.set(key)
+ self.x.set(x)
+ self.r.set(r)
+
+ @compilable
+ def reset(self):
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
+ restVals = jnp.zeros(_shape)
+ self.j_input.set(restVals)
+ self.j_recurrent.set(restVals)
+ self.x.set(restVals)
+ self.r.set(restVals)
+
+ @classmethod
+ def help(cls): ## component help function
+ properties = {
+ "cell_type": "LeakyNoiseCell - evolves neurons according to continuous-time noisy/leaky dynamics "
+ }
+ compartment_props = {
+ "inputs":
+ {"j_input": "External input stimulus value(s)",
+ "j_recurrent": "Recurrent/prior-state stimulus value(s)"},
+ "states":
+ {"x": "Update to continuous noisy, leaky dynamics; value at time t"},
+ "outputs":
+ {"r": "A linear rectifier applied to rate-coded dynamics; f(z)"},
+ }
+ hyperparams = {
+ "n_units": "Number of neuronal cells to model in this layer",
+ "batch_size": "Batch size dimension of this component",
+ "tau_x": "State time constant",
+ "sigma_rec": "The non-zero degree/scale of noise to inject into this neuron"
+ }
+ info = {cls.__name__: properties,
+ "compartments": compartment_props,
+ "dynamics": "tau_x * dz/dt = -z + j_input + j_recurrent + noise, where noise ~N(0, sigma_rec)",
+ "hyperparameters": hyperparams}
+ return info
+
+if __name__ == '__main__':
+ from ngcsimlib.context import Context
+ with Context("Bar") as bar:
+ X = LeakyNoiseCell("X", 9, 0.03)
+ print(X)
diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py
index e55ce8ff..f70b0f52 100755
--- a/ngclearn/components/neurons/graded/rateCell.py
+++ b/ngclearn/components/neurons/graded/rateCell.py
@@ -1,16 +1,16 @@
# %%
from jax import numpy as jnp, random, jit
-from functools import partial
-from ngclearn.utils import tensorstats
-# from ngclearn import resolver, Component, Compartment
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.compilers.process import transition
+
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.jaxComponent import JaxComponent
from ngclearn.utils.model_utils import create_function, threshold_soft, \
threshold_cauchy
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2, step_rk4
+from ngcsimlib.logger import info
+
def _dfz_internal_laplace(z, j, j_td, tau_m, leak_gamma): ## raw dynamics
z_leak = jnp.sign(z) ## d/dx of Laplace is signum
@@ -158,11 +158,12 @@ class RateCell(JaxComponent): ## Rate-coded/real-valued cell
resist_scale: a scaling factor applied to incoming pressure `j` (default: 1)
"""
- # Define Functions
def __init__(
self, name, n_units, tau_m, prior=("gaussian", 0.), act_fx="identity", output_scale=1., threshold=("none", 0.),
integration_type="euler", batch_size=1, resist_scale=1., shape=None, is_stateful=True, **kwargs):
- super().__init__(name, **kwargs)
+ jax_comp_kwargs = {k: v for k, v in kwargs.items() if k not in ('omega_0',)}
+ this_class_kwargs = {k: v for k, v in kwargs.items() if k in ('omega_0',)}
+ super().__init__(name, **jax_comp_kwargs)
## membrane parameter setup (affects ODE integration)
self.output_scale = output_scale
@@ -199,10 +200,9 @@ def __init__(
self.n_units = n_units
self.batch_size = batch_size
-
omega_0 = None
if act_fx == "sine":
- omega_0 = kwargs["omega_0"]
+ omega_0 = this_class_kwargs["omega_0"]
self.fx, self.dfx = create_function(fun_name=act_fx, args=omega_0)
# compartments (state of the cell & parameters will be updated through stateless calls)
@@ -212,70 +212,79 @@ def __init__(
self.j_td = Compartment(restVals, display_name="Modulatory Stimulus Current", units="mA") # top-down electrical current - pressure
self.z = Compartment(restVals, display_name="Rate Activity", units="mA") # rate activity
- @transition(output_compartments=["j", "j_td", "z", "zF"])
- @staticmethod
- def advance_state(
- dt, fx, dfx, tau_m, priorLeakRate, intgFlag, priorType, resist_scale, thresholdType, thr_lmbda, is_stateful,
- output_scale, j, j_td, z):
+ @compilable
+ def advance_state(self, dt):
+ # Get the compartment values
+ j = self.j.get()
+ j_td = self.j_td.get()
+ z = self.z.get()
+
#if tau_m > 0.:
- if is_stateful:
+ if self.is_stateful:
### run a step of integration over neuronal dynamics
## Notes:
## self.pressure <-- "top-down" expectation / contextual pressure
## self.current <-- "bottom-up" data-dependent signal
- dfx_val = dfx(z)
+ dfx_val = self.dfx(z)
j = _modulate(j, dfx_val)
- j = j * resist_scale
- tmp_z = _run_cell(dt, j, j_td, z,
- tau_m, leak_gamma=priorLeakRate,
- integType=intgFlag, priorType=priorType)
+ j = j * self.resist_scale
+ tmp_z = _run_cell(
+ dt, j, j_td, z, self.tau_m, leak_gamma=self.priorLeakRate, integType=self.intgFlag,
+ priorType=self.priorType
+ )
## apply optional thresholding sub-dynamics
- if thresholdType == "soft_threshold":
- tmp_z = threshold_soft(tmp_z, thr_lmbda)
- elif thresholdType == "cauchy_threshold":
- tmp_z = threshold_cauchy(tmp_z, thr_lmbda)
+ if self.thresholdType == "soft_threshold":
+ tmp_z = threshold_soft(tmp_z, self.thr_lmbda)
+ elif self.thresholdType == "cauchy_threshold":
+ tmp_z = threshold_cauchy(tmp_z, self.thr_lmbda)
z = tmp_z ## pre-activation function value(s)
- zF = fx(z) * output_scale ## post-activation function value(s)
+ zF = self.fx(z) * self.output_scale ## post-activation function value(s)
else:
## run in "stateless" mode (when no membrane time constant provided)
j_total = j + j_td
z = _run_cell_stateless(j_total)
- zF = fx(z) * output_scale
- return j, j_td, z, zF
-
- @transition(output_compartments=["j", "j_td", "z", "zF"])
- @staticmethod
- def reset(batch_size, shape): #n_units
- _shape = (batch_size, shape[0])
- if len(shape) > 1:
- _shape = (batch_size, shape[0], shape[1], shape[2])
+ zF = self.fx(z) * self.output_scale
+
+ # Update compartments
+ self.j.set(j)
+ self.j_td.set(j_td)
+ self.z.set(z)
+ self.zF.set(zF)
+
+ @compilable
+ def reset(self): #, batch_size, shape): #n_units
+ _shape = (self.batch_size, self.shape[0])
+ if len(self.shape) > 1:
+ _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
restVals = jnp.zeros(_shape)
- return tuple([restVals for _ in range(4)])
-
-
- def save(self, directory, **kwargs):
- ## do a protected save of constants, depending on whether they are floats or arrays
- tau_m = (self.tau_m if isinstance(self.tau_m, float)
- else jnp.ones([[self.tau_m]]))
- priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
- else jnp.ones([[self.priorLeakRate]]))
- resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
- else jnp.ones([[self.resist_scale]]))
-
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- tau_m=tau_m, priorLeakRate=priorLeakRate,
- resist_scale=resist_scale) #, key=self.key.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- ## constants loaded in
- self.tau_m = data['tau_m']
- self.priorLeakRate = data['priorLeakRate']
- self.resist_scale = data['resist_scale']
- #if seeded:
- # self.key.set(data['key'])
+ self.j.set(restVals)
+ self.j_td.set(restVals)
+ self.z.set(restVals)
+ self.zF.set(restVals)
+
+ # def save(self, directory, **kwargs):
+ # ## do a protected save of constants, depending on whether they are floats or arrays
+ # tau_m = (self.tau_m if isinstance(self.tau_m, float)
+ # else jnp.ones([[self.tau_m]]))
+ # priorLeakRate = (self.priorLeakRate if isinstance(self.priorLeakRate, float)
+ # else jnp.ones([[self.priorLeakRate]]))
+ # resist_scale = (self.resist_scale if isinstance(self.resist_scale, float)
+ # else jnp.ones([[self.resist_scale]]))
+ #
+ # file_name = directory + "/" + self.name + ".npz"
+ # jnp.savez(file_name,
+ # tau_m=tau_m, priorLeakRate=priorLeakRate,
+ # resist_scale=resist_scale) #, key=self.key.value)
+ #
+ # def load(self, directory, seeded=False, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # ## constants loaded in
+ # self.tau_m = data['tau_m']
+ # self.priorLeakRate = data['priorLeakRate']
+ # self.resist_scale = data['resist_scale']
+ # #if seeded:
+ # # self.key.set(data['key'])
@classmethod
def help(cls): ## component help function
@@ -308,20 +317,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/graded/rewardErrorCell.py b/ngclearn/components/neurons/graded/rewardErrorCell.py
index fe9670c3..91a8056d 100755
--- a/ngclearn/components/neurons/graded/rewardErrorCell.py
+++ b/ngclearn/components/neurons/graded/rewardErrorCell.py
@@ -1,8 +1,9 @@
-from ngclearn import resolver, Component, Compartment
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngclearn.utils import tensorstats
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class RewardErrorCell(JaxComponent): ## Reward prediction error cell
"""
@@ -51,38 +52,53 @@ def __init__(self, name, n_units, alpha, ema_window_len=10,
self.accum_reward = Compartment(restVals) ## accumulated reward signal(s)
self.n_ep_steps = Compartment(jnp.zeros((self.batch_size, 1))) ## number of episode steps taken
- @transition(output_compartments=["mu", "rpe", "n_ep_steps", "accum_reward"])
- @staticmethod
- def advance_state(dt, use_online_predictor, alpha, mu, rpe, reward,
- n_ep_steps, accum_reward):
+ @compilable
+ def advance_state(self, dt):
+ # Get the variables
+ mu = self.mu.get()
+ reward = self.reward.get()
+ n_ep_steps = self.n_ep_steps.get()
+ accum_reward = self.accum_reward.get()
+
## compute/update RPE and predictor values
accum_reward = accum_reward + reward
rpe = reward - mu
- if use_online_predictor:
- mu = mu * (1. - alpha) + reward * alpha
+ if self.use_online_predictor:
+ mu = mu * (1. - self.alpha) + reward * self.alpha
n_ep_steps = n_ep_steps + 1
- return mu, rpe, n_ep_steps, accum_reward
- @transition(output_compartments=["mu"])
- @staticmethod
- def evolve(dt, use_online_predictor, ema_window_len, n_ep_steps, mu,
- accum_reward):
- if use_online_predictor:
+ # Update compartments
+ self.mu.set(mu)
+ self.rpe.set(rpe)
+ self.n_ep_steps.set(n_ep_steps)
+ self.accum_reward.set(accum_reward)
+
+ @compilable
+ def evolve(self, dt):
+ # Get the variables
+ mu = self.mu.get()
+ n_ep_steps = self.n_ep_steps.get()
+ accum_reward = self.accum_reward.get()
+
+ if self.use_online_predictor:
## total episodic reward signal
r = accum_reward/n_ep_steps
- mu = (1. - 1./ema_window_len) * mu + (1./ema_window_len) * r
- return mu
+ mu = (1. - 1./self.ema_window_len) * mu + (1./self.ema_window_len) * r
- @transition(output_compartments=["mu", "rpe", "accum_reward", "n_ep_steps"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
+ # Update compartment
+ self.mu.set(mu)
+
+ @compilable
+ def reset(self): ## reset core components/statistics
+ restVals = jnp.zeros((self.batch_size, self.n_units))
mu = restVals
rpe = restVals
accum_reward = restVals
- n_ep_steps = jnp.zeros((batch_size, 1))
- return mu, rpe, accum_reward, n_ep_steps
-
+ n_ep_steps = jnp.zeros((self.batch_size, 1))
+ self.mu.set(mu)
+ self.rpe.set(rpe)
+ self.accum_reward.set(accum_reward)
+ self.n_ep_steps.set(n_ep_steps)
@classmethod
def help(cls): ## component help function
@@ -115,16 +131,8 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
+if __name__ == '__main__':
+ from ngcsimlib.context import Context
+ with Context("Bar") as bar:
+ X = RewardErrorCell("X", 9, 0.03)
+ print(X)
diff --git a/ngclearn/components/neurons/spiking/IFCell.py b/ngclearn/components/neurons/spiking/IFCell.py
index 08416f6d..640d9995 100755
--- a/ngclearn/components/neurons/spiking/IFCell.py
+++ b/ngclearn/components/neurons/spiking/IFCell.py
@@ -1,18 +1,14 @@
from ngclearn.components.jaxComponent import JaxComponent
-from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
+from jax import numpy as jnp, random, nn, Array, jit
+from ngcsimlib import deprecate_args
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
triangular_estimator,
straight_through_estimator)
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
@@ -35,7 +31,7 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
The specific differential equation that characterizes this cell
is (for adjusting v, given current j, over time) is:
- | tau_m * dv/dt = (v_rest - v) + j * R
+ | tau_m * dv/dt = j * R
| where R is the membrane resistance and v_rest is the resting potential
| also, if a spike occurs, v is set to v_reset
@@ -91,10 +87,10 @@ class IFCell(JaxComponent): ## integrate-and-fire cell
"""
@deprecate_args(thr_jitter=None)
- def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
- v_reset=-60., refract_time=0., integration_type="euler",
- surrogate_type="straight_through", lower_clamp_voltage=True,
- **kwargs):
+ def __init__(
+ self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., refract_time=0.,
+ integration_type="euler", surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
+ ):
super().__init__(name, **kwargs)
## Integration properties
@@ -118,12 +114,12 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
self.n_units = n_units
## set up surrogate function for spike emission
- if surrogate_type == "arctan":
- self.spike_fx, self.d_spike_fx = arctan_estimator()
- elif surrogate_type == "triangular":
- self.spike_fx, self.d_spike_fx = triangular_estimator()
- else: ## default: straight_through
- self.spike_fx, self.d_spike_fx = straight_through_estimator()
+ # if surrogate_type == "arctan":
+ # self.spike_fx, self.d_spike_fx = arctan_estimator()
+ # elif surrogate_type == "triangular":
+ # self.spike_fx, self.d_spike_fx = triangular_estimator()
+ # else: ## default: straight_through
+ # self.spike_fx, self.d_spike_fx = straight_through_estimator()
## Compartment setup
@@ -136,76 +132,50 @@ def __init__(self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65.,
display_name="Refractory Time Period", units="ms")
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike
- self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
+ #self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
- @transition(output_compartments=["v", "s", "rfr", "tols", "key", "surrogate"])
- @staticmethod
+ @compilable
def advance_state(
- t, dt, tau_m, resist_m, v_rest, v_reset, refract_T, thr, lower_clamp_voltage, intgFlag, d_spike_fx, key,
- j, v, rfr, tols
+ self, dt, t
):
## run one integration step for neuronal dynamics
- j = j * resist_m
+ j = self.j.get() * self.resist_m
### Runs integrator (or integrate-and-fire; IF) neuronal dynamics
## update voltage / membrane potential
- v_params = (j, rfr, tau_m, refract_T)
- if intgFlag == 1:
- _, _v = step_rk2(0., v, _dfv, dt, v_params)
+ v_params = (j, self.rfr.get(), self.tau_m, self.refract_T)
+ if self.intgFlag == 1:
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
else:
- _, _v = step_euler(0., v, _dfv, dt, v_params)
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
## obtain action potentials/spikes
- s = (_v > thr) * 1.
+ s = (_v > self.thr) * 1.
## update refractory variables
- rfr = (rfr + dt) * (1. - s)
+ rfr = (self.rfr.get() + dt) * (1. - s)
## perform hyper-polarization of neuronal cells
- v = _v * (1. - s) + s * v_reset
+ v = _v * (1. - s) + s * self.v_reset
+
+ #surrogate = d_spike_fx(v, self.thr)
- surrogate = d_spike_fx(v, thr)
## update tols
- tols = (1. - s) * tols + (s * t)
- if lower_clamp_voltage: ## ensure voltage never < v_rest
- v = jnp.maximum(v, v_rest)
- return v, s, rfr, tols, key, surrogate
-
- @transition(output_compartments=["j", "v", "s", "rfr", "tols", "surrogate"])
- @staticmethod
- def reset(batch_size, n_units, v_rest, refract_T):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals #+ 0
- v = restVals + v_rest
- s = restVals #+ 0
- rfr = restVals + refract_T
- tols = restVals #+ 0
- surrogate = restVals + 1.
- return j, v, s, rfr, tols, surrogate
-
- def save(self, directory, **kwargs):
- ## do a protected save of constants, depending on whether they are floats or arrays
- tau_m = (self.tau_m if isinstance(self.tau_m, float)
- else jnp.asarray([[self.tau_m * 1.]]))
- thr = (self.thr if isinstance(self.thr, float)
- else jnp.asarray([[self.thr * 1.]]))
- v_rest = (self.v_rest if isinstance(self.v_rest, float)
- else jnp.asarray([[self.v_rest * 1.]]))
- v_reset = (self.v_reset if isinstance(self.v_reset, float)
- else jnp.asarray([[self.v_reset * 1.]]))
- v_decay = (self.v_decay if isinstance(self.v_decay, float)
- else jnp.asarray([[self.v_decay * 1.]]))
- resist_m = (self.resist_m if isinstance(self.resist_m, float)
- else jnp.asarray([[self.resist_m * 1.]]))
- tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
- else jnp.asarray([[self.tau_theta * 1.]]))
- theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
- else jnp.asarray([[self.theta_plus * 1.]]))
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+ if self.lower_clamp_voltage: ## ensure voltage never < v_rest
+ _v = jnp.maximum(v, self.v_rest)
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- tau_m=tau_m, thr=thr, v_rest=v_rest,
- v_reset=v_reset, v_decay=v_decay,
- resist_m=resist_m, tau_theta=tau_theta,
- theta_plus=theta_plus,
- key=self.key.value)
+ self.v.set(_v)
+ self.s.set(s)
+ self.rfr.set(rfr)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v_rest)
+ self.s.set(restVals)
+ self.rfr.set(restVals + self.refract_T)
+ self.tols.set(restVals)
+ #surrogate = restVals + 1.
def load(self, directory, seeded=False, **kwargs):
file_name = directory + "/" + self.name + ".npz"
@@ -260,17 +230,3 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
diff --git a/ngclearn/components/neurons/spiking/LIFCell.py b/ngclearn/components/neurons/spiking/LIFCell.py
index 371e8058..6fedf559 100644
--- a/ngclearn/components/neurons/spiking/LIFCell.py
+++ b/ngclearn/components/neurons/spiking/LIFCell.py
@@ -1,18 +1,13 @@
from ngclearn.components.jaxComponent import JaxComponent
-from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
+from jax import numpy as jnp, random, nn, Array
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
triangular_estimator,
straight_through_estimator)
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
def _dfv(t, v, params): ## voltage dynamics wrapper
j, rfr, tau_m, refract_T, v_rest, g_L = params
@@ -24,7 +19,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
#@partial(jit, static_argnums=[3, 4])
-def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
+def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
### Runs homeostatic threshold update dynamics one step (via Euler integration).
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
#theta_plus = 0.05
@@ -112,23 +107,23 @@ class LIFCell(JaxComponent): ## leaky integrate-and-fire cell
v_min: minimum voltage to clamp dynamics to (Default: None)
""" ## batch_size arg?
- @deprecate_args(thr_jitter=None, v_decay="conduct_leak")
def __init__(
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., conduct_leak=1., tau_theta=1e7,
theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler", surrogate_type="straight_through",
- v_min=None, max_one_spike=False, **kwargs
+ v_min=None, max_one_spike=False, key=None
):
- super().__init__(name, **kwargs)
+ super().__init__(name, key)
## Integration properties
self.integrationType = integration_type
self.intgFlag = get_integrator_code(self.integrationType)
+ self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
+ self.max_one_spike = max_one_spike
## membrane parameter setup (affects ODE integration)
self.tau_m = tau_m ## membrane time constant
self.resist_m = resist_m ## resistance value
- self.one_spike = one_spike ## True => constrains system to simulate 1 spike per time step
- self.max_one_spike = max_one_spike
+
self.v_min = v_min ## ensures voltage is never < v_min
self.v_rest = v_rest #-65. # mV
@@ -146,146 +141,87 @@ def __init__(
self.batch_size = 1
self.n_units = n_units
- ## set up surrogate function for spike emission
- if surrogate_type == "secant_lif":
- self.spike_fx, self.d_spike_fx = secant_lif_estimator()
- elif surrogate_type == "arctan":
- self.spike_fx, self.d_spike_fx = arctan_estimator()
- elif surrogate_type == "triangular":
- self.spike_fx, self.d_spike_fx = triangular_estimator()
- else: ## default: straight_through
- self.spike_fx, self.d_spike_fx = straight_through_estimator()
-
+ # ## set up surrogate function for spike emission
+ # if surrogate_type == "secant_lif":
+ # spike_fx, d_spike_fx = secant_lif_estimator()
+ # elif surrogate_type == "arctan":
+ # spike_fx, d_spike_fx = arctan_estimator()
+ # elif surrogate_type == "triangular":
+ # spike_fx, d_spike_fx = triangular_estimator()
+ # else: ## default: straight_through
+ # spike_fx, d_spike_fx = straight_through_estimator()
## Compartment setup
restVals = jnp.zeros((self.batch_size, self.n_units))
self.j = Compartment(restVals, display_name="Current", units="mA")
- self.v = Compartment(restVals + self.v_rest,
- display_name="Voltage", units="mV")
+ self.v = Compartment(restVals + self.v_rest, display_name="Voltage", units="mV")
self.s = Compartment(restVals, display_name="Spikes")
self.s_raw = Compartment(restVals, display_name="Raw Spike Pulses")
- self.rfr = Compartment(restVals + self.refract_T,
- display_name="Refractory Time Period", units="ms")
- self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift",
- units="mV")
- self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
- units="ms") ## time-of-last-spike
- self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
-
- @transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
- @staticmethod
- def advance_state(
- t, dt, tau_m, resist_m, v_rest, v_reset, g_L, refract_T, thr, tau_theta, theta_plus, one_spike, max_one_spike,
- v_min, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
- ):
- skey = None ## this is an empty dkey if single_spike mode turned off
- if one_spike and not max_one_spike:
- key, skey = random.split(key, 2)
- ## run one integration step for neuronal dynamics
- j = j * resist_m
- ############################################################################
- ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
- _v_thr = thr_theta + thr ## calc present voltage threshold
- #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
- ## update voltage / membrane potential
- v_params = (j, rfr, tau_m, refract_T, v_rest, g_L)
- if intgFlag == 1:
- _, _v = step_rk2(0., v, _dfv, dt, v_params)
+ self.rfr = Compartment(restVals + self.refract_T, display_name="Refractory Time Period", units="ms")
+ self.thr_theta = Compartment(restVals, display_name="Threshold Adaptive Shift", units="mV")
+ self.tols = Compartment(restVals, display_name="Time-of-Last-Spike", units="ms") ## time-of-last-spike
+ # self.surrogate = Compartment(restVals + 1., display_name="Surrogate State Value")
+
+ @compilable
+ def advance_state(self, dt, t):
+ j = self.j.get() * self.resist_m
+
+ _v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
+
+ v_params = (j, self.rfr.get(), self.tau_m.get(), self.refract_T, self.v_rest, self.g_L)
+
+ if self.intgFlag == 1:
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
else:
- _, _v = step_euler(0., v, _dfv, dt, v_params)
- ## obtain action potentials/spikes/pulses
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
+
s = (_v > _v_thr) * 1.
- v_prespike = v
- ## update refractory variables
- _rfr = (rfr + dt) * (1. - s)
- ## perform hyper-polarization of neuronal cells
- _v = _v * (1. - s) + s * v_reset
-
- raw_s = s + 0 ## preserve un-altered spikes
- ############################################################################
- ## this is a spike post-processing step
- if skey is not None:
+ _rfr = (self.rfr.get() + dt) * (1. - s)
+ _v = _v * (1. - s) + s * self.v_reset
+
+ raw_s = s
+
+ if self.one_spike and not self.max_one_spike:
+ key, skey = random.split(self.key.get(), 2)
+
m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
rS = s * random.uniform(skey, s.shape)
- rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
- dtype=jnp.float32)
+ rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1], dtype=jnp.float32)
s = s * (1. - m_switch) + rS * m_switch
- if max_one_spike:
- rS = nn.one_hot(jnp.argmax(v_prespike, axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike
+ self.key.set(key)
+
+ if self.max_one_spike:
+ rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1], dtype=jnp.float32) ## get max-volt spike
s = s * rS ## mask out non-max volt spikes
- ############################################################################
- raw_spikes = raw_s
- v = _v
- rfr = _rfr
- surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta)
- if tau_theta > 0.:
+ if self.tau_theta > 0.:
## run one integration step for threshold dynamics
- thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
- ## update tols
- tols = (1. - s) * tols + (s * t)
- if v_min is not None: ## ensures voltage never < v_rest
- v = jnp.maximum(v, v_min)
- return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
-
- @transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
- @staticmethod
- def reset(batch_size, n_units, v_rest, refract_T):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals #+ 0
- v = restVals + v_rest
- s = restVals #+ 0
- s_raw = restVals
- rfr = restVals + refract_T
- #thr_theta = restVals ## do not reset thr_theta
- tols = restVals #+ 0
- surrogate = restVals + 1.
- return j, v, s, s_raw, rfr, tols, surrogate
-
- def save(self, directory, **kwargs):
- ## do a protected save of constants, depending on whether they are floats or arrays
- tau_m = (self.tau_m if isinstance(self.tau_m, float)
- else jnp.asarray([[self.tau_m * 1.]]))
- thr = (self.thr if isinstance(self.thr, float)
- else jnp.asarray([[self.thr * 1.]]))
- v_rest = (self.v_rest if isinstance(self.v_rest, float)
- else jnp.asarray([[self.v_rest * 1.]]))
- v_reset = (self.v_reset if isinstance(self.v_reset, float)
- else jnp.asarray([[self.v_reset * 1.]]))
- g_L = (self.g_L if isinstance(self.g_L, float)
- else jnp.asarray([[self.g_L * 1.]]))
- resist_m = (self.resist_m if isinstance(self.resist_m, float)
- else jnp.asarray([[self.resist_m * 1.]]))
- tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
- else jnp.asarray([[self.tau_theta * 1.]]))
- theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
- else jnp.asarray([[self.theta_plus * 1.]]))
-
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- threshold_theta=self.thr_theta.value,
- tau_m=tau_m, thr=thr, v_rest=v_rest,
- v_reset=v_reset, g_L=g_L,
- resist_m=resist_m, tau_theta=tau_theta,
- theta_plus=theta_plus,
- key=self.key.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.thr_theta.set(data['threshold_theta'])
- ## constants loaded in
- self.tau_m = data['tau_m']
- self.thr = data['thr']
- self.v_rest = data['v_rest']
- self.v_reset = data['v_reset']
- self.g_L = data['g_L']
- self.resist_m = data['resist_m']
- self.tau_theta = data['tau_theta']
- self.theta_plus = data['theta_plus']
-
- if seeded:
- self.key.set(data['key'])
+ thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) #.get())
+ self.thr_theta.set(thr_theta)
+
+ ## update time-of-last spike variable(s)
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+
+ if self.v_min is not None: ## ensures voltage never < v_rest
+ _v = jnp.maximum(_v, self.v_min)
+
+
+ self.v.set(_v)
+ self.s.set(s)
+ self.s_raw.set(raw_s)
+ self.rfr.set(_rfr)
+
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v_rest)
+ self.s.set(restVals)
+ self.s_raw.set(restVals)
+ self.rfr.set(restVals + self.refract_T)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -315,17 +251,13 @@ def help(cls): ## component help function
"v_reset": "Reset membrane potential value",
"conduct_leak": "Conductance leak / voltage decay factor",
"tau_theta": "Threshold/homoestatic increment time constant",
- "theta_plus": "Amount to increment threshold by upon occurrence "
- "of spike",
+ "theta_plus": "Amount to increment threshold by upon occurrence of a spike",
"refract_time": "Length of relative refractory period (ms)",
- "one_spike": "Should only one spike be sampled/allowed to emit at "
- "any given time step?",
- "integration_type": "Type of numerical integration to use for the "
- "cell dynamics",
+ "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
+ "integration_type": "Type of numerical integration to use for the cell dynamics",
"surrgoate_type": "Type of surrogate function to use approximate "
"derivative of spike w.r.t. voltage/current",
- "lower_bound_clamp": "Should voltage be lower bounded to be never "
- "be below `v_rest`"
+ "v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
@@ -333,20 +265,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/RAFCell.py b/ngclearn/components/neurons/spiking/RAFCell.py
index df95de1d..6c2bdc5d 100755
--- a/ngclearn/components/neurons/spiking/RAFCell.py
+++ b/ngclearn/components/neurons/spiking/RAFCell.py
@@ -1,16 +1,15 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+########################################################################################################################
+## RAF dynamics (multi-dimensional ODEs)
@jit
def _dfv_internal(j, v, w, tau_m, omega, b): ## "voltage" dynamics
# dy/dt = omega x + b y
@@ -34,6 +33,7 @@ def _dfw(t, w, params): ## angular driver dynamics wrapper
j, v, tau_w, omega, b = params
dv_dt = _dfw_internal(j, v, w, tau_w, omega, b)
return dv_dt
+########################################################################################################################
class RAFCell(JaxComponent):
"""
@@ -60,8 +60,7 @@ class RAFCell(JaxComponent):
| tols - time-of-last-spike
| References:
- | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks
- | 14.6-7 (2001): 883-894.
+ | Izhikevich, Eugene M. "Resonate-and-fire neurons." Neural networks 14.6-7 (2001): 883-894.
Args:
name: the string name of this cell
@@ -77,7 +76,7 @@ class RAFCell(JaxComponent):
omega: angular frequency (Default: 10)
- b: oscillation dampening factor (Default: -1)
+ dampen_factor: oscillation dampening factor (Default: -1) ("b" in Izhikevich 2001)
v_reset: reset condition for membrane potential (Default: 1 mV)
@@ -98,10 +97,10 @@ class RAFCell(JaxComponent):
at an increase in computational cost (and simulation time)
"""
- @deprecate_args(resist_m="resist_v", tau_m="tau_v")
+ @deprecate_args(resist_m="resist_v", tau_m="tau_v", b="dampen_factor")
def __init__(
- self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., b=-1., v_reset=0., w_reset=0., v0=0., w0=0.,
- resist_v=1., integration_type="euler", batch_size=1, **kwargs
+ self, name, n_units, tau_v=1., tau_w=1., thr=1., omega=10., dampen_factor=-1., v_reset=0., w_reset=0.,
+ v0=0., w0=0., resist_v=1., integration_type="euler", batch_size=1, **kwargs
):
#v_rest=-72., v_reset=-75., w_reset=0., thr=5., v0=-70., w0=0., tau_w=400., thr=5., omega=10., b=-1.
super().__init__(name, **kwargs)
@@ -115,8 +114,8 @@ def __init__(
self.resist_v = resist_v
self.tau_w = tau_w
self.omega = omega ## angular frequency
- self.b = b ## dampening factor
- ## note: the smaller b is, the faster the oscillation dampens to resting state values
+ self.dampen_factor = dampen_factor ## dampening factor (b)
+ ## Note: the smaller that dampen_factor "b" is, the faster the oscillation dampens to resting state values
self.v_reset = v_reset
self.w_reset = w_reset
self.v0 = v0
@@ -137,42 +136,44 @@ def __init__(
restVals, display_name="Time-of-Last-Spike", units="ms"
) ## time-of-last-spike
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def advance_state(t, dt, tau_v, resist_v, tau_w, thr, omega, b,
- v_reset, w_reset, intgFlag, j, v, w, tols):
+ @compilable
+ def advance_state(self, t, dt):
## continue with centered dynamics
- j_ = j * resist_v
- if intgFlag == 1: ## RK-2/midpoint
+ j_ = self.j.get() * self.resist_v
+ if self.intgFlag == 1: ## RK-2/midpoint
## Note: we integrate ODEs in order: first w, then v
- w_params = (j_, v, tau_w, omega, b)
- _, _w = step_rk2(0., w, _dfw, dt, w_params)
- v_params = (j_, _w, tau_v, omega, b)
- _, _v = step_rk2(0., v, _dfv, dt, v_params)
+ w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor)
+ _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params)
+ v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor)
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
else: # integType == 0 (default -- Euler)
## Note: we integrate ODEs in order: first w, then v
- w_params = (j_, v, tau_w, omega, b)
- _, _w = step_euler(0., w, _dfw, dt, w_params)
- v_params = (j_, _w, tau_v, omega, b)
- _, _v = step_euler(0., v, _dfv, dt, v_params)
- s = (_v > thr) * 1. ## emit spikes/pulses
+ w_params = (j_, self.v.get(), self.tau_w, self.omega, self.dampen_factor)
+ _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params)
+ v_params = (j_, _w, self.tau_v, self.omega, self.dampen_factor)
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
+
+ s = (_v > self.thr) * 1. ## emit spikes/pulses
## hyperpolarize/reset/snap variables
- w = _w * (1. - s) + s * w_reset
- v = _v * (1. - s) + s * v_reset
-
- tols = (1. - s) * tols + (s * t) ## update times-of-last-spike(s)
- return j, v, w, s, tols
-
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def reset(batch_size, n_units, v0, w0):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals # None
- v = restVals + v0
- w = restVals + w0
- s = restVals #+ 0
- tols = restVals #+ 0
- return j, v, w, s, tols
+ w = _w * (1. - s) + s * self.w_reset
+ v = _v * (1. - s) + s * self.v_reset
+
+ self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s)
+
+ #self.j.set(j_)
+ self.v.set(v)
+ self.w.set(w)
+ self.s.set(s)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v0)
+ self.w.set(restVals + self.w0)
+ self.s.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -198,7 +199,7 @@ def help(cls): ## component help function
"tau_w": "Recovery variable time constant",
"v_reset": "Reset membrane potential value",
"w_reset": "Reset angular driver value",
- "b": "Exponential dampening factor applied to oscillations",
+ "dampen_factor": "Exponential dampening factor applied to oscillations",
"omega": "Angular frequency of neuronal progress per second (radians)",
"v0": "Initial condition for membrane potential/voltage",
"w0": "Initial condition for membrane angular driver variable",
@@ -207,21 +208,7 @@ def help(cls): ## component help function
}
info = {cls.__name__: properties,
"compartments": compartment_props,
- "dynamics": "tau_v * dv/dt = omega * w + v * b; "
- "tau_w * dw/dt = w * b - v * omega + j",
+ "dynamics": "tau_v * dv/dt = omega * w + v * dampen_factor; "
+ "tau_w * dw/dt = w * dampen_factor - v * omega + j",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/neurons/spiking/WTASCell.py b/ngclearn/components/neurons/spiking/WTASCell.py
index c6f9edb6..b4602c74 100755
--- a/ngclearn/components/neurons/spiking/WTASCell.py
+++ b/ngclearn/components/neurons/spiking/WTASCell.py
@@ -1,14 +1,9 @@
from jax import numpy as jnp, random, jit, nn
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngcsimlib import deprecate_args
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.utils.model_utils import softmax
@@ -53,7 +48,7 @@ class WTASCell(JaxComponent): ## winner-take-all spiking cell
thr_jitter: scale of uniform jitter to add to initialization of thresholds
"""
- # Define Functions
+ @deprecate_args(thrBase="thr_base")
def __init__(
self, name, n_units, tau_m, resist_m=1., thr_base=0.4, thr_gain=0.002, refract_time=0., thr_jitter=0.05,
**kwargs
@@ -74,7 +69,7 @@ def __init__(
## base threshold setup
## according to eqn 26 of the source paper, the initial condition for the
## threshold should technically be between: 1/n_units < threshold0 << 0.5, e.g., 0.15
- key, subkey = random.split(self.key.value)
+ key, subkey = random.split(self.key.get())
self.threshold0 = thr_base + random.uniform(subkey, (1, n_units),
minval=-thr_jitter, maxval=thr_jitter,
dtype=jnp.float32)
@@ -88,42 +83,44 @@ def __init__(
self.rfr = Compartment(restVals + self.refract_T)
self.tols = Compartment(restVals) ## time-of-last-spike
- @transition(output_compartments=["v", "s", "thr", "rfr", "tols"])
- @staticmethod
- def advance_state(t, dt, tau_m, R_m, thr_gain, refract_T, j, v, thr, rfr, tols):
- mask = (rfr >= refract_T) * 1. ## check refractory period
- v = (j * R_m) * mask
+ @compilable
+ def advance_state(self, t, dt):
+ mask = (self.rfr.get() >= self.refract_T) * 1. ## check refractory period
+ v = (self.j.get() * self.R_m) * mask
vp = softmax(v) # convert to Categorical (spike) probabilities
# s = nn.one_hot(jnp.argmax(vp, axis=1), j.shape[1]) ## hard-max spike
- s = (vp > thr) * 1. ## calculate action potential
+ s = (vp > self.thr.get()) * 1. ## calculate action potential
q = 1. ## Note: thr_gain ==> "rho_b"
## increment threshold upon spike(s) occurrence
dthr = jnp.sum(s, axis=1, keepdims=True) - q
- thr = jnp.maximum(thr + dthr * thr_gain, 0.025) ## calc new threshold
- rfr = (rfr + dt) * (1. - s) + s * dt # set refract to dt
-
- tols = (1. - s) * tols + (s * t) ## update tols
- return v, s, thr, rfr, tols
-
- @transition(output_compartments=["j", "v", "s", "rfr", "tols"])
- @staticmethod
- def reset(batch_size, n_units, refract_T):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals #+ 0
- v = restVals #+ 0
- s = restVals #+ 0
- rfr = restVals + refract_T
- tols = restVals #+ 0
- return j, v, s, rfr, tols
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name, threshold=self.thr.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.thr.set( data['threshold'] )
+ thr = jnp.maximum(self.thr.get() + dthr * self.thr_gain, 0.025) ## calc new threshold
+ rfr = (self.rfr.get() + dt) * (1. - s) + s * dt # set refract to dt
+
+ self.tols.set((1. - s) * self.tols.get() + (s * t)) ## update times-of-last-spike(s)
+
+ self.v.set(v)
+ self.s.set(s)
+ self.thr.set(thr)
+ self.rfr.set(rfr)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals)
+ self.s.set(restVals)
+ self.rfr.set(restVals + self.refract_T)
+ self.tols.set(restVals)
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # jnp.savez(file_name, threshold=self.thr.get())
+ #
+ # def load(self, directory, seeded=False, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # self.thr.set( data['threshold'] )
@classmethod
def help(cls): ## component help function
@@ -158,20 +155,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/__init__.py b/ngclearn/components/neurons/spiking/__init__.py
index 690087b7..b4c0b3db 100644
--- a/ngclearn/components/neurons/spiking/__init__.py
+++ b/ngclearn/components/neurons/spiking/__init__.py
@@ -9,3 +9,4 @@
from .izhikevichCell import IzhikevichCell
from .RAFCell import RAFCell
from .hodgkinHuxleyCell import HodgkinHuxleyCell
+
diff --git a/ngclearn/components/neurons/spiking/adExCell.py b/ngclearn/components/neurons/spiking/adExCell.py
index fdff5f4c..1e55b55d 100755
--- a/ngclearn/components/neurons/spiking/adExCell.py
+++ b/ngclearn/components/neurons/spiking/adExCell.py
@@ -1,14 +1,11 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
-from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
- step_euler, step_rk2
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
+
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
def _dfv_internal(j, v, w, tau_m, v_rest, sharpV, vT, R_m): ## raw voltage dynamics
@@ -32,7 +29,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
dv_dt = _dfw_internal(j, v, w, a, tau_m, v_rest)
return dv_dt
-class AdExCell(JaxComponent):
+class AdExCell(JaxComponent): ## adaptive exponential integrate-and-fire cell
"""
The AdEx (adaptive exponential leaky integrate-and-fire) neuronal cell
model; a two-variable model. This cell model iteratively evolves
@@ -136,39 +133,40 @@ def __init__(
self.tols = Compartment(restVals, display_name="Time-of-Last-Spike",
units="ms") ## time-of-last-spike
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def advance_state(
- t, dt, tau_m, R_m, tau_w, thr, a, b, sharpV, vT, v_rest, v_reset, intgFlag, j, v, w, tols
- ):
- if intgFlag == 1: ## RK-2/midpoint
- v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
- _, _v = step_rk2(0., v, _dfv, dt, v_params)
- w_params = (j, v, a, tau_w, v_rest)
- _, _w = step_rk2(0., w, _dfw, dt, w_params)
+ @compilable
+ def advance_state(self, t, dt):
+ if self.intgFlag == 1: ## RK-2/midpoint
+ v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m)
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
+ w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest)
+ _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params)
else: # intgFlag == 0 (default -- Euler)
- v_params = (j, w, tau_m, v_rest, sharpV, vT, R_m)
- _, _v = step_euler(0., v, _dfv, dt, v_params)
- w_params = (j, v, a, tau_w, v_rest)
- _, _w = step_euler(0., w, _dfw, dt, w_params)
- s = (_v > thr) * 1. ## emit spikes/pulses
+ v_params = (self.j.get(), self.w.get(), self.tau_m, self.v_rest, self.sharpV, self.vT, self.R_m)
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
+ w_params = (self.j.get(), self.v.get(), self.a, self.tau_w, self.v_rest)
+ _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params)
+ s = (_v > self.thr) * 1. ## emit spikes/pulses
## hyperpolarize/reset/snap variables
- v = _v * (1. - s) + s * v_reset
- w = _w * (1. - s) + s * (_w + b)
-
- tols = (1. - s) * tols + (s * t) ## update time-of-last spike variable(s)
- return j, v, w, s, tols
-
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def reset(batch_size, n_units, v0, w0):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals # None
- v = restVals + v0
- w = restVals + w0
- s = restVals #+ 0
- tols = restVals #+ 0
- return j, v, w, s, tols
+ v = _v * (1. - s) + s * self.v_reset
+ w = _w * (1. - s) + s * (_w + self.b)
+
+ ## update time-of-last spike variable(s)
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+
+ #self.j.set(j) ## j is not getting modified in these dynamics
+ self.v.set(v)
+ self.w.set(w)
+ self.s.set(s)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v0)
+ self.w.set(restVals + self.w0)
+ self.s.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -211,20 +209,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py
index 2cab7f56..9fe7f603 100755
--- a/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py
+++ b/ngclearn/components/neurons/spiking/fitzhughNagumoCell.py
@@ -1,16 +1,12 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
def _dfv_internal(j, v, w, a, b, g, tau_m): ## raw voltage dynamics
@@ -34,7 +30,7 @@ def _dfw(t, w, params): ## recovery dynamics wrapper
dv_dt = _dfw_internal(j, v, w, a, b, g, tau_m)
return dv_dt
-class FitzhughNagumoCell(JaxComponent):
+class FitzhughNagumoCell(JaxComponent): ## F-H cell
"""
The Fitzhugh-Nagumo neuronal cell model; a two-variable simplification
of the Hodgkin-Huxley (squid axon) model. This cell model iteratively evolves
@@ -103,10 +99,10 @@ class FitzhughNagumoCell(JaxComponent):
at an increase in computational cost (and simulation time)
"""
- # Define Functions
- def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
- beta=0.8, gamma=3., v0=0., w0=0., v_thr=1.07, spike_reset=False,
- integration_type="euler", **kwargs):
+ def __init__(
+ self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7, beta=0.8, gamma=3., v0=0., w0=0.,
+ v_thr=1.07, spike_reset=False, integration_type="euler", **kwargs
+ ):
super().__init__(name, **kwargs)
## Integration properties
@@ -115,7 +111,7 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
## Cell properties
self.tau_m = tau_m
- self.R_m = resist_m
+ self.resist_m = resist_m ## resistance R_m
self.tau_w = tau_w
self.alpha = alpha
self.beta = beta
@@ -138,41 +134,44 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., tau_w=12.5, alpha=0.7,
self.s = Compartment(restVals)
self.tols = Compartment(restVals) ## time-of-last-spike
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def advance_state(t, dt, tau_m, R_m, tau_w, v_thr, spike_reset, v0, w0, alpha,
- beta, gamma, intgFlag, j, v, w, tols):
- j_mod = j * R_m
- if intgFlag == 1:
- v_params = (j_mod, w, alpha, beta, gamma, tau_m)
- _, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
- w_params = (j_mod, v, alpha, beta, gamma, tau_w)
- _, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
+ @compilable
+ def advance_state(self, t, dt):
+ j_mod = self.j.get() * self.resist_m
+ if self.intgFlag == 1:
+ v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m)
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
+ w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w)
+ _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
else: # integType == 0 (default -- Euler)
- v_params = (j_mod, w, alpha, beta, gamma, tau_m)
- _, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
- w_params = (j_mod, v, alpha, beta, gamma, tau_w)
- _, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
- s = (_v > v_thr) * 1.
+ v_params = (j_mod, self.w.get(), self.alpha, self.beta, self.gamma, self.tau_m)
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
+ w_params = (j_mod, self.v.get(), self.alpha, self.beta, self.gamma, self.tau_w)
+ _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
+ s = (_v > self.v_thr) * 1.
v = _v
w = _w
- if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
- v = v * (1. - s) + s * v0
- w = w * (1. - s) + s * w0
- tols = (1. - s) * tols + (s * t) ## update tols
- return j, v, w, s, tols
-
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def reset(batch_size, n_units, v0, w0):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals # None
- v = restVals + v0
- w = restVals + w0
- s = restVals #+ 0
- tols = restVals #+ 0
- return j, v, w, s, tols
+ if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions
+ v = v * (1. - s) + s * self.v0
+ w = w * (1. - s) + s * self.w0
+
+ ## update time-of-last spike variable(s)
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+
+ # self.j.set(j) ## j is not getting modified in these dynamics
+ self.v.set(v)
+ self.w.set(w)
+ self.s.set(s)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v0)
+ self.w.set(restVals + self.w0)
+ self.s.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -197,8 +196,7 @@ def help(cls): ## component help function
"resist_m": "Membrane resistance value",
"tau_w": "Recovery variable time constant",
"v_thr": "Base voltage threshold value",
- "spike_reset": "Should voltage/recover be snapped to initial "
- "condition(s) if spike emitted?",
+ "spike_reset": "Should voltage/recover be snapped to initial condition(s) if spike emitted?",
"alpha": "Dimensionless recovery variable shift factor `a",
"beta": "Dimensionless recovery variable scale factor `b`",
"gamma": "Power-term divisor constant",
@@ -213,20 +211,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py
index 29ab648e..87ec823b 100644
--- a/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py
+++ b/ngclearn/components/neurons/spiking/hodgkinHuxleyCell.py
@@ -1,15 +1,11 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
-from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
- step_euler, step_rk2, step_rk4
+from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2, step_rk4
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
def _calc_biophysical_constants(v): ## computes H-H biophysical constants (which are functions of voltage v)
@@ -113,7 +109,6 @@ class HodgkinHuxleyCell(JaxComponent): ## Hodgkin-Huxley spiking cell
at an increase in computational cost (and simulation time)
"""
- # Define Functions
def __init__(
self, name, n_units, tau_v, resist_m=1., v_Na=115., v_K=-35., v_L=10.6, g_Na=100., g_K=5., g_L=0.3, thr=4.,
spike_reset=False, v_reset=0., integration_type="euler", **kwargs
@@ -126,7 +121,7 @@ def __init__(
## cell properties / biophysical parameter setup (affects ODE integration)
self.tau_v = tau_v ## membrane time constant
- self.R_m = resist_m ## resistance value
+ self.resist_m = resist_m ## resistance value R_m
self.spike_reset = spike_reset
self.thr = thr # mV ## base value for threshold
self.v_reset = v_reset ## base value to reset voltage to (if spike_reset = True)
@@ -151,38 +146,49 @@ def __init__(
self.s = Compartment(restVals, display_name="Spike pulse")
self.tols = Compartment(restVals, display_name="Time-of-last-spike") ## time-of-last-spike
- @transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
- @staticmethod
- def advance_state(
- t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag
- ):
- _j = j * R_m
- alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
+ #@transition(output_compartments=["v", "m", "n", "h", "s", "tols"])
+ #@staticmethod
+ @compilable
+ def advance_state(self, t, dt): #t, dt, spike_reset, v_reset, thr, tau_v, R_m, g_Na, g_K, g_L, v_Na, v_K, v_L, j, v, m, n, h, tols, intgFlag
+ _j = self.j.get() * self.resist_m
+ alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(self.v.get())
## integrate voltage / membrane potential
- if intgFlag == 1: ## midpoint method
- _, _v = step_rk2(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
+ if self.intgFlag == 1: ## midpoint method
+ _, _v = step_rk2(
+ 0., self.v.get(), dv_dt, dt,
+ (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
+ self.g_L, self.v_Na, self.v_K, self.v_L)
+ )
## next, integrate different channels
- _, _n = step_rk2(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
- _, _m = step_rk2(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
- _, _h = step_rk2(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
- elif intgFlag == 4: ## Runge-Kutta 4th order
- _, _v = step_rk4(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
+ _, _n = step_rk2(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
+ _, _m = step_rk2(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
+ _, _h = step_rk2(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
+ elif self.intgFlag == 4: ## Runge-Kutta 4th order
+ _, _v = step_rk4(
+ 0., self.v.get(), dv_dt, dt,
+ (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
+ self.g_L, self.v_Na, self.v_K, self.v_L)
+ )
## next, integrate different channels
- _, _n = step_rk4(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
- _, _m = step_rk4(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
- _, _h = step_rk4(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
+ _, _n = step_rk4(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
+ _, _m = step_rk4(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
+ _, _h = step_rk4(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
else: # integType == 0 (default -- Euler)
- _, _v = step_euler(0., v, dv_dt, dt, (_j, m + 0., n + 0., h + 0., tau_v, g_Na, g_K, g_L, v_Na, v_K, v_L))
+ _, _v = step_euler(
+ 0., self.v.get(), dv_dt, dt,
+ (_j, self.m.get() + 0., self.n.get() + 0., self.h.get() + 0., self.tau_v, self.g_Na, self.g_K,
+ self.g_L, self.v_Na, self.v_K, self.v_L)
+ )
## next, integrate different channels
- _, _n = step_euler(0., n, dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
- _, _m = step_euler(0., m, dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
- _, _h = step_euler(0., h, dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
+ _, _n = step_euler(0., self.n.get(), dx_dt, dt, (alpha_n_of_v, beta_n_of_v))
+ _, _m = step_euler(0., self.m.get(), dx_dt, dt, (alpha_m_of_v, beta_m_of_v))
+ _, _h = step_euler(0., self.h.get(), dx_dt, dt, (alpha_h_of_v, beta_h_of_v))
## obtain action potentials/spikes/pulses
- s = (_v > thr) * 1.
- if spike_reset: ## if spike-reset used, variables snapped back to initial conditions
+ s = (_v > self.thr) * 1.
+ if self.spike_reset: ## if spike-reset used, variables snapped back to initial conditions
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = (
- _calc_biophysical_constants(v * 0 + v_reset))
- _v = _v * (1. - s) + s * v_reset
+ _calc_biophysical_constants(self.v.get() * 0 + self.v_reset))
+ _v = _v * (1. - s) + s * self.v_reset
_n = _n * (1. - s) + s * (alpha_n_of_v / (alpha_n_of_v + beta_n_of_v))
_m = _m * (1. - s) + s * (alpha_m_of_v / (alpha_m_of_v + beta_m_of_v))
_h = _h * (1. - s) + s * (alpha_h_of_v / (alpha_h_of_v + beta_h_of_v))
@@ -191,32 +197,40 @@ def advance_state(
m = _m
n = _n
h = _h
- tols = (1. - s) * tols + (s * t) ## update tols
+ ## update time-of-last spike variable(s)
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
- return v, m, n, h, s, tols
+ self.v.set(v)
+ self.m.set(m)
+ self.n.set(n)
+ self.h.set(h)
+ self.s.set(s)
- @transition(output_compartments=["j", "v", "m", "n", "h", "s", "tols"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
v = restVals # + 0
alpha_n_of_v, beta_n_of_v, alpha_m_of_v, beta_m_of_v, alpha_h_of_v, beta_h_of_v = _calc_biophysical_constants(v)
- j = restVals #+ 0
+ if not self.j.targeted:
+ self.j.set(restVals)
n = alpha_n_of_v / (alpha_n_of_v + beta_n_of_v)
m = alpha_m_of_v / (alpha_m_of_v + beta_m_of_v)
h = alpha_h_of_v / (alpha_h_of_v + beta_h_of_v)
- s = restVals #+ 0
- tols = restVals #+ 0
- return j, v, m, n, h, s, tols
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- #jnp.savez(file_name, threshold=self.thr.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- #self.thr.set( data['threshold'] )
+ self.v.set(v)
+ self.n.set(n)
+ self.m.set(m)
+ self.h.set(h)
+ self.s.set(restVals)
+ self.tols.set(restVals)
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # #jnp.savez(file_name, threshold=self.thr.value)
+ #
+ # def load(self, directory, seeded=False, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # #self.thr.set( data['threshold'] )
@classmethod
def help(cls): ## component help function
@@ -257,20 +271,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/izhikevichCell.py b/ngclearn/components/neurons/spiking/izhikevichCell.py
index 0027f314..b94c3402 100755
--- a/ngclearn/components/neurons/spiking/izhikevichCell.py
+++ b/ngclearn/components/neurons/spiking/izhikevichCell.py
@@ -1,16 +1,11 @@
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
-from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
- step_euler, step_rk2
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn.utils.diffeq.ode_utils import get_integrator_code, step_euler, step_rk2
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
def _dfv_internal(j, v, w, b, tau_m): ## raw voltage dynamics
@@ -119,17 +114,16 @@ class IzhikevichCell(JaxComponent): ## Izhikevich neuronal cell
at an increase in computational cost (and simulation time)
"""
- # Define Functions
def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.,
tau_w=50., w_reset=8., coupling_factor=0.2, v0=-65., w0=-14.,
integration_type="euler", **kwargs):
super().__init__(name, **kwargs)
## Cell properties
- self.R_m = resist_m
+ self.resist_m = resist_m ## resistance R_m
self.tau_m = tau_m
self.tau_w = tau_w
- self.coupling = coupling_factor
+ self.coupling_factor = coupling_factor
self.v_reset = v_reset
self.w_reset = w_reset
@@ -153,45 +147,47 @@ def __init__(self, name, n_units, tau_m=1., resist_m=1., v_thr=30., v_reset=-65.
self.s = Compartment(restVals)
self.tols = Compartment(restVals) ## time-of-last-spike
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def advance_state(t, dt, tau_m, tau_w, v_thr, coupling, v_reset, w_reset, R_m,
- intgFlag, j, v, w, s, tols):
+ @compilable
+ def advance_state(self, t, dt):
## note: a = 0.1 --> fast spikes, a = 0.02 --> regular spikes
- a = 1. / tau_w ## we map time constant to variable "a" (a = 1/tau_w)
- _j = j * R_m
+ a = 1. / self.tau_w ## we map time constant to variable "a" (a = 1/tau_w)
+ _j = self.j.get() * self.resist_m
# _j = jnp.maximum(-30.0, _j) ## lower-bound/clip input current
## check for spikes
- s = (v > v_thr) * 1.
+ s = (self.v.get() > self.v_thr) * 1.
## for non-spikes, evolve according to dynamics
- if intgFlag == 1:
- v_params = (_j, w, coupling, tau_m)
- _, _v = step_rk2(0., v, _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
- w_params = (_j, v, coupling, tau_w)
- _, _w = step_rk2(0., w, _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
+ if self.intgFlag == 1:
+ v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m)
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params) # _v = step_rk2(v, v_params, _dfv, dt)
+ w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w)
+ _, _w = step_rk2(0., self.w.get(), _dfw, dt, w_params) # _w = step_rk2(w, w_params, _dfw, dt)
else: # integType == 0 (default -- Euler)
- v_params = (_j, w, coupling, tau_m)
- _, _v = step_euler(0., v, _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
- w_params = (_j, v, coupling, tau_w)
- _, _w = step_euler(0., w, _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
+ v_params = (_j, self.w.get(), self.coupling_factor, self.tau_m)
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params) # _v = step_euler(v, v_params, _dfv, dt)
+ w_params = (_j, self.v.get(), self.coupling_factor, self.tau_w)
+ _, _w = step_euler(0., self.w.get(), _dfw, dt, w_params) # _w = step_euler(w, w_params, _dfw, dt)
## for spikes, snap to particular states
- _v, _w = _post_process(s, _v, _w, v, w, v_reset, w_reset)
+ _v, _w = _post_process(s, _v, _w, self.v.get(), self.w.get(), self.v_reset, self.w_reset)
v = _v
w = _w
- tols = (1. - s) * tols + (s * t) ## update tols
- return j, v, w, s, tols
+ ## update time-of-last spike variable(s)
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+
+ # self.j.set(j) ## j is not getting modified in these dynamics
+ self.v.set(v)
+ self.w.set(w)
+ self.s.set(s)
- @transition(output_compartments=["j", "v", "w", "s", "tols"])
- @staticmethod
- def reset(batch_size, n_units, v0, w0):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals # None
- v = restVals + v0
- w = restVals + w0
- s = restVals #+ 0
- tols = restVals #+ 0
- return j, v, w, s, tols
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v0)
+ self.w.set(restVals + self.w0)
+ self.s.set(restVals)
+ self.tols.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -219,8 +215,7 @@ def help(cls): ## component help function
"v_rest": "Resting membrane potential value",
"v_reset": "Reset membrane potential value",
"w_reset": "Reset recover variable value",
- "coupling_factor": "Degree to which recovery variable is sensitive to "
- "subthreshold voltage fluctuations",
+ "coupling_factor": "Degree to which recovery variable is sensitive to subthreshold voltage fluctuations",
"v0": "Initial condition for membrane potential/voltage",
"w0": "Initial condition for recovery variable",
"integration_type": "Type of numerical integration to use for the cell dynamics"
@@ -232,20 +227,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/quadLIFCell.py b/ngclearn/components/neurons/spiking/quadLIFCell.py
index ec7bbd32..6d7c95b6 100755
--- a/ngclearn/components/neurons/spiking/quadLIFCell.py
+++ b/ngclearn/components/neurons/spiking/quadLIFCell.py
@@ -1,18 +1,15 @@
from ngclearn.components.jaxComponent import JaxComponent
-from jax import numpy as jnp, random, jit, nn
-from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from jax import numpy as jnp, random, jit, nn, Array
+from ngcsimlib import deprecate_args
from ngcsimlib.logger import info, warn
from ngclearn.utils.diffeq.ode_utils import get_integrator_code, \
step_euler, step_rk2
-from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
- triangular_estimator,
- straight_through_estimator)
+# from ngclearn.utils.surrogate_fx import (secant_lif_estimator, arctan_estimator,
+# triangular_estimator,
+# straight_through_estimator)
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.neurons.spiking.LIFCell import LIFCell
@@ -30,7 +27,7 @@ def _dfv(t, v, params): ## voltage dynamics wrapper
return dv_dt
#@partial(jit, static_argnums=[3, 4])
-def _update_theta(dt, v_theta, s, tau_theta, theta_plus=0.05):
+def _update_theta(dt, v_theta, s, tau_theta, theta_plus: Array=0.05):
### Runs homeostatic threshold update dynamics one step (via Euler integration).
#theta_decay = 0.9999999 #0.999999762 #jnp.exp(-dt/1e7)
#theta_plus = 0.05
@@ -117,132 +114,88 @@ class QuadLIFCell(LIFCell): ## quadratic integrate-and-fire cell
(straight-through estimator), "triangular" (triangular estimator),
"arctan" (arc-tangent estimator), and "secant_lif" (the
LIF-specialized secant estimator)
+
+ v_min: minimum voltage to clamp dynamics to (Default: None)
""" ## batch_size arg?
- @deprecate_args(thr_jitter=None, critical_v="critical_V")
+ @deprecate_args(thr_jitter=None, critical_V="critical_v")
def __init__(
self, name, n_units, tau_m, resist_m=1., thr=-52., v_rest=-65., v_reset=-60., v_scale=-41.6, critical_v=1.,
tau_theta=1e7, theta_plus=0.05, refract_time=5., one_spike=False, integration_type="euler",
- surrogate_type="straight_through", lower_clamp_voltage=True, **kwargs
+ surrogate_type="straight_through", v_min=None, **kwargs
):
super().__init__(
name, n_units, tau_m, resist_m, thr, v_rest, v_reset, 1., tau_theta, theta_plus, refract_time,
- one_spike, integration_type, surrogate_type, lower_clamp_voltage, **kwargs
+ one_spike, integration_type, surrogate_type, v_min=v_min, **kwargs
)
+
## only two distinct additional constants distinguish the Quad-LIF cell
self.v_c = v_scale
self.a0 = critical_v
- @transition(output_compartments=["v", "s", "s_raw", "rfr", "thr_theta", "tols", "key", "surrogate"])
- @staticmethod
- def advance_state(
- t, dt, tau_m, resist_m, v_rest, v_reset, v_c, a0, refract_T, thr, tau_theta, theta_plus,
- one_spike, lower_clamp_voltage, intgFlag, d_spike_fx, key, j, v, rfr, thr_theta, tols
- ):
- skey = None ## this is an empty dkey if single_spike mode turned off
- if one_spike:
- key, skey = random.split(key, 2)
- ## run one integration step for neuronal dynamics
- j = j * resist_m
- ############################################################################
- ### Runs leaky integrator (leaky integrate-and-fire; LIF) neuronal dynamics.
- _v_thr = thr_theta + thr #v_theta + v_thr ## calc present voltage threshold
- #mask = (rfr >= refract_T).astype(jnp.float32) # get refractory mask
- ## update voltage / membrane potential
- v_params = (j, rfr, tau_m, refract_T, v_rest, v_c, a0)
- if intgFlag == 1:
- _, _v = step_rk2(0., v, _dfv, dt, v_params)
- else: #_v = v + (v_rest - v) * (dt/tau_m) + (j * mask)
- _, _v = step_euler(0., v, _dfv, dt, v_params)
- ## obtain action potentials/spikes
+ @compilable
+ def advance_state(self, dt, t):
+ j = self.j.get() * self.resist_m
+
+ _v_thr = self.thr_theta.get() + self.thr ## calc present voltage threshold
+
+ v_params = (j, self.rfr.get(), self.tau_m, self.refract_T, self.v_rest, self.v_c, self.a0)
+
+ if self.intgFlag == 1:
+ _, _v = step_rk2(0., self.v.get(), _dfv, dt, v_params)
+ else:
+ _, _v = step_euler(0., self.v.get(), _dfv, dt, v_params)
+
s = (_v > _v_thr) * 1.
- ## update refractory variables
- _rfr = (rfr + dt) * (1. - s)
- ## perform hyper-polarization of neuronal cells
- _v = _v * (1. - s) + s * v_reset
-
- raw_s = s + 0 ## preserve un-altered spikes
- ############################################################################
- ## this is a spike post-processing step
- if skey is not None:
- m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
+ _rfr = (self.rfr.get() + dt) * (1. - s)
+ _v = _v * (1. - s) + s * self.v_reset
+
+ raw_s = s
+
+ #surrogate = d_spike_fx(v, _v_thr) # d_spike_fx(v, thr + thr_theta)
+
+ if self.one_spike and not self.max_one_spike:
+ key, skey = random.split(self.key.get(), 2)
+
+ m_switch = (jnp.sum(s) > 0.).astype(jnp.float32) ## TODO: not batch-able
rS = s * random.uniform(skey, s.shape)
rS = nn.one_hot(jnp.argmax(rS, axis=1), num_classes=s.shape[1],
dtype=jnp.float32)
s = s * (1. - m_switch) + rS * m_switch
- ############################################################################
- raw_spikes = raw_s
- v = _v
- rfr = _rfr
+ self.key.set(key)
- surrogate = d_spike_fx(v, _v_thr) #d_spike_fx(v, thr + thr_theta)
- if tau_theta > 0.:
+ if self.max_one_spike:
+ rS = nn.one_hot(jnp.argmax(self.v.get(), axis=1), num_classes=s.shape[1],
+ dtype=jnp.float32) ## get max-volt spike
+ s = s * rS ## mask out non-max volt spikes
+
+ if self.tau_theta > 0.:
## run one integration step for threshold dynamics
- thr_theta = _update_theta(dt, thr_theta, raw_spikes, tau_theta, theta_plus)
+ thr_theta = _update_theta(dt, self.thr_theta.get(), raw_s, self.tau_theta, self.theta_plus) # .get())
+ self.thr_theta.set(thr_theta)
+
## update tols
- tols = (1. - s) * tols + (s * t)
- if lower_clamp_voltage: ## ensure voltage never < v_rest
- v = jnp.maximum(v, v_rest)
- return v, s, raw_spikes, rfr, thr_theta, tols, key, surrogate
-
- @transition(output_compartments=["j", "v", "s", "s_raw", "rfr", "tols", "surrogate"])
- @staticmethod
- def reset(batch_size, n_units, v_rest, refract_T):
- restVals = jnp.zeros((batch_size, n_units))
- j = restVals #+ 0
- v = restVals + v_rest
- s = restVals #+ 0
- s_raw = restVals
- rfr = restVals + refract_T
- #thr_theta = restVals ## do not reset thr_theta
- tols = restVals #+ 0
- surrogate = restVals + 1.
- return j, v, s, s_raw, rfr, tols, surrogate
-
- def save(self, directory, **kwargs):
- ## do a protected save of constants, depending on whether they are floats or arrays
- tau_m = (self.tau_m if isinstance(self.tau_m, float)
- else jnp.asarray([[self.tau_m * 1.]]))
- thr = (self.thr if isinstance(self.thr, float)
- else jnp.asarray([[self.thr * 1.]]))
- v_rest = (self.v_rest if isinstance(self.v_rest, float)
- else jnp.asarray([[self.v_rest * 1.]]))
- v_reset = (self.v_reset if isinstance(self.v_reset, float)
- else jnp.asarray([[self.v_reset * 1.]]))
- v_decay = (self.v_decay if isinstance(self.v_decay, float)
- else jnp.asarray([[self.v_decay * 1.]]))
- resist_m = (self.resist_m if isinstance(self.resist_m, float)
- else jnp.asarray([[self.resist_m * 1.]]))
- tau_theta = (self.tau_theta if isinstance(self.tau_theta, float)
- else jnp.asarray([[self.tau_theta * 1.]]))
- theta_plus = (self.theta_plus if isinstance(self.theta_plus, float)
- else jnp.asarray([[self.theta_plus * 1.]]))
-
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- threshold_theta=self.thr_theta.value,
- tau_m=tau_m, thr=thr, v_rest=v_rest,
- v_reset=v_reset, v_decay=v_decay,
- resist_m=resist_m, tau_theta=tau_theta,
- theta_plus=theta_plus,
- key=self.key.value)
-
- def load(self, directory, seeded=False, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.thr_theta.set(data['threshold_theta'])
- ## constants loaded in
- self.tau_m = data['tau_m']
- self.thr = data['thr']
- self.v_rest = data['v_rest']
- self.v_reset = data['v_reset']
- self.v_decay = data['v_decay']
- self.resist_m = data['resist_m']
- self.tau_theta = data['tau_theta']
- self.theta_plus = data['theta_plus']
-
- if seeded:
- self.key.set(data['key'])
+ self.tols.set((1. - s) * self.tols.get() + (s * t))
+
+ if self.v_min is not None: ## ensures voltage never < v_rest
+ _v = jnp.maximum(_v, self.v_min)
+
+ self.v.set(_v)
+ self.s.set(s)
+ self.s_raw.set(raw_s)
+ self.rfr.set(_rfr)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ if not self.j.targeted:
+ self.j.set(restVals)
+ self.v.set(restVals + self.v_rest)
+ self.s.set(restVals)
+ self.s_raw.set(restVals)
+ self.rfr.set(restVals + self.refract_T)
+ self.tols.set(restVals)
+ #self.surrogate.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -272,17 +225,13 @@ def help(cls): ## component help function
"v_reset": "Reset membrane potential value",
"v_decay": "Voltage leak/decay factor",
"tau_theta": "Threshold/homoestatic increment time constant",
- "theta_plus": "Amount to increment threshold by upon occurrence "
- "of spike",
+ "theta_plus": "Amount to increment threshold by upon occurrence of a spike",
"refract_time": "Length of relative refractory period (ms)",
- "one_spike": "Should only one spike be sampled/allowed to emit at "
- "any given time step?",
- "integration_type": "Type of numerical integration to use for the "
- "cell dynamics",
+ "one_spike": "Should only one spike be sampled/allowed to emit at any given time step?",
+ "integration_type": "Type of numerical integration to use for the cell dynamics",
"surrgoate_type": "Type of surrogate function to use approximate "
"derivative of spike w.r.t. voltage/current",
- "lower_bound_clamp": "Should voltage be lower bounded to be never "
- "be below `v_rest`"
+ "v_min": "Minimum voltage allowed before voltage variables are min-clipped/clamped"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
@@ -290,20 +239,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/neurons/spiking/sLIFCell.py b/ngclearn/components/neurons/spiking/sLIFCell.py
index 76736aec..6b0c6fd8 100644
--- a/ngclearn/components/neurons/spiking/sLIFCell.py
+++ b/ngclearn/components/neurons/spiking/sLIFCell.py
@@ -1,15 +1,13 @@
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
from ngclearn.utils.diffeq.ode_utils import step_euler
from ngclearn.utils.surrogate_fx import secant_lif_estimator
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@jit
def _dfv_internal(j, v, rfr, tau_m, refract_T): ## raw voltage dynamics
@@ -101,7 +99,9 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
refract_time: relative refractory period time (ms; Default: 1 ms)
- rho_b: threshold sparsity factor (Default: 0)
+ rho_b: threshold sparsity factor (Default: 0); note that setting rho_b > 0 will
+ force the adaptive threshold to follow dynamics that ignore `thr_grain` and
+ `thr_leak`
sticky_spikes: if True, spike variables will be pinned to action potential
value (i.e, 1) throughout duration of the refractory period; this recovers
@@ -112,7 +112,6 @@ class SLIFCell(JaxComponent): ## leaky integrate-and-fire cell
batch_size: batch size dimension of this cell (Default: 1)
"""
- # Define Functions
def __init__(
self, name, n_units, tau_m, resist_m, thr, resist_inh=0., thr_persist=False, thr_gain=0.0, thr_leak=0.0,
rho_b=0., refract_time=0., sticky_spikes=False, thr_jitter=0.05, batch_size=1, **kwargs
@@ -132,7 +131,7 @@ def __init__(
## create simple recurrent inhibitory pressure
self.inh_R = resist_inh ## lateral inhibitory magnitude
- key, subkey = random.split(self.key.value)
+ key, subkey = random.split(self.key.get())
self.inh_weights = random.uniform(subkey, (n_units, n_units), minval=0.025, maxval=1.)
self.inh_weights = self.inh_weights * (1. - jnp.eye(n_units))
@@ -162,12 +161,8 @@ def __init__(
self.rfr = Compartment(restVals + self.refract_T) ## refractory variable(s)
self.surrogate = Compartment(restVals + 1.) ## surrogate signal
- @transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"])
- @staticmethod
- def advance_state(
- t, dt, inh_weights, R_m, inh_R, d_spike_fx, tau_m, spike_fx, refract_T, thrGain,
- thrLeak, rho_b, sticky_spikes, v_min, j, s, v, thr, rfr, tols
- ):
+ @compilable
+ def advance_state(self, t, dt):
#####################################################################################
#The following 3 lines of code modify electrical current j via application of a
#scalar membrane resistance value and an approximate form of lateral inhibition.
@@ -180,20 +175,31 @@ def advance_state(
#| R_m: membrane resistance (to multiply/scale j by),
#| inh_R: inhibitory resistance to scale lateral inhibitory current by; if inh_R = 0,
# NO lateral inhibitory pressure will be applied
- j = j * R_m
- if inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied
- j = j - (jnp.matmul(spikes, inh_weights) * inh_R)
+
+ # First, get the relevant compartment values
+ j = self.j.get()
+ # s = self.s.get() # NOTE: This is unused
+ tols = self.tols.get()
+ v = self.v.get()
+ thr = self.thr.get()
+ rfr = self.rfr.get()
+ surrogate = self.surrogate.get()
+ ## modify electrical current j via membrane resistance and lateral inhibition
+
+ j = j * self.R_m
+ if self.inh_R > 0.: ## if inh_R > 0, then lateral inhibition is applied
+ j = j - (jnp.matmul(self.s.get(), self.inh_weights) * self.inh_R)
#####################################################################################
- surrogate = d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes
+ surrogate = self.d_spike_fx(j, c1=0.82, c2=0.08) ## calc surrogate deriv of spikes
## transition to: voltage(t+dt), spikes, threshold(t+dt), refractory_variables(t+dt)
- v_params = (j, rfr, tau_m, refract_T)
+ v_params = (j, rfr, self.tau_m, self.refract_T)
_, _v = step_euler(0., v, _dfv, dt, v_params)
- spikes = spike_fx(_v, thr)
+ spikes = self.spike_fx(_v, thr)
#_v = _hyperpolarize(_v, spikes)
_v = (1. - spikes) * _v ## hyper-polarize cells
- new_thr = _update_threshold(dt, thr, spikes, thrGain, thrLeak, rho_b)
- _rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, refract_T, sticky_spikes)
+ new_thr = _update_threshold(dt, thr, spikes, self.thrGain, self.thrLeak, self.rho_b)
+ _rfr, spikes = _update_refract_and_spikes(dt, rfr, spikes, self.refract_T, self.sticky_spikes)
v = _v
s = spikes
thr = new_thr
@@ -201,34 +207,48 @@ def advance_state(
## update tols
tols = (1. - s) * tols + (s * t)
- return j, s, tols, v, thr, rfr, surrogate
-
- @transition(output_compartments=["j", "s", "tols", "v", "thr", "rfr", "surrogate"])
- @staticmethod
- def reset(refract_T, thr_persist, threshold0, batch_size, n_units, thr):
- restVals = jnp.zeros((batch_size, n_units))
+ # return j, s, tols, v, thr, rfr, surrogate
+ self.j.set(j)
+ self.s.set(s)
+ self.tols.set(tols)
+ self.v.set(v)
+ self.thr.set(thr)
+ self.rfr.set(rfr)
+ self.surrogate.set(surrogate)
+
+ @compilable
+ def reset(self):
+ # refract_T, thr_persist, threshold0, batch_size, n_units, thr
+ restVals = jnp.zeros((self.batch_size, self.n_units))
voltage = restVals
- refract = restVals + refract_T
+ refract = restVals + self.refract_T
current = restVals
surrogate = restVals + 1.
timeOfLastSpike = restVals
spikes = restVals
- if not thr_persist: ## if thresh non-persistent, reset to base value
- thr = threshold0 + 0
- return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate
+ if not self.thr_persist: ## if thresh non-persistent, reset to base value
+ thr = self.threshold0 + 0
+ self.thr.set(thr)
+ # return current, spikes, timeOfLastSpike, voltage, thr, refract, surrogate
+ self.j.set(current)
+ self.s.set(spikes)
+ self.tols.set(timeOfLastSpike)
+ self.v.set(voltage)
+ self.rfr.set(refract)
+ self.surrogate.set(surrogate)
def save(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
if self.thr_persist == False:
jnp.savez(file_name, threshold=self.threshold0) # save threshold0
else:
- jnp.savez(file_name, threshold=self.thr.value) # save the actual threshold param/compartment
+ jnp.savez(file_name, threshold=self.thr.get()) # save the actual threshold param/compartment
def load(self, directory, **kwargs):
file_name = directory + "/" + self.name + ".npz"
data = jnp.load(file_name)
self.thr.set(data['threshold'])
- self.threshold0 = self.thr.value + 0
+ self.threshold0 = self.thr.get() + 0
@classmethod
def help(cls): ## component help function
@@ -269,20 +289,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/other/expKernel.py b/ngclearn/components/other/expKernel.py
index a074c30e..7c99049f 100644
--- a/ngclearn/components/other/expKernel.py
+++ b/ngclearn/components/other/expKernel.py
@@ -2,12 +2,8 @@
from jax import numpy as jnp, random, jit
from functools import partial
from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@partial(jit, static_argnums=[5,6])
def _apply_kernel(tf_curr, s, t, tau_w, win_len, krn_start, krn_end):
@@ -49,7 +45,6 @@ class ExpKernel(JaxComponent): ## exponential kernel
batch_size: batch size dimension of this cell (Default: 1)
"""
- # Define Functions
def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs):
super().__init__(name, **kwargs)
@@ -67,21 +62,31 @@ def __init__(self, name, n_units, dt, tau_w=500., nu=4., batch_size=1, **kwargs)
## window of spike times
self.tf = Compartment(jnp.zeros((self.win_len, self.batch_size, self.n_units)))
- @transition(output_compartments=["epsp", "tf"])
- @staticmethod
- def advance_state(t, tau_w, win_len, inputs, tf):
+ @compilable
+ def advance_state(self, t):
+ # Get the variables
+ inputs = self.inputs.get()
+ tf = self.tf.get()
+
s = inputs
## update spike time window and corresponding window volume
- tf, epsp = _apply_kernel(tf, s, t, tau_w, win_len, krn_start=0,
- krn_end=win_len-1) #0:win_len-1)
- return epsp, tf
-
- @transition(output_compartments=["inputs", "epsp", "tf"])
- @staticmethod
- def reset(batch_size, n_units, win_len):
- restVals = jnp.zeros((batch_size, n_units))
- restTensor = jnp.zeros([win_len, batch_size, n_units], jnp.float32) # tf
- return restVals, restVals, restTensor # inputs, epsp, tf
+ tf, epsp = _apply_kernel(
+ tf, s, t, self.tau_w, self.win_len, krn_start=0, krn_end=self.win_len-1
+ ) #0:win_len-1)
+
+ # Update compartments
+ self.epsp.set(epsp)
+ self.tf.set(tf)
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units)) ## inputs, epsp
+ restTensor = jnp.zeros([self.win_len, self.batch_size, self.n_units], jnp.float32) ## tf
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.epsp.set(restVals)
+ self.tf.set(restTensor)
@classmethod
def help(cls): ## component help function
diff --git a/ngclearn/components/other/varTrace.py b/ngclearn/components/other/varTrace.py
index 94510e75..d4de9f47 100644
--- a/ngclearn/components/other/varTrace.py
+++ b/ngclearn/components/other/varTrace.py
@@ -1,13 +1,10 @@
+# %%
+
from ngclearn.components.jaxComponent import JaxComponent
from jax import numpy as jnp, random, jit
from functools import partial
-from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
-from ngcsimlib.logger import info, warn
-
-from ngcsimlib.compilers.process import transition
-#from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
@partial(jit, static_argnums=[4])
def _run_varfilter(dt, x, x_tr, decayFactor, gamma_tr, a_delta=0.):
@@ -77,17 +74,17 @@ class VarTrace(JaxComponent): ## low-pass filter
batch_size: batch size dimension of this cell (Default: 1)
"""
- # Define Functions
def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay_type="exp",
- n_nearest_spikes=0, batch_size=1, **kwargs):
- super().__init__(name, **kwargs)
+ n_nearest_spikes=0, batch_size=1, key=None):
+ super().__init__(name, key)
## Trace control coefficients
+ self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
+
self.tau_tr = tau_tr ## trace time constant
self.a_delta = a_delta ## trace increment (if spike occurred)
self.P_scale = P_scale ## trace scale if non-additive trace to be used
self.gamma_tr = gamma_tr
- self.decay_type = decay_type ## lin --> linear decay; exp --> exponential decay
self.n_nearest_spikes = n_nearest_spikes
## Layer Size Setup
@@ -99,32 +96,37 @@ def __init__(self, name, n_units, tau_tr, a_delta, P_scale=1., gamma_tr=1, decay
self.outputs = Compartment(restVals) # output compartment
self.trace = Compartment(restVals)
- @transition(output_compartments=["outputs", "trace"])
- @staticmethod
- def advance_state(
- dt, decay_type, tau_tr, a_delta, P_scale, gamma_tr, inputs, trace, n_nearest_spikes
- ):
- decayFactor = 0.
- if "exp" in decay_type:
- decayFactor = jnp.exp(-dt/tau_tr)
- elif "lin" in decay_type:
- decayFactor = (1. - dt/tau_tr)
- _x_tr = gamma_tr * trace * decayFactor
- if n_nearest_spikes > 0: ## run k-nearest neighbor trace
- _x_tr = _x_tr + inputs * (a_delta - (trace/n_nearest_spikes))
+ @compilable
+ def advance_state(self, dt):
+ if "exp" in self.decay_type:
+ decayFactor = jnp.exp(-dt/self.tau_tr)
+ elif "lin" in self.decay_type:
+ decayFactor = (1. - dt/self.tau_tr)
+ else:
+ decayFactor = 0.
+
+
+ _x_tr = self.gamma_tr * self.trace.get() * decayFactor
+ if self.n_nearest_spikes > 0:
+ _x_tr = _x_tr + self.inputs.get() * (self.a_delta - (self.trace.get() / self.n_nearest_spikes))
else:
- if a_delta > 0.: ## run full convolution trace
- _x_tr = _x_tr + inputs * a_delta
- else: ## run simple max-clamped trace
- _x_tr = _x_tr * (1. - inputs) + inputs * P_scale
- trace = _x_tr
- return trace, trace
-
- @transition(output_compartments=["inputs", "outputs", "trace"])
- @staticmethod
- def reset(batch_size, n_units):
- restVals = jnp.zeros((batch_size, n_units))
- return restVals, restVals, restVals
+ if self.a_delta > 0.:
+ _x_tr = _x_tr + self.inputs.get() * self.a_delta
+ else:
+ _x_tr = _x_tr * (1. - self.inputs.get()) + self.inputs.get() * self.P_scale
+
+ self.trace.set(_x_tr)
+ self.outputs.set(_x_tr)
+
+
+ @compilable
+ def reset(self):
+ restVals = jnp.zeros((self.batch_size, self.n_units))
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(restVals)
+ self.outputs.set(restVals)
+ self.trace.set(restVals)
@classmethod
def help(cls): ## component help function
@@ -159,19 +161,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
if __name__ == '__main__':
from ngcsimlib.context import Context
diff --git a/ngclearn/components/synapses/STPDenseSynapse.py b/ngclearn/components/synapses/STPDenseSynapse.py
index 4fc1a81b..31cf7c67 100755
--- a/ngclearn/components/synapses/STPDenseSynapse.py
+++ b/ngclearn/components/synapses/STPDenseSynapse.py
@@ -1,12 +1,10 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.utils.weight_distribution import initialize_params
from ngcsimlib.logger import info
+
+from ngclearn.utils.distribution_generator import DistributionGenerator
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable
"""
@@ -56,10 +54,10 @@ class STPDenseSynapse(DenseSynapse): ## short-term plastic synaptic cable
resources_int: initialization kernel for synaptic resources matrix
"""
- # Define Functions
- def __init__(self, name, shape, weight_init=None, bias_init=None,
- resist_scale=1., p_conn=1., tau_f=750., tau_d=50.,
- resources_init=None, **kwargs):
+ def __init__(
+ self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., tau_f=750., tau_d=50.,
+ resources_init=None, **kwargs
+ ):
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
## STP meta-parameters
self.resources_init = resources_init
@@ -67,69 +65,72 @@ def __init__(self, name, shape, weight_init=None, bias_init=None,
self.tau_d = tau_d
## Set up short-term plasticity / dynamic synapse compartment values
- tmp_key, *subkeys = random.split(self.key.value, 4)
+ tmp_key, *subkeys = random.split(self.key.get(), 4)
preVals = jnp.zeros((self.batch_size, shape[0]))
self.u = Compartment(preVals) ## release prob variables
self.x = Compartment(preVals + 1) ## resource availability variables
- self.Wdyn = Compartment(self.weights.value * 0) ## dynamic synapse values
+ self.Wdyn = Compartment(self.weights.get() * 0) ## dynamic synapse values
if self.resources_init is None:
info(self.name, "is using default resources value initializer!")
- self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15
+ #self.resources_init = {"dist": "uniform", "amin": 0.125, "amax": 0.175} # 0.15
+ self.resources_init = DistributionGenerator.uniform(low=0.125, high=0.175)
self.resources = Compartment(
- initialize_params(subkeys[2], self.resources_init, shape)
+ self.resources_init(shape, subkeys[2]) #initialize_params(subkeys[2], self.resources_init, shape)
) ## matrix U - synaptic resources matrix
- @transition(output_compartments=["outputs", "u", "x", "Wdyn"])
- @staticmethod
- def advance_state(
- tau_f, tau_d, Rscale, inputs, weights, biases, resources, u, x, Wdyn
- ):
- s = inputs
+ @compilable
+ def advance_state(self, t, dt):
+ s = self.inputs.get()
## compute short-term facilitation
#u = u - u * (1./tau_f) + (resources * (1. - u)) * s
- if tau_f > 0.: ## compute short-term facilitation
- u = u - u * (1./tau_f) + (resources * (1. - u)) * s
+ if self.tau_f > 0.: ## compute short-term facilitation
+ u = self.u.get() - self.u.get() * (1./self.tau_f) + (self.resources.get() * (1. - self.u.get())) * s
else:
- u = resources ## disabling STF yields fixed resource u variables
+ u = self.resources.get() ## disabling STF yields fixed resource u variables
## compute dynamic synaptic values/conductances
- Wdyn = (weights * u * x) * s + Wdyn * (1. - s) ## OR: -W/tau_w + W * u * x
- if tau_d > 0.:
- ## compute short-term depression
- x = x + (1. - x) * (1./tau_d) - u * x * s
- outputs = jnp.matmul(inputs, Wdyn * Rscale) + biases
- return outputs, u, x, Wdyn
-
- @transition(output_compartments=["inputs", "outputs", "u", "x", "Wdyn"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- u = preVals
- x = preVals + 1
- Wdyn = jnp.zeros(shape)
- return inputs, outputs, u, x, Wdyn
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name,
- weights=self.weights.value,
- biases=self.biases.value,
- resources=self.resources.value)
- else:
- jnp.savez(file_name,
- weights=self.weights.value,
- resources=self.resources.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- self.resources.set(data['resources'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ Wdyn = (self.weights.get() * u * self.x.get()) * s + self.Wdyn.get() * (1. - s) ## OR: -W/tau_w + W * u * x
+ ## compute short-term depression
+ x = self.x.get()
+ if self.tau_d > 0.:
+ x = x + (1. - x) * (1./self.tau_d) - u * x * s
+ ## else, do nothing with x (keep it pointing to current x compartment)
+ outputs = jnp.matmul(self.inputs.get(), Wdyn * self.resist_scale) + self.biases.get()
+
+ self.outputs.set(outputs)
+ self.u.set(u)
+ self.x.set(x)
+ self.Wdyn.set(Wdyn)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.u.set(preVals)
+ self.x.set(preVals + 1)
+ self.Wdyn.set(jnp.zeros(self.shape.get()))
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # if self.bias_init != None:
+ # jnp.savez(file_name,
+ # weights=self.weights.value,
+ # biases=self.biases.value,
+ # resources=self.resources.value)
+ # else:
+ # jnp.savez(file_name,
+ # weights=self.weights.value,
+ # resources=self.resources.value)
+ #
+ # def load(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # self.weights.set(data['weights'])
+ # self.resources.set(data['resources'])
+ # if "biases" in data.keys():
+ # self.biases.set(data['biases'])
@classmethod
def help(cls): ## component help function
@@ -166,17 +167,3 @@ def help(cls): ## component help function
"dW/dt = W_full * u * x * inputs",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/__init__.py b/ngclearn/components/synapses/__init__.py
index 2c21c231..95bf3f70 100644
--- a/ngclearn/components/synapses/__init__.py
+++ b/ngclearn/components/synapses/__init__.py
@@ -1,21 +1,19 @@
from .denseSynapse import DenseSynapse
from .staticSynapse import StaticSynapse
-
## short-term plasticity components
from .STPDenseSynapse import STPDenseSynapse
from .exponentialSynapse import ExponentialSynapse
-from .doubleExpSynapse import DoupleExpSynapse
+from .doubleExpSynapse import DoubleExpSynapse
from .alphaSynapse import AlphaSynapse
## dense synaptic components
-from .hebbian.hebbianSynapse import HebbianSynapse
+# from .hebbian.hebbianSynapse import HebbianSynapse
from .hebbian.traceSTDPSynapse import TraceSTDPSynapse
from .hebbian.expSTDPSynapse import ExpSTDPSynapse
from .hebbian.eventSTDPSynapse import EventSTDPSynapse
from .hebbian.BCMSynapse import BCMSynapse
-
## conv/deconv synaptic components
from .convolution.convSynapse import ConvSynapse
from .convolution.staticConvSynapse import StaticConvSynapse
@@ -26,10 +24,9 @@
from .convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
from .convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
-
## modulated synaptic components
from .modulated.MSTDPETSynapse import MSTDPETSynapse
-from .modulated.REINFORCESynapse import REINFORCESynapse
+# from .modulated.REINFORCESynapse import REINFORCESynapse
## patched synaptic components
from .patched.patchedSynapse import PatchedSynapse
diff --git a/ngclearn/components/synapses/alphaSynapse.py b/ngclearn/components/synapses/alphaSynapse.py
index cf5f9543..cbdbb8c8 100644
--- a/ngclearn/components/synapses/alphaSynapse.py
+++ b/ngclearn/components/synapses/alphaSynapse.py
@@ -1,12 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
"""
@@ -62,10 +58,9 @@ class AlphaSynapse(DenseSynapse): ## dynamic alpha synapse cable
"""
- # Define Functions
def __init__(
- self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
- is_nonplastic=True, **kwargs
+ self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1.,
+ p_conn=1., is_nonplastic=True, **kwargs
):
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
## dynamic synapse meta-parameters
@@ -82,55 +77,41 @@ def __init__(
self.g_syn = Compartment(postVals) ## conductance variable
self.h_syn = Compartment(postVals) ## intermediate conductance variable
if is_nonplastic:
- self.weights.set(self.weights.value * 0 + 1.)
+ self.weights.set(self.weights.get() * 0 + 1.)
- @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
- @staticmethod
- def advance_state(
- dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
- ):
- s = inputs
+ @compilable
+ def advance_state(self, t, dt):
+ s = self.inputs.get()
## advance conductance variable(s)
- _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
- dhsyn_dt = -h_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
- h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
+ _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
+ dhsyn_dt = -self.h_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt)
+ h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
- dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay
- g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
+ dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt) # or -g_syn/tau_decay + h_syn/tau_decay
+ g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g
## compute derive electrical current variable
- i_syn = -g_syn * Rscale
- if syn_rest is not None:
- i_syn = -(g_syn * Rscale) * (v - syn_rest)
- outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
- return outputs, i_syn, g_syn, h_syn
-
- @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- i_syn = postVals
- g_syn = postVals
- h_syn = postVals
- v = postVals
- return inputs, outputs, i_syn, g_syn, h_syn, v
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ i_syn = -g_syn * self.resist_scale
+ if self.syn_rest is not None:
+ i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
+ outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases
+
+ self.outputs.set(outputs)
+ self.i_syn.set(i_syn)
+ self.g_syn.set(g_syn)
+ self.h_syn.set(h_syn)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.i_syn.set(postVals)
+ self.g_syn.set(postVals)
+ self.h_syn.set(postVals)
+ self.v.set(postVals)
@classmethod
def help(cls): ## component help function
@@ -170,17 +151,3 @@ def help(cls): ## component help function
"dgsyn_dt = -g_syn/tau_decay + h_syn",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/convolution/__init__.py b/ngclearn/components/synapses/convolution/__init__.py
index ed305c38..0724f25f 100755
--- a/ngclearn/components/synapses/convolution/__init__.py
+++ b/ngclearn/components/synapses/convolution/__init__.py
@@ -6,3 +6,4 @@
from .hebbianDeconvSynapse import HebbianDeconvSynapse
from .traceSTDPConvSynapse import TraceSTDPConvSynapse
from .traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
+
diff --git a/ngclearn/components/synapses/convolution/convSynapse.py b/ngclearn/components/synapses/convolution/convSynapse.py
index 12c5e674..5af6810f 100755
--- a/ngclearn/components/synapses/convolution/convSynapse.py
+++ b/ngclearn/components/synapses/convolution/convSynapse.py
@@ -1,15 +1,12 @@
from jax import random, numpy as jnp, jit
-from ngclearn.components.jaxComponent import JaxComponent
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.utils.weight_distribution import initialize_params
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn.components.synapses.convolution.ngcconv import conv2d
+from ngclearn.components.jaxComponent import JaxComponent
+
class ConvSynapse(JaxComponent): ## base-level convolutional cable
"""
A base convolutional synaptic cable.
@@ -47,7 +44,6 @@ class ConvSynapse(JaxComponent): ## base-level convolutional cable
batch_size: batch size dimension of this component
"""
- # Define Functions
def __init__(
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
batch_size=1, **kwargs
@@ -61,7 +57,7 @@ def __init__(
self.shape = shape ## shape of synaptic filter tensor
x_size, x_size = x_shape
self.x_size = x_size
- self.Rscale = resist_scale ## post-transformation scale factor
+ self.resist_scale = resist_scale ## post-transformation scale factor
self.padding = padding
self.stride = stride
@@ -69,7 +65,7 @@ def __init__(
k_size, k_size, n_in_chan, n_out_chan = shape
self.pad_args = None
if self.padding is not None and self.padding == "SAME":
- if (x_size % stride == 0):
+ if x_size % stride == 0:
pad_along_height = max(k_size - stride, 0)
else:
pad_along_height = max(k_size - (x_size % stride), 0)
@@ -83,8 +79,13 @@ def __init__(
self.pad_args = ((0, 0), (0, 0))
######################### set up compartments ##########################
- tmp_key, *subkeys = random.split(self.key.value, 4)
- weights = dist.initialize_params(subkeys[0], filter_init, shape) ## filter tensor
+ tmp_key, *subkeys = random.split(self.key.get(), 4)
+ #weights = dist.initialize_params(subkeys[0], filter_init, shape)
+ if self.filter_init is None:
+ info(self.name, "is using default weight initializer!")
+ self.filter_init = DistributionGenerator.uniform(0.025, 0.8)
+ weights = self.filter_init(shape, subkeys[0]) ## filter tensor
+
self.batch_size = batch_size # 1
## Compartment setup and shape computation
_x = jnp.zeros((self.batch_size, x_size, x_size, n_in_chan))
@@ -95,42 +96,42 @@ def __init__(
self.outputs = Compartment(jnp.zeros(self.out_shape))
self.weights = Compartment(weights)
if self.bias_init is None:
- info(self.name, "is using default bias value of zero (no bias "
- "kernel provided)!")
+ info(self.name, "is using default bias value of zero (no bias kernel provided)!")
self.biases = Compartment(
- dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
+ #dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
+ self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0
)
- @transition(output_compartments=["outputs"])
- @staticmethod
- def advance_state(Rscale, padding, stride, weights, biases, inputs):
- _x = inputs
- outputs = conv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
- return outputs
-
- @transition(output_compartments=["inputs", "outputs"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- return inputs, outputs
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value,
- biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ @compilable
+ def advance_state(self): #Rscale, padding, stride, weights, biases, inputs):
+ _x = self.inputs.get()
+ ## FIXME: does resist_scale affect update rules?
+ outputs = conv2d(
+ _x, self.weights.get(), stride_size=self.stride, padding=self.padding
+ ) * self.resist_scale + self.biases.get()
+ self.outputs.set(outputs)
+
+ @compilable
+ def reset(self): #in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape)
+ postVals = jnp.zeros(self.out_shape)
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # if self.bias_init != None:
+ # jnp.savez(file_name, weights=self.weights.get(),
+ # biases=self.biases.get())
+ # else:
+ # jnp.savez(file_name, weights=self.weights.get())
+ #
+ # def load(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # self.weights.set(data['weights'])
+ # if "biases" in data.keys():
+ # self.biases.set(data['biases'])
@classmethod
def help(cls): ## component help function
@@ -163,17 +164,3 @@ def help(cls): ## component help function
"dynamics": "outputs = [K @ inputs] * R + b",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/convolution/deconvSynapse.py b/ngclearn/components/synapses/convolution/deconvSynapse.py
index 13d78c6b..cf52d9d7 100755
--- a/ngclearn/components/synapses/convolution/deconvSynapse.py
+++ b/ngclearn/components/synapses/convolution/deconvSynapse.py
@@ -1,15 +1,13 @@
from jax import random, numpy as jnp, jit
-from ngclearn.components.jaxComponent import JaxComponent
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.utils.weight_distribution import initialize_params
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn.components.synapses.convolution.ngcconv import deconv2d
+from ngclearn.components.jaxComponent import JaxComponent
+
+
class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
"""
A base deconvolutional (transposed convolutional) synaptic cable.
@@ -47,7 +45,6 @@ class DeconvSynapse(JaxComponent): ## base-level deconvolutional cable
batch_size: batch size dimension of this component
"""
- # Define Functions
def __init__(
self, name, shape, x_shape, filter_init=None, bias_init=None, stride=1, padding=None, resist_scale=1.,
batch_size=1, **kwargs
@@ -61,7 +58,7 @@ def __init__(
self.shape = shape ## shape of synaptic filter tensor
x_size, x_size = x_shape
self.x_size = x_size
- self.Rscale = resist_scale ## post-transformation scale factor
+ self.resist_scale = resist_scale ## post-transformation scale factor
self.padding = padding
self.stride = stride
@@ -70,9 +67,13 @@ def __init__(
self.pad_args = None
######################### set up compartments ##########################
- tmp_key, *subkeys = random.split(self.key.value, 4)
- weights = dist.initialize_params(subkeys[0], filter_init,
- shape) ## filter tensor
+ tmp_key, *subkeys = random.split(self.key.get(), 4)
+ #weights = dist.initialize_params(subkeys[0], filter_init, shape)
+ if self.filter_init is None:
+ info(self.name, "is using default weight initializer!")
+ self.filter_init = DistributionGenerator.uniform(0.025, 0.8)
+ weights = self.filter_init(shape, subkeys[0]) ## filter tensor
+
self.batch_size = batch_size # 1
## Compartment setup and shape computation
_x = jnp.zeros((self.batch_size, x_size, x_size, n_in_chan))
@@ -85,40 +86,40 @@ def __init__(
if self.bias_init is None:
info(self.name, "is using default bias value of zero (no bias "
"kernel provided)!")
- self.biases = Compartment(dist.initialize_params(subkeys[2], bias_init,
- (1, shape[1]))
- if bias_init else 0.0)
-
- @transition(output_compartments=["outputs"])
- @staticmethod
- def advance_state(Rscale, padding, stride, weights, biases, inputs):
- _x = inputs
- out = deconv2d(_x, weights, stride_size=stride, padding=padding) * Rscale + biases
- return out
-
- @transition(output_compartments=["inputs", "outputs"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- return inputs, outputs
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value,
- biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ self.biases = Compartment(
+ # dist.initialize_params(subkeys[2], bias_init, (1, shape[1])) if bias_init else 0.0
+ self.bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0
+ )
+
+ @compilable
+ def advance_state(self):
+ _x = self.inputs.get()
+ out = deconv2d(
+ _x, self.weights.get(), stride_size=self.stride, padding=self.padding
+ ) * self.resist_scale + self.biases.get()
+ self.outputs.set(out)
+
+ @compilable
+ def reset(self): #in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape)
+ postVals = jnp.zeros(self.out_shape)
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # if self.bias_init != None:
+ # jnp.savez(file_name, weights=self.weights.get(),
+ # biases=self.biases.get())
+ # else:
+ # jnp.savez(file_name, weights=self.weights.get())
+ #
+ # def load(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # self.weights.set(data['weights'])
+ # if "biases" in data.keys():
+ # self.biases.set(data['biases'])
@classmethod
def help(cls): ## component help function
@@ -151,17 +152,3 @@ def help(cls): ## component help function
"dynamics": "outputs = [K @.T inputs] * R + b",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py
index ff45f76b..a66242a4 100755
--- a/ngclearn/components/synapses/convolution/hebbianConvSynapse.py
+++ b/ngclearn/components/synapses/convolution/hebbianConvSynapse.py
@@ -1,13 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from .convSynapse import ConvSynapse
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse
+
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
_conv_valid_transpose_padding)
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -17,8 +12,7 @@
class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
"""
- A synaptic convolutional cable that adjusts its efficacies via a two-factor
- Hebbian adjustment rule.
+ A specialized synaptic convolutional cable that adjusts its efficacies via a two-factor Hebbian adjustment rule.
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
@@ -87,11 +81,11 @@ class HebbianConvSynapse(ConvSynapse): ## Hebbian-evolved convolutional cable
batch_size: batch size dimension of this component
"""
- # Define Functions
- def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None,
- stride=1, padding=None, resist_scale=1., w_bound=0.,
- is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
- batch_size=1, **kwargs):
+ def __init__(
+ self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None,
+ resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
+ batch_size=1, **kwargs
+ ):
super().__init__(
name, shape, x_shape=x_shape, filter_init=filter_init, bias_init=bias_init, resist_scale=resist_scale,
stride=stride, padding=padding, batch_size=batch_size, **kwargs
@@ -107,9 +101,9 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
######################### set up compartments ##########################
## Compartment setup and shape computation
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.dInputs = Compartment(jnp.zeros(self.in_shape))
- self.dBiases = Compartment(self.biases.value * 0)
+ self.dBiases = Compartment(self.biases.get() * 0)
self.pre = Compartment(jnp.zeros(self.in_shape))
self.post = Compartment(jnp.zeros(self.out_shape))
@@ -120,80 +114,75 @@ def __init__(self, name, shape, x_shape, eta=0., filter_init=None, bias_init=Non
self.antiPad = None
k_size, k_size, n_in_chan, n_out_chan = self.shape
if padding == "SAME":
- self.antiPad = _conv_same_transpose_padding(self.post.value.shape[1],
+ self.antiPad = _conv_same_transpose_padding(self.post.get().shape[1],
self.x_size, k_size, stride)
elif padding == "VALID":
- self.antiPad = _conv_valid_transpose_padding(self.post.value.shape[1],
+ self.antiPad = _conv_valid_transpose_padding(self.post.get().shape[1],
self.x_size, k_size, stride)
########################################################################
## set up outer optimization compartments
- self.opt_params = Compartment(get_opt_init_fn(optim_type)(
- [self.weights.value, self.biases.value]
- if bias_init else [self.weights.value]))
+ self.opt_params = Compartment(
+ get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
+ )
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
k_size, k_size, n_in_chan, n_out_chan = shape
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
- _d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
+ _d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0
_dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args)
## get filter update correction
- dx = _dK.shape[0] - weights.value.shape[0]
- dy = _dK.shape[1] - weights.value.shape[1]
+ dx = _dK.shape[0] - weights.get().shape[0]
+ dy = _dK.shape[1] - weights.get().shape[1]
self.delta_shape = (max(dx, 0), max(dy, 0))
## get input update correction
- _dx = _calc_dX_conv(weights.value, _d, stride_size=stride,
- anti_padding=pad_args)
+ _dx = _calc_dX_conv(weights.get(), _d, stride_size=stride, anti_padding=pad_args)
dx = (_dx.shape[1] - _x.shape[1])
dy = (_dx.shape[2] - _x.shape[2])
self.x_delta_shape = (dx, dy)
- @staticmethod
- def _compute_update(
- sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
- ): ## synaptic kernel adjustment calculation co-routine
+ def _compute_update(self): #sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
+ ## synaptic kernel adjustment calculation co-routine
## compute adjustment to filters
- dWeights = calc_dK_conv(pre, post, delta_shape=delta_shape, stride_size=stride, padding=pad_args)
- dWeights = dWeights * sign_value
- if w_decay > 0.: ## apply synaptic decay
- dWeights = dWeights - weights * w_decay
+ dWeights = calc_dK_conv(
+ self.pre.get(), self.post.get(), delta_shape=self.delta_shape, stride_size=self.stride, padding=self.pad_args
+ )
+ dWeights = dWeights * self.sign_value
+ if self.w_decay > 0.: ## apply synaptic decay
+ dWeights = dWeights - self.weights.get() * self.w_decay
## compute adjustment to base-rates (if applicable)
dBiases = 0. # jnp.zeros((1,1))
- if bias_init != None:
- dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value
+ if self.bias_init != None:
+ dBiases = jnp.sum(self.post.get(), axis=0, keepdims=True) * self.sign_value
return dWeights, dBiases
- @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
- @staticmethod
- def evolve(
- opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, stride, pad_args, delta_shape, pre, post,
- weights, biases, opt_params
- ):
+ @compilable
+ def evolve(self):
## calc dFilters / dBiases - update to filters and biases
- dWeights, dBiases = HebbianConvSynapse._compute_update(
- sign_value, w_decay, bias_init, stride, pad_args, delta_shape, pre, post, weights
- )
- if bias_init != None:
- opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
+ dWeights, dBiases = self._compute_update()
+ if self.bias_init is not None:
+ opt_params, [weights, biases] = self.opt(self.opt_params.get(), [self.weights.get(), self.biases.get()], [dWeights, dBiases])
else: ## ignore dBiases since no biases configured
- opt_params, [weights] = opt(opt_params, [weights], [dWeights])
-
+ opt_params, [weights] = self.opt(self.opt_params.get(), [self.weights.get()], [dWeights])
+ biases = None
## apply any enforced filter constraints
- if w_bounds > 0.:
- if is_nonnegative:
- weights = jnp.clip(weights, 0., w_bounds)
+ if self.w_bounds > 0.:
+ if self.is_nonnegative:
+ weights = jnp.clip(weights, 0., self.w_bounds)
else:
- weights = jnp.clip(weights, -w_bounds, w_bounds)
- return opt_params, weights, biases, dWeights, dBiases
-
- @transition(output_compartments=["dInputs"])
- @staticmethod
- def backtransmit(
- sign_value, x_size, shape, stride, padding, x_delta_shape, antiPad, post, weights
- ): ## action-backpropagating routine
+ weights = jnp.clip(weights, -self.w_bounds, self.w_bounds)
+
+ self.opt_params.set(opt_params)
+ self.weights.set(weights)
+ self.biases.set(biases)
+ self.dWeights.set(dWeights)
+ self.dBiases.set(dBiases)
+
+ @compilable
+ def backtransmit(self): ## action-backpropagating co-routine
## calc dInputs - adjustment w.r.t. input signal
- k_size, k_size, n_in_chan, n_out_chan = shape
+ k_size, k_size, n_in_chan, n_out_chan = self.shape
# antiPad = None
# if padding == "SAME":
# antiPad = _conv_same_transpose_padding(post.shape[1], x_size,
@@ -201,22 +190,20 @@ def backtransmit(
# elif padding == "VALID":
# antiPad = _conv_valid_transpose_padding(post.shape[1], x_size,
# k_size, stride)
- dInputs = calc_dX_conv(weights, post, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
+ dInputs = calc_dX_conv(self.weights.get(), self.post.get(), delta_shape=self.x_delta_shape, stride_size=self.stride, anti_padding=self.antiPad)
## flip sign of back-transmitted signal (if applicable)
- dInputs = dInputs * sign_value
- return dInputs
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- pre = preVals
- post = postVals
- dInputs = preVals
- return inputs, outputs, pre, post, dInputs
+ dInputs = dInputs * self.sign_value
+ self.dInputs.set(dInputs)
+
+ @compilable
+ def reset(self): #in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape.get())
+ postVals = jnp.zeros(self.out_shape.get())
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre.set(preVals)
+ self.post.set(postVals)
+ self.dInputs.set(preVals)
@classmethod
def help(cls): ## component help function
diff --git a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py
index f203400a..d3317728 100755
--- a/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py
+++ b/ngclearn/components/synapses/convolution/hebbianDeconvSynapse.py
@@ -1,13 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from .deconvSynapse import DeconvSynapse
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse
+
from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv,
_calc_dK_deconv, calc_dX_deconv,
calc_dK_deconv)
@@ -15,8 +10,8 @@
class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional cable
"""
- A synaptic deconvolutional (transposed convolutional) cable that adjusts its
- efficacies via a two-factor Hebbian adjustment rule.
+ A specialized synaptic deconvolutional (transposed convolutional) cable that adjusts its efficacies via a
+ two-factor Hebbian adjustment rule.
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
@@ -85,7 +80,6 @@ class HebbianDeconvSynapse(DeconvSynapse): ## Hebbian-evolved deconvolutional ca
batch_size: batch size dimension of this component
"""
- # Define Functions
def __init__(
self, name, shape, x_shape, eta=0., filter_init=None, bias_init=None, stride=1, padding=None,
resist_scale=1., w_bound=0., is_nonnegative=False, w_decay=0., sign_value=1., optim_type="sgd",
@@ -104,11 +98,11 @@ def __init__(
## optimization / adjustment properties (given learning dynamics above)
self.opt = get_opt_step_fn(optim_type, eta=self.eta)
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.dInputs = Compartment(jnp.zeros(self.in_shape))
self.pre = Compartment(jnp.zeros(self.in_shape))
self.post = Compartment(jnp.zeros(self.out_shape))
- self.dBiases = Compartment(self.biases.value * 0)
+ self.dBiases = Compartment(self.biases.get() * 0)
########################################################################
## Shape error correction -- do shape correction inference (for local updates)
@@ -117,85 +111,85 @@ def __init__(
########################################################################
## set up outer optimization compartments
- self.opt_params = Compartment(get_opt_init_fn(optim_type)(
- [self.weights.value, self.biases.value]
- if bias_init else [self.weights.value]))
+ self.opt_params = Compartment(
+ get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()])
+ )
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
k_size, k_size, n_in_chan, n_out_chan = shape
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
- _d = deconv2d(_x, self.weights.value, stride_size=self.stride,
+ _d = deconv2d(_x, self.weights.get(), stride_size=self.stride,
padding=self.padding) * 0
_dK = _calc_dK_deconv(_x, _d, stride_size=self.stride, out_size=k_size)
## get filter update correction
- dx = _dK.shape[0] - self.weights.value.shape[0]
- dy = _dK.shape[1] - self.weights.value.shape[1]
+ dx = _dK.shape[0] - self.weights.get().shape[0]
+ dy = _dK.shape[1] - self.weights.get().shape[1]
self.delta_shape = (abs(dx), abs(dy))
## get input update correction
- _dx = _calc_dX_deconv(self.weights.value, _d, stride_size=self.stride,
+ _dx = _calc_dX_deconv(self.weights.get(), _d, stride_size=self.stride,
padding=self.padding)
dx = (_dx.shape[1] - _x.shape[1]) # abs()
dy = (_dx.shape[2] - _x.shape[2])
self.x_delta_shape = (dx, dy)
- @staticmethod
- def _compute_update(sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights):
- k_size, k_size, n_in_chan, n_out_chan = shape
+ def _compute_update(self):
+ k_size, k_size, n_in_chan, n_out_chan = self.shape
## compute adjustment to filters
dWeights = calc_dK_deconv(
- pre, post, delta_shape=delta_shape, stride_size=stride, out_size=k_size, padding=padding
+ self.pre.get(), self.post.get(), delta_shape=self.delta_shape, stride_size=self.stride, out_size=k_size,
+ padding=self.padding
)
- dWeights = dWeights * sign_value
- if w_decay > 0.: ## apply synaptic decay
- dWeights = dWeights - weights * w_decay
+ dWeights = dWeights * self.sign_value
+ if self.w_decay > 0.: ## apply synaptic decay
+ dWeights = dWeights - self.weights.get() * self.w_decay
## compute adjustment to base-rates (if applicable)
dBiases = 0. # jnp.zeros((1,1))
- if bias_init != None:
- dBiases = jnp.sum(post, axis=0, keepdims=True) * sign_value
+ if self.bias_init != None:
+ dBiases = jnp.sum(self.post.get(), axis=0, keepdims=True) * self.sign_value
return dWeights, dBiases
- @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
- @staticmethod
- def evolve(
- opt, sign_value, w_decay, w_bounds, is_nonnegative, bias_init, shape, stride, padding, delta_shape,
- pre, post, weights, biases, opt_params
- ):
- dWeights, dBiases = HebbianDeconvSynapse._compute_update(
- sign_value, w_decay, bias_init, shape, stride, padding, delta_shape, pre, post, weights
- )
- if bias_init != None:
- opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
+ @compilable
+ def evolve(self):
+ dWeights, dBiases = self._compute_update()
+ if self.bias_init != None:
+ opt_params, [weights, biases] = self.opt(self.opt_params.get(), [self.weights.get(), self.biases.get()], [dWeights, dBiases])
else: ## ignore dBiases since no biases configured
- opt_params, [weights] = opt(opt_params, [weights], [dWeights])
+ opt_params, [weights] = self.opt(self.opt_params.get(), [self.weights.get()], [dWeights])
+ biases = None
## apply any enforced filter constraints
- if w_bounds > 0.:
- if is_nonnegative:
- weights = jnp.clip(weights, 0., w_bounds)
+ if self.w_bounds > 0.:
+ if self.is_nonnegative:
+ weights = jnp.clip(weights, 0., self.w_bounds)
else:
- weights = jnp.clip(weights, -w_bounds, w_bounds)
- return opt_params, weights, biases, dWeights, dBiases
+ weights = jnp.clip(weights, -self.w_bounds, self.w_bounds)
+
+ self.opt_params.set(opt_params)
+ self.weights.set(weights)
+ self.biases.set(biases)
+ self.dWeights.set(dWeights)
+ self.dBiases.set(dBiases)
- @transition(output_compartments=["dInputs"])
- @staticmethod
- def backtransmit(sign_value, stride, padding, x_delta_shape, pre, post, weights): ## action-backpropagating routine
+ @compilable
+ def backtransmit(self): ## action-backpropagating co-routine
## calc dInputs
- dInputs = calc_dX_deconv(weights, post, delta_shape=x_delta_shape, stride_size=stride, padding=padding)
+ dInputs = calc_dX_deconv(
+ self.weights.get(), self.post.get(), delta_shape=self.x_delta_shape, stride_size=self.stride,
+ padding=self.padding
+ )
## flip sign of back-transmitted signal (if applicable)
- dInputs = dInputs * sign_value
- return dInputs
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "dInputs"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- pre = preVals
- post = postVals
- dInputs = preVals
- return inputs, outputs, pre, post, dInputs
+ dInputs = dInputs * self.sign_value
+ self.dInputs.set(dInputs)
+
+ @compilable
+ def reset(self): #in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape.get())
+ postVals = jnp.zeros(self.out_shape.get())
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre.set(preVals)
+ self.post.set(postVals)
+ self.dInputs.set(preVals)
@classmethod
def help(cls): ## component help function
diff --git a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py
index 7fbb5021..86aa33c4 100755
--- a/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py
+++ b/ngclearn/components/synapses/convolution/traceSTDPConvSynapse.py
@@ -1,13 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from .convSynapse import ConvSynapse
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.convolution.convSynapse import ConvSynapse
+
from ngclearn.components.synapses.convolution.ngcconv import (_conv_same_transpose_padding,
_conv_valid_transpose_padding)
from ngclearn.components.synapses.convolution.ngcconv import (conv2d, _calc_dX_conv,
@@ -16,8 +11,8 @@
class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
"""
- A synaptic convolutional cable that adjusts its filter efficacies via a
- trace-based form of spike-timing-dependent plasticity (STDP).
+ A specialized synaptic convolutional cable that adjusts its filter efficacies via a trace-based form of
+ spike-timing-dependent plasticity (STDP).
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
@@ -73,7 +68,6 @@ class TraceSTDPConvSynapse(ConvSynapse): ## trace-based STDP convolutional cable
batch_size: batch size dimension of this component
"""
- # Define Functions
def __init__(
self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1,
padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs
@@ -93,7 +87,7 @@ def __init__(
######################### set up compartments ##########################
## Compartment setup and shape computation
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.dInputs = Compartment(jnp.zeros(self.in_shape))
self.preSpike = Compartment(jnp.zeros(self.in_shape))
self.preTrace = Compartment(jnp.zeros(self.in_shape))
@@ -108,72 +102,64 @@ def __init__(
k_size, k_size, n_in_chan, n_out_chan = self.shape
if padding == "SAME":
self.antiPad = _conv_same_transpose_padding(
- self.postSpike.value.shape[1],
+ self.postSpike.get().shape[1],
self.x_size, k_size, stride)
elif padding == "VALID":
self.antiPad = _conv_valid_transpose_padding(
- self.postSpike.value.shape[1],
+ self.postSpike.get().shape[1],
self.x_size, k_size, stride)
########################################################################
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
k_size, k_size, n_in_chan, n_out_chan = shape
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
- _d = conv2d(_x, weights.value, stride_size=stride, padding=padding) * 0
+ _d = conv2d(_x, weights.get(), stride_size=stride, padding=padding) * 0
_dK = _calc_dK_conv(_x, _d, stride_size=stride, padding=pad_args)
## get filter update correction
- dx = _dK.shape[0] - weights.value.shape[0]
- dy = _dK.shape[1] - weights.value.shape[1]
+ dx = _dK.shape[0] - weights.get().shape[0]
+ dy = _dK.shape[1] - weights.get().shape[1]
#self.delta_shape = (dx, dy)
self.delta_shape = (max(dx, 0), max(dy, 0))
## get input update correction
- _dx = _calc_dX_conv(weights.value, _d, stride_size=stride,
+ _dx = _calc_dX_conv(weights.get(), _d, stride_size=stride,
anti_padding=pad_args)
dx = (_dx.shape[1] - _x.shape[1])
dy = (_dx.shape[2] - _x.shape[2])
self.x_delta_shape = (dx, dy)
- @staticmethod
- def _compute_update(
- pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
- ):
+ def _compute_update(self):
## Compute long-term potentiation to filters
dW_ltp = calc_dK_conv(
- preTrace - pretrace_target, postSpike * Aplus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
+ self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus, delta_shape=self.delta_shape,
+ stride_size=self.stride, padding=self.pad_args
)
## Compute long-term depression to filters
dW_ltd = -calc_dK_conv(
- preSpike, postTrace * Aminus, delta_shape=delta_shape, stride_size=stride, padding=pad_args
+ self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape,
+ stride_size=self.stride, padding=self.pad_args
)
dWeights = (dW_ltp + dW_ltd)
return dWeights
- @transition(output_compartments=["weights", "dWeights"])
- @staticmethod
- def evolve(
- pretrace_target, Aplus, Aminus, w_decay, w_bound, stride, pad_args, delta_shape, preSpike, preTrace,
- postSpike, postTrace, weights, eta
- ):
- dWeights = TraceSTDPConvSynapse._compute_update(
- pretrace_target, Aplus, Aminus, stride, pad_args, delta_shape, preSpike, preTrace, postSpike, postTrace
- )
- if w_decay > 0.: ## apply synaptic decay
- weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
+ @compilable
+ def evolve(self):
+ dWeights = self._compute_update()
+ if self.w_decay > 0.: ## apply synaptic decay
+ weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay ## conduct decayed STDP-ascent
else:
- weights = weights + dWeights * eta ## conduct STDP-ascent
+ weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent
## Apply any enforced filter constraints
- if w_bound > 0.: ## enforce non-negativity
+ if self.w_bound > 0.: ## enforce non-negativity
eps = 0.01 # 0.001
- weights = jnp.clip(weights, eps, w_bound - eps)
- return weights, dWeights
-
- @transition(output_compartments=["dInputs"])
- @staticmethod
- def backtransmit(
- x_size, shape, stride, padding, x_delta_shape, antiPad, postSpike, weights
- ): ## action-backpropagating routine
+ weights = jnp.clip(weights, eps, self.w_bound - eps)
+
+ self.weights.set(weights)
+ self.dWeights.set(dWeights)
+
+ @compilable
+ def backtransmit(self): ## action-backpropagating co-routine
## calc dInputs - adjustment w.r.t. input signal
- k_size, k_size, n_in_chan, n_out_chan = shape
+ k_size, k_size, n_in_chan, n_out_chan = self.shape
# antiPad = None
# if padding == "SAME":
# antiPad = _conv_same_transpose_padding(postSpike.shape[1], x_size,
@@ -181,21 +167,22 @@ def backtransmit(
# elif padding == "VALID":
# antiPad = _conv_valid_transpose_padding(postSpike.shape[1], x_size,
# k_size, stride)
- dInputs = calc_dX_conv(weights, postSpike, delta_shape=x_delta_shape, stride_size=stride, anti_padding=antiPad)
- return dInputs
-
- @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- preSpike = preVals
- postSpike = postVals
- preTrace = preVals
- postTrace = postVals
- return inputs, outputs, preSpike, postSpike, preTrace, postTrace
+ dInputs = calc_dX_conv(
+ self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride,
+ anti_padding=self.antiPad
+ )
+ self.dInputs.set(dInputs)
+
+ @compilable
+ def reset(self): # in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape.get())
+ postVals = jnp.zeros(self.out_shape.get())
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.preSpike.set(preVals)
+ self.postSpike.set(postVals)
+ self.preTrace.set(preVals)
+ self.postTrace.set(postVals)
@classmethod
def help(cls): ## component help function
diff --git a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py
index 0e5d76b4..a894213e 100755
--- a/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py
+++ b/ngclearn/components/synapses/convolution/traceSTDPDeconvSynapse.py
@@ -1,22 +1,16 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from .deconvSynapse import DeconvSynapse
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-from ngclearn.utils import tensorstats
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.convolution.deconvSynapse import DeconvSynapse
+
from ngclearn.components.synapses.convolution.ngcconv import (deconv2d, _calc_dX_deconv,
_calc_dK_deconv, calc_dX_deconv,
calc_dK_deconv)
-from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional cable
"""
- A synaptic deconvolutional (transposed convolutional) cable that adjusts its
- filter efficacies via a trace-based form of spike-timing-dependent plasticity (STDP).
+ A specialized synaptic deconvolutional (transposed convolutional) cable that adjusts its filter efficacies via a
+ trace-based form of spike-timing-dependent plasticity (STDP).
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
@@ -72,7 +66,6 @@ class TraceSTDPDeconvSynapse(DeconvSynapse): ## trace-based STDP deconvolutional
batch_size: batch size dimension of this component
"""
- # Define Functions
def __init__(
self, name, shape, x_shape, A_plus, A_minus, eta=0., pretrace_target=0., filter_init=None, stride=1,
padding=None, resist_scale=1., w_bound=0., w_decay=0., batch_size=1, **kwargs
@@ -92,7 +85,7 @@ def __init__(
######################### set up compartments ##########################
## Compartment setup and shape computation
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.dInputs = Compartment(jnp.zeros(self.in_shape))
self.preSpike = Compartment(jnp.zeros(self.in_shape))
self.preTrace = Compartment(jnp.zeros(self.in_shape))
@@ -108,76 +101,73 @@ def __init__(
def _init(self, batch_size, x_size, shape, stride, padding, pad_args, weights):
k_size, k_size, n_in_chan, n_out_chan = shape
_x = jnp.zeros((batch_size, x_size, x_size, n_in_chan))
- _d = deconv2d(_x, self.weights.value, stride_size=self.stride,
+ _d = deconv2d(_x, self.weights.get(), stride_size=self.stride,
padding=self.padding) * 0
_dK = _calc_dK_deconv(_x, _d, stride_size=self.stride, out_size=k_size)
## get filter update correction
- dx = _dK.shape[0] - self.weights.value.shape[0]
- dy = _dK.shape[1] - self.weights.value.shape[1]
+ dx = _dK.shape[0] - self.weights.get().shape[0]
+ dy = _dK.shape[1] - self.weights.get().shape[1]
self.delta_shape = (abs(dx), abs(dy))
## get input update correction
- _dx = _calc_dX_deconv(self.weights.value, _d, stride_size=self.stride,
+ _dx = _calc_dX_deconv(self.weights.get(), _d, stride_size=self.stride,
padding=self.padding)
dx = (_dx.shape[1] - _x.shape[1]) # abs()
dy = (_dx.shape[2] - _x.shape[2])
self.x_delta_shape = (dx, dy)
- @staticmethod
- def _compute_update(
- pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape, preSpike, preTrace, postSpike, postTrace
- ):
- k_size, k_size, n_in_chan, n_out_chan = shape
+ def _compute_update(self):
+ k_size, k_size, n_in_chan, n_out_chan = self.shape
## calc dFilters
- dW_ltp = calc_dK_deconv(preTrace - pretrace_target, postSpike * Aplus,
- delta_shape=delta_shape, stride_size=stride,
- out_size=k_size, padding=padding)
- dW_ltd = -calc_dK_deconv(preSpike, postTrace * Aminus,
- delta_shape=delta_shape, stride_size=stride,
- out_size=k_size, padding=padding)
+ dW_ltp = calc_dK_deconv(
+ self.preTrace.get() - self.pretrace_target, self.postSpike.get() * self.Aplus,
+ delta_shape=self.delta_shape, stride_size=self.stride, out_size=k_size, padding=self.padding
+ )
+ dW_ltd = -calc_dK_deconv(
+ self.preSpike.get(), self.postTrace.get() * self.Aminus, delta_shape=self.delta_shape,
+ stride_size=self.stride, out_size=k_size, padding=self.padding
+ )
dWeights = (dW_ltp + dW_ltd)
return dWeights
- @transition(output_compartments=["weights", "dWeights"])
- @staticmethod
- def evolve(
- pretrace_target, Aplus, Aminus, w_decay, w_bound, shape, stride, padding, delta_shape, preSpike, preTrace,
- postSpike, postTrace, weights, eta
- ):
- dWeights = TraceSTDPDeconvSynapse._compute_update(
- pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
- preSpike, preTrace, postSpike, postTrace
- )
- if w_decay > 0.: ## apply synaptic decay
- weights = weights + dWeights * eta - weights * w_decay ## conduct decayed STDP-ascent
+ @compilable
+ def evolve(self):
+ dWeights = self._compute_update()
+ # dWeights = TraceSTDPDeconvSynapse._compute_update(
+ # pretrace_target, Aplus, Aminus, shape, stride, padding, delta_shape,
+ # preSpike, preTrace, postSpike, postTrace
+ # )
+ if self.w_decay > 0.: ## apply synaptic decay and conduct decayed STDP-ascent
+ weights = self.weights.get() + dWeights * self.eta - self.weights.get() * self.w_decay
else:
- weights = weights + dWeights * eta ## conduct STDP-ascent
+ weights = self.weights.get() + dWeights * self.eta ## conduct STDP-ascent
## Apply any enforced filter constraints
- if w_bound > 0.: ## enforce non-negativity
+ if self.w_bound > 0.: ## enforce non-negativity
eps = 0.01 # 0.001
- weights = jnp.clip(weights, eps, w_bound - eps)
- return weights, dWeights
+ weights = jnp.clip(weights, eps, self.w_bound - eps)
- @transition(output_compartments=["dInputs"])
- @staticmethod
- def backtransmit(stride, padding, x_delta_shape, preSpike, postSpike, weights): ## action-backpropagating routine
+ self.weights.set(weights)
+ self.dWeights.set(dWeights)
+
+ @compilable
+ def backtransmit(self): ## action-backpropagating co-routine
## calc dInputs
- dInputs = calc_dX_deconv(weights, postSpike, delta_shape=x_delta_shape,
- stride_size=stride, padding=padding)
- return dInputs
-
- @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace"])
- @staticmethod
- def reset(in_shape, out_shape):
- preVals = jnp.zeros(in_shape)
- postVals = jnp.zeros(out_shape)
- inputs = preVals
- outputs = postVals
- preSpike = preVals
- postSpike = postVals
- preTrace = preVals
- postTrace = postVals
- return inputs, outputs, preSpike, postSpike, preTrace, postTrace
+ dInputs = calc_dX_deconv(
+ self.weights.get(), self.postSpike.get(), delta_shape=self.x_delta_shape, stride_size=self.stride,
+ padding=self.padding
+ )
+ self.dInputs.set(dInputs)
+
+ @compilable
+ def reset(self): # in_shape, out_shape):
+ preVals = jnp.zeros(self.in_shape.get())
+ postVals = jnp.zeros(self.out_shape.get())
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.preSpike.set(preVals)
+ self.postSpike.set(postVals)
+ self.preTrace.set(preVals)
+ self.postTrace.set(postVals)
@classmethod
def help(cls): ## component help function
diff --git a/ngclearn/components/synapses/denseSynapse.py b/ngclearn/components/synapses/denseSynapse.py
index fc4e7ea0..977f2464 100755
--- a/ngclearn/components/synapses/denseSynapse.py
+++ b/ngclearn/components/synapses/denseSynapse.py
@@ -1,12 +1,10 @@
from jax import random, numpy as jnp, jit
from ngclearn.components.jaxComponent import JaxComponent
-from ngclearn.utils import tensorstats
-from ngclearn.utils.weight_distribution import initialize_params
+from ngclearn.utils.distribution_generator import DistributionGenerator
from ngcsimlib.logger import info
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
class DenseSynapse(JaxComponent): ## base dense synaptic cable
"""
@@ -40,75 +38,55 @@ class DenseSynapse(JaxComponent): ## base dense synaptic cable
(lower values yield sparse structure)
"""
- # Define Functions
def __init__(
- self, name, shape, weight_init=None, bias_init=None, resist_scale=1.,
- p_conn=1., batch_size=1, **kwargs
+ self, name, shape, weight_init=None, bias_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
):
super().__init__(name, **kwargs)
self.batch_size = batch_size
- self.weight_init = weight_init
- self.bias_init = bias_init
## Synapse meta-parameters
self.shape = shape
- self.Rscale = resist_scale
+ self.resist_scale = resist_scale
## Set up synaptic weight values
- tmp_key, *subkeys = random.split(self.key.value, 4)
- if self.weight_init is None:
+ tmp_key, *subkeys = random.split(self.key.get(), 4)
+
+ if weight_init is None:
info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
- weights = initialize_params(subkeys[0], self.weight_init, shape)
- if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed
+ # self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8}
+ weight_init = DistributionGenerator.uniform(0.025, 0.8)
+ weights = weight_init(shape, subkeys[0])
+
+ if 0. < p_conn < 1.: ## Modifier/constraint: only non-zero and <1 probs allowed
p_mask = random.bernoulli(subkeys[1], p=p_conn, shape=shape)
weights = weights * p_mask ## sparsify matrix
- self.batch_size = batch_size #1
## Compartment setup
preVals = jnp.zeros((self.batch_size, shape[0]))
postVals = jnp.zeros((self.batch_size, shape[1]))
+
self.inputs = Compartment(preVals)
self.outputs = Compartment(postVals)
self.weights = Compartment(weights)
## Set up (optional) bias values
- if self.bias_init is None:
- info(self.name, "is using default bias value of zero (no bias "
- "kernel provided)!")
- self.biases = Compartment(initialize_params(subkeys[2], bias_init,
- (1, shape[1]))
- if bias_init else 0.0)
-
- @transition(output_compartments=["outputs"])
- @staticmethod
- def advance_state(Rscale, inputs, weights, biases):
- outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
- return outputs
-
- @transition(output_compartments=["inputs", "outputs"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- return inputs, outputs
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value,
- biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ if bias_init is None:
+ info(self.name, "is using default bias value of zero (no bias kernel provided)!")
+ self.biases = Compartment(bias_init((1, shape[1]), subkeys[2]) if bias_init else 0.0)
+ ## pin weight/bias initializers to component
+ self.weight_init = weight_init
+ self.bias_init = bias_init
+
+ @compilable
+ def advance_state(self):
+ self.outputs.set((jnp.matmul(self.inputs.get(), self.weights.get()) * self.resist_scale) + self.biases.get())
+
+ @compilable
+ def reset(self):
+ if not self.inputs.targeted:
+ self.inputs.set(jnp.zeros((self.batch_size, self.shape[0])))
+
+ self.outputs.set(jnp.zeros((self.batch_size, self.shape[1])))
@classmethod
def help(cls): ## component help function
@@ -141,20 +119,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/synapses/doubleExpSynapse.py b/ngclearn/components/synapses/doubleExpSynapse.py
index 86225a68..91a05d60 100644
--- a/ngclearn/components/synapses/doubleExpSynapse.py
+++ b/ngclearn/components/synapses/doubleExpSynapse.py
@@ -1,14 +1,10 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
-class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable
+class DoubleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cable
"""
A dynamic double-exponential synaptic cable; this synapse evolves according to difference of two exponentials
synaptic conductance dynamics.
@@ -64,10 +60,9 @@ class DoupleExpSynapse(DenseSynapse): ## dynamic double-exponential synapse cabl
"""
- # Define Functions
def __init__(
- self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
- is_nonplastic=True, **kwargs
+ self, name, shape, tau_decay, tau_rise, g_syn_bar, syn_rest, weight_init=None, bias_init=None,
+ resist_scale=1., p_conn=1., is_nonplastic=True, **kwargs
):
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
## dynamic synapse meta-parameters
@@ -85,57 +80,44 @@ def __init__(
self.g_syn = Compartment(postVals) ## conductance variable
self.h_syn = Compartment(postVals) ## intermediate conductance variable
if is_nonplastic:
- self.weights.set(self.weights.value * 0 + 1.)
+ self.weights.set(self.weights.get() * 0 + 1.)
- @transition(output_compartments=["outputs", "i_syn", "g_syn", "h_syn"])
- @staticmethod
- def advance_state(
- dt, tau_decay, tau_rise, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, h_syn, v
- ):
- s = inputs
+ @compilable
+ def advance_state(self, t, dt):
+ s = self.inputs.get()
#A = tau_decay/(tau_decay - tau_rise) * jnp.power((tau_rise/tau_decay), tau_rise/(tau_rise - tau_decay))
- A = 1.
+ A = 1. ## FIXME: scale factor to use?
## advance conductance variable(s)
- _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
- dhsyn_dt = -h_syn/tau_rise + ((_out * g_syn_bar) * (1. / tau_rise - 1. / tau_decay) * A) * (1./dt)
- h_syn = h_syn + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
+ _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
+ dhsyn_dt = (-self.h_syn.get()/self.tau_rise +
+ ((_out * self.g_syn_bar) * (1. / self.tau_rise - 1. / self.tau_decay) * A) * (1./dt))
+ h_syn = self.h_syn.get() + dhsyn_dt * dt ## run Euler step to move intermediate conductance h
- dgsyn_dt = -g_syn/tau_decay + h_syn * (1./dt)
- g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance g
+ dgsyn_dt = -self.g_syn.get()/self.tau_decay + h_syn * (1./dt)
+ g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance g
## compute derive electrical current variable
- i_syn = -g_syn * Rscale
- if syn_rest is not None:
- i_syn = -(g_syn * Rscale) * (v - syn_rest)
+ i_syn = -g_syn * self.resist_scale
+ if self.syn_rest is not None:
+ i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
- return outputs, i_syn, g_syn, h_syn
-
- @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "h_syn", "v"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- i_syn = postVals
- g_syn = postVals
- h_syn = postVals
- v = postVals
- return inputs, outputs, i_syn, g_syn, h_syn, v
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+
+ self.outputs.set(outputs)
+ self.i_syn.set(i_syn)
+ self.g_syn.set(g_syn)
+ self.h_syn.set(h_syn)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.i_syn.set(postVals)
+ self.g_syn.set(postVals)
+ self.h_syn.set(postVals)
+ self.v.set(postVals)
@classmethod
def help(cls): ## component help function
@@ -176,17 +158,3 @@ def help(cls): ## component help function
"dgsyn_dt = -g_syn/tau_decay + h_syn",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/exponentialSynapse.py b/ngclearn/components/synapses/exponentialSynapse.py
index a873baf9..dc20c362 100644
--- a/ngclearn/components/synapses/exponentialSynapse.py
+++ b/ngclearn/components/synapses/exponentialSynapse.py
@@ -1,12 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
"""
@@ -61,10 +57,9 @@ class ExponentialSynapse(DenseSynapse): ## dynamic exponential synapse cable
"""
- # Define Functions
def __init__(
- self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1., p_conn=1.,
- is_nonplastic=True, **kwargs
+ self, name, shape, tau_decay, g_syn_bar, syn_rest, weight_init=None, bias_init=None, resist_scale=1.,
+ p_conn=1., is_nonplastic=True, **kwargs
):
super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, **kwargs)
## dynamic synapse meta-parameters
@@ -80,50 +75,35 @@ def __init__(
self.i_syn = Compartment(postVals) ## electrical current output
self.g_syn = Compartment(postVals) ## conductance variable
if is_nonplastic:
- self.weights.set(self.weights.value * 0 + 1.)
+ self.weights.set(self.weights.get() * 0 + 1.)
- @transition(output_compartments=["outputs", "i_syn", "g_syn"])
- @staticmethod
- def advance_state(
- dt, tau_decay, g_syn_bar, syn_rest, Rscale, inputs, weights, i_syn, g_syn, v
- ):
- s = inputs
+ @compilable
+ def advance_state(self, t, dt):
+ s = self.inputs.get()
## advance conductance variable
- _out = jnp.matmul(s, weights) ## sum all pre-syn spikes at t going into post-neuron)
- dgsyn_dt = -g_syn/tau_decay + (_out * g_syn_bar) * (1./dt)
- g_syn = g_syn + dgsyn_dt * dt ## run Euler step to move conductance
+ _out = jnp.matmul(s, self.weights.get()) ## sum all pre-syn spikes at t going into post-neuron)
+ dgsyn_dt = -self.g_syn.get()/self.tau_decay + (_out * self.g_syn_bar) * (1./dt)
+ g_syn = self.g_syn.get() + dgsyn_dt * dt ## run Euler step to move conductance
## compute derive electrical current variable
- i_syn = -g_syn * Rscale
- if syn_rest is not None:
- i_syn = -(g_syn * Rscale) * (v - syn_rest)
- outputs = i_syn #jnp.matmul(inputs, Wdyn * Rscale) + biases
- return outputs, i_syn, g_syn
-
- @transition(output_compartments=["inputs", "outputs", "i_syn", "g_syn", "v"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- i_syn = postVals
- g_syn = postVals
- v = postVals
- return inputs, outputs, i_syn, g_syn, v
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value, biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ i_syn = -g_syn * self.resist_scale
+ if self.syn_rest is not None:
+ i_syn = -(g_syn * self.resist_scale) * (self.v.get() - self.syn_rest)
+ outputs = i_syn #jnp.matmul(inputs, Wdyn * self.resist_scale) + biases
+
+ self.outputs.set(outputs)
+ self.i_syn.set(i_syn)
+ self.g_syn.set(g_syn)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.i_syn.set(postVals)
+ self.g_syn.set(postVals)
+ self.v.set(postVals)
@classmethod
def help(cls): ## component help function
@@ -162,17 +142,3 @@ def help(cls): ## component help function
"dgsyn_dt = (W * inputs) * g_syn_bar - g_syn/tau_decay ",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/hebbian/BCMSynapse.py b/ngclearn/components/synapses/hebbian/BCMSynapse.py
index 6b391335..ff669a07 100755
--- a/ngclearn/components/synapses/hebbian/BCMSynapse.py
+++ b/ngclearn/components/synapses/hebbian/BCMSynapse.py
@@ -1,10 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
-from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
+from ngclearn.components.synapses.denseSynapse import DenseSynapse
class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable
"""
@@ -66,13 +64,11 @@ class BCMSynapse(DenseSynapse): # BCM-adjusted synaptic cable
this to < 1. will result in a sparser synaptic structure
"""
- # Define Functions
def __init__(
self, name, shape, tau_w, tau_theta, theta0=-1., w_bound=0., w_decay=0., weight_init=None, resist_scale=1.,
p_conn=1., batch_size=1, **kwargs
):
- super().__init__(name, shape, weight_init, None, resist_scale, p_conn,
- batch_size=batch_size, **kwargs)
+ super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
## Synapse and BCM hyper-parameters
self.shape = shape ## shape of synaptic efficacy matrix
@@ -90,48 +86,51 @@ def __init__(
self.post = Compartment(postVals) ## post-synaptic statistic
self.post_term = Compartment(postVals)
self.theta = Compartment(postVals + self.theta0) ## synaptic modification thresholds
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
- @transition(output_compartments=["weights", "theta", "dWeights", "post_term"])
- @staticmethod
- def evolve(t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
+ @compilable
+ def evolve(self, t, dt): #t, dt, tau_w, tau_theta, w_bound, w_decay, pre, post, theta, weights):
eps = 1e-7
- post_term = post * (post - theta) # post - theta
- post_term = post_term * (1. / (theta + eps))
- dWeights = jnp.matmul(pre.T, post_term)
- if w_bound > 0.:
- dWeights = dWeights * (w_bound - jnp.abs(weights))
+ post_term = self.post.get() * (self.post.get() - self.theta.get()) # post - theta
+ post_term = post_term * (1. / (self.theta.get() + eps))
+ dWeights = jnp.matmul(self.pre.get().T, post_term)
+ if self.w_bound > 0.:
+ dWeights = dWeights * (self.w_bound - jnp.abs(self.weights.get()))
## update synaptic efficacies according to a leaky ODE
- dWeights = -weights * w_decay + dWeights
- _W = weights + dWeights * dt / tau_w
+ dWeights = -self.weights.get() * self.w_decay + dWeights
+ _W = self.weights.get() + dWeights * dt / self.tau_w
## update synaptic modification threshold as a leaky ODE
- dtheta = jnp.mean(jnp.square(post), axis=0, keepdims=True) ## batch avg
- theta = theta + (-theta + dtheta) * dt / tau_theta
- return weights, theta, dWeights, post_term
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "post_term"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- pre = preVals
- post = postVals
- dWeights = jnp.zeros(shape)
- post_term = postVals
- return inputs, outputs, pre, post, dWeights, post_term
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- jnp.savez(file_name,
- weights=self.weights.value, theta=self.theta.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- self.theta.set(data['theta'])
+ dtheta = jnp.mean(jnp.square(self.post.get()), axis=0, keepdims=True) ## batch avg
+ theta = self.theta.get() + (-self.theta.get() + dtheta) * dt / self.tau_theta
+
+ #self.weights.set(weights)
+ self.theta.set(theta)
+ self.dWeights.set(dWeights)
+ self.post_term.set(post_term)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre.set(preVals)
+ self.post.set(postVals)
+ self.dWeights.set(jnp.zeros(self.shape.get()))
+ self.post_term.set(postVals)
+
+ # def save(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # jnp.savez(file_name,
+ # weights=self.weights.value, theta=self.theta.value)
+ #
+ # def load(self, directory, **kwargs):
+ # file_name = directory + "/" + self.name + ".npz"
+ # data = jnp.load(file_name)
+ # self.weights.set(data['weights'])
+ # self.theta.set(data['theta'])
@classmethod
def help(cls): ## component help function
@@ -175,21 +174,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py
index fde8758a..826b9ff9 100755
--- a/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py
+++ b/ngclearn/components/synapses/hebbian/eventSTDPSynapse.py
@@ -1,10 +1,7 @@
-from jax import numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
+from jax import random, numpy as jnp, jit
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.denseSynapse import DenseSynapse
class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
"""
@@ -56,12 +53,11 @@ class EventSTDPSynapse(DenseSynapse): # event-driven, post-synaptic STDP
this to < 1. will result in a sparser synaptic structure
"""
- # Define Functions
- def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
- presyn_win_len=2., w_bound=1., weight_init=None, resist_scale=1.,
- p_conn=1., batch_size=1, **kwargs):
- super().__init__(name, shape, weight_init, None, resist_scale, p_conn,
- batch_size=batch_size, **kwargs)
+ def __init__(
+ self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1., presyn_win_len=2., w_bound=1.,
+ weight_init=None, resist_scale=1., p_conn=1., batch_size=1, **kwargs
+ ):
+ super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
## Synaptic hyper-parameters
self.eta = eta ## global learning rate governing plasticity
@@ -78,53 +74,47 @@ def __init__(self, name, shape, eta, lmbda=0.01, A_plus=1., A_minus=1.,
postVals = jnp.zeros((self.batch_size, shape[1]))
self.pre_tols = Compartment(preVals)
self.postSpike = Compartment(postVals)
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity
- @staticmethod
- def _compute_update(
- t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
- ): ## synaptic adjustment calculation co-routine
+ def _compute_update(self, t, dt): ## synaptic adjustment calculation co-routine
## check if a spike occurred in window of (t - presyn_win_len, t]
- m = (pre_tols > 0.) * 1. ## ignore default value of tols = 0 ms
- if presyn_win_len > 0.:
- lbound = ((t - presyn_win_len) < pre_tols) * 1.
+ m = (self.pre_tols.get() > 0.) * 1. ## ignore default value of tols = 0 ms
+ if self.presyn_win_len > 0.:
+ lbound = ((t - self.presyn_win_len) < self.pre_tols.get()) * 1.
preSpike = lbound * m
else:
- check_spike = (pre_tols == t) * 1.
+ check_spike = (self.pre_tols.get() == t) * 1.
preSpike = check_spike * m
## this implements a generalization of the rule in eqn 18 of the paper
- pos_shift = w_bound - (weights * (1. + lmbda))
- pos_shift = pos_shift * Aplus
- neg_shift = -weights * (1. + lmbda)
- neg_shift = neg_shift * Aminus
+ pos_shift = self.w_bound - (self.weights.get() * (1. + self.lmbda))
+ pos_shift = pos_shift * self.Aplus
+ neg_shift = -self.weights.get() * (1. + self.lmbda)
+ neg_shift = neg_shift * self.Aminus
dW = jnp.where(preSpike.T, pos_shift, neg_shift) # at pre-spikes => LTP, else decay
- dW = (dW * postSpike) ## gate to make sure only post-spikes trigger updates
+ dW = (dW * self.postSpike.get()) ## gate to make sure only post-spikes trigger updates
return dW
- @transition(output_compartments=["weights", "dWeights"])
- @staticmethod
- def evolve(
- t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights, eta
- ):
- dWeights = EventSTDPSynapse._compute_update(
- t, lmbda, presyn_win_len, Aminus, Aplus, w_bound, pre_tols, postSpike, weights
- )
- weights = weights + dWeights * eta # * (1. - w) * eta
- weights = jnp.clip(weights, 0.01, w_bound) ## Note: this step not in source paper
- return weights, dWeights
-
- @transition(output_compartments=["inputs", "outputs", "pre_tols", "postSpike", "dWeights"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- pre_tols = preVals ## pre-synaptic time-of-last-spike(s) record
- postSpike = postVals
- dWeights = jnp.zeros(shape)
- return inputs, outputs, pre_tols, postSpike, dWeights
+ @compilable
+ def evolve(self, t, dt):
+ dWeights = self._compute_update(t, dt)
+ weights = self.weights.get() + dWeights * self.eta # * (1. - w) * eta
+ weights = jnp.clip(weights, 0.01, self.w_bound) ## Note: this step not in source paper
+
+ self.weights.set(weights)
+ self.dWeights.set(dWeights)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre_tols.set(preVals) ## pre-synaptic time-of-last-spike(s) record
+ self.postSpike.set(postVals)
+ self.dWeights.set(jnp.zeros(self.shape.get()))
@classmethod
def help(cls): ## component help function
@@ -166,20 +156,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py
index ff184b9c..bb481512 100644
--- a/ngclearn/components/synapses/hebbian/expSTDPSynapse.py
+++ b/ngclearn/components/synapses/hebbian/expSTDPSynapse.py
@@ -1,10 +1,7 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.denseSynapse import DenseSynapse
class ExpSTDPSynapse(DenseSynapse):
"""
@@ -61,16 +58,19 @@ class ExpSTDPSynapse(DenseSynapse):
this to < 1. will result in a sparser synaptic structure
w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1)
+
+ tau_w: synaptic weight decay coefficient to apply to STDP update
+
+ weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None)
"""
- # Define Functions
def __init__(
self, name, shape, A_plus, A_minus, exp_beta, eta=1., pretrace_target=0., weight_init=None, resist_scale=1.,
- p_conn=1., w_bound=1., batch_size=1, **kwargs
+ p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs
):
- super().__init__(name, shape, weight_init, None, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
+ super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
+ self.tau_w = tau_w
## Exp-STDP meta-parameters
self.shape = shape ## shape of synaptic efficacy matrix
self.eta = eta ## global learning rate governing plasticity
@@ -81,6 +81,12 @@ def __init__(
self.Rscale = resist_scale ## post-transformation scale factor
self.w_bound = w_bound #1. ## soft weight constraint
+ if weight_mask is None:
+ self.weight_mask = jnp.ones((1, 1))
+ else:
+ self.weight_mask = weight_mask
+ self.weights.set(self.weights.get() * self.weight_mask)
+
## Compartment setup
preVals = jnp.zeros((self.batch_size, shape[0]))
postVals = jnp.zeros((self.batch_size, shape[1]))
@@ -88,64 +94,61 @@ def __init__(
self.postSpike = Compartment(postVals)
self.preTrace = Compartment(preVals)
self.postTrace = Compartment(postVals)
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate governing plasticity
- @staticmethod
- def _compute_update(
- dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
- ):
- pre = preSpike
- x_pre = preTrace
- post = postSpike
- x_post = postTrace
- W = weights
- x_tar = preTrace_target
+ def _compute_update(self): # dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
+ pre = self.preSpike.get()
+ x_pre = self.preTrace.get()
+ post = self.postSpike.get()
+ x_post = self.postTrace.get()
+ W = self.weights.get()
+ x_tar = self.preTrace_target
## equations 4 from Diehl and Cook - full exponential weight-dependent STDP
## calculate post-synaptic term
- post_term1 = jnp.exp(-exp_beta * W) * jnp.matmul(x_pre.T, post)
+ post_term1 = jnp.exp(-self.exp_beta * W) * jnp.matmul(x_pre.T, post)
x_tar_vec = x_pre * 0 + x_tar # need to broadcast scalar x_tar to mat/vec form
- post_term2 = jnp.exp(-exp_beta * (w_bound - W)) * jnp.matmul(x_tar_vec.T,
- post)
- dWpost = (post_term1 - post_term2) * Aplus
+ post_term2 = jnp.exp(-self.exp_beta * (self.w_bound - W)) * jnp.matmul(x_tar_vec.T, post)
+ dWpost = (post_term1 - post_term2) * self.Aplus
## calculate pre-synaptic term
dWpre = 0.
- if Aminus > 0.:
- dWpre = -jnp.exp(-exp_beta * W) * jnp.matmul(pre.T, x_post) * Aminus
+ if self.Aminus > 0.:
+ dWpre = -jnp.exp(-self.exp_beta * W) * jnp.matmul(pre.T, x_post) * self.Aminus
## calc final weighted adjustment
dW = (dWpost + dWpre)
return dW
- @transition(output_compartments=["weights", "dWeights"])
- @staticmethod
- def evolve(
- dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace,
- weights, eta
- ):
- dW = ExpSTDPSynapse._compute_update(
- dt, w_bound, preTrace_target, exp_beta, Aplus, Aminus,
- preSpike, postSpike, preTrace, postTrace, weights
- )
+ @compilable
+ def evolve(self):
+ dWeights = self._compute_update()
+ if self.tau_w > 0.:
+ decayTerm = self.weights.get() / self.tau_w
+ else:
+ decayTerm = 0.
+
## do a gradient ascent update/shift
- _W = weights + dW * eta
+ _W = self.weights.get() + (dWeights * self.eta) #- decayTerm
## enforce non-negativity
eps = 0.01
- _W = jnp.clip(_W, eps, w_bound - eps)
- return weights, dW
-
- @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- preSpike = preVals
- postSpike = postVals
- preTrace = preVals
- postTrace = postVals
- dWeights = jnp.zeros(shape)
- return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights
+ _W = jnp.clip(_W, eps, self.w_bound - eps)
+ _W = jnp.where(self.weight_mask != 0., _W, 0.)
+
+ self.weights.set(_W)
+ self.dWeights.set(dWeights)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.preSpike.set(preVals)
+ self.postSpike.set(postVals)
+ self.preTrace.set(preVals)
+ self.postTrace.set(postVals)
+ self.dWeights.set(jnp.zeros(self.shape.get()))
@classmethod
def help(cls): ## component help function
@@ -183,6 +186,7 @@ def help(cls): ## component help function
"exp_beta": "Controls effect of exponential Hebbian shift / dependency (B)",
"eta": "Global learning rate initial condition",
"pretrace_target": "Pre-synaptic disconnecting/decay factor (x_tar)",
+ "weight_mask" : "Binary synaptic weight mask to apply to enforce a sparsity structure"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
@@ -192,20 +196,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
diff --git a/ngclearn/components/synapses/hebbian/hebbianSynapse.py b/ngclearn/components/synapses/hebbian/hebbianSynapse.py
index faaee5c9..f0814443 100644
--- a/ngclearn/components/synapses/hebbian/hebbianSynapse.py
+++ b/ngclearn/components/synapses/hebbian/hebbianSynapse.py
@@ -1,16 +1,23 @@
+# %%
+
+import jax
+import pickle
from jax import random, numpy as jnp, jit
from functools import partial
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
-from ngclearn import resolver, Component, Compartment
-from ngcsimlib.compilers.process import transition
+
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
from ngclearn.components.synapses import DenseSynapse
from ngclearn.utils import tensorstats
-from ngcsimlib.deprecators import deprecate_args
+from ngcsimlib import deprecate_args
+from ngclearn.utils.io_utils import save_pkl, load_pkl
@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
-def _calc_update(pre, post, W, w_bound, is_nonnegative=True, signVal=1.,
- prior_type=None, prior_lmbda=0.,
- pre_wght=1., post_wght=1.):
+def _calc_update(
+ pre, post, W, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1.,
+ post_wght=1.
+):
"""
Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -160,15 +167,13 @@ class HebbianSynapse(DenseSynapse):
this to < 1. will result in a sparser synaptic structure
"""
- # Define Functions
@deprecate_args(_rebind=False, w_decay='prior')
def __init__(
self, name, shape, eta=0., weight_init=None, bias_init=None, w_bound=1., is_nonnegative=False,
- prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
- resist_scale=1., batch_size=1, **kwargs
+ prior=("constant", 0.), w_decay=0., sign_value=1., optim_type="sgd", pre_wght=1., post_wght=1.,
+ p_conn=1., resist_scale=1., batch_size=1, **kwargs
):
- super().__init__(name, shape, weight_init, bias_init, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
+ super().__init__(name, shape, weight_init, bias_init, resist_scale, p_conn, batch_size=batch_size, **kwargs)
if w_decay > 0.:
prior = ('l2', w_decay)
@@ -204,13 +209,26 @@ def __init__(
self.dBiases = Compartment(jnp.zeros(shape[1]))
#key, subkey = random.split(self.key.value)
- self.opt_params = Compartment(get_opt_init_fn(optim_type)(
- [self.weights.value, self.biases.value]
- if bias_init else [self.weights.value]))
+ # NOTE: we don't save this compartment directly because it is a tuple can cannot be saved directly by numpy
+ self.opt_params = Compartment(
+ get_opt_init_fn(optim_type)([self.weights.get(), self.biases.get()] if bias_init else [self.weights.get()]),
+ auto_save=False
+ )
+
+ def save(self, directory: str):
+ super().save(directory)
+ # Also save the optimizer parameters
+ save_pkl(directory, self.name + "_opt_params", self.opt_params.get())
+
+ def load(self, directory: str):
+ super().load(directory)
+ # load the optimizer parameters in a custom way
+ self.opt_params.set(load_pkl(directory, self.name + "_opt_params"))
@staticmethod
- def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
- post_wght, pre, post, weights):
+ def _compute_update(
+ w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght, pre, post, weights
+ ):
## calculate synaptic update values
dW, db = _calc_update(
pre, post, weights, w_bound, is_nonnegative=is_nonnegative,
@@ -218,38 +236,65 @@ def _compute_update(w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda
post_wght=post_wght)
return dW, db
- @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
- @staticmethod
- def evolve(opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
- post_wght, bias_init, pre, post, weights, biases, opt_params):
+ @compilable
+ def calc_update(self):
+ # Get the variables
+ pre = self.pre.get()
+ post = self.post.get()
+ weights = self.weights.get()
+ biases = self.biases.get()
+ opt_params = self.opt_params.get()
+
## calculate synaptic update values
dWeights, dBiases = HebbianSynapse._compute_update(
- w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght, post_wght,
+ self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
+ pre, post, weights
+ )
+
+ self.dWeights.set(dWeights)
+ self.dBiases.set(dBiases)
+
+ @compilable
+ def evolve(self):
+ # Get the variables
+ pre = self.pre.get()
+ post = self.post.get()
+ weights = self.weights.get()
+ biases = self.biases.get()
+ opt_params = self.opt_params.get()
+
+ ## calculate synaptic update values
+ dWeights, dBiases = HebbianSynapse._compute_update(
+ self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda, self.pre_wght, self.post_wght,
pre, post, weights
)
## conduct a step of optimization - get newly evolved synaptic weight value matrix
- if bias_init != None:
- opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
+ if self.bias_init != None:
+ opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases])
else:
# ignore db since no biases configured
- opt_params, [weights] = opt(opt_params, [weights], [dWeights])
+ opt_params, [weights] = self.opt(opt_params, [weights], [dWeights])
## ensure synaptic efficacies adhere to constraints
- weights = _enforce_constraints(weights, w_bound, is_nonnegative=is_nonnegative)
- return opt_params, weights, biases, dWeights, dBiases
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- return (
- preVals, # inputs
- postVals, # outputs
- preVals, # pre
- postVals, # post
- jnp.zeros(shape), # dW
- jnp.zeros(shape[1]), # db
- )
+ weights = _enforce_constraints(weights, self.w_bound, is_nonnegative=self.is_nonnegative)
+
+ # Update compartments
+ self.opt_params.set(opt_params)
+ self.weights.set(weights)
+ self.biases.set(biases)
+ self.dWeights.set(dWeights)
+ self.dBiases.set(dBiases)
+
+ @compilable
+ def reset(self): #, batch_size, shape):
+ preVals = jnp.zeros((self.batch_size, self.shape[0]))
+ postVals = jnp.zeros((self.batch_size, self.shape[1]))
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals) # outputs
+ self.pre.set(preVals) # pre
+ self.post.set(postVals) # post
+ self.dWeights.set(jnp.zeros(self.shape)) # dW
+ self.dBiases.set(jnp.zeros(self.shape[1])) # db
@classmethod
def help(cls): ## component help function
@@ -296,23 +341,10 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
Wab = HebbianSynapse("Wab", (2, 3), 0.0004, optim_type='adam',
sign_value=-1.0, prior=("l1l2", 0.001))
print(Wab)
+ print(Wab.opt_params.get())
diff --git a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py
index 777c26cc..1c7ac3ab 100755
--- a/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py
+++ b/ngclearn/components/synapses/hebbian/traceSTDPSynapse.py
@@ -1,10 +1,7 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-
-from ngclearn.components.synapses import DenseSynapse
-from ngclearn.utils import tensorstats
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+from ngclearn.components.synapses.denseSynapse import DenseSynapse
class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
@@ -57,36 +54,37 @@ class TraceSTDPSynapse(DenseSynapse): # power-law / trace-based STDP
initialization to use
resist_scale: a fixed scaling factor to apply to synaptic transform
- (Default: 1.), i.e., yields: out = ((W * Rscale) * in)
+ (Default: 1.), i.e., yields: out = ((W * resistance) * in)
p_conn: probability of a connection existing (default: 1); setting
this to < 1. will result in a sparser synaptic structure
w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1)
+
+ tau_w: synaptic weight decay coefficient to apply to STDP update
+
+ weight_mask: synaptic binary masking matrix to apply (to enforce a constant sparse structure; default: None)
"""
- # Define Functions
def __init__(
self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., weight_init=None, resist_scale=1.,
p_conn=1., w_bound=1., tau_w=0., weight_mask=None, batch_size=1, **kwargs
):
- super().__init__(name, shape, weight_init, None, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
+ super().__init__(name, shape, weight_init, None, resist_scale, p_conn, batch_size=batch_size, **kwargs)
- ## Synaptic hyper-parameters
- self.shape = shape ## shape of synaptic efficacy matrix
self.tau_w = tau_w
self.mu = mu ## controls power-scaling of STDP rule
self.preTrace_target = pretrace_target ## target (pre-synaptic) trace activity value # 0.7
self.Aplus = A_plus ## LTP strength
self.Aminus = A_minus ## LTD strength
- self.Rscale = resist_scale ## post-transformation scale factor
self.w_bound = w_bound #1. ## soft weight constraint
self.w_eps = 0. ## w_eps = 0.01
- self.weight_mask = weight_mask
- if self.weight_mask is None:
+
+ if weight_mask is None:
self.weight_mask = jnp.ones((1, 1))
- self.weights.set(self.weights.value * self.weight_mask)
+ else:
+ self.weight_mask = weight_mask
+ self.weights.set(self.weights.get() * self.weight_mask)
## Compartment setup
preVals = jnp.zeros((self.batch_size, shape[0]))
@@ -95,80 +93,59 @@ def __init__(
self.postSpike = Compartment(postVals)
self.preTrace = Compartment(preVals)
self.postTrace = Compartment(postVals)
- self.dWeights = Compartment(self.weights.value * 0)
- self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate
-
- #@transition(output_compartments=["outputs"])
- #@staticmethod
- #def advance_state(Rscale, inputs, weights, biases, weight_mask):
- # outputs = (jnp.matmul(inputs, weights * weight_mask) * Rscale) + biases
- # return outputs
-
- @staticmethod
- def _compute_update(
- dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
- ):
- pre = preSpike
- x_pre = preTrace
- post = postSpike
- x_post = postTrace
- W = weights
- x_tar = preTrace_target
- if mu > 0.:
- ## equations 3, 5, & 6 from Diehl and Cook - full power-law STDP
- post_shift = jnp.power(w_bound - W, mu)
- pre_shift = jnp.power(W, mu)
- dWpost = (post_shift * jnp.matmul((x_pre - x_tar).T, post)) * Aplus
- dWpre = 0.
- if Aminus > 0.:
- dWpre = -(pre_shift * jnp.matmul(pre.T, x_post)) * Aminus
+ self.dWeights = Compartment(self.weights.get() * 0)
+ self.eta = eta ## global learning rate
+
+ def _compute_update(self):
+ if self.mu > 0.:
+ post_shift = jnp.power(self.w_bound - self.weights.get(), self.mu)
+ pre_shift = jnp.power(self.weights.get(), self.mu)
+ dWpost = (post_shift * jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get())) * self.Aplus
+
+ if self.Aminus > 0.:
+ dWpre = -(pre_shift * jnp.matmul(self.preSpike.get().T, self.postTrace.get())) * self.Aminus
+ else:
+ dWpre = 0.
+
else:
- ## calculate post-synaptic term
- dWpost = jnp.matmul((x_pre - x_tar).T, post * Aplus)
-
- dWpre = 0.
- if Aminus > 0.:
- ## calculate pre-synaptic term
- dWpre = -jnp.matmul(pre.T, x_post * Aminus)
- ## calc final weighted adjustment
+ dWpost = jnp.matmul((self.preTrace.get() - self.preTrace_target).T, self.postSpike.get() * self.Aplus)
+ if self.Aminus > 0.:
+ dWpre = -jnp.matmul(self.preSpike.get().T, self.postTrace.get() * self.Aminus)
+ else:
+ dWpre = 0.
+
dW = (dWpost + dWpre)
return dW
- @transition(output_compartments=["weights", "dWeights"])
- @staticmethod
- def evolve(
- dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_w, preSpike, postSpike, preTrace,
- postTrace, weights, eta, weight_mask
- ):
- #_wm = weight_mask #
- _wm = (weight_mask != 0.)
- dWeights = TraceSTDPSynapse._compute_update(
- dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
- )
- ## do a gradient ascent update/shift
- decayTerm = 0.
- if tau_w > 0.:
- decayTerm = weights / tau_w
- weights = weights + (dWeights * eta) - decayTerm #weight_mask * eta)
- ## enforce non-negativity
- #w_eps = 0. # 0.01 # 0.001
- weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
- weights = weights * _wm # weight_mask
- return weights, dWeights
-
- @transition(output_compartments=["inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- inputs = preVals
- outputs = postVals
- preSpike = preVals
- postSpike = postVals
- preTrace = preVals
- postTrace = postVals
- dWeights = jnp.zeros(shape)
- return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights
+ @compilable
+ def evolve(self):
+ dWeights = self._compute_update()
+ if self.tau_w > 0.:
+ decayTerm = self.weights.get() / self.tau_w
+ else:
+ decayTerm = 0.
+
+ # print(jnp.nonzero(dWeights))
+ w = self.weights.get() + (dWeights * self.eta) - decayTerm
+ w = jnp.clip(w, self.w_eps, self.w_bound - self.w_eps)
+ w = jnp.where(self.weight_mask != 0., w, 0.)
+ self.weights.set(w)
+ self.dWeights.set(dWeights)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.preSpike.set(preVals)
+ self.postSpike.set(postVals)
+ self.preTrace.set(preVals)
+ self.postTrace.set(postVals)
+ self.dWeights.set(jnp.zeros(self.shape.get()))
+
@classmethod
def help(cls): ## component help function
@@ -206,6 +183,7 @@ def help(cls): ## component help function
"eta": "Global learning rate initial condition",
"mu": "Power factor for STDP adjustment",
"pretrace_target": "Pre-synaptic disconnecting/decay factor (x_tar)",
+ "weight_mask" : "Binary synaptic weight mask to apply to enforce a sparsity structure"
}
info = {cls.__name__: properties,
"compartments": compartment_props,
@@ -214,19 +192,6 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
if __name__ == '__main__':
from ngcsimlib.context import Context
diff --git a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py
index 6e5dd8c4..150ebc9b 100755
--- a/ngclearn/components/synapses/modulated/MSTDPETSynapse.py
+++ b/ngclearn/components/synapses/modulated/MSTDPETSynapse.py
@@ -1,12 +1,8 @@
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
from ngclearn.components.synapses.hebbian import TraceSTDPSynapse
-from ngclearn.utils import tensorstats
class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligility traces
"""
@@ -72,78 +68,69 @@ class MSTDPETSynapse(TraceSTDPSynapse): # modulated trace-based STDP w/ eligilit
p_conn: probability of a connection existing (default: 1.); setting
this to < 1. will result in a sparser synaptic structure
+
+ w_bound: maximum value/magnitude any synaptic efficacy can be (default: 1)
"""
- # Define Functions
def __init__(
- self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1., tau_w=0.,
- weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs
+ self, name, shape, A_plus, A_minus, eta=1., mu=0., pretrace_target=0., tau_elg=0., elg_decay=1.,
+ tau_w=0., weight_init=None, resist_scale=1., p_conn=1., w_bound=1., batch_size=1, **kwargs
):
- super().__init__(
+ super().__init__( # call to parent trace-stdp component
name, shape, A_plus, A_minus, eta=eta, mu=mu, pretrace_target=pretrace_target, weight_init=weight_init,
resist_scale=resist_scale, p_conn=p_conn, w_bound=w_bound, batch_size=batch_size, **kwargs
)
self.w_eps = 0.
self.tau_w = tau_w
## MSTDP/MSTDP-ET meta-parameters
- self.tau_elg = tau_elg
- self.elg_decay = elg_decay
+ self.tau_elg = tau_elg ## time constant for eligibility trace
+ self.elg_decay = elg_decay ## decay factor eligibility trace
## MSTDP/MSTDP-ET compartments
self.modulator = Compartment(jnp.zeros((self.batch_size, 1)))
self.eligibility = Compartment(jnp.zeros(shape))
self.outmask = Compartment(jnp.zeros((1, shape[1])))
- @transition(output_compartments=["weights", "dWeights", "eligibility"])
- @staticmethod
- def evolve(
- dt, w_bound, w_eps, preTrace_target, mu, Aplus, Aminus, tau_elg, elg_decay, tau_w, preSpike, postSpike,
- preTrace, postTrace, weights, dWeights, eta, modulator, eligibility, outmask
- ):
- # dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
- # dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
- # )
+ @compilable
+ def evolve(self, dt, t):
+ # dW_dt = self._compute_update()
# dWeights = dW_dt ## can think of this as eligibility at time t
- if tau_elg > 0.: ## perform dynamics of M-STDP-ET
- eligibility = eligibility * jnp.exp(-dt / tau_elg) * elg_decay + dWeights/tau_elg
+ if self.tau_elg > 0.: ## perform dynamics of M-STDP-ET
+ eligibility = self.eligibility.get() * jnp.exp(-dt / self.tau_elg) * self.elg_decay + self.dWeights.get()/self.tau_elg
else: ## otherwise, just do M-STDP
- eligibility = dWeights ## dynamics of M-STDP had no eligibility tracing
+ eligibility = self.dWeights.get() ## dynamics of M-STDP had no eligibility tracing
## do a gradient ascent update/shift
decayTerm = 0.
- if tau_w > 0.:
- decayTerm = weights * (1. / tau_w)
- weights = weights + (eligibility * modulator * eta) * outmask - decayTerm ## do modulated update
+ if self.tau_w > 0.:
+ decayTerm = self.weights.get() * (1. / self.tau_w)
+ ## do modulated update
+ weights = self.weights.get() + (eligibility * self.modulator.get() * self.eta) * self.outmask.get() - decayTerm
- dW_dt = TraceSTDPSynapse._compute_update( ## use Hebbian/STDP rule to obtain a non-modulated update
- dt, w_bound, preTrace_target, mu, Aplus, Aminus, preSpike, postSpike, preTrace, postTrace, weights
- )
+ dW_dt = self._compute_update() ## apply a Hebbian/STDP rule to obtain a non-modulated update
dWeights = dW_dt ## can think of this as eligibility at time t
#w_eps = 0.01
- weights = jnp.clip(weights, w_eps, w_bound - w_eps) # jnp.abs(w_bound))
-
- return weights, dWeights, eligibility
-
- @transition(
- output_compartments=[
- "inputs", "outputs", "preSpike", "postSpike", "preTrace", "postTrace", "dWeights", "eligibility", "outmask"
- ]
- )
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- synVals = jnp.zeros(shape)
- inputs = preVals
- outputs = postVals
- preSpike = preVals
- postSpike = postVals
- preTrace = preVals
- postTrace = postVals
- dWeights = synVals
- eligibility = synVals
- outmask = postVals + 1.
- return inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeights, eligibility, outmask
+ weights = jnp.clip(weights, self.w_eps, self.w_bound - self.w_eps) # jnp.abs(w_bound))
+ self.weights.set(weights)
+ self.dWeights.set(dWeights)
+ self.eligibility.set(eligibility)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size.get(), self.shape.get()[0]))
+ postVals = jnp.zeros((self.batch_size.get(), self.shape.get()[1]))
+ synVals = jnp.zeros(self.shape.get())
+
+ if not self.inputs.targeted:
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.preSpike.set(preVals)
+ self.postSpike.set(postVals)
+ self.preTrace.set(preVals)
+ self.postTrace.set(postVals)
+ self.dWeights.set(synVals)
+ self.eligibility.set(synVals)
+ self.outmask.set(postVals + 1.)
@classmethod
def help(cls): ## component help function
@@ -195,17 +182,3 @@ def help(cls): ## component help function
"dW^{stdp}_{ij}/dt = A_plus * (z_j - x_tar) * s_i - A_minus * s_j * z_i",
"hyperparameters": hyperparams}
return info
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
diff --git a/ngclearn/components/synapses/modulated/REINFORCESynapse.py b/ngclearn/components/synapses/modulated/REINFORCESynapse.py
index 92b72d88..9219e930 100644
--- a/ngclearn/components/synapses/modulated/REINFORCESynapse.py
+++ b/ngclearn/components/synapses/modulated/REINFORCESynapse.py
@@ -1,27 +1,76 @@
+# %%
+
from jax import random, numpy as jnp, jit
-from ngcsimlib.compilers.process import transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
+from ngclearn import compilable, Compartment
+
from ngclearn.utils.model_utils import clip, d_clip
import jax
-import jax.numpy as jnp
-import numpy as np
+#import numpy as np
from ngclearn.components.synapses import DenseSynapse
from ngclearn.utils import tensorstats
from ngclearn.utils.model_utils import create_function
-def gaussian_logpdf(event, mean, stddev):
+def _gaussian_logpdf(event, mean, stddev):
scale_sqrd = stddev ** 2
log_normalizer = jnp.log(2 * jnp.pi * scale_sqrd)
quadratic = (event - mean)**2 / scale_sqrd
return - 0.5 * (log_normalizer + quadratic)
+
+def _compute_update(
+ dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
+):
+ learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
+ # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
+ W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
+ # Forward pass
+ activation = act_fx(inputs)
+ mean = activation @ W_mu
+ fx_mean = mu_act_fx(mean)
+ logstd = activation @ W_logstd
+ clip_logstd = clip(logstd, -10.0, 2.0)
+ std = jnp.exp(clip_logstd)
+ std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
+ # Sample using reparameterization trick
+ epsilon = jax.random.normal(seed, fx_mean.shape)
+ sample = epsilon * std + fx_mean
+ sample = jnp.clip(sample, mu_out_min, mu_out_max)
+ outputs = sample # the actual action that we take
+ # Compute log probability density of the Gaussian
+ log_prob = _gaussian_logpdf(sample, fx_mean, std).sum(-1)
+ # Compute objective (negative REINFORCE objective)
+ objective = (-log_prob * rewards).mean() * 1e-2
+
+ # Backward pass
+ batch_size = inputs.shape[0] # B
+ dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)
+
+ # Compute gradients manually based on the derivation
+ # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
+ dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2)
+ dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A)
+ dL_dWmu = activation.T @ dL_dmean
+
+ # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
+ dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
+ dL_dstd = dL_dlogp * dlog_prob_dlogstd
+ # Apply gradient clipping for logstd
+ dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
+ dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
+ dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
+
+ # Update weights, negate the gradient because gradient ascent in ngc-learn
+ dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
+ # Finally, return metrics if needed
+ return dW, objective, outputs
+
+
class REINFORCESynapse(DenseSynapse):
"""
A stochastic synapse implementing the REINFORCE algorithm (policy gradient method). This synapse
uses Gaussian distributions for generating actions and performs gradient-based updates.
-
+
| --- Synapse Compartments: ---
| inputs - input (takes in external signals)
| outputs - output signals (sampled actions from Gaussian distribution)
@@ -73,8 +122,10 @@ def __init__(
) -> None:
# This is because we have weights mu and weight log sigma
input_dim, output_dim = shape
- super().__init__(name, (input_dim, output_dim * 2), weight_init, None, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
+ super().__init__(
+ name, (input_dim, output_dim * 2), weight_init, None, resist_scale, p_conn,
+ batch_size=batch_size, **kwargs
+ )
## Synaptic hyper-parameters
self.shape = shape ## shape of synaptic efficacy matrix
@@ -89,7 +140,7 @@ def __init__(
self.scalar_stddev = scalar_stddev
## Compartment setup
- self.dWeights = Compartment(self.weights.value * 0)
+ self.dWeights = Compartment(self.weights.get() * 0)
# self.eta = Compartment(jnp.ones((1, 1)) * eta) ## global learning rate # For eligiblity traces later
self.objective = Compartment(jnp.zeros(()))
self.outputs = Compartment(jnp.zeros((batch_size, output_dim)))
@@ -101,83 +152,63 @@ def __init__(
self.learning_mask = Compartment(jnp.zeros(()))
self.seed = Compartment(jax.random.PRNGKey(seed if seed is not None else 42))
- @staticmethod
- def _compute_update(dt, inputs, rewards, act_fx, weights, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
- learning_stddev_mask = jnp.asarray(scalar_stddev <= 0.0, dtype=jnp.float32)
- # (input_dim, output_dim * 2) => (input_dim, output_dim), (input_dim, output_dim)
- W_mu, W_logstd = jnp.split(weights, 2, axis=-1)
- # Forward pass
- activation = act_fx(inputs)
- mean = activation @ W_mu
- fx_mean = mu_act_fx(mean)
- logstd = activation @ W_logstd
- clip_logstd = clip(logstd, -10.0, 2.0)
- std = jnp.exp(clip_logstd)
- std = learning_stddev_mask * std + (1.0 - learning_stddev_mask) * scalar_stddev # masking trick
- # Sample using reparameterization trick
- epsilon = jax.random.normal(seed, fx_mean.shape)
- sample = epsilon * std + fx_mean
- sample = jnp.clip(sample, mu_out_min, mu_out_max)
- outputs = sample # the actual action that we take
- # Compute log probability density of the Gaussian
- log_prob = gaussian_logpdf(sample, fx_mean, std).sum(-1)
- # Compute objective (negative REINFORCE objective)
- objective = (-log_prob * rewards).mean() * 1e-2
-
- # Backward pass
- batch_size = inputs.shape[0] # B
- dL_dlogp = -rewards[:, None] * 1e-2 / batch_size # (B, 1)
-
- # Compute gradients manually based on the derivation
- # dL/dmu = -(r-r_hat) * dlog_prob/dmu = -(r-r_hat) * -(sample-mu)/sigma^2
- dlog_prob_dfxmean = (sample - fx_mean) / (std ** 2)
- dL_dmean = dL_dlogp * dlog_prob_dfxmean * dmu_act_fx(mean) # (B, A)
- dL_dWmu = activation.T @ dL_dmean
-
- # dL/dlog(sigma) = -(r-r_hat) * dlog_prob/dlog(sigma) = -(r-r_hat) * (((sample-mu)/sigma)^2 - 1)
- dlog_prob_dlogstd = - 1.0 / std + (sample - fx_mean)**2 / std**3
- dL_dstd = dL_dlogp * dlog_prob_dlogstd
- # Apply gradient clipping for logstd
- dL_dlogstd = d_clip(logstd, -10.0, 2.0) * dL_dstd * std
- dL_dWlogstd = activation.T @ dL_dlogstd # (I, B) @ (B, A) = (I, A)
- dL_dWlogstd = dL_dWlogstd * learning_stddev_mask # there is no learning for the scalar stddev
-
- # Update weights, negate the gradient because gradient ascent in ngc-learn
- dW = jnp.concatenate([-dL_dWmu, -dL_dWlogstd], axis=-1)
- # Finally, return metrics if needed
- return dW, objective, outputs
-
- @transition(output_compartments=["weights", "dWeights", "objective", "outputs", "accumulated_gradients", "step_count", "seed"])
- @staticmethod
- def evolve(dt, w_bound, inputs, rewards, act_fx, weights, eta, learning_mask, decay, accumulated_gradients, step_count, seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev):
+ @compilable
+ def evolve(self, dt):
+ # Get compartment values
+ weights = self.weights.get()
+ dWeights = self.dWeights.get()
+ objective = self.objective.get()
+ outputs = self.outputs.get()
+ accumulated_gradients = self.accumulated_gradients.get()
+ step_count = self.step_count.get()
+ seed = self.seed.get()
+ inputs = self.inputs.get()
+ rewards = self.rewards.get()
+
+ # Main logic
main_seed, sub_seed = jax.random.split(seed)
- dWeights, objective, outputs = REINFORCESynapse._compute_update(
- dt, inputs, rewards, act_fx, weights, sub_seed, mu_act_fx, dmu_act_fx, mu_out_min, mu_out_max, scalar_stddev
+ dWeights, objective, outputs = _compute_update(
+ dt, inputs, rewards, self.act_fx, weights, sub_seed, self.mu_act_fx, self.dmu_act_fx, self.mu_out_min, self.mu_out_max, self.scalar_stddev
)
## do a gradient ascent update/shift
- weights = (weights + dWeights * eta) * learning_mask + weights * (1.0 - learning_mask) # update the weights only where learning_mask is 1.0
+ weights = (weights + dWeights * self.eta) * self.learning_mask + weights * (1.0 - self.learning_mask.get()) # update the weights only where learning_mask is 1.0
## enforce non-negativity
eps = 0.0 # 0.01 # 0.001
- weights = jnp.clip(weights, eps, w_bound - eps) # jnp.abs(w_bound))
+ weights = jnp.clip(weights, eps, self.w_bound - eps) # jnp.abs(w_bound))
step_count += 1
- accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
- step_count = step_count * (1 - learning_mask) # reset the step count to 0 when we have learned
- return weights, dWeights, objective, outputs, accumulated_gradients, step_count, main_seed
-
- @transition(output_compartments=["inputs", "outputs", "objective", "rewards", "dWeights", "accumulated_gradients", "step_count", "seed"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
+ accumulated_gradients = (step_count - 1) / step_count * accumulated_gradients * self.decay + 1.0 / step_count * dWeights # EMA update of accumulated gradients
+ step_count = step_count * (1 - self.learning_mask.get()) # reset the step count to 0 when we have learned
+
+ # Set updated compartment values
+ self.weights.set(weights)
+ self.dWeights.set(dWeights)
+ self.objective.set(objective)
+ self.outputs.set(outputs)
+ self.accumulated_gradients.set(accumulated_gradients)
+ self.step_count.set(step_count)
+ self.seed.set(main_seed)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size, self.shape[0]))
+ postVals = jnp.zeros((self.batch_size, self.shape[1]))
inputs = preVals
outputs = postVals
objective = jnp.zeros(())
- rewards = jnp.zeros((batch_size,))
- dWeights = jnp.zeros(shape)
- accumulated_gradients = jnp.zeros((shape[0], shape[1] * 2))
+ rewards = jnp.zeros((self.batch_size,))
+ dWeights = jnp.zeros(self.shape)
+ accumulated_gradients = jnp.zeros((self.shape[0], self.shape[1] * 2))
step_count = jnp.zeros(())
seed = jax.random.PRNGKey(42)
- return inputs, outputs, objective, rewards, dWeights, accumulated_gradients, step_count, seed
+
+ hasattr(self.inputs, 'targeted') and not self.inputs.targeted and self.inputs.set(inputs)
+ self.outputs.set(outputs)
+ self.objective.set(objective)
+ self.rewards.set(rewards)
+ self.dWeights.set(dWeights)
+ self.accumulated_gradients.set(accumulated_gradients)
+ self.step_count.set(step_count)
+ self.seed.set(seed)
@classmethod
def help(cls): ## component help function
@@ -222,16 +253,14 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
+
+if __name__ == '__main__':
+ from ngcsimlib.context import Context
+ with Context("Bar") as bar:
+ syn = REINFORCESynapse(
+ name="reinforce_syn",
+ shape=(3, 2)
+ )
+ # Wab = syn.weights.get()
+ print(syn)
+
diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
index 1415f51a..ae58c6ac 100644
--- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
+++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
@@ -1,16 +1,22 @@
+# %%
+
import matplotlib.pyplot as plt
from jax import random, numpy as jnp, jit
from functools import partial
from ngclearn.utils.optim import get_opt_init_fn, get_opt_step_fn
-from ngclearn import resolver, Component, Compartment
+
+from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+
from ngclearn.components.synapses.patched import PatchedSynapse
from ngclearn.utils import tensorstats
-from ngcsimlib.compilers.process import transition
-@partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
-def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1.,
- prior_type=None, prior_lmbda=0.,
- pre_wght=1., post_wght=1.):
+# @partial(jit, static_argnums=[3, 4, 5, 6, 7, 8, 9])
+def _calc_update(
+ pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1., prior_type=None, prior_lmbda=0., pre_wght=1.,
+ post_wght=1.
+):
"""
Compute a tensor of adjustments to be applied to a synaptic value matrix.
@@ -64,12 +70,12 @@ def _calc_update(pre, post, W, mask, w_bound, is_nonnegative=True, signVal=1.,
dW = dW + prior_lmbda * dW_reg
- if mask!=None:
+ if mask != None:
dW = dW * mask
return dW * signVal, db * signVal
-@partial(jit, static_argnums=[1,2, 3])
+# @partial(jit, static_argnums=[1,2, 3])
def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True):
"""
Enforces constraints that the (synaptic) efficacies/values within matrix
@@ -89,12 +95,12 @@ def _enforce_constraints(W, block_mask, w_bound, is_nonnegative=True):
"""
_W = W
if w_bound > 0.:
- if is_nonnegative == True:
+ if is_nonnegative:
_W = jnp.clip(_W, 0., w_bound)
else:
_W = jnp.clip(_W, -w_bound, w_bound)
- if block_mask!=None:
+ if block_mask != None:
_W = _W * block_mask
return _W
@@ -185,12 +191,15 @@ class HebbianPatchedSynapse(PatchedSynapse):
batch_size: the size of each mini batch
"""
- def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
- block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1.,
- optim_type="sgd", pre_wght=1., post_wght=1., p_conn=1.,
- resist_scale=1., batch_size=1, **kwargs):
- super().__init__(name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale,
- p_conn, batch_size=batch_size, **kwargs)
+ def __init__(
+ self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weight_init=None, bias_init=None,
+ block_mask=None, w_bound=1., is_nonnegative=False, prior=(None, 0.), sign_value=1., optim_type="sgd",
+ pre_wght=1., post_wght=1., p_conn=1., resist_scale=1., batch_size=1, **kwargs
+ ):
+ super().__init__(
+ name, shape, n_sub_models, stride_shape, block_mask, weight_init, bias_init, resist_scale, p_conn,
+ batch_size=batch_size, **kwargs
+ )
prior_type, prior_lmbda = prior
self.prior_type = prior_type
@@ -225,10 +234,10 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), eta=0., weig
self.dWeights = Compartment(jnp.zeros(self.shape))
self.dBiases = Compartment(jnp.zeros(self.shape[1]))
- #key, subkey = random.split(self.key.value)
+ #key, subkey = random.split(self.key.get())
self.opt_params = Compartment(get_opt_init_fn(optim_type)(
- [self.weights.value, self.biases.value]
- if bias_init else [self.weights.value]))
+ [self.weights.get(), self.biases.get()]
+ if bias_init else [self.weights.get()]))
@staticmethod
def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
@@ -241,38 +250,48 @@ def _compute_update(block_mask, w_bound, is_nonnegative, sign_value, prior_type,
return dW * jnp.where(0 != jnp.abs(weights), 1, 0) , db
- @transition(output_compartments=["opt_params", "weights", "biases", "dWeights", "dBiases"])
- @staticmethod
- def evolve(block_mask, opt, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda, pre_wght,
- post_wght, bias_init, pre, post, weights, biases, opt_params):
+ @compilable
+ def evolve(self):
+ # Get the variables
+ pre = self.pre.get()
+ post = self.post.get()
+ weights = self.weights.get()
+ biases = self.biases.get()
+ opt_params = self.opt_params.get()
+
## calculate synaptic update values
dWeights, dBiases = HebbianPatchedSynapse._compute_update(
- block_mask, w_bound, is_nonnegative, sign_value, prior_type, prior_lmbda,
- pre_wght, post_wght, pre, post, weights
+ self.block_mask, self.w_bound, self.is_nonnegative, self.sign_value, self.prior_type, self.prior_lmbda,
+ self.pre_wght, self.post_wght, pre, post, weights
)
## conduct a step of optimization - get newly evolved synaptic weight value matrix
- if bias_init != None:
- opt_params, [weights, biases] = opt(opt_params, [weights, biases], [dWeights, dBiases])
+ if self.bias_init != None:
+ opt_params, [weights, biases] = self.opt(opt_params, [weights, biases], [dWeights, dBiases])
else:
# ignore db since no biases configured
- opt_params, [weights] = opt(opt_params, [weights], [dWeights])
+ opt_params, [weights] = self.opt(opt_params, [weights], [dWeights])
## ensure synaptic efficacies adhere to constraints
- weights = _enforce_constraints(weights, block_mask, w_bound, is_nonnegative=is_nonnegative)
- return opt_params, weights, biases, dWeights, dBiases
-
- @transition(output_compartments=["inputs", "outputs", "pre", "post", "dWeights", "dBiases"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
- return (
- preVals, # inputs
- postVals, # outputs
- preVals, # pre
- postVals, # post
- jnp.zeros(shape), # dW
- jnp.zeros(shape[1]), # db
- )
+ weights = _enforce_constraints(weights, self.block_mask, self.w_bound, is_nonnegative=self.is_nonnegative)
+
+ # Update compartments
+ self.opt_params.set(opt_params)
+ self.weights.set(weights)
+ self.biases.set(biases)
+ self.dWeights.set(dWeights)
+ self.dBiases.set(dBiases)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size, self.shape[0]))
+ postVals = jnp.zeros((self.batch_size, self.shape[1]))
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
+ self.outputs.set(postVals) # outputs
+ self.pre.set(preVals) # pre
+ self.post.set(postVals) # post
+ self.dWeights.set(jnp.zeros(self.shape)) # dW
+ self.dBiases.set(jnp.zeros(self.shape[1])) # db
@classmethod
@@ -323,35 +342,12 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
-
-
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-
-
-
-
-
-
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
Wab = HebbianPatchedSynapse("Wab", (9, 30), 3, (0, 0), optim_type='adam',
sign_value=-1.0, prior=("l1l2", 0.001))
print(Wab)
- plt.imshow(Wab.weights.value, cmap='gray')
+ plt.imshow(Wab.weights.get(), cmap='gray')
plt.show()
diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py
index 43d1dc16..540bd30a 100644
--- a/ngclearn/components/synapses/patched/patchedSynapse.py
+++ b/ngclearn/components/synapses/patched/patchedSynapse.py
@@ -1,21 +1,55 @@
+# %%
+
import matplotlib.pyplot as plt
from jax import random, numpy as jnp, jit
-from ngclearn import resolver, Component, Compartment
from ngclearn.components.jaxComponent import JaxComponent
-from ngclearn.utils import tensorstats
-from ngcsimlib.compilers.process import transition
-from ngclearn.utils.weight_distribution import initialize_params
-from ngcsimlib.logger import info
-import math
-
+from ngclearn.utils.distribution_generator import DistributionGenerator
-def create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
+from ngcsimlib.logger import info
+from ngclearn import compilable #from ngcsimlib.parser import compilable
+from ngclearn import Compartment #from ngcsimlib.compartment import Compartment
+# from ngclearn.utils.weight_distribution import initialize_params
+
+
+# def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
+# sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
+# di, dj = sub_shape
+# si, sj = sub_stride
+
+# weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
+# #weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
+# large_weight_init = DistributionGenerator.constant(value=0.)
+# weights = large_weight_init(weight_shape, key[2])
+
+# for i in range(n_sub_models):
+# start_i = i * di
+# end_i = (i + 1) * di + 2 * si
+# start_j = i * dj
+# end_j = (i + 1) * dj + 2 * sj
+
+# shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
+
+# ## FIXME: this line below might be wonky...
+# weights.at[start_i: end_i, start_j: end_j].set( weight_init(shape_, key[2]) )
+# # weights[start_i : end_i,
+# # start_j : end_j] = initialize_params(key[2], init_kernel=weight_init, shape=shape_, use_numpy=True)
+# if si != 0:
+# weights.at[:si,:].set(0.) ## FIXME: this setter line might be wonky...
+# weights.at[-si:,:].set(0.) ## FIXME: this setter line might be wonky...
+# if sj != 0:
+# weights.at[:,:sj].set(0.) ## FIXME: this setter line might be wonky...
+# weights.at[:, -sj:].set(0.) ## FIXME: this setter line might be wonky...
+
+# return weights
+
+def _create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_init):
sub_shape = (shape[0] // n_sub_models, shape[1] // n_sub_models)
di, dj = sub_shape
si, sj = sub_stride
weight_shape = ((n_sub_models * di) + 2 * si, (n_sub_models * dj) + 2 * sj)
- weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
+ # weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, weight_shape, use_numpy=True)
+ weights = DistributionGenerator.constant(value=0.)(weight_shape, key[2])
for i in range(n_sub_models):
start_i = i * di
@@ -25,22 +59,23 @@ def create_multi_patch_synapses(key, shape, n_sub_models, sub_stride, weight_ini
shape_ = (end_i - start_i, end_j - start_j) # (di + 2 * si, dj + 2 * sj)
- weights[start_i : end_i,
- start_j : end_j] = initialize_params(key[2],
- init_kernel=weight_init,
- shape=shape_,
- use_numpy=True)
+ # weights[start_i : end_i,
+ # start_j : end_j] = initialize_params(key[2],
+ # init_kernel=weight_init,
+ # shape=shape_,
+ # use_numpy=True)
+ weights = weights.at[start_i : end_i,
+ start_j : end_j].set(weight_init(shape_, key[2]))
if si!=0:
- weights[:si,:] = 0.
- weights[-si:,:] = 0.
+ weights = weights.at[:si,:].set(0.)
+ weights = weights.at[-si:,:].set(0.)
if sj!=0:
- weights[:,:sj] = 0.
- weights[:, -sj:] = 0.
+ weights = weights.at[:,:sj].set(0.)
+ weights = weights.at[:, -sj:].set(0.)
return weights
-
class PatchedSynapse(JaxComponent): ## base patched synaptic cable
"""
A patched dense synaptic cables that creates multiple small dense synaptic cables; no form of synaptic evolution/adaptation
@@ -66,7 +101,7 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
with number of inputs by number of outputs)
n_sub_models: The number of submodels in each layer (Default: 1 similar functionality as DenseSynapse)
-
+
stride_shape: Stride shape of overlapping synaptic weight value matrix
(Default: (0, 0))
@@ -92,8 +127,10 @@ class PatchedSynapse(JaxComponent): ## base patched synaptic cable
this to < 1. will result in a sparser synaptic structure
"""
- def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None,
- resist_scale=1., p_conn=1., batch_size=1, **kwargs):
+ def __init__(
+ self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=None, weight_init=None, bias_init=None,
+ resist_scale=1., p_conn=1., batch_size=1, **kwargs
+ ):
super().__init__(name, **kwargs)
self.Rscale = resist_scale
@@ -104,13 +141,16 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=N
self.n_sub_models = n_sub_models
self.sub_stride = stride_shape
- tmp_key, *subkeys = random.split(self.key.value, 4)
+ tmp_key, *subkeys = random.split(self.key.get(), 4)
if self.weight_init is None:
info(self.name, "is using default weight initializer!")
- self.weight_init = {"dist": "fan_in_gaussian"}
+ #self.weight_init = {"dist": "fan_in_gaussian"}
+ self.weight_init = DistributionGenerator.fan_in_gaussian()
- weights = create_multi_patch_synapses(key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
- weight_init=self.weight_init)
+ weights = _create_multi_patch_synapses(
+ key=subkeys, shape=shape, n_sub_models=self.n_sub_models, sub_stride=self.sub_stride,
+ weight_init=self.weight_init
+ )
self.block_mask = jnp.where(weights!=0, 1, 0)
self.sub_shape = (shape[0]//n_sub_models, shape[1]//n_sub_models)
@@ -133,39 +173,31 @@ def __init__(self, name, shape, n_sub_models=1, stride_shape=(0,0), block_mask=N
if self.bias_init is None:
info(self.name, "is using default bias value of zero (no bias "
"kernel provided)!")
- self.biases = Compartment(initialize_params(subkeys[2], bias_init,
- (1, self.shape[1]))
- if bias_init else 0.0)
-
- @transition(output_compartments=["outputs"])
- @staticmethod
- def advance_state(Rscale, inputs, weights, biases):
- outputs = (jnp.matmul(inputs, weights) * Rscale) + biases
- return outputs
-
- @transition(output_compartments=["inputs", "outputs"])
- @staticmethod
- def reset(batch_size, shape):
- preVals = jnp.zeros((batch_size, shape[0]))
- postVals = jnp.zeros((batch_size, shape[1]))
+ self.biases = Compartment(self.bias_init((1, self.shape[1]), subkeys[2]) if bias_init else 0.0)
+ #elf.biases = Compartment(initialize_params(subkeys[2], bias_init, (1, self.shape[1])) if bias_init else 0.0)
+
+ @compilable
+ def advance_state(self):
+ # Get the variables
+ inputs = self.inputs.get()
+ weights = self.weights.get()
+ biases = self.biases.get()
+
+ outputs = (jnp.matmul(inputs, weights) * self.Rscale) + biases
+
+ # Update compartment
+ self.outputs.set(outputs)
+
+ @compilable
+ def reset(self):
+ preVals = jnp.zeros((self.batch_size, self.shape[0]))
+ postVals = jnp.zeros((self.batch_size, self.shape[1]))
inputs = preVals
outputs = postVals
- return inputs, outputs
-
- def save(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- if self.bias_init != None:
- jnp.savez(file_name, weights=self.weights.value,
- biases=self.biases.value)
- else:
- jnp.savez(file_name, weights=self.weights.value)
-
- def load(self, directory, **kwargs):
- file_name = directory + "/" + self.name + ".npz"
- data = jnp.load(file_name)
- self.weights.set(data['weights'])
- if "biases" in data.keys():
- self.biases.set(data['biases'])
+ # BUG: the self.inputs here does not have the targeted field
+ # NOTE: Quick workaround is to check if targeted is in the input or not
+ hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(inputs)
+ self.outputs.set(outputs)
@classmethod
def help(cls): ## component help function
@@ -201,36 +233,11 @@ def help(cls): ## component help function
"hyperparameters": hyperparams}
return info
- def __repr__(self):
- comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
- maxlen = max(len(c) for c in comps) + 5
- lines = f"[{self.__class__.__name__}] PATH: {self.name}\n"
- for c in comps:
- stats = tensorstats(getattr(self, c).value)
- if stats is not None:
- line = [f"{k}: {v}" for k, v in stats.items()]
- line = ", ".join(line)
- else:
- line = "None"
- lines += f" {f'({c})'.ljust(maxlen)}{line}\n"
- return lines
-
-
-
-
-
-
if __name__ == '__main__':
from ngcsimlib.context import Context
with Context("Bar") as bar:
Wab = PatchedSynapse("Wab", (9, 30), 3)
print(Wab)
- plt.imshow(Wab.weights.value, cmap='gray')
+ plt.imshow(Wab.weights.get(), cmap='gray')
plt.show()
-
-
-
-
-
-
diff --git a/ngclearn/modules/__init__.py b/ngclearn/modules/__init__.py
index b18f84b7..38866e21 100644
--- a/ngclearn/modules/__init__.py
+++ b/ngclearn/modules/__init__.py
@@ -2,7 +2,3 @@
from .regression.lasso import Iterative_Lasso
from .regression.ridge import Iterative_Ridge
-
-
-
-
diff --git a/ngclearn/modules/regression/__init__.py b/ngclearn/modules/regression/__init__.py
index 064d5303..bc45b6b2 100644
--- a/ngclearn/modules/regression/__init__.py
+++ b/ngclearn/modules/regression/__init__.py
@@ -2,8 +2,3 @@
from .lasso import Iterative_Lasso
from .ridge import Iterative_Ridge
-
-
-
-
-
diff --git a/ngclearn/modules/regression/elastic_net.py b/ngclearn/modules/regression/elastic_net.py
index 9cec8948..5860d2bc 100644
--- a/ngclearn/modules/regression/elastic_net.py
+++ b/ngclearn/modules/regression/elastic_net.py
@@ -1,18 +1,17 @@
-from jax import random, jit
import numpy as np
-from ngclearn.utils import weight_distribution as dist
-from ngclearn import Context, numpy as jnp
-from ngclearn.components import (RateCell,
- HebbianSynapse,
- GaussianErrorCell,
- StaticSynapse)
-from ngclearn.utils.model_utils import scanner
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn import numpy as jnp
+from jax import numpy as jnp, random, jit
+from ngclearn import Context, MethodProcess
+from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
+from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
+from ngcsimlib.global_state import stateManager
class Iterative_ElasticNet():
"""
A neural circuit implementation of the iterative Elastic Net (L1 and L2) algorithm
- using Hebbian learning update rule.
+ using a Hebbian learning update rule.
The circuit implements sparse regression through Hebbian synapses with Elastic Net regularization.
@@ -22,8 +21,6 @@ class Iterative_ElasticNet():
| dW_reg = (jnp.sign(W) * l1_ratio) + (W * (1-l1_ratio)/2)
| dW/dt = dW + lmbda * dW_reg
-
-
| --- Circuit Components: ---
| W - HebbianSynapse for learning regularized dictionary weights
| err - GaussianErrorCell for computing prediction errors
@@ -77,54 +74,43 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
feature_dim = dict_dim
with Context(self.name) as self.circuit:
- self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
- sign_value=-1, weight_init=dist.constant(weight_fill),
- prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0.,
- optim_type=optim_type, key=subkeys[0])
+ self.W = HebbianSynapse(
+ "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
+ weight_init=dist.constant(value=weight_fill), prior=('elastic_net', (lmbda, l1_ratio)), w_bound=0.,
+ optim_type=optim_type, key=subkeys[0]
+ )
self.err = GaussianErrorCell("err", n_units=sys_dim)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.err.mu << self.W.outputs
- self.W.post << self.err.dmu
+ self.W.outputs >> self.err.mu
+ self.err.dmu >> self.W.post
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
- self.err, ## finally, execute error neurons
- compile_key="advance_state")
- evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
- reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
- # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.dynamic()
- def dynamic(self): ## create dynamic commands forself.circuit
- W, err = self.circuit.get_components("W", "err")
- self.self = W
- self.err = err
-
- @Context.dynamicCommand
- def batch_set(batch_size):
- self.W.batch_size = batch_size
- self.err.batch_size = batch_size
+ advance = (MethodProcess(name="advance_state")
+ >> self.W.advance_state
+ >> self.err.advance_state)
+ self.advance = advance
- @Context.dynamicCommand
- def clamps(y_scaled, X):
- self.W.inputs.set(X)
- self.W.pre.set(X)
- self.err.target.set(y_scaled)
+ evolve = (MethodProcess(name="evolve")
+ >> self.W.evolve)
+ self.evolve = evolve
- self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
- self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
- self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
+ reset = (MethodProcess(name="reset")
+ >> self.err.reset
+ >> self.W.reset)
+ self.reset = reset
+ def batch_set(self, batch_size):
+ self.W.batch_size = batch_size
+ self.err.batch_size = batch_size
- @scanner
- def _process(compartment_values, args):
- _t, _dt = args
- compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
- return compartment_values, compartment_values[self.W.weights.path]
-
+ def clamp(self, y_scaled, X):
+ self.W.inputs.set(X)
+ self.W.pre.set(X)
+ self.err.target.set(y_scaled)
def thresholding(self, scale=1.):
coef_old = self.coef_
@@ -138,18 +124,15 @@ def thresholding(self, scale=1.):
def fit(self, y, X):
- self.circuit.reset()
- self.circuit.clamps(y_scaled=y, X=X)
+ self.reset.run()
+ self.clamp(y_scaled=y, X=X)
for i in range(self.epochs):
- self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
- self.circuit.evolve(t=self.T, dt=self.dt)
-
- self.coef_ = np.array(self.W.weights.value)
-
- return self.coef_, self.err.mu.value, self.err.L.value
-
-
+ inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt))
+ stateManager.state, outputs = self.advance.scan(inputs)
+ self.evolve.run(t=self.T, dt=self.dt)
+ self.coef_ = np.array(self.W.weights.get())
+ return self.coef_, self.err.mu.get(), self.err.L.get()
diff --git a/ngclearn/modules/regression/lasso.py b/ngclearn/modules/regression/lasso.py
index c0d8c8ef..15a014bb 100644
--- a/ngclearn/modules/regression/lasso.py
+++ b/ngclearn/modules/regression/lasso.py
@@ -1,26 +1,19 @@
-import jax
-import pandas as pd
-from jax import random, jit
import numpy as np
-from scipy.integrate import solve_ivp
-import matplotlib.pyplot as plt
-from ngcsimlib.utils import Get_Compartment_Batch
-from ngclearn.utils.model_utils import normalize_matrix
-from ngclearn.utils import weight_distribution as dist
-from ngclearn import Context, numpy as jnp
-from ngclearn.components import (RateCell,
- HebbianSynapse,
- GaussianErrorCell,
- StaticSynapse)
-from ngclearn.utils.model_utils import scanner
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn import numpy as jnp
+from jax import numpy as jnp, random, jit
+from ngclearn import Context, MethodProcess
+from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
+from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
+from ngcsimlib.global_state import stateManager
class Iterative_Lasso():
"""
A neural circuit implementation of the iterative Lasso (L1) algorithm
- using Hebbian learning update rule.
+ using a Hebbian learning update rule.
- The circuit implements sparse coding through Hebbian synapses with L1 regularization.
+ The circuit implements sparse coding-like regression through Hebbian synapses with L1 regularization.
The specific differential equation that characterizes this model is adding lmbda * sign(W)
to the dW (the gradient of loss/energy function):
@@ -80,52 +73,42 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
feature_dim = dict_dim
with Context(self.name) as self.circuit:
- self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
- sign_value=-1, weight_init=dist.constant(weight_fill),
- prior=('lasso', lasso_lmbda), w_bound=0.,
- optim_type=optim_type, key=subkeys[0])
+ self.W = HebbianSynapse(
+ "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
+ weight_init=dist.constant(value=weight_fill), prior=('lasso', lasso_lmbda), w_bound=0.,
+ optim_type=optim_type, key=subkeys[0]
+ )
self.err = GaussianErrorCell("err", n_units=sys_dim)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.err.mu << self.W.outputs
- self.W.post << self.err.dmu
+ self.W.outputs >> self.err.mu
+ self.err.dmu >> self.W.post
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
- self.err, ## finally, execute error neurons
- compile_key="advance_state")
- evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
- reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
- # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.dynamic()
-
- def dynamic(self): ## create dynamic commands for self.circuit
- W, err = self.circuit.get_components("W", "err")
- self.self = W
- self.err = err
-
- @Context.dynamicCommand
- def batch_set(batch_size):
- self.W.batch_size = batch_size
- self.err.batch_size = batch_size
-
- @Context.dynamicCommand
- def clamps(y_scaled, X):
- self.W.inputs.set(X)
- self.W.pre.set(X)
- self.err.target.set(y_scaled)
-
- self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
- self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
- self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
-
- @scanner
- def _process(compartment_values, args):
- _t, _dt = args
- compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
- return compartment_values, compartment_values[self.W.weights.path]
-
+
+ advance = (MethodProcess(name="advance_state")
+ >> self.W.advance_state
+ >> self.err.advance_state)
+ self.advance = advance
+
+ evolve = (MethodProcess(name="evolve")
+ >> self.W.evolve)
+ self.evolve = evolve
+
+ reset = (MethodProcess(name="reset")
+ >> self.err.reset
+ >> self.W.reset)
+ self.reset = reset
+
+ def batch_set(self, batch_size):
+ self.W.batch_size = batch_size
+ self.err.batch_size = batch_size
+
+ def clamp(self, y_scaled, X):
+ self.W.inputs.set(X)
+ self.W.pre.set(X)
+ self.err.target.set(y_scaled)
def thresholding(self, scale=2):
coef_old = self.coef_
@@ -136,23 +119,16 @@ def thresholding(self, scale=2):
return self.coef_, coef_old
-
def fit(self, y, X):
-
- self.circuit.reset()
- self.circuit.clamps(y_scaled=y, X=X)
+ self.reset.run()
+ self.clamp(y_scaled=y, X=X)
for i in range(self.epochs):
- self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
- self.circuit.evolve(t=self.T, dt=self.dt)
-
- self.coef_ = np.array(self.W.weights.value)
-
- return self.coef_, self.err.mu.value, self.err.L.value
-
-
-
-
+ inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt))
+ stateManager.state, outputs = self.advance.scan(inputs)
+ self.evolve.run(t=self.T, dt=self.dt)
+ self.coef_ = np.array(self.W.weights.get())
+ return self.coef_, self.err.mu.get(), self.err.L.get()
diff --git a/ngclearn/modules/regression/ridge.py b/ngclearn/modules/regression/ridge.py
index b1698aba..dfbacb03 100644
--- a/ngclearn/modules/regression/ridge.py
+++ b/ngclearn/modules/regression/ridge.py
@@ -1,21 +1,19 @@
-from jax import random, jit
import numpy as np
-from ngclearn.utils import weight_distribution as dist
-from ngclearn import Context, numpy as jnp
-from ngclearn.components import (RateCell,
- HebbianSynapse,
- GaussianErrorCell,
- StaticSynapse)
-from ngclearn.utils.model_utils import scanner
-
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn import numpy as jnp
+from jax import numpy as jnp, random, jit
+from ngclearn import Context, MethodProcess
+from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
+from ngclearn.components.neurons.graded.gaussianErrorCell import GaussianErrorCell
+from ngcsimlib.global_state import stateManager
class Iterative_Ridge():
"""
A neural circuit implementation of the iterative Ridge (L2) algorithm
- using Hebbian learning update rule.
+ using a Hebbian learning update rule.
- The circuit implements sparse regression through Hebbian synapses with L2 regularization.
+ This circuit implements sparse regression through Hebbian synapses with L2 regularization.
The specific differential equation that characterizes this model is adding lmbda * W
to the dW (the gradient of loss/energy function):
@@ -75,54 +73,43 @@ def __init__(self, key, name, sys_dim, dict_dim, batch_size, weight_fill=0.05, l
feature_dim = dict_dim
with Context(self.name) as self.circuit:
- self.W = HebbianSynapse("W", shape=(feature_dim, sys_dim), eta=self.lr,
- sign_value=-1, weight_init=dist.constant(weight_fill),
- prior=('ridge', ridge_lmbda), w_bound=0.,
- optim_type=optim_type, key=subkeys[0])
+ self.W = HebbianSynapse(
+ "W", shape=(feature_dim, sys_dim), eta=self.lr, sign_value=-1,
+ weight_init=dist.constant(value=weight_fill), prior=('ridge', ridge_lmbda), w_bound=0.,
+ optim_type=optim_type, key=subkeys[0]
+ )
self.err = GaussianErrorCell("err", n_units=sys_dim)
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
self.W.batch_size = batch_size
self.err.batch_size = batch_size
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.err.mu << self.W.outputs
- self.W.post << self.err.dmu
+ self.W.outputs >> self.err.mu
+ self.err.dmu >> self.W.post
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- advance_cmd, advance_args =self.circuit.compile_by_key(self.W, ## execute prediction synapses
- self.err, ## finally, execute error neurons
- compile_key="advance_state")
- evolve_cmd, evolve_args =self.circuit.compile_by_key(self.W, compile_key="evolve")
- reset_cmd, reset_args =self.circuit.compile_by_key(self.err, self.W, compile_key="reset")
- # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- self.dynamic()
-
- def dynamic(self): ## create dynamic commands forself.circuit
- W, err = self.circuit.get_components("W", "err")
- self.self = W
- self.err = err
-
- @Context.dynamicCommand
- def batch_set(batch_size):
- self.W.batch_size = batch_size
- self.err.batch_size = batch_size
-
- @Context.dynamicCommand
- def clamps(y_scaled, X):
- self.W.inputs.set(X)
- self.W.pre.set(X)
- self.err.target.set(y_scaled)
-
- self.circuit.wrap_and_add_command(jit(self.circuit.evolve), name="evolve")
- self.circuit.wrap_and_add_command(jit(self.circuit.advance_state), name="advance")
- self.circuit.wrap_and_add_command(jit(self.circuit.reset), name="reset")
-
-
- @scanner
- def _process(compartment_values, args):
- _t, _dt = args
- compartment_values = self.circuit.advance_state(compartment_values, t=_t, dt=_dt)
- return compartment_values, compartment_values[self.W.weights.path]
+ advance = (MethodProcess(name="advance_state")
+ >> self.W.advance_state
+ >> self.err.advance_state)
+ self.advance = advance
+
+ evolve = (MethodProcess(name="evolve")
+ >> self.W.evolve)
+ self.evolve = evolve
+
+ reset = (MethodProcess(name="reset")
+ >> self.err.reset
+ >> self.W.reset)
+ self.reset = reset
+
+ def batch_set(self, batch_size):
+ self.W.batch_size = batch_size
+ self.err.batch_size = batch_size
+
+ def clamp(self, y_scaled, X):
+ self.W.inputs.set(X)
+ self.W.pre.set(X)
+ self.err.target.set(y_scaled)
def thresholding(self, scale=2):
coef_old = self.coef_ #self.W.weights.value
@@ -135,21 +122,15 @@ def thresholding(self, scale=2):
def fit(self, y, X):
- self.circuit.reset()
- self.circuit.clamps(y_scaled=y, X=X)
+ self.reset.run()
+ self.clamp(y_scaled=y, X=X)
for i in range(self.epochs):
- self.circuit._process(jnp.array([[self.dt * i, self.dt] for i in range(self.T)]))
- self.circuit.evolve(t=self.T, dt=self.dt)
-
- self.coef_ = np.array(self.W.weights.value)
-
- return self.coef_, self.err.mu.value, self.err.L.value
-
-
-
-
-
+ inputs = jnp.array(self.advance.pack_rows(self.T, t=lambda x: x, dt=self.dt))
+ stateManager.state, outputs = self.advance.scan(inputs)
+ self.evolve.run(t=self.T, dt=self.dt)
+ self.coef_ = np.array(self.W.weights.get())
+ return self.coef_, self.err.mu.get(), self.err.L.get()
diff --git a/ngclearn/utils/JaxProcessesMixin.py b/ngclearn/utils/JaxProcessesMixin.py
new file mode 100644
index 00000000..ae1a655c
--- /dev/null
+++ b/ngclearn/utils/JaxProcessesMixin.py
@@ -0,0 +1,41 @@
+from ngcsimlib import JointProcess, MethodProcess
+from ngcsimlib.global_state import stateManager
+import jax
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ngcsimlib._src.process.baseProcess import BaseProcess
+
+class JaxProcessesMixin:
+ def __init__(self: "BaseProcess"):
+ self._previous_result = None
+ self._previous_state = None
+
+ @property
+ def previous_result(self):
+ return self._previous_result
+
+ @property
+ def previous_state(self):
+ return self._previous_state
+
+ def clear(self):
+ self._previous_result = None
+ self._previous_state = None
+
+
+ def scan(self: "BaseProcess", inputs, current_state=None, save_state: bool = True, store_results: bool = True):
+ state = current_state or stateManager.state
+ final_state, result = jax.lax.scan(self.run.compiled, state, inputs)
+ if save_state:
+ self._previous_state = final_state
+ if store_results:
+ self._previous_result = result
+ return final_state, result
+
+
+
+class JaxJointProcess(JointProcess, JaxProcessesMixin):
+ pass
+
+class JaxMethodProcess(MethodProcess, JaxProcessesMixin):
+ pass
diff --git a/ngclearn/utils/__init__.py b/ngclearn/utils/__init__.py
old mode 100755
new mode 100644
index 9c9f984c..1d8c114e
--- a/ngclearn/utils/__init__.py
+++ b/ngclearn/utils/__init__.py
@@ -1,10 +1,4 @@
+from .distribution_generator import DistributionGenerator
+from .JaxProcessesMixin import JaxJointProcess as JointProcess, JaxMethodProcess as MethodProcess
from .model_utils import tensorstats
-from .jaxProcess import JaxProcess
-## forward imports from core ngc-learn utility sub-packages
-from . import viz
-from . import io_utils
-from . import metric_utils
-from . import model_utils
-from . import patch_utils
-from . import weight_distribution
-from . import surrogate_fx
+
diff --git a/ngclearn/utils/analysis/linear_probe.py b/ngclearn/utils/analysis/linear_probe.py
index e6eb2a31..a5546073 100644
--- a/ngclearn/utils/analysis/linear_probe.py
+++ b/ngclearn/utils/analysis/linear_probe.py
@@ -4,7 +4,7 @@
from ngclearn.utils.model_utils import drop_out, softmax, layer_normalize
from jax import jit, random, numpy as jnp, lax, nn
from functools import partial as bind
-import ngclearn.utils.weight_distribution as dist
+from ngclearn.utils.distribution_generator import DistributionGenerator
from ngclearn.utils.optim import adam, sgd
@bind(jax.jit, static_argnums=[2, 3])
@@ -88,10 +88,10 @@ def __init__(
## set up classifier
flat_input_dim = input_dim * source_seq_length
- weight_init = dist.fan_in_gaussian() # dist.gaussian(mu=0., sigma=0.05) # 0.02)
+ weight_init = DistributionGenerator.fan_in_gaussian() #dist.fan_in_gaussian() # dist.gaussian(mu=0., sigma=0.05) # 0.02)
Wln_mu = jnp.zeros((1, flat_input_dim))
Wln_scale = jnp.ones((1, flat_input_dim))
- W = dist.initialize_params(subkeys[0], weight_init, (flat_input_dim, out_dim))
+ W = weight_init((flat_input_dim, out_dim), subkeys[0]) #dist.initialize_params(subkeys[0], weight_init, (flat_input_dim, out_dim))
b = jnp.zeros((1, out_dim))
self.probe_params = [Wln_mu, Wln_scale, W, b]
diff --git a/ngclearn/utils/data_loader.py b/ngclearn/utils/data_loader.py
index 8bb1cf88..e90df8d9 100644
--- a/ngclearn/utils/data_loader.py
+++ b/ngclearn/utils/data_loader.py
@@ -6,24 +6,24 @@
class DataLoader(object):
"""
- A data loader object, meant to allow sampling w/o replacement of one or
- more named design matrices. Note that this object is iterable (and
- implements an __iter__() method).
+ A data loader object, meant to allow sampling w/o replacement of one or
+ more named design matrices. Note that this object is iterable (and
+ implements an __iter__() method).
- Args:
- design_matrices: list of named data design matrices - [("name", matrix), ...]
+ Args:
+ design_matrices: list of named data design matrices - [("name", matrix), ...]
- batch_size: number of samples to place inside a mini-batch
+ batch_size: number of samples to place inside a mini-batch
- disable_shuffle: if True, turns off sample shuffling (thus no sampling w/o replacement)
+ disable_shuffle: if True, turns off sample shuffling (thus no sampling w/o replacement)
- ensure_equal_batches: if True, ensures sampled batches are equal in size (Default = True).
- Note that this means the very last batch, if it's not the same size as the rest, will
- reuse random samples from previously seen batches (yielding a batch with a mix of
- vectors sampled with and without replacement).
+ ensure_equal_batches: if True, ensures sampled batches are equal in size (Default = True).
+ Note that this means the very last batch, if it's not the same size as the rest, will
+ reuse random samples from previously seen batches (yielding a batch with a mix of
+ vectors sampled with and without replacement).
- key: PRNG key to control determinism of any underlying random values
- associated with this synaptic cable
+ key: PRNG key to control determinism of any underlying random values
+ associated with this synaptic cable
"""
def __init__(self, design_matrices, batch_size, disable_shuffle=False,
ensure_equal_batches=True, key=None):
@@ -47,7 +47,7 @@ def __init__(self, design_matrices, batch_size, disable_shuffle=False,
def __iter__(self):
"""
- Yields a mini-batch of the form: [("name", batch),("name",batch),...]
+ Yields a mini-batch of the form: [("name", batch),("name",batch),...]
"""
if self.disable_shuffle == False:
self.key, *subkeys = random.split(self.key, 2)
diff --git a/ngclearn/utils/density/__init__.py b/ngclearn/utils/density/__init__.py
index e69de29b..f5ebbc56 100644
--- a/ngclearn/utils/density/__init__.py
+++ b/ngclearn/utils/density/__init__.py
@@ -0,0 +1,6 @@
+from .mixture import Mixture ## general mixture template parent class
+## point to supported density estimator models
+from .gaussianMixture import GaussianMixture ## mixture-of-Gaussians
+from .bernoulliMixture import BernoulliMixture ## mixture-of-Bernoullis
+from .exponentialMixture import ExponentialMixture ## mixture-of-exponentials
+
diff --git a/ngclearn/utils/density/bernoulliMixture.py b/ngclearn/utils/density/bernoulliMixture.py
new file mode 100644
index 00000000..7957aea6
--- /dev/null
+++ b/ngclearn/utils/density/bernoulliMixture.py
@@ -0,0 +1,221 @@
+from jax import numpy as jnp, random, jit, scipy
+from functools import partial
+import time, sys
+import numpy as np
+
+from ngclearn.utils.density.mixture import Mixture
+
+########################################################################################################################
+## internal routines for mixture model
+########################################################################################################################
+
+@jit
+def _log_bernoulli_pdf(X, p):
+ """
+ Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under a given parameter
+ probability `p`.
+
+ Args:
+ X: a design matrix (dataset) to compute the log likelihood of
+
+ p: a parameter mean vector (positive case probability)
+
+ Returns:
+ the log likelihood (scalar) of this design matrix X
+ """
+ #D = X.shape[1] * 1. ## get dimensionality
+ ## general format: x log(mu_k) + (1-x) log(1 - mu_k)
+ vec_ll = X * jnp.log(p) + (1. - X) * jnp.log(1. - p) ## binary cross-entropy (log Bernoulli)
+ log_ll = jnp.sum(vec_ll, axis=1, keepdims=True) ## get per-datapoint LL
+ return log_ll
+
+@jit
+def _calc_bernoulli_pdf_vals(X, p):
+ log_ll = _log_bernoulli_pdf(X, p) ## get log-likelihood
+ ll = jnp.exp(log_ll) ## likelihood
+ return log_ll, ll
+
+@jit
+def _calc_bernoulli_mixture_stats(raw_likeli, pi):
+ likeli = raw_likeli * pi
+ gamma = likeli / jnp.sum(likeli, axis=1, keepdims=True) ## responsibilities
+ likeli = jnp.sum(likeli, axis=1, keepdims=True) ## Sum_j[ pi_j * pdf_gauss(x_n; mu_j, Sigma_j) ]
+ log_likeli = jnp.log(likeli) ## vector of individual log p(x_n) values
+ complete_log_likeli = jnp.sum(log_likeli) ## complete log-likelihood for design matrix X, i.e., log p(X)
+ return log_likeli, complete_log_likeli, gamma
+
+@jit
+def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
+ ## calc new means, responsibilities, and priors given current stats
+ N = X.shape[0] ## get number of samples
+ ## calc responsibilities
+ _pi = jnp.sum(weights, axis=0, keepdims=True) / N ## calc new priors
+ ## calc weighted means (weighted by responsibilities)
+ Z = jnp.sum(weights, axis=0, keepdims=True) ## partition function
+ M = (Z > 0.) * 1.
+ Z = Z * M + (1. + M) ## removes div-by-0 cases
+ means = jnp.matmul(weights.T, X) / Z.T
+ return _pi, means
+
+@partial(jit, static_argnums=[1])
+def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
+ log_pi = jnp.log(pi) ## calc log(prior)
+ lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
+ return lats
+
+@partial(jit, static_argnums=[1])
+def _sample_component(dkey, n_samples, mu): ## samples a component (of mixture)
+ x_s = random.bernoulli(dkey, p=mu, shape=(n_samples, mu.shape[1])) ## draw Bernoulli samples
+ return x_s
+
+########################################################################################################################
+
+class BernoulliMixture(Mixture): ## Bernoulli mixture model (mixture-of-Bernoullis)
+ """
+ Implements a Bernoulli mixture model (BMM) -- or mixture of Bernoullis (MoB).
+ Adaptation of parameters is conducted via the Expectation-Maximization (EM)
+ learning algorithm. Note that this Bernoulli mixture assumes that each component
+ is a factorizable mutlivariate Bernoulli distribution. (A Categorical distribution
+ is assumed over the latent variables).
+
+ Args:
+ K: the number of components/latent variables within this BMM
+
+ max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
+
+ init_kmeans:
+ """
+
+ def __init__(self, K, max_iter=50, init_kmeans=False, key=None, **kwargs):
+ super().__init__(K, max_iter, **kwargs)
+ self.K = K
+ self.max_iter = int(max_iter)
+ self.init_kmeans = init_kmeans ## Unsupported currently
+ self.mu = [] ## component mean parameters
+ self.pi = None ## prior weight parameters
+ #self.z_weights = None # variables for parameterizing weights for SGD
+ self.key = random.PRNGKey(time.time_ns()) if key is None else key
+
+ def init(self, X):
+ """
+ Initializes this BMM in accordance to a supplied design matrix.
+
+ Args:
+ X: the design matrix to initialize this BMM to
+
+ """
+ dim = X.shape[1]
+ self.key, *skey = random.split(self.key, 3)
+ self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
+ ptrs = random.permutation(skey[0], X.shape[0])
+ for j in range(self.K):
+ ptr = ptrs[j]
+ self.key, *skey = random.split(self.key, 3)
+ #self.mu.append(X[ptr:ptr+1,:] * 0 + (1./(dim * 1.)))
+ eps = random.uniform(skey[0], minval=0., maxval=0.9, shape=(1, dim)) ## jitter initial prob params
+ self.mu.append(eps)
+
+ def calc_log_likelihood(self, X):
+ """
+ Calculates the multivariate Bernoulli log likelihood of a design matrix/dataset `X`, under the current
+ parameters of this Bernoulli mixture.
+
+ Args:
+ X: the design matrix to estimate log likelihood values over under this BMM
+
+ Returns:
+ (column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
+ """
+ likeli = []
+ for j in range(self.K):
+ _, likeli_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
+ likeli.append(likeli_j)
+ likeli = jnp.concat(likeli, axis=1)
+ log_likeli_vec, complete_log_likeli, gamma = _calc_bernoulli_mixture_stats(likeli, self.pi)
+ return log_likeli_vec, complete_log_likeli
+
+ def _E_step(self, X): ## Expectation (E) step, co-routine
+ likeli = []
+ for j in range(self.K):
+ _, likeli_j = _calc_bernoulli_pdf_vals(X, self.mu[j])
+ likeli.append(likeli_j)
+ likeli = jnp.concat(likeli, axis=1)
+ log_likeli_vec, complete_log_likeli, gamma = _calc_bernoulli_mixture_stats(likeli, self.pi)
+ ## gamma => ## data-dependent weights (responsibilities)
+ return gamma, log_likeli_vec, complete_log_likeli
+
+ def _M_step(self, X, weights): ## Maximization (M) step, co-routine
+ pi, means = _calc_priors_and_means(X, weights, self.pi)
+ self.pi = pi ## store new prior parameters
+ for j in range(self.K):
+ #r_j = weights[:, j:j + 1] ## get j-th responsibility slice
+ mu_j = means[j:j + 1, :]
+ self.mu[j] = mu_j ## store new mean(j) parameter
+ return pi, means
+
+ def fit(self, X, tol=1e-3, verbose=False):
+ """
+ Run full fitting process of this BMM.
+
+ Args:
+ X: the dataset to fit this BMM to
+
+ tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
+ if tol >= 0. (Default: 1e-3)
+
+ verbose: if True, this function will print out per-iteration measurements to I/O
+ """
+ means_prev = jnp.concat(self.mu, axis=0)
+ for i in range(self.max_iter):
+ gamma, pi, means, complete_loglikeli = self.update(X) ## carry out one E-step followed by an M-step
+ #means = jnp.concat(self.mu, axis=0)
+ dom = jnp.linalg.norm(means - means_prev) ## norm of difference-of-means
+ if verbose:
+ print(f"{i}: Mean-diff = {dom} log(p(X)) = {complete_loglikeli} nats")
+ #print(jnp.linalg.norm(means - means_prev))
+ if tol >= 0. and dom < tol:
+ print(f"Converged after {i + 1} iterations.")
+ break
+ means_prev = means
+
+ def update(self, X):
+ """
+ Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
+
+ Args:
+ X: the dataset / design matrix to fit this BMM to
+ """
+ gamma, _, complete_likeli = self._E_step(X) ## carry out E-step
+ pi, means = self._M_step(X, gamma) ## carry out M-step
+ return gamma, pi, means, complete_likeli
+
+ def sample(self, n_samples, mode_j=-1):
+ """
+ Draw samples from the current underlying BMM model
+
+ Args:
+ n_samples: the number of samples to draw from this BMM
+
+ mode_j: if >= 0, will only draw samples from a specific component of this BMM
+ (Default = -1), ignoring the Categorical prior over latent variables/components
+
+ Returns:
+ Design matrix of samples drawn under the distribution defined by this BMM
+ """
+ self.key, *skey = random.split(self.key, 3)
+ if mode_j >= 0: ## sample from a particular mode
+ mu_j = self.mu[mode_j] ## directly select a specific component
+ Xs = _sample_component(skey[0], n_samples=n_samples, mu=mu_j)
+ else: ## sample from full mixture distribution
+ ## sample (prior) components/latents
+ lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
+ ## then sample chosen component Bernoulli(s)
+ Xs = []
+ for j in range(self.K):
+ freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
+ self.key, *skey = random.split(self.key, 3)
+ x_s = _sample_component(skey[0], n_samples=freq_j, mu=self.mu[j])
+ Xs.append(x_s)
+ Xs = jnp.concat(Xs, axis=0)
+ return Xs
+
diff --git a/ngclearn/utils/density/exponentialMixture.py b/ngclearn/utils/density/exponentialMixture.py
new file mode 100644
index 00000000..f718d57f
--- /dev/null
+++ b/ngclearn/utils/density/exponentialMixture.py
@@ -0,0 +1,216 @@
+from jax import numpy as jnp, random, jit, scipy
+from functools import partial
+import time, sys
+
+from ngclearn.utils.density.mixture import Mixture
+
+########################################################################################################################
+## internal routines for mixture model
+########################################################################################################################
+@jit
+def _log_exponential_pdf(X, lmbda):
+ """
+ Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under a given parameter
+ probability `p`.
+
+ Args:
+ X: a design matrix (dataset) to compute the log likelihood of
+
+ lmbda: a parameter rate vector
+
+ Returns:
+ the log likelihood (scalar) of this design matrix X
+ """
+ log_pdf = -jnp.matmul(X, lmbda.T) + jnp.sum(jnp.log(lmbda.T), axis=0)
+ return log_pdf
+
+@jit
+def _calc_exponential_mixture_stats(X, lmbda, pi):
+ log_exp_pdf = _log_exponential_pdf(X, lmbda)
+ log_likeli = log_exp_pdf + jnp.log(pi) ## raw log-likelihood
+ likeli = jnp.exp(log_likeli) ## raw likelihood
+ gamma = likeli / jnp.sum(likeli, axis=1, keepdims=True) ## responsibilities
+ weighted_log_likeli = jnp.sum(log_likeli * gamma, axis=1, keepdims=True) ## get weighted EMM log-likelihood
+ complete_loglikeli = jnp.sum(weighted_log_likeli) ## complete log-likelihood for design matrix X, i.e., log p(X)
+ return log_likeli, likeli, gamma, weighted_log_likeli, complete_loglikeli
+
+@jit
+def _calc_priors_and_rates(X, weights, pi): ## M-step co-routine
+ ## compute updates to pi params
+ Zk = jnp.sum(weights, axis=0, keepdims=True) ## summed weights/responsibilities; 1 x K
+ Z = jnp.sum(Zk) ## partition function
+ pi = Zk / Z
+ ## compute updates to lmbda params
+ Z = jnp.matmul(weights.T, X)
+ lmbda = Zk.T / Z
+ return pi, lmbda
+
+@partial(jit, static_argnums=[1])
+def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
+ log_pi = jnp.log(pi) ## calc log(prior)
+ lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
+ return lats
+
+@partial(jit, static_argnums=[1])
+def _sample_component(dkey, n_samples, rate): ## samples a component (of mixture)
+ ## sampling ~[exp(rx)] is same as r * [~exp(x)]
+ x_s = random.exponential(dkey, shape=(n_samples, rate.shape[1])) * rate ## draw exponential samples
+ return x_s
+
+########################################################################################################################
+
+class ExponentialMixture(Mixture): ## Exponential mixture model (mixture-of-exponentials)
+ """
+ Implements an exponential mixture model (EMM) -- or mixture of exponentials (MoExp). Adaptation of parameters is
+ conducted via the Expectation-Maximization (EM) learning algorithm. Note that this exponential mixture assumes that
+ each component is a factorizable mutlivariate exponential distribution. (A Categorical distribution is assumed over
+ the latent variables).
+
+ The exponential distribution of each component (dimension `d`) is assumed to be:
+
+ | pdf(x_d; lmbda_d) = lmbda_d * exp(-lmbda_d x_d) for x >= 0, else 0 for x < 0;
+ | where lbmda is the rate parameter vector
+
+ Args:
+ K: the number of components/latent variables within this EMM
+
+ max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
+
+ init_kmeans:
+ """
+
+ def __init__(self, K, max_iter=50, init_kmeans=False, key=None, **kwargs):
+ super().__init__(K, max_iter, **kwargs)
+ self.K = K
+ self.max_iter = int(max_iter)
+ self.init_kmeans = init_kmeans ## Unsupported currently
+ self.rate = [] ## component rate parameters
+ self.pi = None ## prior weight parameters
+ #self.z_weights = None # variables for parameterizing weights for SGD
+ self.key = random.PRNGKey(time.time_ns()) if key is None else key
+
+ def init(self, X):
+ """
+ Initializes this EMM in accordance to a supplied design matrix.
+
+ Args:
+ X: the design matrix to initialize this EMM to
+
+ """
+ dim = X.shape[1]
+ self.key, *skey = random.split(self.key, 4)
+ ## Computed jittered initial phi param values
+ #self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
+ pi = jnp.ones((1, self.K))
+ eps = random.uniform(skey[0], minval=0.99, maxval=1.01, shape=(1, self.K))
+ pi = pi * eps
+ self.pi = pi / jnp.sum(pi)
+
+ ## Computed jittered initial rate (lmbda) param values
+ lmbda_h = 1.0/jnp.mean(X, axis=0, keepdims=True)
+ lmbda = random.uniform(skey[1], minval=0.99, maxval=1.01, shape=(self.K, dim)) * lmbda_h
+ self.rate = []
+ for j in range(self.K): ## set rates/lmbdas
+ self.rate.append(lmbda[j:j+1, :])
+
+ def calc_log_likelihood(self, X):
+ """
+ Calculates the multivariate exponential log likelihood of a design matrix/dataset `X`, under the current
+ parameters of this exponential mixture.
+
+ Args:
+ X: the design matrix to estimate log likelihood values over under this EMM
+
+ Returns:
+ (column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
+ """
+ pi = self.pi ## get prior weight values
+ lmbda = jnp.concat(self.rate, axis=0) ## get rates as a block matrix
+ ## compute relevant log-likelihoods/likelihoods
+ log_ll, ll, gamma, weighted_loglikeli, complete_likeli = _calc_exponential_mixture_stats(X, lmbda, pi)
+ return weighted_loglikeli, complete_likeli
+
+ def _E_step(self, X): ## Expectation (E) step, co-routine
+ pi = self.pi ## get prior weight values
+ lmbda = jnp.concat(self.rate, axis=0) ## get rates as a block matrix
+ _, _, gamma, weighted_loglikeli, complete_likeli = _calc_exponential_mixture_stats(X, lmbda, pi)
+ ## Note: responsibility weights gamma have shape => N x K
+ return gamma, weighted_loglikeli, complete_likeli
+
+ def _M_step(self, X, weights): ## Maximization (M) step, co-routine
+ ## compute updates to pi and lmbda params
+ pi, lmbda = _calc_priors_and_rates(X, weights, self.pi)
+ self.pi = pi ## store new prior parameters
+ for j in range(self.K): ## store new rate/lmbda parameters
+ self.rate[j] = lmbda[j:j+1, :]
+ return pi, lmbda
+
+ def fit(self, X, tol=1e-3, verbose=False):
+ """
+ Run full fitting process of this EMM.
+
+ Args:
+ X: the dataset to fit this EMM to
+
+ tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
+ if tol >= 0. (Default: 1e-3)
+
+ verbose: if True, this function will print out per-iteration measurements to I/O
+ """
+ rates_prev = jnp.concat(self.rate, axis=0)
+ for i in range(self.max_iter):
+ gamma, pi, rates, complete_loglikeli = self.update(X) ## carry out one E-step followed by an M-step
+ #rates = jnp.concat(self.rate, axis=0)
+ dor = jnp.linalg.norm(rates - rates_prev) ## norm of difference-of-rates
+ if verbose:
+ print(f"{i}: Rate-diff = {dor} log(p(X)) = {complete_loglikeli} nats")
+ #print(jnp.linalg.norm(rates - rates_prev))
+ if tol >= 0. and dor < tol:
+ print(f"Converged after {i + 1} iterations.")
+ break
+ rates_prev = rates
+
+ def update(self, X):
+ """
+ Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
+
+ Args:
+ X: the dataset / design matrix to fit this BMM to
+
+ Returns:
+ responsibilities (gamma), priors (pi), rates (lambda), EMM log-likelihood
+ """
+ gamma, _, complete_log_likeli = self._E_step(X) ## carry out E-step
+ pi, rates = self._M_step(X, gamma) ## carry out M-step
+ return gamma, pi, rates, complete_log_likeli
+
+ def sample(self, n_samples, mode_j=-1):
+ """
+ Draw samples from the current underlying EMM model
+
+ Args:
+ n_samples: the number of samples to draw from this EMM
+
+ mode_j: if >= 0, will only draw samples from a specific component of this EMM
+ (Default = -1), ignoring the Categorical prior over latent variables/components
+
+ Returns:
+ Design matrix of samples drawn under the distribution defined by this EMM
+ """
+ self.key, *skey = random.split(self.key, 3)
+ if mode_j >= 0: ## sample from a particular mode
+ rate_j = self.rate[mode_j] ## directly select a specific component
+ Xs = _sample_component(skey[0], n_samples=n_samples, rate=rate_j)
+ else: ## sample from full mixture distribution
+ ## sample (prior) components/latents
+ lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
+ ## then sample chosen component exponential(s)
+ Xs = []
+ for j in range(self.K):
+ freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
+ self.key, *skey = random.split(self.key, 3)
+ x_s = _sample_component(skey[0], n_samples=freq_j, rate=self.rate[j])
+ Xs.append(x_s)
+ Xs = jnp.concat(Xs, axis=0)
+ return Xs
+
diff --git a/ngclearn/utils/density/gaussianMixture.py b/ngclearn/utils/density/gaussianMixture.py
new file mode 100644
index 00000000..506f2032
--- /dev/null
+++ b/ngclearn/utils/density/gaussianMixture.py
@@ -0,0 +1,277 @@
+from jax import numpy as jnp, random, jit, scipy
+from functools import partial
+import time, sys
+import numpy as np
+
+from ngclearn.utils.density.mixture import Mixture
+
+########################################################################################################################
+## internal routines for mixture model
+########################################################################################################################
+
+@partial(jit, static_argnums=[3])
+def _log_gaussian_pdf(X, mu, Sigma, use_chol_prec=True):
+ """
+ Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under a given parameter mean
+ `mu` and parameter covariance `Sigma`.
+
+ Args:
+ X: a design matrix (dataset) to compute the log likelihood of
+ mu: a parameter mean vector
+ Sigma: a parameter covariance matrix
+ use_chol_prec: should this routine use Cholesky-factor computation of the precision (Default: True)
+
+ Returns:
+ the log likelihood (scalar) of this design matrix X
+ """
+ D = mu.shape[1] * 1. ## get dimensionality
+ if use_chol_prec: ## use Cholesky-factor calc of precision
+ C = jnp.linalg.cholesky(Sigma) # calc_prec_chol(mu, cov)
+ inv_C = jnp.linalg.pinv(C)
+ precision = jnp.matmul(inv_C.T, inv_C)
+ else: ## use Moore-Penrose pseudo-inverse calc of precision
+ precision = jnp.linalg.pinv(Sigma)
+ ## finish computing log-likelihood
+ sign_ld, abs_ld = jnp.linalg.slogdet(Sigma)
+ log_det_sigma = abs_ld * sign_ld ## log-determinant of precision
+ Z = X - mu ## calc deltas
+ quad_term = jnp.sum((jnp.matmul(Z, precision) * Z), axis=1, keepdims=True) ## LL quadratic term
+ return -(jnp.log(2. * np.pi) * D + log_det_sigma + quad_term) * 0.5
+
+@partial(jit, static_argnums=[3])
+def _calc_gaussian_pdf_vals(X, mu, Sigma, use_chol_prec=True):
+ log_likeli = _log_gaussian_pdf(X, mu, Sigma, use_chol_prec)
+ likeli = jnp.exp(log_likeli)
+ return log_likeli, likeli
+
+@jit
+def _calc_gaussian_mixture_stats(raw_likeli, pi):
+ likeli = raw_likeli * pi
+ gamma = likeli / jnp.sum(likeli, axis=1, keepdims=True) ## responsibilities
+ likeli = jnp.sum(likeli, axis=1, keepdims=True) ## Sum_j[ pi_j * pdf_gauss(x_n; mu_j, Sigma_j) ]
+ log_likeli = jnp.log(likeli) ## vector of individual log p(x_n) values
+ complete_log_likeli = jnp.sum(log_likeli) ## complete log-likelihood for design matrix X, i.e., log p(X)
+ return log_likeli, complete_log_likeli, gamma
+
+@partial(jit, static_argnums=[3])
+def _calc_weighted_cov(X, mu, weights, assume_diag_cov=False): ## M-step co-routine
+ ## calc new covariance Sigma given data, means, and responsibilities
+ diff = X - mu
+ sigma_j = jnp.matmul((weights * diff).T, diff) / jnp.sum(weights)
+ if assume_diag_cov:
+ sigma_j = sigma_j * jnp.eye(sigma_j.shape[1])
+ return sigma_j
+
+@jit
+def _calc_priors_and_means(X, weights, pi): ## M-step co-routine
+ ## calc new means, responsibilities, and priors given current stats
+ N = X.shape[0] ## get number of samples
+ ## calc responsibilities
+ _pi = jnp.sum(weights, axis=0, keepdims=True) / N ## calc new priors
+ ## calc weighted means (weighted by responsibilities)
+ Z = jnp.sum(weights, axis=0, keepdims=True) ## partition function
+ M = (Z > 0.) * 1.
+ Z = Z * M + (1. + M) ## removes div-by-0 cases
+ means = jnp.matmul(weights.T, X) / Z.T
+ return _pi, means
+
+@partial(jit, static_argnums=[1])
+def _sample_prior_weights(dkey, n_samples, pi): ## samples prior weighting parameters (of mixture)
+ log_pi = jnp.log(pi) ## calc log(prior)
+ lats = random.categorical(dkey, logits=log_pi, shape=(n_samples, 1)) ## sample components/latents
+ return lats
+
+@partial(jit, static_argnums=[1, 4])
+def _sample_component(dkey, n_samples, mu, Sigma, assume_diag_cov=False): ## samples a component (of mixture)
+ eps = random.normal(dkey, shape=(n_samples, mu.shape[1])) ## draw unit Gaussian noise
+ ## apply scale-shift transformation
+ if assume_diag_cov:
+ R = jnp.sum(jnp.sqrt(Sigma), axis=0, keepdims=True)
+ x_s = mu + eps * R
+ else:
+ R = jnp.linalg.cholesky(Sigma) ## decompose covariance via Cholesky
+ x_s = mu + jnp.matmul(eps, R) # tf.matmul(eps, R)
+ return x_s
+
+# def _log_gaussian_pdf(X, mu, sigma):
+# C = jnp.linalg.cholesky(sigma) #calc_prec_chol(mu, cov)
+# inv_C = jnp.linalg.pinv(C)
+# prec_chol = jnp.matmul(inv_C, inv_C.T)
+# #prec_chol = jnp.linalg.inv(sigma)
+#
+# N, D = X.shape ## n_samples x dimensionality
+# # det(precision_chol) is half of det(precision)
+# sign_ld, abs_ld = jnp.linalg.slogdet(prec_chol)
+# log_det = abs_ld * sign_ld ## log determinant of Cholesky precision
+# y = jnp.matmul(X, prec_chol) - jnp.matmul(mu, prec_chol)
+# log_prob = jnp.sum(y * y, axis=1, keepdims=True)
+# #return -0.5 * (D * jnp.log(np.pi * 2) + log_prob) + log_det
+# #return -0.5 * (D * jnp.log(np.pi * 2) + log_det + log_prob)
+# return -jnp.log(np.pi * 2) * (D * 0.5) - log_det * 0.5 - log_prob * 0.5
+
+########################################################################################################################
+
+class GaussianMixture(Mixture): ## Gaussian mixture model (mixture-of-Gaussians)
+ """
+ Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians (MoG).
+ Adaptation of parameters is conducted via the Expectation-Maximization (EM)
+ learning algorithm and leverages full covariance matrices in the component
+ multivariate Gaussians. (A Categorical distribution is assumed over the
+ latent variables).
+
+ Args:
+ K: the number of components/latent variables within this GMM
+
+ max_iter: the maximum number of EM iterations to fit parameters to data (Default = 50)
+
+ assume_diag_cov: if True, assumes a diagonal covariance for each component (Default = False)
+
+ init_kmeans:
+ """
+ # init_kmeans: if True, first learn use the K-Means algorithm to initialize
+ # the component Gaussians of this GMM (Default = False)
+
+ def __init__(self, K, max_iter=50, assume_diag_cov=False, init_kmeans=False, key=None, **kwargs):
+ super().__init__(K, max_iter, **kwargs)
+ self.K = K
+ self.max_iter = int(max_iter)
+ self.assume_diag_cov = assume_diag_cov
+ self.init_kmeans = init_kmeans ## Unsupported currently
+ self.mu = [] ## component mean parameters
+ self.Sigma = [] ## component covariance parameters
+ self.pi = None ## prior weight parameters
+ #self.z_weights = None # variables for parameterizing weights for SGD
+ self.key = random.PRNGKey(time.time_ns()) if key is None else key
+
+ def init(self, X):
+ """
+ Initializes this GMM in accordance to a supplied design matrix.
+
+ Args:
+ X: the design matrix to initialize this GMM to
+
+ """
+ dim = X.shape[1]
+ self.key, *skey = random.split(self.key, 3)
+ self.pi = jnp.ones((1, self.K)) / (self.K * 1.)
+ ptrs = random.permutation(skey[0], X.shape[0])
+ for j in range(self.K):
+ ptr = ptrs[j]
+ #self.key, *skey = random.split(self.key, 3)
+ self.mu.append(X[ptr:ptr+1,:])
+ Sigma_j = jnp.eye(dim)
+ #sigma_j = random.uniform(skey[0], minval=0.01, maxval=0.9, shape=(dim, dim))
+ self.Sigma.append(Sigma_j)
+
+ def calc_log_likelihood(self, X):
+ """
+ Calculates the multivariate Gaussian log likelihood of a design matrix/dataset `X`, under the current
+ parameters of this Gaussian mixture model.
+
+ Args:
+ X: the design matrix to estimate log likelihood values over under this GMM
+
+ Returns:
+ (column) vector of individual log likelihoods, scalar for the complete log likelihood p(X)
+ """
+ likeli = []
+ for j in range(self.K):
+ _, likeli_j = _calc_gaussian_pdf_vals(X, self.mu[j], self.Sigma[j])
+ likeli.append(likeli_j)
+ likeli = jnp.concat(likeli, axis=1)
+ log_likeli_vec, complete_log_likeli, gamma = _calc_gaussian_mixture_stats(likeli, self.pi)
+ return log_likeli_vec, complete_log_likeli
+
+ def _E_step(self, X): ## Expectation (E) step, co-routine
+ likeli = []
+ for j in range(self.K):
+ _, likeli_j = _calc_gaussian_pdf_vals(X, self.mu[j], self.Sigma[j])
+ likeli.append(likeli_j)
+ likeli = jnp.concat(likeli, axis=1)
+ log_likeli_vec, complete_log_likeli, gamma = _calc_gaussian_mixture_stats(likeli, self.pi)
+ ## gamma => ## data-dependent weights (responsibilities)
+ return gamma, log_likeli_vec, complete_log_likeli
+
+ def _M_step(self, X, weights): ## Maximization (M) step, co-routine
+ pi, means = _calc_priors_and_means(X, weights, self.pi)
+ self.pi = pi ## store new prior parameters
+ # calc weighted covariances
+ for j in range(self.K):
+ r_j = weights[:, j:j + 1] ## get j-th responsibility slice
+ mu_j = means[j:j + 1, :]
+ sigma_j = _calc_weighted_cov(X, mu_j, r_j, assume_diag_cov=self.assume_diag_cov)
+ self.mu[j] = mu_j ## store new mean(j) parameter
+ self.Sigma[j] = sigma_j ## store new covariance(j) parameter
+ return pi, means
+
+ def fit(self, X, tol=1e-3, verbose=False):
+ """
+ Run full fitting process of this GMM.
+
+ Args:
+ X: the dataset to fit this GMM to
+
+ tol: the tolerance value for detecting convergence (via difference-of-means); will engage in early-stopping
+ if tol >= 0. (Default: 1e-3)
+
+ verbose: if True, this function will print out per-iteration measurements to I/O
+ """
+ means_prev = jnp.concat(self.mu, axis=0)
+ for i in range(self.max_iter):
+ gamma, pi, means, complete_loglikeli = self.update(X) ## carry out one E-step followed by an M-step
+ #means = jnp.concat(self.mu, axis=0)
+ dom = jnp.linalg.norm(means - means_prev) ## norm of difference-of-means
+ if verbose:
+ print(f"{i}: Mean-diff = {dom} log(p(X)) = {complete_loglikeli} nats")
+ #print(jnp.linalg.norm(means - means_prev))
+ if tol >= 0. and dom < tol:
+ print(f"Converged after {i + 1} iterations.")
+ break
+ means_prev = means
+
+ def update(self, X):
+ """
+ Performs a single iterative update (E-step followed by M-step) of parameters (assuming model initialized)
+
+ Args:
+ X: the dataset / design matrix to fit this GMM to
+ """
+ gamma, _, complete_likeli = self._E_step(X) ## carry out E-step
+ pi, means = self._M_step(X, gamma) ## carry out M-step
+ return gamma, pi, means, complete_likeli
+
+ def sample(self, n_samples, mode_j=-1):
+ """
+ Draw samples from the current underlying GMM model
+
+ Args:
+ n_samples: the number of samples to draw from this GMM
+
+ mode_j: if >= 0, will only draw samples from a specific component of this GMM
+ (Default = -1), ignoring the Categorical prior over latent variables/components
+
+ Returns:
+ Design matrix of samples drawn under the distribution defined by this GMM
+ """
+ self.key, *skey = random.split(self.key, 3)
+ if mode_j >= 0: ## sample from a particular mode
+ mu_j = self.mu[mode_j] ## directly select a specific component
+ Sigma_j = self.Sigma[mode_j]
+ Xs = _sample_component(
+ skey[0], n_samples=n_samples, mu=mu_j, Sigma=Sigma_j, assume_diag_cov=self.assume_diag_cov
+ )
+ else: ## sample from full mixture distribution
+ ## sample (prior) components/latents
+ lats = _sample_prior_weights(skey[0], n_samples=n_samples, pi=self.pi)
+ ## then sample chosen component Gaussian(s)
+ Xs = []
+ for j in range(self.K):
+ freq_j = int(jnp.sum((lats == j))) ## compute frequency over mode
+ self.key, *skey = random.split(self.key, 3)
+ x_s = _sample_component( ## now physically sample component
+ skey[0], n_samples=freq_j, mu=self.mu[j], Sigma=self.Sigma[j], assume_diag_cov=self.assume_diag_cov
+ )
+ Xs.append(x_s)
+ Xs = jnp.concat(Xs, axis=0)
+ return Xs
+
diff --git a/ngclearn/utils/density/gmm.py b/ngclearn/utils/density/gmm.py
deleted file mode 100644
index 6d2ed813..00000000
--- a/ngclearn/utils/density/gmm.py
+++ /dev/null
@@ -1,82 +0,0 @@
-from jax import numpy as jnp, random, jit
-from functools import partial
-import time, sys
-import numpy as np
-#from sklearn import mixture
-#from sklearn.cluster import KMeans
-from scipy.stats import multivariate_normal
-#from ngclearn.utils.stat_utils import calc_log_gauss_pdf
-from ngclearn.utils.model_utils import softmax
-#from kmeans import K_Means
-from sklearn import mixture
-
-#seed = 69
-#tf.random.set_seed(seed=seed)
-
-class GMM:
- """
- Implements a Gaussian mixture model (GMM) -- or mixture of Gaussians, MoG.
- Adaptation of parameters is conducted via the Expectation-Maximization (EM)
- learning algorithm and leverages full covariance matrices in the component
- multivariate Gaussians.
-
- Note this is a (JAX) wrapper model that houses the sklearn implementation for learning.
- The sampling process has been rewritten to utilize GPU matrix computation.
-
- Args:
- k: the number of components/latent variables within this GMM
-
- max_iter: the maximum number of EM iterations to fit parameters to data
- (Default = 5)
-
- assume_diag_cov: if True, assumes a diagonal covariance for each component
- (Default = False)
-
- init_kmeans: if True, first learn use the K-Means algorithm to initialize
- the component Gaussians of this GMM (Default = True)
- """
- def __init__(self, k, max_iter=5, assume_diag_cov=False, init_kmeans=True):
- self.use_sklearn = True
- self.k = k
- self.max_iter = int(max_iter)
- self.assume_diag_cov = assume_diag_cov
- self.init_kmeans = init_kmeans
- self.mu = []
- self.sigma = []
- self.prec = []
- self.weights = None
- self.z_weights = None # variables for parameterizing weights for SGD
-
- def fit(self, data):
- """
- Run full fitting process of this GMM.
-
- Args:
- data: the dataset to fit this GMM to
- """
- pass
-
- def update(self, X):
- """
- Performs a single iterative update of parameters (assuming model initialized)
-
- Args:
- X: the dataset / design matrix to fit this GMM to
- """
- pass
-
- def sample(self, n_s, mode_i=-1, samples_modes_evenly=False):
- """
- (Efficiently) Draw samples from the current underlying GMM model
-
- Args:
- n_s: the number of samples to draw from this GMM
-
- mode_i: if >= 0, will only draw samples from a specific component of this GMM
- (Default = -1), ignoring the Categorical prior over latent variables/components
-
- samples_modes_evenly: if True, will ignore the Categorical prior over latent
- variables/components and draw an approximately equal number of samples from
- each component
- """
- pass
diff --git a/ngclearn/utils/density/mixture.py b/ngclearn/utils/density/mixture.py
new file mode 100644
index 00000000..107df7c7
--- /dev/null
+++ b/ngclearn/utils/density/mixture.py
@@ -0,0 +1,33 @@
+
+
+class Mixture: ## General mixture structure
+ """
+ Implements a general mixture model template/structure. Effectively, this is the parent
+ class/template for mixtures of distributions.
+
+ Args:
+ K: the number of components/latent variables within this mixture model
+
+ max_iter: the maximum number of iterations to fit parameters to data (Default = 50)
+
+ """
+
+ def __init__(self, K, max_iter=50, **kwargs):
+ self.K = K
+ self.max_iter = max_iter
+
+ def init(self, X): ## model data-dependent initialization function
+ pass
+
+ def calc_log_likelihood(self, X): ## log-likelihood calculation routine
+ pass
+
+ def fit(self, X, tol=1e-3, verbose=False): ## outer fitting process
+ pass
+
+ def update(self, X): ## inner/iterative adjustment/update step
+ pass
+
+ def sample(self, n_samples, mode_j=-1): ## model sampling routine
+ pass
+
diff --git a/ngclearn/utils/diffeq/ode_utils.py b/ngclearn/utils/diffeq/ode_utils.py
index 30ddb2d4..55a70ace 100755
--- a/ngclearn/utils/diffeq/ode_utils.py
+++ b/ngclearn/utils/diffeq/ode_utils.py
@@ -1,12 +1,13 @@
"""
Routines and co-routines for ngc-learn's differential equation integration backend.
-Currently supported back-end forms of integration in ngc-learn include:
-0) Euler integration (RK-1);
-1) Midpoint method (RK-2);
-2) Heun's method (error-corrector RK-2);
-3) Ralston's method (error-corrector RK-2);
-4) 4th-order Runge-Kutta method (RK-4);
+| Currently supported back-end forms of integration in ngc-learn include:
+| 0) Euler integration (RK-1);
+| 1) Midpoint method (RK-2);
+| 2) Heun's method (error-corrector RK-2);
+| 3) Ralston's method (error-corrector RK-2);
+| 4) 4th-order Runge-Kutta method (RK-4);
+
"""
from jax import numpy as jnp, random, jit #, nn
diff --git a/ngclearn/utils/diffeq/odes.py b/ngclearn/utils/diffeq/odes.py
index b37e2408..733f1082 100644
--- a/ngclearn/utils/diffeq/odes.py
+++ b/ngclearn/utils/diffeq/odes.py
@@ -1,3 +1,16 @@
+"""
+In-built dynamical systems built on differential equations. Note that these systems are designed such that they
+directly operzte with ngc-learn's ODE integration backend.
+
+| Currently in-built dynamical systems include:
+| 0) A continuous linear 2D system;
+| 1) A continuous cubic 2D system;
+| 2) A Lorenz attractor system;
+| 3) A continuous linear 3D system;
+| 4) A continuous oscillator system.
+
+"""
+
import jax.numpy as jnp
def linear_2D(t, x, params):
diff --git a/ngclearn/utils/distribution_generator.py b/ngclearn/utils/distribution_generator.py
new file mode 100644
index 00000000..0af4d6c6
--- /dev/null
+++ b/ngclearn/utils/distribution_generator.py
@@ -0,0 +1,383 @@
+import time
+from typing import TypedDict, List, Protocol, Sequence
+from typing_extensions import Unpack
+import jax
+import numpy
+
+from ngcsimlib.logger import error
+
+
+class DistributionParams(TypedDict, total=False):
+ """
+ Extra parameters to be used when generating distributions. (Attributes listed below)
+
+ Args:
+ amin: sets the lower bound of the distribution
+
+ amax: sets the upper bound of the distribution
+
+ lower_triangle: keeps the lower triangle, sets the rest to zero
+
+ upper_triangle: keeps the upper triangle, sets the rest to zero
+
+ hollow: produces a hollow distribution (zeros along the diagonal)
+
+ eye: produces an eye distribution (zeros the off-diagonal)
+
+ col_mask: single value, keeps n random columns; list values, keeps the provided column indices
+
+ row_mask: single value, keeps n random rows; list values, keeps the provided row indices
+
+ use_numpy: use default numpy
+
+ """
+ amin: float
+ amax: float
+ lower_triangle: bool
+ upper_triangle: bool
+ hollow: bool
+ eye: bool
+ col_mask: int | List[int]
+ row_mask: int | List[int]
+ use_numpy: bool
+ dtype: numpy.dtype
+
+
+class DistributionInitializer(Protocol):
+ def __call__(self, shape: Sequence[int], dkey: jax.dtypes.prng_key | int | None = None) -> jax.Array: ...
+
+
+class DistributionGenerator(object):
+ @staticmethod
+ def constant(value: float, **params: Unpack[DistributionParams]) -> DistributionInitializer:
+ """
+ Produces a distribution initializer for a constant distribution.
+
+ Args:
+ value: the constant value to fill the array with
+ **params: the extra distribution parameters
+
+ Returns:
+ a distribution initializer
+ """
+ using_np = params.get("use_numpy", False)
+ if using_np:
+ def constant_generator(shape: Sequence[int], seed: int | None = None) -> numpy.ndarray:
+ matrix = numpy.ones(shape, params.get("dtype", numpy.float32)) * value
+ matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
+ return matrix
+ else:
+ def constant_generator(shape: Sequence[int], dKey: jax.dtypes.prng_key | None = None) -> jax.Array:
+ matrix = jax.numpy.ones(shape, params.get("dtype", jax.numpy.float32)) * value
+ matrix = DistributionGenerator._process_params_jax(matrix, params, dKey)
+ return matrix
+ return constant_generator
+
+ @staticmethod
+ def uniform(low: float = 0.0, high: float = 1.0, **params: Unpack[DistributionParams]) -> DistributionInitializer:
+ """
+ Produces a distribution initializer for a uniform distribution.
+
+ Args:
+ low: lower bound of the uniform distribution (inclusive)
+ high: upper bound of the uniform distribution (exclusive)
+ **params: the extra distribution parameters
+
+ Returns:
+ a distribution initializer
+ """
+ using_np = params.get("use_numpy", False)
+
+ if using_np:
+ def uniform_generator(shape: Sequence[int], seed: int | None = None) -> numpy.ndarray:
+ rng = numpy.random.default_rng(seed)
+ matrix = rng.uniform(low=low, high=high, size=shape).astype(
+ params.get("dtype", numpy.float32))
+ matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
+ return matrix
+ else:
+ def uniform_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> jax.Array:
+ if dKey is None:
+ dKey = jax.random.PRNGKey(time.time_ns())
+ dKey, subKey = jax.random.split(dKey, 2)
+
+ matrix = jax.random.uniform(
+ dKey,
+ shape=shape,
+ minval=low,
+ maxval=high,
+ dtype=params.get("dtype", jax.numpy.float32)
+ )
+ matrix = DistributionGenerator._process_params_jax(matrix, params, subKey)
+ return matrix
+
+ return uniform_generator
+
+ @staticmethod
+ def gaussian(mean: float = 0.0, std: float = 1.0, **params: Unpack[DistributionParams]) -> DistributionInitializer:
+ """
+ Produces a distribution initializer for a Gaussian (normal) distribution.
+
+ Args:
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ **params: the extra distribution parameters
+
+ Returns:
+ a distribution initializer
+ """
+ using_numpy = params.get("use_numpy", False)
+
+ if using_numpy:
+ def gaussian_generator(shape: Sequence[int], seed: int | None = None) -> numpy.ndarray:
+ rng = numpy.random.default_rng(seed)
+ matrix = rng.normal(loc=mean, scale=std, size=shape).astype(
+ params.get("dtype", numpy.float32))
+ matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
+ return matrix
+ else:
+ def gaussian_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> jax.Array:
+ if dKey is None:
+ dKey = jax.random.PRNGKey(time.time_ns())
+ dKey, subKey = jax.random.split(dKey, 2)
+ matrix = jax.random.normal(
+ dKey,
+ shape=shape,
+ dtype=params.get("dtype", jax.numpy.float32)
+ )
+ matrix = mean + std * matrix
+ matrix = DistributionGenerator._process_params_jax(matrix, params, subKey)
+ return matrix
+
+ return gaussian_generator
+
+ @staticmethod
+ def fan_in_uniform(**params: Unpack[DistributionParams]) -> DistributionInitializer:
+ """
+ Produces a distribution initializer using a fan-in uniform strategy.
+ The values are sampled from a uniform distribution in the range [-limit, limit],
+ where limit = sqrt(1 / fan_in), and fan_in is inferred from the shape.
+
+ | Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural
+ | networks." Proceedings of the thirteenth international conference on artificial intelligence and statistics.
+ | JMLR Workshop and Conference Proceedings, 2010.
+
+ Args:
+ **params: extra distribution parameters
+
+ Returns:
+ a distribution initializer
+ """
+ using_numpy = params.get("use_numpy", False)
+
+ def compute_limit(fan_in: int) -> float:
+ return float(numpy.sqrt(1.0 / fan_in))
+
+ if using_numpy:
+ def fan_in_uniform_generator(shape: Sequence[int], seed: int | None = None) -> numpy.ndarray:
+ if len(shape) < 2:
+ error("fan_in_uniform requires shape with at least 2 dimensions")
+ fan_in = shape[1]
+ limit = compute_limit(fan_in)
+
+ rng = numpy.random.default_rng(seed)
+ matrix = rng.uniform(low=-limit, high=limit, size=shape).astype(
+ params.get("dtype", numpy.float32))
+ matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
+ return matrix
+ else:
+ def fan_in_uniform_generator(shape: Sequence[int], dKey: jax.Array | None = None) -> jax.Array:
+ if len(shape) < 2:
+ error("fan_in_uniform requires shape with at least 2 dimensions")
+ fan_in = shape[1]
+ limit = compute_limit(fan_in)
+
+ if dKey is None:
+ dKey = jax.random.PRNGKey(time.time_ns())
+ dKey, subKey = jax.random.split(dKey, 2)
+
+ matrix = jax.random.uniform(
+ dKey,
+ shape=shape,
+ minval=-limit,
+ maxval=limit,
+ dtype=params.get("dtype", jax.numpy.float32)
+ )
+ matrix = DistributionGenerator._process_params_jax(matrix, params, subKey)
+ return matrix
+
+ return fan_in_uniform_generator
+
+ @staticmethod
+ def fan_in_gaussian(**params: Unpack[DistributionParams]) -> DistributionInitializer:
+ """
+ Produces a distribution initializer using a fan-in Gaussian (normal) strategy.
+ The values are sampled from a normal distribution with mean 0 and stddev = sqrt(1 / fan_in),
+ where fan_in is inferred from the shape.
+
+ | He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet
+ | classification." Proceedings of the IEEE international conference on computer vision. 2015.
+
+ Args:
+ **params: extra distribution parameters
+
+ Returns:
+ a distribution initializer
+ """
+ using_numpy = params.get("use_numpy", False)
+
+ def compute_std(fan_in: int) -> float:
+ return float(numpy.sqrt(1.0 / fan_in))
+
+ if using_numpy:
+ def fan_in_gaussian_generator(shape: Sequence[int], seed: int | None) -> numpy.ndarray:
+ if len(shape) < 2:
+ error("fan_in_gaussian requires shape with at least 2 dimensions")
+ fan_in = shape[0]
+ std = compute_std(fan_in)
+
+ rng = numpy.random.default_rng(seed)
+ matrix = rng.normal(loc=0.0, scale=std, size=shape).astype(
+ params.get("dtype", numpy.float32))
+ matrix = DistributionGenerator._process_params_numpy(matrix, params, seed)
+ return matrix
+ else:
+ def fan_in_gaussian_generator(shape: Sequence[int], dKey: jax.Array | None) -> jax.Array:
+ if len(shape) < 2:
+ error("fan_in_gaussian requires shape with at least 2 dimensions")
+ fan_in = shape[0]
+ std = compute_std(fan_in)
+
+ if dKey is None:
+ dKey = jax.random.PRNGKey(time.time_ns())
+ dKey, subKey = jax.random.split(dKey, 2)
+
+ matrix = jax.random.normal(
+ dKey,
+ shape=shape,
+ dtype=params.get("dtype", jax.numpy.float32)
+ )
+ matrix = matrix * std
+ matrix = DistributionGenerator._process_params_jax(matrix, params, subKey)
+ return matrix
+
+ return fan_in_gaussian_generator
+
+ @staticmethod
+ def _process_params_jax(ary: jax.Array, params: DistributionParams, dKey: jax.dtypes.prng_key | None) -> jax.Array:
+ if dKey is None:
+ dKey = jax.random.PRNGKey(time.time_ns())
+
+ amin = params.get("amin", None)
+ if amin is not None:
+ ary = jax.numpy.maximum(ary, amin)
+
+ amax = params.get("amax", None)
+ if amax is not None:
+ ary = jax.numpy.minimum(ary, amax)
+
+ lower_triangle = params.get("lower_triangle", False)
+ upper_triangle = params.get("upper_triangle", False)
+ if lower_triangle and upper_triangle:
+ error("lower_triangle and upper_triangle are mutually exclusive when initializing a distribution")
+
+ if lower_triangle:
+ ary = jax.numpy.tril(ary)
+ if upper_triangle:
+ ary = jax.numpy.triu(ary)
+
+ if params.get("hollow", False):
+ ary = (1.0 - jax.numpy.eye(*ary.shape)) * ary
+
+ if params.get("eye", False):
+ ary = jax.numpy.eye(*ary.shape) * ary
+
+ col_mask = params.get("col_mask", None)
+ if col_mask is not None:
+ if isinstance(col_mask, int):
+ dKey, subKey = jax.random.split(dKey, 2)
+ keep_indices = jax.random.choice(subKey, ary.shape[1], shape=(col_mask,), replace=False)
+ mask = jax.numpy.zeros(ary.shape[1], dtype=bool).at[
+ keep_indices].set(True)
+ mask = jax.numpy.broadcast_to(mask, ary.shape)
+ ary = jax.numpy.where(mask, ary, 0)
+ elif isinstance(col_mask, Sequence):
+ mask = jax.numpy.zeros(ary.shape[1], dtype=bool).at[
+ col_mask].set(True)
+ mask = jax.numpy.broadcast_to(mask, ary.shape)
+ ary = jax.numpy.where(mask, ary, 0)
+
+ row_mask = params.get("row_mask", None)
+ if row_mask is not None:
+ if isinstance(row_mask, int):
+ dKey, subKey = jax.random.split(dKey, 2)
+ keep_indices = jax.random.choice(subKey, ary.shape[0], shape=(row_mask,), replace=False)
+ mask = jax.numpy.zeros(ary.shape[0], dtype=bool).at[
+ keep_indices].set(True)
+ mask = jax.numpy.broadcast_to(mask, ary.shape)
+ ary = jax.numpy.where(mask, ary, 0)
+ elif isinstance(row_mask, Sequence):
+ mask = jax.numpy.zeros(ary.shape[0], dtype=bool).at[
+ row_mask].set(True)
+ mask = jax.numpy.broadcast_to(mask, ary.shape)
+ ary = jax.numpy.where(mask, ary, 0)
+
+ return ary.astype(params.get("dtype", jax.numpy.float32))
+
+ @staticmethod
+ def _process_params_numpy(ary: numpy.ndarray, params: DistributionParams, seed: int | None) -> numpy.ndarray:
+ amin = params.get("amin", None)
+ if amin is not None:
+ ary = numpy.maximum(ary, amin)
+
+ amax = params.get("amax", None)
+ if amax is not None:
+ ary = numpy.minimum(ary, amax)
+
+ lower_triangle = params.get("lower_triangle", False)
+ upper_triangle = params.get("upper_triangle", False)
+ if lower_triangle and upper_triangle:
+ error("lower_triangle and upper_triangle are mutually exclusive when initializing a distribution")
+
+ if lower_triangle:
+ ary = numpy.tril(ary)
+ if upper_triangle:
+ ary = numpy.triu(ary)
+
+ if params.get("hollow", False):
+ ary = (1.0 - numpy.eye(*ary.shape)) * ary
+
+ if params.get("eye", False):
+ ary = numpy.eye(*ary.shape) * ary
+
+ col_mask = params.get("col_mask", None)
+ if col_mask is not None:
+ if isinstance(col_mask, int):
+ rng = numpy.random.default_rng(seed)
+ keep_indices = rng.choice(ary.shape[1], size=col_mask, replace=False)
+ mask = numpy.zeros(ary.shape[1], dtype=bool)
+ mask[keep_indices] = True
+ mask = numpy.broadcast_to(mask, ary.shape)
+ ary = numpy.where(mask, ary, 0)
+ elif isinstance(col_mask, Sequence):
+ mask = numpy.zeros(ary.shape[1], dtype=bool)
+ mask[list(col_mask)] = True
+ mask = numpy.broadcast_to(mask, ary.shape)
+ ary = numpy.where(mask, ary, 0)
+
+ row_mask = params.get("row_mask", None)
+ if row_mask is not None:
+ if isinstance(row_mask, int):
+ rng = numpy.random.default_rng(seed)
+ keep_indices = rng.choice(ary.shape[0], size=row_mask, replace=False)
+ mask = numpy.zeros(ary.shape[0], dtype=bool)
+ mask[keep_indices] = True
+ mask = numpy.broadcast_to(mask, ary.shape)
+ ary = numpy.where(mask, ary, 0)
+ elif isinstance(row_mask, Sequence):
+ mask = numpy.zeros(ary.shape[0], dtype=bool)
+ mask[list(row_mask)] = True
+ mask = numpy.broadcast_to(mask, ary.shape)
+ ary = numpy.where(mask, ary, 0)
+
+ return ary
diff --git a/ngclearn/utils/feature_dictionaries/__init__.py b/ngclearn/utils/feature_dictionaries/__init__.py
new file mode 100644
index 00000000..ad4ea443
--- /dev/null
+++ b/ngclearn/utils/feature_dictionaries/__init__.py
@@ -0,0 +1 @@
+from .polynomialLibrary import PolynomialLibrary
diff --git a/ngclearn/utils/feature_dictionaries/polynomialLibrary.py b/ngclearn/utils/feature_dictionaries/polynomialLibrary.py
index f50686ef..42e13780 100644
--- a/ngclearn/utils/feature_dictionaries/polynomialLibrary.py
+++ b/ngclearn/utils/feature_dictionaries/polynomialLibrary.py
@@ -1,19 +1,16 @@
-import jax.numpy as jnp
-from jax import jit, random
-import jax.numpy as jnp
+from jax import jit, random, numpy as jnp
from typing import List, Tuple, Union
from dataclasses import dataclass
-
-
@dataclass
class PolynomialLibrary:
"""
A class for creating polynomial feature libraries in 1D, 2D, or 3D.
- Attributes:
- poly_order (int): Maximum order of polynomial terms
- include_bias (bool): Whether to include the bias term in the output
+ Args:
+ poly_order (int): Maximum order of polynomial terms (Attribute)
+
+ include_bias (bool): Whether to include the bias term in the output (Attribute)
"""
poly_order: int = None
@@ -65,6 +62,15 @@ def _create_library(self, *arrays: jnp.ndarray) -> Tuple[jnp.ndarray, List[str]]
def fit(self, X: List[jnp.ndarray]) -> Tuple[jnp.ndarray, List[str]]:
+ """
+ Fits this library to a design matrix X
+
+ Args:
+ X: the design matrix to fit this library to
+
+ Returns:
+ the data-fit/retro-fit library
+ """
if not 1 <= len(X) <=3:
raise ValueError("Input must be 1D, 2D, or 3D; e.g. len(X) >= 1 ")
@@ -72,7 +78,6 @@ def fit(self, X: List[jnp.ndarray]) -> Tuple[jnp.ndarray, List[str]]:
arrays = [jnp.array(x).reshape(-1, 1) for x in X]
lib, names = self._create_library(*arrays)
-
start_idx = 1 if not self.include_bias else 0
return lib[:, start_idx+1:], names[start_idx:]
diff --git a/ngclearn/utils/io_utils.py b/ngclearn/utils/io_utils.py
index 7422ebe0..8553af44 100755
--- a/ngclearn/utils/io_utils.py
+++ b/ngclearn/utils/io_utils.py
@@ -1,9 +1,10 @@
"""
File and OS input/output (reading/writing) utilities.
"""
-import jax
-from jax import numpy as jnp, grad, jit, vmap, random, lax
+# import jax
+# from jax import numpy as jnp, grad, jit, vmap, random, lax
import os, sys, pickle
+from typing import Any
def serialize(fname, object): ## object "saving" routine
"""
@@ -65,3 +66,15 @@ def makedirs(directories):
"""
for dir in directories:
makedir(dir)
+
+
+def save_pkl(directory: str, name: str, value: Any) -> None:
+ file_name = directory + "/" + name + ".pkl"
+ with open(file_name, 'wb') as f:
+ pickle.dump(value, f)
+
+def load_pkl(directory: str, name: str) -> Any:
+ file_name = directory + "/" + name + ".pkl"
+ with open(file_name, 'rb') as f:
+ data = pickle.load(f)
+ return data
diff --git a/ngclearn/utils/jaxProcess.py b/ngclearn/utils/jaxProcess.py
deleted file mode 100644
index dd1dabc3..00000000
--- a/ngclearn/utils/jaxProcess.py
+++ /dev/null
@@ -1,171 +0,0 @@
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.compilers.process import Process
-from jax.lax import scan as _scan
-from ngcsimlib.logger import warn
-from jax import numpy as jnp
-
-
-class JaxProcess(Process):
- """
- The JaxProcess is a subclass of the ngcsimlib Process class. The
- functionality added by this subclass is the use of the jax scanner to run a
- process quickly through the use of jax's JIT compiler.
- """
-
- def __init__(self, name):
- super().__init__(name)
- self._process_scan_method = None
- self._monitoring = []
-
- def _make_scanner(self):
- arg_order = self.get_required_args()
-
- def _pure(current_state, x):
- v = self.pure(current_state,
- **{key: value for key, value in zip(arg_order, x)})
- return v, [v[m] for m in self._monitoring]
-
- return _pure
-
- def watch(self, compartment):
- """
- Adds a compartment to the process to watch during a scan
-
- Args:
- compartment: the compartment to watch
- """
- if not isinstance(compartment, Compartment):
- warn(
- "Jax Process trying to watch a value that is not a compartment")
-
- self._monitoring.append(compartment.path)
- self._process_scan_method = self._make_scanner()
-
- def clear_watch_list(self):
- """
- Clears the watch list so no values are watched
- """
- self._monitoring = []
- self._process_scan_method = self._make_scanner()
-
- def transition(self, transition_call):
- """
- Appends to the base transition call to create pure method for use by its
- scanner
-
- Args:
- transition_call: the transition being passed into the default process
-
- Returns:
- this JaxProcess instance for chaining
- """
- super().transition(transition_call)
- self._process_scan_method = self._make_scanner()
- return self
-
- def scan(self, save_state=True, scan_length=None, **kwargs):
- """
- There a quite a few ways to initialize the scan method for the
- jaxProcess. To start the straight forward arguments is "save_state".
- The save_state flag is simply there to note if the state
- of the model should reflect the final state of the model after the scan
- is complete.
-
- This scan method can also watch and report intermediate compartment
- values defined through calling the JaxProcess.watch() method watching a
- compartment means at the end of each process cycle record the value of
- the compartment and then at the end a tuple of concatenated values will
- be returned that correspond to each compartment the process is watching.
-
- Where there are options for the arguments is when defining the keyword
- arguments for the process. The process will do its best to broadcast all
- the inputs to the largest size, so they can be scanned over. This means
- that is one is a (2, 3) and the other is a constant, it will broadcast
- constant to a (2, 3). This does mean that every keyword value that is
- passed to a method in the process will be the same size. This is a
- limitation of the jax scanner as all the values have to be concatenated
- into a single jax array to be passed into the scanner. The accepted
- types for arguments, are lists, tuples, numpy arrays, jax arrays, ints,
- and floats. If all the keyword arguments are passed as ints or floats
- the scan_length flag must be set so the scanner knows how many
- iterations to run. If any of the arguments are iterable it will
- automatically assume that the leading axis is the number of iterations
- to run.
-
-
- Args:
- save_state: A boolean flag to indicate if the model state should be saved
-
- scan_length: a value to be used to denote the number of iterations of the scanner if all keyword
- arguments are passed as ints or floats
-
- **kwargs: the required keyword arguments for the process to run
-
- Returns: the final state of the model, the stacked output of the scan method
-
- """
- arg_order = list(self.get_required_args())
-
- args = []
- max_axis = 1
- max_next_axis = 0
-
- for kwarg in arg_order:
- if kwarg not in kwargs.keys():
- warn("Missing kwarg in Process", self.name)
- return
-
- kval = kwargs.get(kwarg, None)
- if isinstance(kval, (float, int, list, tuple)):
- val = jnp.array(kval)
- else:
- val = kval
-
- max_axis = max(max_axis, len(val.shape))
- if max_axis == len(val.shape):
- max_next_axis = max(max_next_axis, val.shape[0])
- args.append(val)
-
- # Check axis && get max_next_axis
-
- if max_next_axis == 0:
- if scan_length is None:
- warn("scan_length must be defined if all keyword arguments are "
- "constants")
- return
- elif scan_length > 0:
- max_next_axis = scan_length
- else:
- warn("scan_length must be greater than 0")
- return
-
- for axis in range(max_axis):
- current_axis = max_next_axis
- max_next_axis = 0
- new_args = []
- for a in args:
- if len(a.shape) >= axis + 1:
- if a.shape[axis] == current_axis:
- new_args.append(a)
- else:
- warn("Keyword arguments must all be able to be "
- "broadcasted to the largest shape")
- return
- else:
- new_args.append(jnp.zeros(list(a.shape) + [current_axis],
- dtype=a.dtype) + a.reshape(
- *a.shape, 1))
-
- if len(a.shape) > axis + 1:
- max_next_axis = max(max_next_axis, a.shape[axis + 1])
-
- args = new_args
-
- args = jnp.array(args).transpose(
- [1, 0] + [i for i in range(2, max_axis + 1)])
- state, stacked = _scan(self._process_scan_method,
- init=self.get_required_state(
- include_special_compartments=True), xs=args)
- if save_state:
- self.updated_modified_state(state)
- return state, stacked
diff --git a/ngclearn/utils/masks/__init__.py b/ngclearn/utils/masks/__init__.py
new file mode 100644
index 00000000..dd79c3d8
--- /dev/null
+++ b/ngclearn/utils/masks/__init__.py
@@ -0,0 +1 @@
+from .multiblock2d import MaskCollator
diff --git a/ngclearn/utils/masks/multiblock2d.py b/ngclearn/utils/masks/multiblock2d.py
index 3588d684..e146b534 100644
--- a/ngclearn/utils/masks/multiblock2d.py
+++ b/ngclearn/utils/masks/multiblock2d.py
@@ -1,163 +1,159 @@
# %%
-
-# Adapted from meta jepa
-
import math
import numpy as np
from multiprocessing import Value
-class MaskCollator(object):
+class MaskCollator(object): # Adapted from the Meta JEPA code-base to ngc-learn compliance
+ """
+ A mechanism for generating/creating patch masks, generally for self-supervised learning.
- def __init__(
- self,
- cfgs_mask,
- crop_size=(224, 224),
- patch_size=(16, 16),
- ):
- super(MaskCollator, self).__init__()
+ Args:
+ cfgs_mask: configuration masks to apply
- self.mask_generators = []
- for m in cfgs_mask:
- mask_generator = _MaskGenerator(
- crop_size=crop_size,
- patch_size=patch_size,
- pred_mask_scale=m.get('spatial_scale'),
- aspect_ratio=m.get('aspect_ratio'),
- npred=m.get('num_blocks'),
- max_keep=m.get('max_keep', None),
- )
- self.mask_generators.append(mask_generator)
+ crop_size: dimensions of crop
- def step(self):
- for mask_generator in self.mask_generators:
- mask_generator.step()
+ patch_size: dimensions of patches to create
+ """
- def __call__(self, batch):
+ def __init__(self, cfgs_mask, crop_size=(224, 224), patch_size=(16, 16),):
+ super(MaskCollator, self).__init__()
- batch_size = len(batch)
+ self.mask_generators = []
+ for m in cfgs_mask:
+ mask_generator = _MaskGenerator(
+ crop_size=crop_size,
+ patch_size=patch_size,
+ pred_mask_scale=m.get('spatial_scale'),
+ aspect_ratio=m.get('aspect_ratio'),
+ npred=m.get('num_blocks'),
+ max_keep=m.get('max_keep', None),
+ )
+ self.mask_generators.append(mask_generator)
- collated_masks_pred, collated_masks_enc = [], []
- for i, mask_generator in enumerate(self.mask_generators):
- masks_enc, masks_pred = mask_generator(batch_size)
- collated_masks_enc.append(masks_enc)
- collated_masks_pred.append(masks_pred)
+ def step(self):
+ """
+ Steps this generator forward one step.
- return collated_masks_enc, collated_masks_pred
+ Returns:
+ next set of collated encoder masks, next set of predictor masks
+ """
+ for mask_generator in self.mask_generators:
+ mask_generator.step()
+ def __call__(self, batch):
+ batch_size = len(batch)
+ collated_masks_pred, collated_masks_enc = [], []
+ for i, mask_generator in enumerate(self.mask_generators):
+ masks_enc, masks_pred = mask_generator(batch_size)
+ collated_masks_enc.append(masks_enc)
+ collated_masks_pred.append(masks_pred)
-class _MaskGenerator(object):
+ return collated_masks_enc, collated_masks_pred
- def __init__(
- self,
- crop_size=(224, 224),
- patch_size=(16, 16),
- pred_mask_scale=(0.2, 0.8),
- aspect_ratio=(0.3, 3.0),
- npred=1,
- max_keep=None,
- ):
- super(_MaskGenerator, self).__init__()
- if not isinstance(crop_size, tuple):
- crop_size = (crop_size, ) * 2
- self.crop_size = crop_size
- self.height, self.width = crop_size[0] // patch_size[0], crop_size[1] // patch_size[1]
-
- self.patch_size = patch_size
- self.aspect_ratio = aspect_ratio
- self.pred_mask_scale = pred_mask_scale
- self.npred = npred
- self.max_keep = max_keep
- self._itr_counter = Value('i', -1) # collator is shared across worker processes
-
- def step(self):
- i = self._itr_counter
- with i.get_lock():
- i.value = (i.value + 1) % 2**16
- v = i.value
- return v
-
- def _sample_block_size(
- self,
- rng: np.random.RandomState,
- scale,
- aspect_ratio_scale
- ):
- # -- Sample spatial block mask scale
- _rand = rng.random()
- min_s, max_s = scale
- spatial_mask_scale = min_s + _rand * (max_s - min_s)
- spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
-
- # -- Sample block aspect-ratio
- _rand = rng.random()
- min_ar, max_ar = aspect_ratio_scale
- aspect_ratio = min_ar + _rand * (max_ar - min_ar)
-
- # -- Compute block height and width (given scale and aspect-ratio)
- h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
- w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
- h = min(h, self.height)
- w = min(w, self.width)
-
- return (h, w)
-
- def _sample_block_mask(self, b_size, rng: np.random.RandomState):
- h, w = b_size
- top = rng.randint(0, self.height - h + 1)
- left = rng.randint(0, self.width - w + 1)
-
- mask = np.ones((self.height, self.width), dtype=np.int32)
- mask[top:top+h, left:left+w] = 0
-
- return mask
-
- def __call__(self, batch_size):
- """
- Create encoder and predictor masks when collating imgs into a batch
- # 1. sample pred block size using seed
- # 2. sample several pred block locations for each image (w/o seed)
- # 3. return pred masks and complement (enc mask)
- """
- seed = self.step()
- rng = np.random.RandomState(seed)
- p_size = self._sample_block_size(
- rng=rng,
- scale=self.pred_mask_scale,
- aspect_ratio_scale=self.aspect_ratio,
- )
-
- collated_masks_pred, collated_masks_enc = [], []
- min_keep_enc = min_keep_pred = self.height * self.width
- for _ in range(batch_size):
-
- empty_context = True
- while empty_context:
- # Create a mask for this sample
- mask_e = np.ones((self.height, self.width), dtype=np.int32)
- for _ in range(self.npred):
- mask_e *= self._sample_block_mask(p_size, rng)
- mask_e = mask_e.flatten()
- mask_p = np.where(mask_e == 0)[0]
- mask_e = np.where(mask_e != 0)[0]
+class _MaskGenerator(object):
- empty_context = len(mask_e) == 0
- if not empty_context:
- min_keep_pred = min(min_keep_pred, len(mask_p))
- min_keep_enc = min(min_keep_enc, len(mask_e))
- collated_masks_pred.append(mask_p)
- collated_masks_enc.append(mask_e)
+ def __init__(
+ self,crop_size=(224, 224), patch_size=(16, 16), pred_mask_scale=(0.2, 0.8), aspect_ratio=(0.3, 3.0),
+ npred=1,max_keep=None
+ ):
+ super(_MaskGenerator, self).__init__()
+ if not isinstance(crop_size, tuple):
+ crop_size = (crop_size, ) * 2
+ self.crop_size = crop_size
+ self.height, self.width = crop_size[0] // patch_size[0], crop_size[1] // patch_size[1]
+
+ self.patch_size = patch_size
+ self.aspect_ratio = aspect_ratio
+ self.pred_mask_scale = pred_mask_scale
+ self.npred = npred
+ self.max_keep = max_keep
+ self._itr_counter = Value('i', -1) # collator is shared across worker processes
+
+ def step(self):
+ i = self._itr_counter
+ with i.get_lock():
+ i.value = (i.value + 1) % 2**16
+ v = i.value
+ return v
+
+ def _sample_block_size(self,rng: np.random.RandomState,scale, aspect_ratio_scale):
+ # -- Sample spatial block mask scale
+ _rand = rng.random()
+ min_s, max_s = scale
+ spatial_mask_scale = min_s + _rand * (max_s - min_s)
+ spatial_num_keep = int(self.height * self.width * spatial_mask_scale)
+
+ # -- Sample block aspect-ratio
+ _rand = rng.random()
+ min_ar, max_ar = aspect_ratio_scale
+ aspect_ratio = min_ar + _rand * (max_ar - min_ar)
+
+ # -- Compute block height and width (given scale and aspect-ratio)
+ h = int(round(math.sqrt(spatial_num_keep * aspect_ratio)))
+ w = int(round(math.sqrt(spatial_num_keep / aspect_ratio)))
+ h = min(h, self.height)
+ w = min(w, self.width)
+
+ return (h, w)
+
+ def _sample_block_mask(self, b_size, rng: np.random.RandomState):
+ h, w = b_size
+ top = rng.randint(0, self.height - h + 1)
+ left = rng.randint(0, self.width - w + 1)
+
+ mask = np.ones((self.height, self.width), dtype=np.int32)
+ mask[top:top+h, left:left+w] = 0
+
+ return mask
+
+ def __call__(self, batch_size):
+ """
+ Create encoder and predictor masks when collating imgs into a batch:
+
+ | # 1. sample pred block size using seed
+ | # 2. sample several pred block locations for each image (w/o seed)
+ | # 3. return pred masks and complement (enc mask)
+
+ Args:
+ batch_size: number of samples to place w/in a generate batch
+
+ Returns:
+ collated encoder masks, collated predictor masks
+ """
+ seed = self.step()
+ rng = np.random.RandomState(seed)
+ p_size = self._sample_block_size(rng=rng, scale=self.pred_mask_scale, aspect_ratio_scale=self.aspect_ratio,)
+
+ collated_masks_pred, collated_masks_enc = [], []
+ min_keep_enc = min_keep_pred = self.height * self.width
+ for _ in range(batch_size):
+ empty_context = True
+ while empty_context:
+ # Create a mask for this sample
+ mask_e = np.ones((self.height, self.width), dtype=np.int32)
+ for _ in range(self.npred):
+ mask_e *= self._sample_block_mask(p_size, rng)
+ mask_e = mask_e.flatten()
- if self.max_keep is not None:
- min_keep_enc = min(min_keep_enc, self.max_keep)
+ mask_p = np.where(mask_e == 0)[0]
+ mask_e = np.where(mask_e != 0)[0]
- # Truncate arrays to the minimum length to create uniform arrays
- collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
- collated_masks_pred = np.array(collated_masks_pred)
+ empty_context = len(mask_e) == 0
+ if not empty_context:
+ min_keep_pred = min(min_keep_pred, len(mask_p))
+ min_keep_enc = min(min_keep_enc, len(mask_e))
+ collated_masks_pred.append(mask_p)
+ collated_masks_enc.append(mask_e)
- collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
- collated_masks_enc = np.array(collated_masks_enc)
+ if self.max_keep is not None:
+ min_keep_enc = min(min_keep_enc, self.max_keep)
- return collated_masks_enc, collated_masks_pred
+ # Truncate arrays to the minimum length to create uniform arrays
+ collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred]
+ collated_masks_pred = np.array(collated_masks_pred)
+ collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc]
+ collated_masks_enc = np.array(collated_masks_enc)
+ return collated_masks_enc, collated_masks_pred
diff --git a/ngclearn/utils/metric_utils.py b/ngclearn/utils/metric_utils.py
index 0ab3a078..e5a61eb4 100755
--- a/ngclearn/utils/metric_utils.py
+++ b/ngclearn/utils/metric_utils.py
@@ -1,6 +1,6 @@
"""
-Metric and measurement routines and co-routines. These functions are useful
-for model-level/simulation analysis as well as experimental inspection and probing.
+Metric and measurement routines and co-routines. These functions are useful for model-level/simulation analysis as well
+as experimental inspection and probing (many of these are neuroscience-oriented measurement functions).
"""
from jax import numpy as jnp, jit
from functools import partial
@@ -26,7 +26,7 @@ def measure_fanoFactor(spikes, preserve_batch=False):
mu = jnp.mean(spikes, axis=0, keepdims=True)
sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True))
fano = sigSqr/mu
- if preserve_batch is False:
+ if not preserve_batch:
fano = jnp.mean(fano)
return fano
@@ -49,7 +49,7 @@ def measure_firingRate(spikes, preserve_batch=False):
counts = jnp.sum(spikes, axis=0, keepdims=True)
T = spikes.shape[0] * 1.
fireRates = counts/T
- if preserve_batch is False:
+ if not preserve_batch:
fireRates = jnp.mean(fireRates)
return fireRates
@@ -78,7 +78,7 @@ def measure_breadth_TC(spikes, preserve_batch=False):
sigSqr = jnp.square(jnp.std(spikes, axis=0, keepdims=True))
C = sigSqr/mu
BTC = 1./(1 + jnp.square(C))
- if preserve_batch is False:
+ if not preserve_batch:
BTC = jnp.mean(BTC)
return BTC
@@ -104,24 +104,24 @@ def measure_sparsity(codes, tolerance=0.):
#@partial(jit, static_argnums=[2])
def analyze_scores(mu, y, extract_label_indx=True): ## examines classifcation statistics
"""
- Analyzes a set of prediction matrix and target/ground-truth matrix or vector.
+ Analyzes a set of prediction matrix and target/ground-truth matrix or vector.
- Args:
- mu: prediction (design) matrix; shape is (N x C) where C is number of classes
- and N is the number of patterns examined
+ Args:
+ mu: prediction (design) matrix; shape is (N x C) where C is number of classes
+ and N is the number of patterns examined
- y: target / ground-truth (design) matrix; shape is (N x C) OR an array
- of class integers of length N (with "extract_label_indx = True")
+ y: target / ground-truth (design) matrix; shape is (N x C) OR an array
+ of class integers of length N (with "extract_label_indx = True")
- extract_label_indx: run an argmax to pull class integer indices from
- "y", assuming y is a one-hot binary encoding matrix (Default: True),
- otherwise, this assumes "y" is an array of class integer indices
- of length N
+ extract_label_indx: run an argmax to pull class integer indices from
+ "y", assuming y is a one-hot binary encoding matrix (Default: True),
+ otherwise, this assumes "y" is an array of class integer indices
+ of length N
- Returns:
- confusion matrix, precision, recall, misses (empty predictions/all-zero rows),
- accuracy, adjusted-accuracy (counts all misses as incorrect)
- """
+ Returns:
+ confusion matrix, precision, recall, misses (empty predictions/all-zero rows),
+ accuracy, adjusted-accuracy (counts all misses as incorrect)
+ """
miss_mask = (jnp.sum(mu, axis=1) == 0.) * 1.
misses = jnp.sum(miss_mask) ## how many misses?
labels = y
@@ -167,13 +167,46 @@ def measure_ACC(mu, y, extract_label_indx=True): ## measures/calculates accuracy
acc = jnp.sum( jnp.equal(guess, lab) )/(y.shape[0] * 1.)
return acc
+@partial(jit, static_argnums=[3])
+def measure_BIC(X, n_model_params, max_model_score, is_log=True):
+ """
+ Measures the Bayesian information criterion (BIC) with respect to the final
+ score obtained by the model on a given dataset.
+
+ | BIC = -2 ln(L) + K * ln(N);
+ | where N is number of data-points/rows of design matrix X,
+ | K is total number parameters of the model of interest, and
+ | L is the max/best-found value of a likelihood-like score L of the model
+
+ Args:
+ X: dataset/design matrix that a model was fit to (max-likelihood optimized)
+
+ n_model_params: total number of model parameters (int)
+
+ max_model_score: max likelihood-like score obtained by model on X
+
+ is_log: is supplied `max_model_score` a log-likelihood? if this is False,
+ this metric will apply a natural logarithm of the score (Default: True)
+
+ Returns:
+ scalar for the Bayesian information criterion score
+ """
+ ## BIC = K * ln(N) - 2 ln(L)
+ L_hat = max_model_score ## model's likelihood-like score (at max point)
+ K = n_model_params ## number of model params
+ N = X.shape[0] ## number of data-points
+ if not is_log:
+ L_hat = jnp.log(L_hat) ## get log likelihood
+ bic = -L_hat * 2. + jnp.log(N * 1.) * K
+ return bic
+
@partial(jit, static_argnums=[2])
def measure_KLD(p_xHat, p_x, preserve_batch=False):
"""
- Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two
- input arguments contain valid probability distributions (in each row, if
- they are matrices). Note: If batch is preserved, this returns a column
- vector where each row is the KLD(x_pred, x_true) for that row's datapoint.
+ Measures the (raw) Kullback-Leibler divergence (KLD), assuming that the two input arguments contain valid
+ probability distributions (in each row, if they are matrices). Note: If batch is preserved, this returns a column
+ vector where each row is the KLD(x_pred, x_true) for that row's datapoint. (Further note that this function
+ does not assume any particular distribution when calculating KLD)
| Formula:
| KLD(p_xHat, p_x) = (1/N) [ sum_i(p_x * jnp.log(p_x)) - sum_i(p_x * jnp.log(p_xHat)) ]
@@ -198,17 +231,62 @@ def measure_KLD(p_xHat, p_x, preserve_batch=False):
N = p_x.shape[1]
term1 = jnp.sum(_p_x * jnp.log(_p_x), axis=1, keepdims=True) # * (1/N)
term2 = -jnp.sum(_p_x * jnp.log(_p_xHat), axis=1, keepdims=True) # * (1/N)
- kld = (term1 + term2) * (1/N)
- if preserve_batch is False:
+ kld = (term1 + term2) * (1/N) ## KLD-per-datapoint
+ if not preserve_batch:
kld = jnp.mean(kld)
return kld
+def measure_gaussian_KLD(mu1, Sigma1, mu2, Sigma2, use_chol_prec=True):
+ """
+ Calculates the Kullback-Leibler (KL) divergence between two multivariate Gaussian distributions, i.e.,
+ KL(N(mu1, Sigma1) || N(mu2, Sigma2)).
+ Formally, this means this routine calculates:
+
+ | KL(N1 || N2) = [log(det(Sigma2)/det(Sigma1)) + trace(Prec2 * Sigma1) + (z * Prec2 * z) - D] * (1/2)
+ | where N1 is the 1st Gaussian, i.e., N(mu1,Sigma1), and N2 is the 2nd Gaussian, i.e., N(mu2,Sigma2);
+ | and where: Prec2 = (Sigma2)^{-1}, z = mu2 - mu1, and D is the data dimensionality
+
+ Args:
+ mu1: mean vector of first Gaussian distribution
+
+ Sigma1: covariance matrix of first Gaussian distribution
+
+ mu2: mean vector of second Gaussian distribution
+
+ Sigma2: covariance matrix of second Gaussian distribution
+
+ use_chol_prec: should this routine use Cholesky-factor computation of the precision of Sigma2 (Default: True)
+
+ Returns:
+ scalar representing KL-divergence between N(mu1, Sigma1) and N(mu2, Sigma2)
+ """
+ D = mu1.shape[1] ## dimensionality of data
+ ## log(|Sigma2|/|Sigma1|) = log(|Sigma2|) - log(|Sigma1|)
+ sgn_s1, val_s1 = jnp.linalg.slogdet(Sigma1)
+ log_detSigma1 = val_s1 * sgn_s1
+ sgn_s2, val_s2 = jnp.linalg.slogdet(Sigma2)
+ log_detSigma2 = val_s2 * sgn_s2
+
+ if use_chol_prec: ## use Cholesky-factor calc of (Sigma2)^{-1}
+ C = jnp.linalg.cholesky(Sigma2) ## cholesky factor matrix
+ inv_C = jnp.linalg.pinv(C)
+ Prec2 = jnp.matmul(inv_C.T, inv_C)
+ else:
+ Prec2 = jnp.linalg.pinv(Sigma2) ## pseudo-inverse calc of (Sigma2)^{-1}
+
+ trace_term = jnp.trace(jnp.dot(Prec2, Sigma1)) ## trace term of KL divergence
+ delta_mu = mu2 - mu1
+ quadratic_term = jnp.sum((jnp.matmul(delta_mu, Prec2) * delta_mu), axis=1, keepdims=True)
+ #quadratic_term = jnp.matmul(jnp.matmul(delta_mu.T, Prec2), delta_mu) ## quadratic term of KL divergence
+ # calc full KL divergence
+ kld = ((log_detSigma2 - log_detSigma1) + quadratic_term + trace_term + quadratic_term - D) * 0.5
+ return kld
+
@partial(jit, static_argnums=[3])
def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False):
"""
- Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is
- preserved, this returns a column vector where each row is the
- Cat.NLL(p, x) for that row's datapoint.
+ Measures the negative Categorical log likelihood (Cat.NLL). Note: If batch is preserved, this returns a column
+ vector where each row is the Cat.NLL(p, x) for that row's datapoint.
Args:
p: predicted probabilities; (N x C matrix, where C is number of categories)
@@ -225,17 +303,36 @@ def measure_CatNLL(p, x, offset=1e-7, preserve_batch=False):
"""
p_ = jnp.clip(p, offset, 1.0 - offset)
loss = -(x * jnp.log(p_))
- nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0)
- if preserve_batch is False:
+ nll = jnp.sum(loss, axis=1, keepdims=True) #/(y_true.shape[0] * 1.0) ## CatNLL-per-datapoint
+ if not preserve_batch:
nll = jnp.mean(nll)
return nll #tf.reduce_mean(nll)
+@jit
+def measure_RMSE(mu, x, preserve_batch=False):
+ """
+ Measures root mean squared error (RMSE). Note: If batch is preserved, this returns a column vector where each
+ row is the MSE(mu, x) for that row's datapoint. (THis is a simple wrapper/extension of the in-built MSE.)
+
+ Args:
+ mu: predicted values (mean); (N x D matrix)
+
+ x: target values (data); (N x D matrix)
+
+ preserve_batch: if True, will return one score per sample in batch
+ (Default: False), otherwise, returns scalar mean score
+
+ Returns:
+ an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
+ """
+ mse = measure_MSE(mu, x, preserve_batch=preserve_batch)
+ return jnp.sqrt(mse) ## sqrt(MSE) is the root-mean-squared-error
+
@jit
def measure_MSE(mu, x, preserve_batch=False):
"""
- Measures mean squared error (MSE), or the negative Gaussian log likelihood
- with variance of 1.0. Note: If batch is preserved, this returns a column
- vector where each row is the MSE(mu, x) for that row's datapoint.
+ Measures mean squared error (MSE), or the negative Gaussian log likelihood with variance of 1.0. Note: If batch
+ is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint.
Args:
mu: predicted values (mean); (N x D matrix)
@@ -250,17 +347,40 @@ def measure_MSE(mu, x, preserve_batch=False):
"""
diff = mu - x
se = jnp.square(diff) ## squared error
- mse = jnp.sum(se, axis=1, keepdims=True) # technically se at this point
- if preserve_batch is False:
+ mse = jnp.sum(se, axis=1, keepdims=True) ## technically squared-error per data-point
+ if not preserve_batch:
mse = jnp.mean(mse) # this is proper mse
return mse
+@jit
+def measure_MAE(shift, x, preserve_batch=False):
+ """
+ Measures mean absolute error (MAE), or the negative Laplacian log likelihood with scale of 1.0. Note: If batch
+ is preserved, this returns a column vector where each row is the MSE(mu, x) for that row's datapoint.
+
+ Args:
+ shift: predicted values (mean); (N x D matrix)
+
+ x: target values (data); (N x D matrix)
+
+ preserve_batch: if True, will return one score per sample in batch
+ (Default: False), otherwise, returns scalar mean score
+
+ Returns:
+ an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
+ """
+ diff = shift - x
+ se = jnp.abs(diff) ## squared error
+ mae = jnp.sum(se, axis=1, keepdims=True) ## technically abs-error per data-point
+ if not preserve_batch:
+ mae = jnp.mean(mae) # this is proper mae
+ return mae
+
@jit
def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
"""
- Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE).
- Note: If batch is preserved, this returns a column vector where each row is
- the BCE(p, x) for that row's datapoint.
+ Calculates the negative Bernoulli log likelihood or binary cross entropy (BCE). Note: If batch is preserved,
+ this returns a column vector where each row is the BCE(p, x) for that row's datapoint.
Args:
p: predicted probabilities of shape; (N x D matrix)
@@ -276,7 +396,7 @@ def measure_BCE(p, x, offset=1e-7, preserve_batch=False): #1e-10
an (N x 1) column vector (if preserve_batch=True) OR (1,1) scalar otherwise
"""
p_ = jnp.clip(p, offset, 1 - offset)
- bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True)
- if preserve_batch is False:
+ bce = -jnp.sum(x * jnp.log(p_) + (1.0 - x) * jnp.log(1.0 - p_),axis=1, keepdims=True) ## BCE-per-datapoint
+ if not preserve_batch:
bce = jnp.mean(bce)
return bce
diff --git a/ngclearn/utils/model_utils.py b/ngclearn/utils/model_utils.py
index facad87e..0ba13ece 100755
--- a/ngclearn/utils/model_utils.py
+++ b/ngclearn/utils/model_utils.py
@@ -6,7 +6,6 @@
import jax
from jax import numpy as jnp, grad, jit, vmap, random, lax, nn
from jax.lax import scan as _scan
-from ngcsimlib.utils import Get_Compartment_Batch, Set_Compartment_Batch, get_current_context
import os, sys
from functools import partial
import numpy as np
@@ -299,8 +298,14 @@ def d_relu(x):
@jit
def telu(x):
"""
- Proposed by Fernandez and Mali 24, https://arxiv.org/abs/2412.20269 and https://arxiv.org/abs/2402.02790
- TeLU activation: f(x) = x * tanh(e^x)
+ The hyperbolic tangent exponential linear (TeLU) function:
+
+ | f(x) = x * tanh(e^x)
+
+ This was proposed by Fernandez and Mali 24 in:
+
+ | https://arxiv.org/abs/2412.20269 and in,
+ | https://arxiv.org/abs/2402.02790
Args:
x: input (tensor) value
@@ -313,8 +318,10 @@ def telu(x):
@jit
def d_telu(x):
"""
-
- Derivative of TeLU: f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))
+ Derivative of the hyperbolic tangent exponential linear (TeLU) function.
+ Effectively, this is formally:
+
+ | f'(x) = tanh(e^x) + x * e^x * (1 - tanh^2(e^x))
Args:
x: input (tensor) value
@@ -719,39 +726,39 @@ def d_clip(x, min_val, max_val):
return jnp.where((x < min_val) | (x > max_val), 0.0, 1.0)
-def scanner(fn):
- """
- A wrapper for Jax's scanner that handles the "getting" of the current
- state and "setting" of the final state to and from the model.
-
- | @scanner
- | def process(current_state, args):
- | t = args[0]
- | dt = args[1]
- | current_state = model.advance_state(current_state, t, dt)
- | current_state = model.evolve(current_state, t, dt)
- | return current_state, (current_state[COMPONENT.COMPARTMENT.path], ...)
- |
- | outputs = models.process(jnp.array([[ARG0, ARG1] for i in range(NUM_LOOPS)]))
-
- | Notes on the scanner function call:
- | 1) `current_state` is a hash-map mapped to all compartment values by path
- | 2) `args` is the external arguments defined in the passed Jax array
- | 3) `outputs` is a tuple containing time-concatenated Jax arrays of the
- | compartment statistics you want tracked
-
- Args:
- fn: function that is executed at every time step of a Jax-unrolled loop,
- it must take in the current state and external arguments
-
- Returns:
- wrapped (fast) function that is Jax-scanned/jit-i-fied
- """
- def _scanned(_xs):
- vals, stacked = _scan(fn, init=Get_Compartment_Batch(), xs=_xs)
- Set_Compartment_Batch(vals)
- return stacked
-
- if get_current_context() is not None:
- get_current_context().__setattr__(fn.__name__, _scanned)
- return _scanned
+# def scanner(fn):
+# """
+# A wrapper for Jax's scanner that handles the "getting" of the current
+# state and "setting" of the final state to and from the model.
+#
+# | @scanner
+# | def process(current_state, args):
+# | t = args[0]
+# | dt = args[1]
+# | current_state = model.advance_state(current_state, t, dt)
+# | current_state = model.evolve(current_state, t, dt)
+# | return current_state, (current_state[COMPONENT.COMPARTMENT.path], ...)
+# |
+# | outputs = models.process(jnp.array([[ARG0, ARG1] for i in range(NUM_LOOPS)]))
+#
+# | Notes on the scanner function call:
+# | 1) `current_state` is a hash-map mapped to all compartment values by path
+# | 2) `args` is the external arguments defined in the passed Jax array
+# | 3) `outputs` is a tuple containing time-concatenated Jax arrays of the
+# | compartment statistics you want tracked
+#
+# Args:
+# fn: function that is executed at every time step of a Jax-unrolled loop,
+# it must take in the current state and external arguments
+#
+# Returns:
+# wrapped (fast) function that is Jax-scanned/jit-i-fied
+# """
+# def _scanned(_xs):
+# vals, stacked = _scan(fn, init=Get_Compartment_Batch(), xs=_xs)
+# Set_Compartment_Batch(vals)
+# return stacked
+#
+# if get_current_context() is not None:
+# get_current_context().__setattr__(fn.__name__, _scanned)
+# return _scanned
diff --git a/ngclearn/utils/optim/adam.py b/ngclearn/utils/optim/adam.py
index 12b1d756..4fb5c87a 100644
--- a/ngclearn/utils/optim/adam.py
+++ b/ngclearn/utils/optim/adam.py
@@ -1,16 +1,11 @@
# %%
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.resolver import resolver
-
import numpy as np
from jax import jit, numpy as jnp, random, nn, lax
from functools import partial
-import time
-def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps):
+def step_update(param, update, g1, g2, eta, beta1, beta2, time_step, eps):
"""
Runs one step of Adam over a set of parameters given updates.
The dynamics for any set of parameters is as follows:
@@ -33,24 +28,24 @@ def step_update(param, update, g1, g2, lr, beta1, beta2, time, eps):
g2: second moment factor/correction factor to use in parameter update
(must be same shape as "update")
- lr: global step size value to be applied to updates to parameters
+ eta: global step size value to be applied to updates to parameters
beta1: 1st moment control factor
beta2: 2nd moment control factor
- time: current time t or iteration step/call to this Adam update
+ time_step: current time t or iteration step/call to this Adam update
eps: numberical stability coefficient (for calculating final update)
Returns:
- adjusted parameter tensor (same shape as "param")
+ adjusted parameter tensor (same shape as "param"), adjusted g1, adjusted g2
"""
_g1 = beta1 * g1 + (1. - beta1) * update
_g2 = beta2 * g2 + (1. - beta2) * jnp.square(update)
- g1_unb = _g1 / (1. - jnp.power(beta1, time))
- g2_unb = _g2 / (1. - jnp.power(beta2, time))
- _param = param - lr * g1_unb/(jnp.sqrt(g2_unb) + eps)
+ g1_unb = _g1 / (1. - jnp.power(beta1, time_step))
+ g2_unb = _g2 / (1. - jnp.power(beta2, time_step))
+ _param = param - eta * g1_unb/(jnp.sqrt(g2_unb) + eps)
return _param, _g1, _g2
@jit
@@ -83,9 +78,7 @@ def adam_step(opt_params, theta, updates, eta=0.001, beta1=0.9, beta2=0.999, eps
new_g1 = []
new_g2 = []
for i in range(len(theta)):
- px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i],
- g2[i], eta, beta1,
- beta2, time_step, eps)
+ px_i, g1_i, g2_i = step_update(theta[i], updates[i], g1[i], g2[i], eta, beta1, beta2, time_step, eps)
new_theta.append(px_i)
new_g1.append(g1_i)
new_g2.append(g2_i)
diff --git a/ngclearn/utils/optim/nag.py b/ngclearn/utils/optim/nag.py
new file mode 100644
index 00000000..045be116
--- /dev/null
+++ b/ngclearn/utils/optim/nag.py
@@ -0,0 +1,84 @@
+# %%
+
+import numpy as np
+from jax import jit, numpy as jnp, random, nn, lax
+from functools import partial
+import time
+
+
+def step_update(param, update, phi_old, eta, mu, time_step):
+ """
+ Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates.
+ The dynamics for any set of parameters is as follows:
+
+ | phi = param - update * lr
+ | param = phi + (phi - phi_previous) * mu, where mu = 0 iff t <= 1 (first iteration)
+
+ Args:
+ param: parameter tensor to change/adjust
+
+ update: update tensor to be applied to parameter tensor (must be same
+ shape as "param")
+
+ phi_old: previous friction/momentum parameter
+
+ eta: global step size value to be applied to updates to parameters
+
+ mu: friction/momentum control factor
+
+ time_step: current time t or iteration step/call to this NAG update
+
+ Returns:
+ adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable
+ """
+ phi = param - update * eta ## do a phantom gradient adjustment step
+ _param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step
+ _phi_old = phi
+ return _param, _phi_old
+
+@jit
+def nag_step(opt_params, theta, updates, eta=0.01, mu=0.9): ## apply adjustment to theta
+ """
+ Implements Nesterov's accelerated gradient (NAG) algorithm as a decoupled update rule given adjustments produced
+ by a credit assignment algorithm/process.
+
+ Args:
+ opt_params: (ArrayLike) parameters of the optimization algorithm
+
+ theta: (ArrayLike) the weights of neural network
+
+ updates: (ArrayLike) the updates of neural network
+
+ eta: (float, optional) step size coefficient for NAG update (Default: 0.001)
+
+ mu: (float, optional) friction/momentum control factor. (Default: 0.9)
+
+ Returns:
+ ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights
+ """
+ phi, time_step = opt_params
+ time_step = time_step + 1
+ new_theta = []
+ new_phi = []
+ for i in range(len(theta)):
+ px_i, phi_i = step_update(theta[i], updates[i], phi[i], eta, mu, time_step)
+ new_theta.append(px_i)
+ new_phi.append(phi_i)
+ return (new_phi, time_step), new_theta
+
+@jit
+def nag_init(theta):
+ time_step = jnp.asarray(0.0)
+ phi = [jnp.zeros(theta[i].shape) for i in range(len(theta))]
+ return phi, time_step
+
+if __name__ == '__main__':
+ weights = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])]
+ updates = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])]
+ opt_params = nag_init(weights)
+ opt_params, theta = nag_step(opt_params, weights, updates)
+ print(f"opt_params: {opt_params}, theta: {theta}")
+ weights = theta
+ print("##################")
+ opt_params, theta = nag_step(opt_params, weights, updates)
+ print(f"opt_params: {opt_params}, theta: {theta}")
diff --git a/ngclearn/utils/optim/optim_utils.py b/ngclearn/utils/optim/optim_utils.py
index f02de676..e521c07b 100755
--- a/ngclearn/utils/optim/optim_utils.py
+++ b/ngclearn/utils/optim/optim_utils.py
@@ -1,17 +1,20 @@
import functools
from .sgd import sgd_step, sgd_init
+from .nag import nag_step, nag_init
from .adam import adam_step, adam_init
def get_opt_init_fn(opt='adam'):
return {
'adam': adam_init,
+ 'nag': nag_init,
'sgd': sgd_init
}[opt]
def get_opt_step_fn(opt='adam', **kwargs):
- # **kwargs here is the hyper parameters you want to pass in the optimization function
+ ## **kwargs here is the hyper-parameters you want to pass in the optimization function
return {
'adam': functools.partial(adam_step, **kwargs),
+ 'nag': functools.partial(nag_step, **kwargs),
'sgd': functools.partial(sgd_step, **kwargs),
}[opt]
diff --git a/ngclearn/utils/optim/sgd.py b/ngclearn/utils/optim/sgd.py
index 68594d4a..dfde125c 100755
--- a/ngclearn/utils/optim/sgd.py
+++ b/ngclearn/utils/optim/sgd.py
@@ -1,30 +1,22 @@
-# %%
+from jax import jit, numpy as jnp
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.resolver import resolver
-
-import numpy as np
-from jax import jit, numpy as jnp, random, nn, lax
-from functools import partial
-import time
-
-def step_update(param, update, lr):
+def step_update(param, update, eta):
"""
Runs one step of SGD over a set of parameters given updates.
Args:
- lr: global step size to apply when adjusting parameters
+ eta: global step size to apply when adjusting parameters
Returns:
adjusted parameter tensor (same shape as "param")
"""
- _param = param - lr * update
+ _param = param - update * eta
return _param
@jit
def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to theta
- """Return a params update
+ """
+ Returns updated parameters in accordance to a stochastic gradient descent (SGD) recipe
Args:
opt_params: (ArrayLike) parameters of the optimization algorithm
@@ -51,7 +43,6 @@ def sgd_step(opt_params, theta, updates, eta=0.001): ## apply adjustment to thet
def sgd_init(theta):
return jnp.asarray(0.0)
-
if __name__ == '__main__':
opt_params, theta = sgd_step((2.0), [1.0, 1.0], [3.0, 4.0], 3e-2)
print(f"opt_params: {opt_params}, theta: {theta}")
diff --git a/ngclearn/utils/patch.py b/ngclearn/utils/patch.py
new file mode 100644
index 00000000..94ad9af6
--- /dev/null
+++ b/ngclearn/utils/patch.py
@@ -0,0 +1,101 @@
+from typing import Literal
+
+from jax import numpy as jnp
+
+from ngcsimlib.logger import error, warn
+
+
+class PatchGenerator(object):
+ def __init__(self,
+ patch_height: int,
+ patch_width: int,
+ horizontal_alignment: Literal['left', 'right', 'center', 'fit']=None,
+ vertical_alignment: Literal['top', 'bottom', 'center', 'fit']=None,
+ horizontal_stride: int | None = None,
+ vertical_stride: int | None = None):
+ self.horizontal_alignment = horizontal_alignment or 'left'
+ self.horizontal_stride = horizontal_stride or 0
+ self.patch_height = patch_height
+
+ self.vertical_alignment = vertical_alignment or 'top'
+ self.vertical_stride = vertical_stride or 0
+ self.patch_width = patch_width
+
+ self.idx_cache = {}
+
+ self._current_height = None
+ self._current_width = None
+
+ self._max_patch = None
+ self._current_idx = -1
+ self._current_img = None
+
+ def __iter__(self):
+ if self._current_img is None:
+ error("Attempting to generate patches but no image has been provided")
+
+ self._current_idx = 0
+ return self
+
+ def target(self, img: jnp.ndarray):
+ height, width = img.shape[:2]
+ if height == self._current_height and width == self._current_width:
+ self._current_img = img
+ return
+
+ if self.patch_height > height or self.patch_width > width:
+ warn("Image to small for patches to be extracted, aborting")
+ return
+
+ horizontal_idxs = []
+ vertical_idxs = []
+
+ actual_patch_width = self.patch_width - self.horizontal_stride
+ if self.horizontal_alignment == 'left':
+ horizontal_idxs += range(0, width-self.patch_width, actual_patch_width)
+ elif self.horizontal_alignment == 'right':
+ horizontal_idxs += [i - self.patch_width for i in range(width, self.patch_width, -actual_patch_width)]
+ elif self.horizontal_alignment == 'center':
+ centerx = width // 2
+ horizontal_idxs += range(centerx, width-self.patch_width, actual_patch_width)
+ horizontal_idxs += [i - self.patch_width for i in range(centerx, self.patch_width, -actual_patch_width)]
+ elif self.horizontal_alignment == 'fit':
+ extra = ((width - self.patch_width) % actual_patch_width) // 2
+ horizontal_idxs += range(extra, width - self.patch_width + 1,
+ actual_patch_width)
+ else:
+ pass
+
+ actual_patch_height = self.patch_height - self.vertical_stride
+ if self.vertical_alignment == 'left':
+ horizontal_idxs += range(0, height-self.patch_height, actual_patch_height)
+ elif self.vertical_alignment == 'right':
+ horizontal_idxs += [i - self.patch_height for i in range(height, self.patch_width, -actual_patch_height)]
+ elif self.vertical_alignment == 'center':
+ centery = height // 2
+ horizontal_idxs += range(centery, height-self.patch_height, actual_patch_height)
+ horizontal_idxs += [i - self.patch_width for i in range(centery, self.patch_height, -actual_patch_height)]
+ elif self.vertical_alignment == 'fit':
+ extra = ((height - self.patch_height) % actual_patch_height) // 2
+ horizontal_idxs += range(extra, height - self.patch_height + 1,
+ actual_patch_height)
+
+ print(horizontal_idxs)
+
+ img = jnp.zeros((len(horizontal_idxs), width))
+ for row, idx in enumerate(horizontal_idxs):
+ img = img.at[row, idx:idx + self.patch_width].set(
+ img[row, idx:idx + self.patch_width] + 50)
+
+ import matplotlib.pyplot as plt
+
+ plt.imshow(img)
+ plt.show()
+
+
+## testing code
+# gen = PatchGenerator(patch_width=5, patch_height=5, horizontal_alignment='center', horizontal_stride=1)
+#
+# test_img = jnp.zeros((32, 32))
+#
+# gen.target(test_img)
diff --git a/ngclearn/utils/patch_utils.py b/ngclearn/utils/patch_utils.py
index f3116e84..a558e82e 100755
--- a/ngclearn/utils/patch_utils.py
+++ b/ngclearn/utils/patch_utils.py
@@ -39,17 +39,19 @@ class Create_Patches:
Args:
img: jax array of size (H, W)
- patched: (height_patch, width_patch)
- overlap: (height_overlap, width_overlap)
- add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap)
- create_patches: creates small patches out of the image based on the provided attributes.
+ patch_shape: (height_patch, width_patch)
+
+ overlap_shape: (height_overlap, width_overlap)
Returns:
- jnp.array: Array containing the patches
- shape: (num_patches, patch_height, patch_width)
+ jnp.array: Array containing the patches, shape: (num_patches, patch_height, patch_width)
"""
+ #patched: (height_patch, width_patch)
+ #overlap: (height_overlap, width_overlap)
+ #add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap)
+ #create_patches: creates small patches out of the image based on the provided attributes.
def __init__(self, img, patch_shape, overlap_shape):
self.img = img
@@ -90,6 +92,8 @@ def create_patches(self, add_frame=False, center=True):
Keyword Args:
add_frame: If true the function will add zero frames (increase the dimension) to the image
+ center:
+
Returns:
jnp.array: Array containing the patches
shape: (num_patches, patch_height, patch_width)
diff --git a/ngclearn/utils/viz/compartment_plot.py b/ngclearn/utils/viz/compartment_plot.py
new file mode 100644
index 00000000..639d813a
--- /dev/null
+++ b/ngclearn/utils/viz/compartment_plot.py
@@ -0,0 +1,38 @@
+"""
+Raster visualization functions/utilities.
+"""
+import matplotlib.pyplot as plt
+import jax
+from typing import Sequence
+
+def create_plot(history: jax.Array, ax: plt.Axes | None = None,
+ indices: Sequence[int] | None = None):
+ """
+ Generates a raster plot of a given (binary) spike train (row dimension
+ corresponds to the discrete time dimension).
+
+ Args:
+ history: a numpy binary array of shape (T x number_of_neurons)
+
+ ax: a hook/pointer to a currently external plot that this raster plot
+ should be made a sub-figure of
+
+ indices: optional indices of neurons (row integer indices) to focus on
+ plotting
+
+ s: size of the spike scatter points (Default = 0.5)
+
+ c: color of the spike scatter points (Default = black)
+
+ """
+ n_count = history.shape[0]
+ if ax is None:
+ nc = n_count if indices is None else len(indices)
+ fig_size = 5 if nc < 25 else int(nc / 5)
+ plt.figure(figsize=(fig_size, fig_size))
+
+ _ax = ax if ax is not None else plt
+
+ for k in range(history.shape[1]):
+ if indices is None or k in indices:
+ _ax.plot(history[:, k])
\ No newline at end of file
diff --git a/ngclearn/utils/viz/compartment_raster.py b/ngclearn/utils/viz/compartment_raster.py
new file mode 100755
index 00000000..d66a73eb
--- /dev/null
+++ b/ngclearn/utils/viz/compartment_raster.py
@@ -0,0 +1,49 @@
+"""
+Raster visualization functions/utilities.
+"""
+import matplotlib.pyplot as plt
+import jax
+from typing import Sequence
+
+def create_raster_plot(spike_train: jax.Array, ax: plt.Axes | None = None,
+ indices: Sequence[int] | None = None, s=0.5, c="black"):
+ """
+ Generates a raster plot of a given (binary) spike train (row dimension
+ corresponds to the discrete time dimension).
+
+ Args:
+ spike_train: a numpy binary array of shape (T x number_of_neurons)
+
+ ax: a hook/pointer to a currently external plot that this raster plot
+ should be made a sub-figure of
+
+ indices: optional indices of neurons (row integer indices) to focus on
+ plotting
+
+ s: size of the spike scatter points (Default = 0.5)
+
+ c: color of the spike scatter points (Default = black)
+
+ """
+ step_count = spike_train.shape[0]
+ n_count = spike_train.shape[1]
+ if ax is None:
+ nc = n_count if indices is None else len(indices)
+ fig_size = 5 if nc < 25 else int(nc / 5)
+ plt.figure(figsize=(fig_size, fig_size))
+
+ _ax = ax if ax is not None else plt
+
+ events = []
+ for t in range(n_count):
+ if indices is None or t in indices:
+ e = spike_train[:, t].nonzero()
+ events.append(e[0])
+ _ax.eventplot(events, linelengths=s, colors=c)
+ if ax is None:
+ _ax.yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))],
+ labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)])
+ _ax.xticks(ticks=[i for i in range(0, step_count+1, max(int(step_count / 5), 1))])
+ else:
+ _ax.set_yticks(ticks=[i for i in (range(n_count if indices is None else len(indices)))],
+ labels=["N" + str(i) for i in (range(n_count) if indices is None else indices)])
diff --git a/ngclearn/utils/viz/dim_reduce.py b/ngclearn/utils/viz/dim_reduce.py
index 4fd8c244..3f32057d 100755
--- a/ngclearn/utils/viz/dim_reduce.py
+++ b/ngclearn/utils/viz/dim_reduce.py
@@ -3,8 +3,8 @@
default_cmap = plt.cm.jet
import numpy as np
-from sklearn.decomposition import IncrementalPCA
-from sklearn.manifold import TSNE
+from sklearn.decomposition import IncrementalPCA ## sci-kit learning dependency
+from sklearn.manifold import TSNE ## sci-kit learning dependency
def extract_pca_latents(vectors): ## PCA mapping routine
"""
@@ -20,7 +20,6 @@ def extract_pca_latents(vectors): ## PCA mapping routine
"""
batch_size = 50
z_dim = vectors.shape[1]
- z_2D = None
if z_dim != 2:
ipca = IncrementalPCA(n_components=2, batch_size=batch_size)
ipca.fit(vectors)
@@ -31,26 +30,25 @@ def extract_pca_latents(vectors): ## PCA mapping routine
def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500): ## tSNE mapping routine
"""
- Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D)
- visualization space via the t-distributed stochastic neighbor embedding
- algorithm (t-SNE). This algorithm also uses PCA to produce an
- intermediate project to speed up the t-SNE final mapping step. Note that
- if the input already has a 2D dimensionality, the original input is returned.
+ Projects collection of K vectors (stored in a matrix) to a two-dimensional (2D) visualization space via the
+ t-distributed stochastic neighbor embedding algorithm (t-SNE). This algorithm also uses PCA to produce an
+ intermediate project to speed up the t-SNE final mapping step. Note that if the input already has a 2D
+ dimensionality, the original input is returned.
Args:
vectors: a matrix/codebook of (K x D) vectors to project
perplexity: the perplexity control factor for t-SNE (Default: 30)
- batch_size: number of sampled embedding vectors to use per iteration
- of online internal PCA
+ n_pca_comp: number of PCA top components (sorted by eigen-values) to retain/extract before continuing
+ with t-SNE dimensionality reduction
+
+ batch_size: number of sampled embedding vectors to use per iteration of online internal PCA
Returns:
a matrix (K x 2) of projected vectors (to 2D space)
"""
- #batch_size = 500 #50
z_dim = vectors.shape[1]
- z_2D = None
if z_dim != 2:
print(" > Projecting latents via iPCA...")
n_comp = n_pca_comp #32 #10 #16 #50
@@ -69,11 +67,10 @@ def extract_tsne_latents(vectors, perplexity=30, n_pca_comp=32, batch_size=500):
z_2D = vectors
return z_2D
-def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.,
- cmap=None):
+def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1., cmap=None):
"""
- Produces a label-overlaid (label map to distinct colors) scatterplot for
- visualizing two-dimensional latent codes (produced by either PCA or t-SNE).
+ Produces a label-overlaid (label map to distinct colors) scatterplot for visualizing two-dimensional latent codes
+ (produced by either PCA or t-SNE).
Args:
code_vectors: a matrix of shape (K x 2) with vectors to plot/visualize
@@ -92,8 +89,7 @@ def plot_latents(code_vectors, labels, plot_fname="2Dcode_plot.jpg", alpha=1.,
matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting
print(" > Plotting 2D latent encodings...")
curr_backend = plt.rcParams["backend"]
- matplotlib.use(
- 'Agg') ## temporarily go in Agg plt backend for tsne plotting
+ matplotlib.use('Agg') ## temporarily go in Agg plt backend for tsne plotting
lab = labels
if lab.shape[1] > 1: ## extract integer class labels from a one-hot matrix
lab = np.argmax(lab, 1)
diff --git a/ngclearn/utils/weight_distribution.py b/ngclearn/utils/weight_distribution.py
deleted file mode 100755
index c3c20893..00000000
--- a/ngclearn/utils/weight_distribution.py
+++ /dev/null
@@ -1,269 +0,0 @@
-"""
-Weight distribution initialization routines and co-routines, including
-parameter mapping functions for standard initializers.
-"""
-import numpy as np
-import jax
-from jax import numpy as jnp, jit, vmap, lax, nn, random
-from ngcsimlib.logger import critical
-
-################################################################################
-## supported distribution initializer configuration generator routines
-
-def constant(value, **kwargs):
- """
- Produce a configuration for a constant weight distribution initializer.
-
- Args:
- value: magnitude of the weight values (shared across all)
-
- Returns:
- a constant weight initializer configuration
- """
- dist_dict = {"dist": "constant", "value": value}
- return {**kwargs, **dist_dict}
-
-def fan_in_gaussian(**kwargs):
- """
- Produce a configuration for a fan-in scaled (centered) Gaussian
- distribution initializer.
-
- Returns:
- a fan-in scaled Gaussian distribution configuration
- """
- dist_dict = {"dist": "fan_in_gaussian"}
- return {**kwargs, **dist_dict}
-
-def gaussian(mu=0., sigma=1., **kwargs):
- """
- Produce a configuration for a Gaussian distribution initializer.
-
- Args:
- mu: mean of the weight values (default: 0)
-
- sigma: standard deviation of the weight values (default: 1)
-
- Returns:
- a Gaussian distribution configuration
- """
- assert sigma >= 0.
- dist_dict = {"dist": "gaussian", "mu": mu, "sigma": sigma}
- return {**kwargs, **dist_dict}
-
-def uniform(amin=0., amax=1., **kwargs):
- """
- Produce a configuration for a uniform distribution initializer.
-
- Args:
- amin: minimum value/bound of weight values (default: 0)
-
- amax: maximum value/bound of weight values (default: 1)
-
- Returns:
- a uniform distribution configuration
- """
- assert amin < amax
- dist_dict = {"dist": "uniform", "amin": amin, "amax": amax}
- return {**kwargs, **dist_dict}
-
-def fan_in_uniform(**kwargs):
- """
- Produce a configuration for a fan-in scaled unit uniform
- distribution initializer.
-
- Returns:
- a fan-in scaled (unit) uniform distribution configuration
- """
- dist_dict = {"dist": "fan_in_uniform"}
- return {**kwargs, **dist_dict}
-
-def hollow(scale, **kwargs):
- """
- Produce a configuration for a constant hollow distribution initializer.
-
- Args:
- scale: magnitude of all off-diagonal values
-
- Returns:
- a constant hollow distribution configuration
- """
- dist_dict = {"dist": "hollow", "scale": scale}
- return {**kwargs, **dist_dict}
-
-def eye(scale, **kwargs):
- """
- Produce a configuration for a constant diagonal/eye distribution initializer.
-
- Args:
- scale: magnitude of all (on-)diagonal values
-
- Returns:
- a constant diagonal/eye distribution configuration
- """
- dist_dict = {"dist": "eye", "scale": scale}
- return {**kwargs, **dist_dict}
-
-################################################################################
-## initializer co-routine(s)
-
-def initialize_params(dkey, init_kernel, shape, use_numpy=False):
- """
- Creates the intiial condition values for a parameter tensor.
-
- Args:
- dkey: PRNG key to control determinism of this routine
-
- init_kernel: dictionary specifying the distribution type and its
- parameters (default: `uniform` dist w/ `amin=0.02`, `amax=0.8`) --
- note that kernel dictionary may contain "post-processing" arguments
- that can be "stacked" on top of the base matrix, for example, you
- can pass in a dictionary:
- {"dist": "uniform", "hollow": True, "lower_triangle": True} which
- will create unit-uniform value matrix with upper triangle and main
- diagonal values masked to zero (lower-triangle masking applied after
- hollow matrix masking)
-
- :Note: Currently supported distribution (dist) kernel schemes include:
- "constant" (value);
- "uniform" (amin, amax);
- "gaussian" (mu, sigma);
- "fan_in_gaussian" (NO params);
- "fan_in_uniform" (NO params);
- "hollow" (scale);
- "eye" (scale);
- while currently supported post-processing keyword arguments include:
- "amin" (clip weights values to be >= amin);
- "amax" (clip weights values to be <= amin);
- "lower_triangle" (extract lower triangle of params, set rest to 0);
- "upper_triangle" (extract upper triangle of params, set rest to 0);
- "hollow" (zero out values along main diagonal);
- "eye" (zero out off-diagonal values);
- "n_row_active" (keep only n random rows non-masked/zero);
- "n_col_active" (keep only n random columns non-masked/zero)
-
- shape: tuple containing the dimensions/shape of the tensor to initialize
-
- use_numpy: if true, conducts weight value initialization/post-processing using
- exclusively Numpy, disabling Jax calls (default: False)
-
- Returns:
- output (tensor) value
- """
- if dkey is None:
- use_numpy = True
-
- _init_kernel = init_kernel
- if _init_kernel is None: ## the "universal default distribution" if None provided
- critical("No initialization kernel provided!")
- dist_type = _init_kernel.get("dist")
- params = None
- if dist_type == "hollow": ## scaled hollow-matrix init
- diag_scale = _init_kernel.get("scale", 1.)
- if use_numpy:
- params = (1. - np.eye(N=shape[0], M=shape[1])) * diag_scale
- else:
- params = (1. - jnp.eye(N=shape[0], M=shape[1])) * diag_scale
- elif dist_type == "eye": ## scaled diagonal/eye init
- off_diag_scale = _init_kernel.get("scale", 1.)
- if use_numpy:
- params = np.eye(N=shape[0], M=shape[1]) * off_diag_scale
- else:
- params = jnp.eye(N=shape[0], M=shape[1]) * off_diag_scale
- elif dist_type == "gaussian" or dist_type == "normal": ## normal distrib
- mu = _init_kernel.get("mu", 0.)
- sigma = _init_kernel.get("sigma", 1.)
- if use_numpy:
- params = np.random.normal(size=shape) * sigma + mu
- else:
- params = jax.random.normal(dkey, shape) * sigma + mu
- elif dist_type == "uniform": ## uniform distrib
- amin = _init_kernel.get("amin", 0.)
- amax = _init_kernel.get("amax", 1.)
- if use_numpy:
- params = np.random.uniform(low=amin, high=amax, size=shape)
- else:
- params = jax.random.uniform(dkey, shape, minval=amin, maxval=amax)
- elif dist_type == "fan_in_gaussian": ## fan-in scaled standard normal init
- if use_numpy:
- phi = np.random.normal(size=shape)
- else:
- phi = jax.random.normal(dkey, shape)
- phi = phi * jnp.sqrt(1.0 / (shape[0] * 1.))
- params = phi.astype(jnp.float32)
- elif dist_type == "fan_in_uniform": ## fan-in scaled unit uniform init
- phi = jnp.sqrt(1.0 / (shape[0] * 1.)) # sometimes "k" in other libraries
- if use_numpy:
- params = np.random.uniform(low=-phi, high=phi, size=shape)
- else:
- params = jax.random.uniform(dkey, shape, minval=-phi, maxval=phi)
- params = params.astype(jnp.float32)
- elif dist_type == "constant": ## constant value (everywhere) init
- scale = _init_kernel.get("value", 1.)
- if use_numpy:
- params = np.ones(shape) * scale
- else:
- params = jnp.ones(shape) * scale
- else:
- critical("Initialization scheme (" + dist_type + ") is not recognized/supported!")
- ## check for any additional distribution post-processing kwargs (e.g., clipping)
- clip_min = _init_kernel.get("amin")
- clip_max = _init_kernel.get("amax")
- lower_triangle = init_kernel.get("lower_triangle", False)
- upper_triangle = init_kernel.get("upper_triangle", False)
- is_hollow = _init_kernel.get("hollow", False)
- is_eye = _init_kernel.get("eye", False)
- n_row_active = _init_kernel.get("n_row_active", None)
- n_col_active = _init_kernel.get("n_col_active", None)
- block_diag_mask_width = _init_kernel.get("block_diag_mask_width", None)
- ## run any configured post-processing to condition the final value matrix
- if clip_min is not None: ## bound all values to be > clip_min
- if use_numpy:
- params = np.maximum(params, clip_min)
- else:
- params = jnp.maximum(params, clip_min)
- if clip_max is not None: ## bound all values to be < clip_max
- if use_numpy:
- params = np.minimum(params, clip_max)
- else:
- params = jnp.minimum(params, clip_max)
- if block_diag_mask_width is not None:
- k = int(params.shape[0] / block_diag_mask_width) #5
- n = block_diag_mask_width #2
- source = jnp.eye(k, k)
- block_mask = jnp.repeat(jnp.repeat(source, n, axis=1), n, axis=0)
- if block_mask.shape[0] == params.shape[0] and block_mask.shape[1] == params.shape[1]:
- params = params * block_mask
- else:
- critical(
- "Initialization block matrix w/ width (" + block_diag_mask_width +
- ") is not recognized/supported!"
- )
- if lower_triangle: ## extract lower triangle of params matrix
- ltri_params = jax.numpy.tril(params.shape[0])
- params = ltri_params
- if upper_triangle: ## extract upper triangle of params matrix
- ltri_params = jax.numpy.triu(params.shape[0])
- params = ltri_params
- if is_hollow: ## apply a hollow mask
- if use_numpy:
- params = (1. - np.eye(N=shape[0], M=shape[1])) * params
- else:
- params = (1. - jnp.eye(N=shape[0], M=shape[1])) * params
- if is_eye: ## apply an eye/diagonal mask
- if use_numpy:
- params = np.eye(N=shape[0], M=shape[1]) * params
- else:
- params = jnp.eye(N=shape[0], M=shape[1]) * params
- if n_row_active is not None: ## keep only n rows active (rest are zero)
- row_ind = random.permutation(dkey, shape[0])[0:n_row_active]
- mask = jnp.zeros(shape)
- mask = mask.at[row_ind, :].set(jnp.ones((shape[0], 1))) ## only set keep rows to ones
- params = params * mask
- if n_col_active is not None: ## keep only n cols active (rest are zero)
- row_col = random.permutation(dkey, shape[1])[0:n_col_active]
- mask = jnp.zeros(shape)
- mask = mask.at[:, row_col].set(jnp.zeros((1, shape[0]))) ## only set keep cols to ones
- params = params * mask
-
- return params ## return initial distribution conditions
-
diff --git a/pyproject.toml b/pyproject.toml
index 71681a99..b50245fd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,8 +8,8 @@ build-backend = "setuptools.build_meta" # using setuptool building engine
[project]
name = "ngclearn"
-version = "2.0.3"
-description = "Simulation software for building and analyzing arbitrary predictive coding, spiking network, and biomimetic neural systems."
+version = "3.0.0"
+description = "Simulation software for building and analyzing computational neuroscience models, brain-inspired computing systems, and NeuroAI agents."
authors = [
{name = "Alexander Ororbia", email = "ago@cs.rit.edu"},
{name = "William Gebhardt", email = "wdg1351@rit.edu"},
diff --git a/requirements.txt b/requirements.txt
index 36285e9d..1e871c1b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,10 +1,10 @@
-numpy>=1.22.0
+numpy>=1.26.4
scikit-learn>=1.6.1
-scipy>=1.7.0
-matplotlib>=3.8.0
-patchify
+scipy>=1.15.2
+matplotlib>=3.10.1
+# patchify # patchify has issues with pip installation
jax>=0.4.28
jaxlib>=0.4.28
-ngcsimlib>=1.0.1
-imageio>=2.31.5
+ngcsimlib>=3.0.0
+imageio>=2.37.0
pandas>=2.2.3
diff --git a/tests/components/input_encoders/test_bernoulliCell.py b/tests/components/input_encoders/test_bernoulliCell.py
index f73951b7..43c616e7 100644
--- a/tests/components/input_encoders/test_bernoulliCell.py
+++ b/tests/components/input_encoders/test_bernoulliCell.py
@@ -1,15 +1,13 @@
+# %%
+
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import BernoulliCell
#from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngclearn.utils import JaxProcess
-from ngcsimlib.context import Context
-#from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_bernoulliCell1():
@@ -23,16 +21,13 @@ def test_bernoulliCell1():
with Context(name) as ctx:
a = BernoulliCell(name="a", n_units=1, key=subkeys[0])
- advance_process = (JaxProcess("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
@@ -40,12 +35,12 @@ def clamp(x):
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0,ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts*1., dt=dt)
- outs.append(a.outputs.value)
+ clamp(x_t)
+ advance_process.run(t=ts*1., dt=dt)
+ outs.append(a.outputs.get())
outs = jnp.concatenate(outs, axis=1)
## output should equal input
diff --git a/tests/components/input_encoders/test_latencyCell.py b/tests/components/input_encoders/test_latencyCell.py
index 19843e54..4abf0552 100644
--- a/tests/components/input_encoders/test_latencyCell.py
+++ b/tests/components/input_encoders/test_latencyCell.py
@@ -1,17 +1,11 @@
+# %%
+
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import LatencyCell
-from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
-
+from ngclearn import MethodProcess, Context
def test_latencyCell1():
name = "latency_ctx"
@@ -29,23 +23,19 @@ def test_latencyCell1():
)
## create and compile core simulation commands
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="advance")
- calc_spike_times_process = (Process("calc_sptimes_proc")
+ calc_spike_times_process = (MethodProcess("calc_sptimes_proc")
>> a.calc_spike_times)
- ctx.wrap_and_add_command(jit(calc_spike_times_process.pure), name="calc_spike_times")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
## input spike train
- inputs = jnp.asarray([[0.02, 0.5, 1., 0.0]])
+ x_t = jnp.asarray([[0.02, 0.5, 1., 0.0]])
targets = np.zeros((T, 4))
targets[0, 2] = 1.
@@ -55,14 +45,14 @@ def clamp(x):
targets = jnp.array(targets) ## gold-standard solution to check against
outs = []
- ctx.reset()
- ctx.clamp(inputs)
- ctx.calc_spike_times()
+ reset_process.run()
+ clamp(x_t)
+ calc_spike_times_process.run()
for ts in range(T):
- ctx.clamp(inputs)
- ctx.advance(t=ts * dt, dt=dt)
+ clamp(x_t)
+ advance_process.run(t=ts * dt, dt=dt)
## naively extract simple statistics at time ts and print them to I/O
- s = a.outputs.value
+ s = a.outputs.get()
outs.append(s)
#print(" {}: s {} ".format(ts, jnp.squeeze(s)))
outs = jnp.concatenate(outs, axis=0)
diff --git a/tests/components/input_encoders/test_phasorCell.py b/tests/components/input_encoders/test_phasorCell.py
index 2f4735ac..d9091888 100644
--- a/tests/components/input_encoders/test_phasorCell.py
+++ b/tests/components/input_encoders/test_phasorCell.py
@@ -1,16 +1,9 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import PhasorCell
-#from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-
-from ngcsimlib.compilers.process import Process, transition
-#from ngcsimlib.component import Component
-#from ngcsimlib.compartment import Compartment
-#from ngcsimlib.context import Context
-#from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_phasorCell1():
@@ -24,16 +17,13 @@ def test_phasorCell1():
with Context(name) as ctx:
a = PhasorCell(name="a", n_units=1, target_freq=1000., disable_phasor=True, key=subkeys[0])
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
@@ -41,12 +31,12 @@ def clamp(x):
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.outputs.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.outputs.get())
#print(a.outputs.value)
outs = jnp.concatenate(outs, axis=1)
#print(outs)
diff --git a/tests/components/input_encoders/test_poissonCell.py b/tests/components/input_encoders/test_poissonCell.py
index 10c05867..fd29a13b 100644
--- a/tests/components/input_encoders/test_poissonCell.py
+++ b/tests/components/input_encoders/test_poissonCell.py
@@ -1,16 +1,10 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import PoissonCell
-from ngcsimlib.compilers import compile_command, wrap_command
+from ngclearn.components.input_encoders.poissonCell import PoissonCell
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_poissonCell1():
@@ -24,16 +18,13 @@ def test_poissonCell1():
with Context(name) as ctx:
a = PoissonCell(name="a", n_units=1, target_freq=1000., key=subkeys[0])
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
@@ -41,12 +32,12 @@ def clamp(x):
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.outputs.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.outputs.get())
outs = jnp.concatenate(outs, axis=1)
## output should equal input
diff --git a/tests/components/neurons/graded/test_RateCell.py b/tests/components/neurons/graded/test_RateCell.py
index bbd91d2b..ecd1ce9a 100644
--- a/tests/components/neurons/graded/test_RateCell.py
+++ b/tests/components/neurons/graded/test_RateCell.py
@@ -1,18 +1,12 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import RateCell
-from ngcsimlib.compilers import compile_command, wrap_command
+from ngclearn.components.neurons.graded.rateCell import RateCell
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
def test_RateCell1():
@@ -26,19 +20,11 @@ def test_RateCell1():
threshold=("none", 0.), integration_type="euler",
batch_size=1, resist_scale=1., shape=None, is_stateful=True
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
-
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.ones((1, 10))
@@ -46,15 +32,16 @@ def clamp(x):
y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.z.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.z.get())
outs = jnp.concatenate(outs, axis=1)
# print(outs)
## output should equal input
# assert_array_equal(outs, y_seq, tol=1e-3)
np.testing.assert_allclose(outs, y_seq, atol=1e-3)
+#test_RateCell1()
diff --git a/tests/components/neurons/graded/test_bernoulliErrorCell.py b/tests/components/neurons/graded/test_bernoulliErrorCell.py
index 897c6ef3..11d25c5f 100644
--- a/tests/components/neurons/graded/test_bernoulliErrorCell.py
+++ b/tests/components/neurons/graded/test_bernoulliErrorCell.py
@@ -1,19 +1,10 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import BernoulliErrorCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
-
+from ngclearn import MethodProcess, Context
def test_bernoulliErrorCell():
np.random.seed(42)
@@ -25,21 +16,12 @@ def test_bernoulliErrorCell():
a = BernoulliErrorCell(
name="a", n_units=1, batch_size=1, input_logits=False, shape=None
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- @Context.dynamicCommand
def clamp(x):
a.p.set(x)
- @Context.dynamicCommand
def clamp_target(x):
a.target.set(x)
@@ -50,14 +32,14 @@ def clamp_target(x):
y_seq = jnp.asarray([[-2.8193381, -4976.9263, -2.1224928, -2939.0425, -1233.3916, -0.24662945, -708.30042, 0.28213939, 3550.8477, 1.3651246]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
+ clamp(x_t)
target_xt = jnp.array([[target_seq[0, ts]]])
- ctx.clamp_target(target_xt)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.dp.value)
+ clamp_target(target_xt)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.dp.get())
outs = jnp.concatenate(outs, axis=1)
# print(outs)
## output should equal input
diff --git a/tests/components/neurons/graded/test_gaussianErrorCell.py b/tests/components/neurons/graded/test_gaussianErrorCell.py
index 1dd2a2e1..9c52746b 100644
--- a/tests/components/neurons/graded/test_gaussianErrorCell.py
+++ b/tests/components/neurons/graded/test_gaussianErrorCell.py
@@ -1,18 +1,11 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import GaussianErrorCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_gaussianErrorCell():
@@ -25,21 +18,12 @@ def test_gaussianErrorCell():
a = GaussianErrorCell(
name="a", n_units=1, batch_size=1, sigma=1.0, shape=None
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
-
- @Context.dynamicCommand
def clamp_mu(x):
a.mu.set(x)
- @Context.dynamicCommand
def clamp_target(x):
a.target.set(x)
@@ -53,15 +37,15 @@ def clamp_target(x):
dmu_outs = []
L_outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(mu_seq.shape[1]):
mu_t = jnp.array([[mu_seq[0, ts]]]) ## get data at time t
- ctx.clamp_mu(mu_t)
+ clamp_mu(mu_t)
target_t = jnp.array([[target_seq[0, ts]]])
- ctx.clamp_target(target_t)
- ctx.run(t=ts * 1., dt=dt)
- dmu_outs.append(a.dmu.value)
- L_outs.append(a.L.value)
+ clamp_target(target_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ dmu_outs.append(a.dmu.get())
+ L_outs.append(a.L.get())
dmu_outs = jnp.concatenate(dmu_outs, axis=1)
L_outs = jnp.array(L_outs)[None] # (1, 10)
@@ -74,4 +58,4 @@ def clamp_target(x):
np.testing.assert_allclose(dmu_outs, expected_dmu, atol=1e-5)
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
-# test_gaussianErrorCell()
\ No newline at end of file
+# test_gaussianErrorCell()
diff --git a/tests/components/neurons/graded/test_laplacianErrorCell.py b/tests/components/neurons/graded/test_laplacianErrorCell.py
index 4167bad9..16cd8539 100644
--- a/tests/components/neurons/graded/test_laplacianErrorCell.py
+++ b/tests/components/neurons/graded/test_laplacianErrorCell.py
@@ -1,21 +1,12 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import LaplacianErrorCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
+from ngclearn import MethodProcess, Context
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
-
-
-def test_laplacianErrorCell():
+def test_laplacianErrorCell1():
np.random.seed(42)
name = "laplacian_error_ctx"
dkey = random.PRNGKey(42)
@@ -25,27 +16,18 @@ def test_laplacianErrorCell():
a = LaplacianErrorCell(
name="a", n_units=1, batch_size=1, scale=1.0, shape=None
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
+
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- @Context.dynamicCommand
- def clamp_modulator(x):
- a.modulator.set(x)
+ def clamp_modulator(x):
+ a.modulator.set(x)
- @Context.dynamicCommand
- def clamp_shift(x):
- a.shift.set(x)
+ def clamp_shift(x):
+ a.shift.set(x)
- @Context.dynamicCommand
- def clamp_target(x):
- a.target.set(x)
+ def clamp_target(x):
+ a.target.set(x)
## input sequence
modulator_seq = jnp.ones((1, 10))
@@ -59,22 +41,22 @@ def clamp_target(x):
dshift_outs = []
L_outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(shift_seq.shape[1]):
shift_t = jnp.array([[shift_seq[0, ts]]]) ## get data at time t
- ctx.clamp_shift(shift_t)
+ clamp_shift(shift_t)
modulator_t = jnp.array([[modulator_seq[0, ts]]])
- ctx.clamp_modulator(modulator_t)
+ clamp_modulator(modulator_t)
target_t = jnp.array([[target_seq[0, ts]]])
- ctx.clamp_target(target_t)
- ctx.run(t=ts * 1., dt=dt)
- dshift_outs.append(a.dshift.value)
+ clamp_target(target_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ dshift_outs.append(a.dshift.get())
# print(f"a.L.value: {a.L.value}")
# print(f"a.shift.value: {a.shift.value}")
# print(f"a.target.value: {a.target.value}")
# print(f"a.Scale.value: {a.Scale.value}")
# print(f"a.mask.value: {a.mask.value}")
- L_outs.append(a.L.value)
+ L_outs.append(a.L.get())
dshift_outs = jnp.concatenate(dshift_outs, axis=1)
L_outs = jnp.array(L_outs)[None] # (1, 10)
@@ -87,3 +69,4 @@ def clamp_target(x):
np.testing.assert_allclose(dshift_outs, expected_dshift, atol=1e-5)
np.testing.assert_allclose(L_outs, expected_L, atol=1e-5)
+#test_laplacianErrorCell1()
diff --git a/tests/components/neurons/graded/test_leakyNoiseCell.py b/tests/components/neurons/graded/test_leakyNoiseCell.py
new file mode 100644
index 00000000..096c4f68
--- /dev/null
+++ b/tests/components/neurons/graded/test_leakyNoiseCell.py
@@ -0,0 +1,47 @@
+# %%
+
+from jax import numpy as jnp, random, jit
+import numpy as np
+np.random.seed(42)
+from ngclearn.components.neurons.graded.leakyNoiseCell import LeakyNoiseCell
+from numpy.testing import assert_array_equal
+
+from ngclearn import Context, MethodProcess
+
+
+def test_LeakyNoiseCell1():
+ name = "leaky_noise_ctx"
+ dkey = random.PRNGKey(42)
+ dkey, *subkeys = random.split(dkey, 100)
+ dt = 1. # ms
+ with Context(name) as ctx:
+ a = LeakyNoiseCell(
+ name="a", n_units=1, tau_x=50., act_fx="identity", integration_type="euler", batch_size=1, sigma_rec=0.,
+ leak_scale=0.
+ )
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
+
+ def clamp(x):
+ a.j_input.set(x)
+
+ ## input spike train
+ x_seq = jnp.ones((1, 10))
+ ## desired output/epsp pulses
+ y_seq = jnp.asarray([[0.02, 0.04, 0.06, 0.08, 0.09999999999999999, 0.11999999999999998, 0.13999999999999999, 0.15999999999999998, 0.17999999999999998, 0.19999999999999998]], dtype=jnp.float32)
+
+ outs = []
+ reset_process.run()
+ for ts in range(x_seq.shape[1]):
+ x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.x.get())
+ outs = jnp.concatenate(outs, axis=1)
+ # print(outs)
+ # print(y_seq)
+ ## output should approximately equal input
+ # assert_array_equal(outs, y_seq, tol=1e-3)
+ np.testing.assert_allclose(outs, y_seq, atol=1e-3)
+
+#test_LeakyNoiseCell1()
diff --git a/tests/components/neurons/graded/test_rewardErrorCell.py b/tests/components/neurons/graded/test_rewardErrorCell.py
index 6ecb7710..e465d07c 100644
--- a/tests/components/neurons/graded/test_rewardErrorCell.py
+++ b/tests/components/neurons/graded/test_rewardErrorCell.py
@@ -1,18 +1,11 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import RewardErrorCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_rewardErrorCell():
@@ -27,21 +20,10 @@ def test_rewardErrorCell():
name="a", n_units=1, alpha=alpha, ema_window_len=10,
use_online_predictor=True, batch_size=1
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- evolve_process = (Process("evolve_proc") >> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
+ evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- # ctx.add_command(wrap_command(jit(ctx.evolve)), name="evolve")
-
- @Context.dynamicCommand
def clamp_reward(x):
a.reward.set(x)
@@ -71,18 +53,18 @@ def clamp_reward(x):
mu_outs = []
rpe_outs = []
accum_reward_outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(reward_seq.shape[1]):
reward_t = jnp.array([[reward_seq[0, ts]]]) ## get reward at time t
- ctx.clamp_reward(reward_t)
- ctx.run(t=ts * 1., dt=dt)
- mu_outs.append(a.mu.value)
- rpe_outs.append(a.rpe.value)
- accum_reward_outs.append(a.accum_reward.value)
+ clamp_reward(reward_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ mu_outs.append(a.mu.get())
+ rpe_outs.append(a.rpe.get())
+ accum_reward_outs.append(a.accum_reward.get())
# Test evolve function
- ctx.evolve(t=10 * 1., dt=dt)
- final_mu = a.mu.value
+ evolve_process.run(t=10 * 1., dt=dt)
+ final_mu = a.mu.get()
# print(f"final_mu: {final_mu}")
mu_outs = jnp.concatenate(mu_outs, axis=1)
@@ -103,4 +85,4 @@ def clamp_reward(x):
expected_final_mu = (1 - 1/10) * mu_outs[0, -1] + (1/10) * (accum_reward_outs[0, -1] / 10)
np.testing.assert_allclose(final_mu, expected_final_mu, atol=1e-5)
-# test_rewardErrorCell()
\ No newline at end of file
+#test_rewardErrorCell()
diff --git a/tests/components/neurons/spiking/test_IFCell.py b/tests/components/neurons/spiking/test_IFCell.py
index 28f3d8c0..3db38d72 100644
--- a/tests/components/neurons/spiking/test_IFCell.py
+++ b/tests/components/neurons/spiking/test_IFCell.py
@@ -2,15 +2,10 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import IFCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.IFCell import IFCell
+from numpy.testing import assert_array_equal
def test_IFCell1():
name = "if_ctx"
@@ -18,35 +13,28 @@ def test_IFCell1():
dkey = random.PRNGKey(1234)
dkey, *subkeys = random.split(dkey, 6)
dt = 1. # ms
- trace_increment = 0.1
# ---- build a simple Poisson cell system ----
with Context(name) as ctx:
a = IFCell(
name="a", n_units=1, tau_m=5., resist_m=10., key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- #ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32)
@@ -54,15 +42,16 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
- print(outs)
-
+ # print(outs)
+ # print(y_seq)
+
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_LIFCell.py b/tests/components/neurons/spiking/test_LIFCell.py
index 6f5f7c1a..b918d9a1 100644
--- a/tests/components/neurons/spiking/test_LIFCell.py
+++ b/tests/components/neurons/spiking/test_LIFCell.py
@@ -2,15 +2,10 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import LIFCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.LIFCell import LIFCell
+from numpy.testing import assert_array_equal
def test_LIFCell1():
name = "lif_ctx"
@@ -26,27 +21,21 @@ def test_LIFCell1():
)
#"""
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- #ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ #ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ #ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
#"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0.]], dtype=jnp.float32)
@@ -54,15 +43,16 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() #ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) #ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
- #print(outs)
-
+ # print(outs)
+ # print(y_seq)
+
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_RAFCell.py b/tests/components/neurons/spiking/test_RAFCell.py
index a8a7fbfc..3a076ba6 100644
--- a/tests/components/neurons/spiking/test_RAFCell.py
+++ b/tests/components/neurons/spiking/test_RAFCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
-
np.random.seed(42)
-from ngclearn.components import RAFCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.RAFCell import RAFCell
+from numpy.testing import assert_array_equal
def test_RAFCell1():
@@ -26,28 +20,22 @@ def test_RAFCell1():
name="a", n_units=1, tau_v=20., resist_v=1., key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[0., 1., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
@@ -55,14 +43,13 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 1., 0., 0., 0., 0., 1.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
- #print(outs)
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_WTASCell.py b/tests/components/neurons/spiking/test_WTASCell.py
index b56b87e5..82384701 100644
--- a/tests/components/neurons/spiking/test_WTASCell.py
+++ b/tests/components/neurons/spiking/test_WTASCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import WTASCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.WTASCell import WTASCell
+from numpy.testing import assert_array_equal
def test_WTASCell1():
@@ -27,27 +21,22 @@ def test_WTASCell1():
)
#"""
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess(name="advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ #ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess(name="reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ #ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
#"""
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
+ # ## set up non-compiled utility commands
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
- ## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[0., 1.], [0., 1.], [1., 0.], [1., 0.]], dtype=jnp.float32)
@@ -55,14 +44,15 @@ def clamp(x):
y_seq = x_seq
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[0]):
x_t = x_seq[ts:ts+1, :] ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) #ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=0)
- #print(outs)
+ # print(outs)
+ # print(y_seq)
#exit()
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_adExCell.py b/tests/components/neurons/spiking/test_adExCell.py
index 2c0b9338..cb1dd528 100644
--- a/tests/components/neurons/spiking/test_adExCell.py
+++ b/tests/components/neurons/spiking/test_adExCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
-
np.random.seed(42)
-from ngclearn.components import AdExCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.adExCell import AdExCell
+from numpy.testing import assert_array_equal
def test_adExCell1():
@@ -26,28 +20,22 @@ def test_adExCell1():
name="a", n_units=1, tau_m=50., resist_m=30., thr=-66., key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.ones((1, 10))
@@ -55,16 +43,18 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
+
outs = jnp.concatenate(outs, axis=1)
- #print(outs)
+ # print(outs)
+ # print(y_seq)
## output should equal input
assert_array_equal(outs, y_seq)
-#test_adExCell1()
+test_adExCell1()
diff --git a/tests/components/neurons/spiking/test_fitzhughNagumoCell.py b/tests/components/neurons/spiking/test_fitzhughNagumoCell.py
index eecc28e5..5ca0f489 100644
--- a/tests/components/neurons/spiking/test_fitzhughNagumoCell.py
+++ b/tests/components/neurons/spiking/test_fitzhughNagumoCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
-
np.random.seed(42)
-from ngclearn.components import FitzhughNagumoCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.fitzhughNagumoCell import FitzhughNagumoCell
+from numpy.testing import assert_array_equal
def test_fitzhughNagumoCell1():
@@ -26,28 +20,22 @@ def test_fitzhughNagumoCell1():
name="a", n_units=1, tau_m=1., resist_m=5., v_thr=2.1, key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
@@ -55,14 +43,16 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
- x_t = x_seq[:, ts:ts+1] ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
+
outs = jnp.concatenate(outs, axis=1)
- #print(outs)
+ # print(outs)
+ # print(y_seq)
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py b/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py
index d86c3fd0..aeb80c48 100644
--- a/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py
+++ b/tests/components/neurons/spiking/test_hodgkinHuxleyCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
-
np.random.seed(42)
-from ngclearn.components import HodgkinHuxleyCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_almost_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.hodgkinHuxleyCell import HodgkinHuxleyCell
+from numpy.testing import assert_array_almost_equal
import matplotlib.pyplot as plt
@@ -30,27 +24,21 @@ def test_hodgkinHuxleyCell1():
)
# """
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
# """
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.zeros((1, 20))
@@ -61,12 +49,15 @@ def clamp(x):
0.40085957, 0.42394499, 0.44698984, 0.46999594]], dtype=jnp.float32)
v = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- v.append(a.v.value[0, 0])
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ v.append(a.v.get()[0, 0])
+ # print(outs)
+ # print(y_seq)
+
outs = jnp.array(v)
diff = np.abs(outs - y_seq)
## delta/error should be approximately zero
diff --git a/tests/components/neurons/spiking/test_izhikevichCell.py b/tests/components/neurons/spiking/test_izhikevichCell.py
index 165752d9..04ec6bcb 100644
--- a/tests/components/neurons/spiking/test_izhikevichCell.py
+++ b/tests/components/neurons/spiking/test_izhikevichCell.py
@@ -1,17 +1,11 @@
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
-
np.random.seed(42)
-from ngclearn.components import IzhikevichCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.izhikevichCell import IzhikevichCell
+from numpy.testing import assert_array_equal
def test_izhikevichCell1():
@@ -26,28 +20,22 @@ def test_izhikevichCell1():
name="a", n_units=1, tau_m=1., resist_m=4., v_thr=30., key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]], dtype=jnp.float32)
@@ -55,16 +43,16 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
- x_t = x_seq[:, ts:ts+1] ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
- print(a.v.value)
+ x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
- print(outs)
- #exit()
+ # print(outs)
+ # print(y_seq)
+
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_quadLIFCell.py b/tests/components/neurons/spiking/test_quadLIFCell.py
index d79418ff..58756dba 100644
--- a/tests/components/neurons/spiking/test_quadLIFCell.py
+++ b/tests/components/neurons/spiking/test_quadLIFCell.py
@@ -2,15 +2,10 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import QuadLIFCell
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import Context, MethodProcess
+from ngclearn.components.neurons.spiking.quadLIFCell import QuadLIFCell
+from numpy.testing import assert_array_equal
def test_quadLIFCell1():
name = "quadlif_ctx"
@@ -18,35 +13,29 @@ def test_quadLIFCell1():
dkey = random.PRNGKey(1234)
dkey, *subkeys = random.split(dkey, 6)
dt = 1. # ms
- trace_increment = 0.1
+ critical_V = 1.
# ---- build a simple Poisson cell system ----
with Context(name) as ctx:
a = QuadLIFCell(
- name="a", n_units=1, tau_m=30., resist_m=1., key=subkeys[0]
+ name="a", n_units=1, tau_m=30., resist_m=1., critical_V=critical_V, key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ # """
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- #ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
+ # ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
-
+ # ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ # """
## set up non-compiled utility commands
- @Context.dynamicCommand
- def clamp(x):
- a.j.set(x)
+ # @Context.dynamicCommand
+ # def clamp(x):
+ # a.j.set(x)
+
+ def clamp(x):
+ a.j.set(x)
## input spike train
x_seq = jnp.asarray([[1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0., 0.]], dtype=jnp.float32)
@@ -54,14 +43,13 @@ def clamp(x):
y_seq = jnp.asarray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t) # ctx.clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
- #print(outs)
## output should equal input
assert_array_equal(outs, y_seq)
diff --git a/tests/components/neurons/spiking/test_sLIFCell.py b/tests/components/neurons/spiking/test_sLIFCell.py
index b1b5f517..697f1790 100644
--- a/tests/components/neurons/spiking/test_sLIFCell.py
+++ b/tests/components/neurons/spiking/test_sLIFCell.py
@@ -1,16 +1,11 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import SLIFCell
-from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
+
def test_sLIFCell1():
name = "slif_ctx"
@@ -25,26 +20,12 @@ def test_sLIFCell1():
name="a", n_units=1, tau_m=50., resist_m=10., thr=0.3, key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- #ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
-
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.j.set(x)
@@ -54,12 +35,12 @@ def clamp(x):
y_seq = jnp.asarray([[0., 1., 0., 0., 0., 1., 0.]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.s.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.s.get())
outs = jnp.concatenate(outs, axis=1)
## output should equal input
diff --git a/tests/components/other/test_expKernel.py b/tests/components/other/test_expKernel.py
index 0ece0bad..9375da66 100644
--- a/tests/components/other/test_expKernel.py
+++ b/tests/components/other/test_expKernel.py
@@ -1,16 +1,8 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import ExpKernel
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
def test_expKernel1():
name = "expKernel_ctx"
@@ -25,16 +17,12 @@ def test_expKernel1():
name="a", n_units=1, dt=1., tau_w=500., nu=4., key=subkeys[0]
)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
-
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
@@ -44,16 +32,16 @@ def clamp(x):
y_seq = jnp.asarray([[0., 1., 0.998002, 0.996008, 1.9940181]], dtype=jnp.float32)
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.epsp.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.epsp.get())
outs = jnp.concatenate(outs, axis=1)
#print(outs)
## output should equal input
np.testing.assert_allclose(outs, y_seq, atol=1e-8)
-#test_expKernel1()
+test_expKernel1()
diff --git a/tests/components/other/test_varTrace.py b/tests/components/other/test_varTrace.py
index 88444588..8b8ba84d 100644
--- a/tests/components/other/test_varTrace.py
+++ b/tests/components/other/test_varTrace.py
@@ -1,16 +1,11 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
from ngclearn.components import VarTrace
-from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
+
def test_varTrace1():
name = "trace_ctx"
@@ -26,35 +21,32 @@ def test_varTrace1():
key=subkeys[0]
)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
## set up non-compiled utility commands
- @Context.dynamicCommand
def clamp(x):
a.inputs.set(x)
## input spike train
x_seq = jnp.asarray([[1., 1., 0., 0., 1.]], dtype=jnp.float32)
## desired output pulses
- y_seq = x_seq * trace_increment
+ y_seq = x_seq * trace_increment
outs = []
- ctx.reset()
+ reset_process.run()
for ts in range(x_seq.shape[1]):
x_t = jnp.array([[x_seq[0, ts]]]) ## get data at time t
- ctx.clamp(x_t)
- ctx.run(t=ts * 1., dt=dt)
- outs.append(a.outputs.value)
+ clamp(x_t)
+ advance_process.run(t=ts * 1., dt=dt)
+ outs.append(a.outputs.get())
outs = jnp.concatenate(outs, axis=1)
#print(outs)
## output should equal input
assert_array_equal(outs, y_seq)
-#test_varTrace1()
+test_varTrace1()
diff --git a/tests/components/synapses/convolution/test_hebbianConvSynapse.py b/tests/components/synapses/convolution/test_hebbianConvSynapse.py
index db6cd662..e5ee3d74 100644
--- a/tests/components/synapses/convolution/test_hebbianConvSynapse.py
+++ b/tests/components/synapses/convolution/test_hebbianConvSynapse.py
@@ -1,16 +1,11 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import HebbianConvSynapse
-import ngclearn.utils.weight_distribution as dist
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.convolution.hebbianConvSynapse import HebbianConvSynapse
+from numpy.testing import assert_array_equal
def test_HebbianConvSynapse1():
name = "hebb_conv_ctx"
@@ -36,41 +31,24 @@ def test_HebbianConvSynapse1():
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
+ evolve_process = (MethodProcess("evolve_process")
>> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
- backtransmit_process = (Process("btransmit_proc")
- >> a.backtransmit)
- ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
+ backtransmit_process = (MethodProcess("backtransmit_process")
+ >> a.backtransmit)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
- ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
- """
x = jnp.ones(x_shape)
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.inputs.set(x)
- ctx.run(t=1., dt=dt)
- y = a.outputs.value
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ y = a.outputs.get()
y_truth = jnp.array(
[[[[4.],[2.]],
@@ -79,17 +57,16 @@ def test_HebbianConvSynapse1():
assert_array_equal(y, y_truth)
# print(y)
+ # print("y.Tr:\n", y_truth)
# print("======")
- # print("NGC-Learn.shape = ", node.outputs.value.shape)
+ # print("NGC-Learn.shape = ", node.outputs.get().shape)
a.pre.set(x)
a.post.set(y)
- ctx.adapt(t=1., dt=dt)
- dK = a.dWeights.value
- #print(dK)
- ctx.backtransmit(t=1., dt=dt)
- dx = a.dInputs.value
- #print(dx)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ dK = a.dWeights.get()
+ backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt)
+ dx = a.dInputs.get()
dK_truth = jnp.array(
[[[[9.]],
[[6.]]],
@@ -102,6 +79,10 @@ def test_HebbianConvSynapse1():
[[6.],
[9.]]]]
)
+ # print(dK)
+ # print("dK.Tr:\n", dK_truth)
+ # print(dx)
+ # print("dx.Tr:\n", dx_truth)
assert_array_equal(dK, dK_truth)
assert_array_equal(dx, dx_truth)
diff --git a/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py b/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py
index a91e69d4..57ed9756 100644
--- a/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py
+++ b/tests/components/synapses/convolution/test_hebbianDeconvSynapse.py
@@ -1,16 +1,11 @@
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import HebbianDeconvSynapse
-import ngclearn.utils.weight_distribution as dist
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.convolution.hebbianDeconvSynapse import HebbianDeconvSynapse
+from numpy.testing import assert_array_equal
def test_HebbianDeconvSynapse1():
name = "hebb_deconv_ctx"
@@ -36,43 +31,24 @@ def test_HebbianDeconvSynapse1():
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
- backtransmit_process = (Process("btransmit_proc")
+ backtransmit_process = (MethodProcess("backtransmit_process")
>> a.backtransmit)
- ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
- ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
- """
x = jnp.ones(x_shape)
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.inputs.set(x)
- ctx.run(t=1., dt=dt)
- y = a.outputs.value
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ y = a.outputs.get()
y_truth = jnp.array(
[[[[1.],[2.]],
@@ -80,18 +56,17 @@ def test_HebbianDeconvSynapse1():
)
assert_array_equal(y, y_truth)
- #print(y)
- #print("======")
+ # print(y)
+ # print("y.Tr:\n", y_truth)
+ # print("======")
- # print("NGC-Learn.shape = ", node.outputs.value.shape)
+ # print("NGC-Learn.shape = ", node.outputs.get().shape)
a.pre.set(x)
a.post.set(y)
- ctx.adapt(t=1., dt=dt)
- dK = a.dWeights.value
- #print(dK)
- ctx.backtransmit(t=1., dt=dt)
- dx = a.dInputs.value
- #print(dx)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ dK = a.dWeights.get()
+ backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt)
+ dx = a.dInputs.get()
dK_truth = jnp.array(
[[[[4.]],
[[6.]]],
@@ -104,6 +79,10 @@ def test_HebbianDeconvSynapse1():
[[6.],
[4.]]]]
)
+ # print(dK)
+ # print("dK.Tr:\n", dK_truth)
+ # print(dx)
+ # print("dx.Tr:\n", dx_truth)
assert_array_equal(dK, dK_truth)
assert_array_equal(dx, dx_truth)
diff --git a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py
index bf113760..2df6a7cc 100644
--- a/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py
+++ b/tests/components/synapses/convolution/test_traceSTDPConvSynapse.py
@@ -2,15 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import TraceSTDPConvSynapse
-import ngclearn.utils.weight_distribution as dist
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.convolution.traceSTDPConvSynapse import TraceSTDPConvSynapse
+from numpy.testing import assert_array_equal
def test_TraceSTDPConvSynapse1():
name = "stdp_conv_ctx"
@@ -36,34 +32,17 @@ def test_TraceSTDPConvSynapse1():
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
+ evolve_process = (MethodProcess("evolve_process")
>> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
- backtransmit_process = (Process("btransmit_proc")
+ backtransmit_process = (MethodProcess("backtransmit_process")
>> a.backtransmit)
- ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
- ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
- """
## fake out a mix of pre-synaptic spikes/no-spikes
x = np.ones(x_shape)
@@ -75,25 +54,25 @@ def test_TraceSTDPConvSynapse1():
[[1.], [0.]]]]
)
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.inputs.set(x)
- ctx.run(t=1., dt=dt)
- y = (a.outputs.value > 0.) * 1. ## fake out post-syn spikes
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes
assert_array_equal(y, y_truth)
- #print(y)
- #print("======")
+ # print(y)
+ # print("y.Tr:\n", y_truth)
+ # print("======")
- # print("NGC-Learn.shape = ", node.outputs.value.shape)
+ # print("NGC-Learn.shape = ", node.outputs.get().shape)
a.preSpike.set(x)
a.postSpike.set(y)
a.preTrace.set(x * 0.4) ## fake out pre-syn trace values
a.postTrace.set(y * 1.3) ## fake out post-syn trace values
- ctx.adapt(t=1., dt=dt)
- dK = a.dWeights.value
- #print(dK)
- ctx.backtransmit(t=1., dt=dt)
- dx = a.dInputs.value
- #print(dx)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ dK = a.dWeights.get()
+
+ backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt)
+ dx = a.dInputs.get()
dK_truth = jnp.array(
[[[[-1.8]],
[[-0.9]]],
@@ -106,6 +85,10 @@ def test_TraceSTDPConvSynapse1():
[[2.],
[3.]]]]
)
+ # print(dK)
+ # print("dK.Tr:\n", dK_truth)
+ # print(dx)
+ # print("dx.Tr:\n", dx_truth)
assert_array_equal(dK, dK_truth)
assert_array_equal(dx, dx_truth)
diff --git a/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py b/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py
index 76be1c2a..03753a22 100644
--- a/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py
+++ b/tests/components/synapses/convolution/test_traceSTDPDeconvSynapse.py
@@ -2,15 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import TraceSTDPDeconvSynapse
-import ngclearn.utils.weight_distribution as dist
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.convolution.traceSTDPDeconvSynapse import TraceSTDPDeconvSynapse
+from numpy.testing import assert_array_equal
def test_TraceSTDPDeconvSynapse1():
name = "stdp_deconv_ctx"
@@ -37,36 +33,17 @@ def test_TraceSTDPDeconvSynapse1():
stride=stride, padding=padding_style, batch_size=batch_size, key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
- backtransmit_process = (Process("btransmit_proc")
+ backtransmit_process = (MethodProcess("backtransmit_process")
>> a.backtransmit)
- ctx.wrap_and_add_command(jit(backtransmit_process.pure), name="backtransmit")
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- backpass_cmd, backpass_args = ctx.compile_by_key(a, compile_key="backtransmit")
- ctx.add_command(wrap_command(jit(ctx.backtransmit)), name="backtransmit")
- """
## fake out a mix of pre-synaptic spikes/no-spikes
x = np.ones(x_shape)
@@ -78,25 +55,24 @@ def test_TraceSTDPDeconvSynapse1():
[[1.], [1.]]]]
)
- ctx.reset()
+ reset_process.run() #ctx.reset()
a.inputs.set(x)
- ctx.run(t=1., dt=dt)
- y = (a.outputs.value > 0.) * 1. ## fake out post-syn spikes
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ y = (a.outputs.get() > 0.) * 1. ## fake out post-syn spikes
assert_array_equal(y, y_truth)
- #print(y)
- #print("======")
+ # print(y)
+ # print("y.Tr:\n", y_truth)
+ # print("======")
- # print("NGC-Learn.shape = ", node.outputs.value.shape)
+ # print("NGC-Learn.shape = ", node.outputs.get().shape)
a.preSpike.set(x)
a.postSpike.set(y)
a.preTrace.set(x * 0.4) ## fake out pre-syn trace values
a.postTrace.set(y * 1.3) ## fake out post-syn trace values
- ctx.adapt(t=1., dt=dt)
- dK = a.dWeights.value
- #print(dK)
- ctx.backtransmit(t=1., dt=dt)
- dx = a.dInputs.value
- #print(dx)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ dK = a.dWeights.get()
+ backtransmit_process.run(t=1., dt=dt) # ctx.backtransmit(t=1., dt=dt)
+ dx = a.dInputs.get()
dK_truth = jnp.array(
[[[[0.]],
[[-0.9]]],
@@ -109,6 +85,10 @@ def test_TraceSTDPDeconvSynapse1():
[[2.],
[1.]]]]
)
+ # print(dK)
+ # print("dK.Tr:\n", dK_truth)
+ # print(dx)
+ # print("dx.Tr:\n", dx_truth)
assert_array_equal(dK, dK_truth)
assert_array_equal(dx, dx_truth)
diff --git a/tests/components/synapses/hebbian/test_BCMSynapse.py b/tests/components/synapses/hebbian/test_BCMSynapse.py
index 7597f549..7dbff4de 100644
--- a/tests/components/synapses/hebbian/test_BCMSynapse.py
+++ b/tests/components/synapses/hebbian/test_BCMSynapse.py
@@ -2,14 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import BCMSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.hebbian.BCMSynapse import BCMSynapse
+from numpy.testing import assert_array_equal
def test_BCMSynapse1():
name = "bcm_stdp_ctx"
@@ -23,42 +20,26 @@ def test_BCMSynapse1():
name="a", shape=(1,1), tau_w=40., tau_theta=20., key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
+ evolve_process = (MethodProcess("evolve_process")
>> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- """
pre_value = jnp.ones((1, 1)) * 0.425
post_value = jnp.ones((1, 1)) * 1.55
truth = jnp.array([[-1.6798127]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.pre.set(pre_value)
a.post.set(post_value)
- ctx.run(t=1., dt=dt)
- ctx.adapt(t=1., dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
-
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
-#test_BCMSynapse1()
+test_BCMSynapse1()
diff --git a/tests/components/synapses/hebbian/test_eventSTDPSynapse.py b/tests/components/synapses/hebbian/test_eventSTDPSynapse.py
index b51c16de..a3c9a371 100644
--- a/tests/components/synapses/hebbian/test_eventSTDPSynapse.py
+++ b/tests/components/synapses/hebbian/test_eventSTDPSynapse.py
@@ -2,14 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import EventSTDPSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
+from numpy.testing import assert_array_equal
def test_eventSTDPSynapse1():
name = "event_stdp_ctx"
@@ -24,46 +21,32 @@ def test_eventSTDPSynapse1():
name="a", shape=(1,1), eta=0., presyn_win_len=2., key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- """
a.weights.set(jnp.ones((1, 1)) * 0.1)
t = 12. ## fake out current time
- ## Case 1: outside of pre-syn time window
+ ## Case 1: outside pre-syn time window
input_tols = jnp.ones((1, 1,)) * 9.
out_spike = jnp.ones((1, 1))
## check pre-synaptic STDP only
truth = jnp.array([[-0.101]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.pre_tols.set(input_tols)
a.postSpike.set(out_spike)
- ctx.run(t=t, dt=dt)
- ctx.adapt(t=t, dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt)
+ evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt)
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
## Case 2: within pre-syn time window
input_tols = jnp.ones((1, 1,)) * 11.
@@ -71,13 +54,14 @@ def test_eventSTDPSynapse1():
## check pre-synaptic STDP only
truth = jnp.array([[0.899]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.pre_tols.set(input_tols)
a.postSpike.set(out_spike)
- ctx.run(t=t, dt=dt)
- ctx.adapt(t=t, dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=t, dt=dt) # ctx.run(t=t, dt=dt)
+ evolve_process.run(t=t, dt=dt) # ctx.adapt(t=t, dt=dt)
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
#test_eventSTDPSynapse1()
diff --git a/tests/components/synapses/hebbian/test_expSTDPSynapse.py b/tests/components/synapses/hebbian/test_expSTDPSynapse.py
index 9765315d..ca18d89f 100644
--- a/tests/components/synapses/hebbian/test_expSTDPSynapse.py
+++ b/tests/components/synapses/hebbian/test_expSTDPSynapse.py
@@ -1,15 +1,13 @@
+
from jax import numpy as jnp, random, jit
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import ExpSTDPSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
+from numpy.testing import assert_array_equal
def test_expSTDPSynapse1():
name = "exp_stdp_ctx"
@@ -20,33 +18,18 @@ def test_expSTDPSynapse1():
# ---- build a simple Poisson cell system ----
with Context(name) as ctx:
a = ExpSTDPSynapse(
- name="a", shape=(1,1), A_plus=1., A_minus=1., exp_beta=1.25, key=subkeys[0]
+ name="a", shape=(1,1), A_plus=1., A_minus=1., exp_beta=1.25, eta=0., key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- """
a.weights.set(jnp.ones((1, 1)) * 0.1)
in_spike = jnp.ones((1, 1))
@@ -56,26 +39,30 @@ def test_expSTDPSynapse1():
## check pre-synaptic STDP only
truth = jnp.array([[1.1031212]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike * 0)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike)
a.postTrace.set(out_trace)
- ctx.run(t=1., dt=dt)
- ctx.adapt(t=1., dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ # print("W: ",a.weights.get())
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
truth = jnp.array([[-0.57362294]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike * 0)
a.postTrace.set(out_trace)
- ctx.run(t=1., dt=dt)
- ctx.adapt(t=1., dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ # print("W: ", a.weights.get())
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
#test_expSTDPSynapse1()
diff --git a/tests/components/synapses/hebbian/test_hebbianSynapse.py b/tests/components/synapses/hebbian/test_hebbianSynapse.py
index 35a2b191..ba5dc463 100644
--- a/tests/components/synapses/hebbian/test_hebbianSynapse.py
+++ b/tests/components/synapses/hebbian/test_hebbianSynapse.py
@@ -1,18 +1,13 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
+
import numpy as np
np.random.seed(42)
-from ngclearn.components import HebbianSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
+from ngclearn.components.synapses.hebbian.hebbianSynapse import HebbianSynapse
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from numpy.testing import assert_array_equal
+from ngclearn import Context, MethodProcess
def test_hebbianSynapse():
@@ -29,37 +24,23 @@ def test_hebbianSynapse():
with Context(name) as ctx:
a = HebbianSynapse(
- name="a",
- shape=shape,
+ name="a",
+ shape=shape,
resist_scale=resist_scale,
batch_size=batch_size,
prior = ("gaussian", 0.01)
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- evolve_process = (Process("evolve_proc") >> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
-
- # Compile and add commands
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
- # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
-
- @Context.dynamicCommand
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
+ evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
+
def clamp_inputs(x):
a.inputs.set(x)
- @Context.dynamicCommand
def clamp_pre(x):
a.pre.set(x)
- @Context.dynamicCommand
def clamp_post(x):
a.post.set(x)
@@ -70,16 +51,17 @@ def clamp_post(x):
in_pre = jnp.ones((1, 10)) * 1.0
in_post = jnp.ones((1, 5)) * 0.75
- ctx.reset()
+ reset_process.run()
clamp_pre(in_pre)
clamp_post(in_post)
- ctx.run(t=1. * dt, dt=dt)
- ctx.evolve(t=1. * dt, dt=dt)
+ advance_process.run(t=1. * dt, dt=dt)
+ evolve_process.run(t=1. * dt, dt=dt)
- print(a.weights.value)
+ #print(a.weights.get())
# Basic assertions to check learning dynamics
- assert a.weights.value.shape == (10, 5), ""
- assert a.weights.value[0, 0] == 0.5, ""
+ assert a.weights.get().shape == (10, 5), ""
+ assert a.weights.get()[0, 0] == 0.5, ""
+
+#test_hebbianSynapse()
-# test_hebbianSynapse()
\ No newline at end of file
diff --git a/tests/components/synapses/hebbian/test_traceSTDPSynapse.py b/tests/components/synapses/hebbian/test_traceSTDPSynapse.py
index 4e1e42de..a7d94d45 100644
--- a/tests/components/synapses/hebbian/test_traceSTDPSynapse.py
+++ b/tests/components/synapses/hebbian/test_traceSTDPSynapse.py
@@ -2,14 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import TraceSTDPSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
+#from ngclearn.utils.distribution_generator import DistributionGenerator as dist
+from ngclearn.components.synapses.hebbian.traceSTDPSynapse import TraceSTDPSynapse
+from numpy.testing import assert_array_equal
def test_traceSTDPSynapse1():
name = "trace_stdp_ctx"
@@ -20,33 +17,18 @@ def test_traceSTDPSynapse1():
# ---- build a simple Poisson cell system ----
with Context(name) as ctx:
a = TraceSTDPSynapse(
- name="a", shape=(1,1), A_plus=1., A_minus=1., key=subkeys[0]
+ name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0., key=subkeys[0]
)
- #"""
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- """
a.weights.set(jnp.ones((1, 1)) * 0.1)
in_spike = jnp.ones((1, 1))
@@ -56,25 +38,29 @@ def test_traceSTDPSynapse1():
## check pre-synaptic STDP only
truth = jnp.array([[1.25]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike * 0)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike)
a.postTrace.set(out_trace)
- ctx.run(t=1., dt=dt)
- ctx.adapt(t=1., dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ # print("W: ", a.weights.get())
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
truth = jnp.array([[-0.65]])
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike * 0)
a.postTrace.set(out_trace)
- ctx.run(t=1., dt=dt)
- ctx.adapt(t=1., dt=dt)
- #print(a.dWeights.value)
- assert_array_equal(a.dWeights.value, truth)
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1., dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1., dt=dt)
+ # print("W: ", a.weights.get())
+ # print(a.dWeights.get())
+ # print(truth)
+ assert_array_equal(a.dWeights.get(), truth)
#test_traceSTDPSynapse1()
diff --git a/tests/components/synapses/modulated/test_MSTDPETSynapse.py b/tests/components/synapses/modulated/test_MSTDPETSynapse.py
index e1c7ce36..8726c5be 100644
--- a/tests/components/synapses/modulated/test_MSTDPETSynapse.py
+++ b/tests/components/synapses/modulated/test_MSTDPETSynapse.py
@@ -2,15 +2,11 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import MSTDPETSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import Context, MethodProcess
+#import ngclearn.utils.weight_distribution as dist
+from ngclearn.components.synapses.modulated.MSTDPETSynapse import MSTDPETSynapse
+from numpy.testing import assert_array_equal
def test_MSTDPETSynapse1():
name = "mstdpet_ctx"
@@ -24,30 +20,14 @@ def test_MSTDPETSynapse1():
name="a", shape=(1,1), A_plus=1., A_minus=1., eta=0.1, key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ evolve_process = (MethodProcess("evolve_process")
+ >> a.evolve)
+
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- evolve_process = (Process("evolve_proc")
- >> a.evolve)
- #ctx.wrap_and_add_command(evolve_process.pure, name="run")
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
-
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
-
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- ctx.add_command(wrap_command(jit(ctx.evolve)), name="adapt")
- """
a.weights.set(jnp.ones((1, 1)) * 0.75)
@@ -59,28 +39,28 @@ def test_MSTDPETSynapse1():
r_pos = jnp.ones((1, 1))
#print(a.weights.value)
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike * 0)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike)
a.postTrace.set(out_trace)
a.modulator.set(r_pos)
- ctx.run(t=1. * dt, dt=dt)
- ctx.adapt(t=1. * dt, dt=dt)
- ctx.adapt(t=1. * dt, dt=dt)
- #print(a.weights.value)
- assert_array_equal(a.weights.value, jnp.array([[0.875]]))
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
+ #print(a.weights.get())
+ assert_array_equal(a.weights.get(), jnp.array([[0.875]]))
- ctx.reset()
+ reset_process.run() # ctx.reset()
a.preSpike.set(in_spike * 0)
a.preTrace.set(in_trace)
a.postSpike.set(out_spike)
a.postTrace.set(out_trace)
a.modulator.set(r_neg)
- ctx.run(t=1. * dt, dt=dt)
- ctx.adapt(t=1. * dt, dt=dt)
- ctx.adapt(t=1. * dt, dt=dt)
- #print(a.weights.value)
- assert_array_equal(a.weights.value, jnp.array([[0.75]]))
+ advance_process.run(t=1., dt=dt) # ctx.run(t=1. * dt, dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
+ evolve_process.run(t=1., dt=dt) # ctx.adapt(t=1. * dt, dt=dt)
+ #print(a.weights.get())
+ assert_array_equal(a.weights.get(), jnp.array([[0.75]]))
#test_MSTDPETSynapse1()
diff --git a/tests/components/synapses/modulated/test_REINFORCESynapse.py b/tests/components/synapses/modulated/test_REINFORCESynapse.py
index b81c909d..f0235789 100644
--- a/tests/components/synapses/modulated/test_REINFORCESynapse.py
+++ b/tests/components/synapses/modulated/test_REINFORCESynapse.py
@@ -2,17 +2,12 @@
import jax
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components.synapses.modulated.REINFORCESynapse import REINFORCESynapse, gaussian_logpdf
-from ngcsimlib.compilers import compile_command, wrap_command
+from ngclearn.components.synapses.modulated.REINFORCESynapse import REINFORCESynapse, _gaussian_logpdf
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
+from ngclearn import Context, MethodProcess
import jax
import jax.numpy as jnp
@@ -39,22 +34,16 @@ def test_REINFORCESynapse1():
scalar_stddev=-1.0
)
- evolve_process = (Process("evolve_proc") >> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
+ evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- @Context.dynamicCommand
def clamp_inputs(x):
a.inputs.set(x)
- @Context.dynamicCommand
def clamp_rewards(x):
assert x.ndim == 1, "Rewards must be a 1D array"
a.rewards.set(x)
- @Context.dynamicCommand
def clamp_weights(x):
a.weights.set(x)
@@ -69,7 +58,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
std = jnp.exp(logstd.clip(-10.0, 2.0))
sample = jax.random.normal(seed, mean.shape) * std + mean
sample = jnp.clip(sample, mu_out_min, mu_out_max)
- logp = gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
+ logp = _gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
return (-logp * outputs).mean() * 1e-2
grad_fn = jax.value_and_grad(fn)
@@ -80,7 +69,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
expected_gradient_list = []
- ctx.reset()
+ reset_process.run()
# Loop through 3 steps
for step in range(10):
@@ -94,12 +83,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
clamp_weights(initial_ngclearn_weights)
clamp_rewards(outputs)
clamp_inputs(inputs)
- ctx.adapt(t=1., dt=dt)
- print(f"[ngclearn] objective: {a.objective.value}")
- print(f"[ngclearn] weights: {a.weights.value}")
- print(f"[ngclearn] dWeights: {a.dWeights.value}")
- print(f"[ngclearn] step_count: {a.step_count.value}")
- print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
+ evolve_process.run(t=1., dt=dt)
+ print(f"[ngclearn] objective: {a.objective.get()}")
+ print(f"[ngclearn] weights: {a.weights.get()}")
+ print(f"[ngclearn] dWeights: {a.dWeights.get()}")
+ print(f"[ngclearn] step_count: {a.step_count.get()}")
+ print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.get()}")
# -------- Expectation ---------
print("--------------")
expected_objective, expected_grads = grad_fn(
@@ -116,12 +105,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
print(f"[Expectation] dWeights: {expected_grads}")
print(f"[Expectation] objective: {expected_objective}")
np.testing.assert_allclose(
- a.dWeights.value[0],
+ a.dWeights.get()[0],
expected_grads,
atol=1e-8
)
np.testing.assert_allclose(
- a.objective.value,
+ a.objective.get(),
expected_objective,
atol=1e-8
)
@@ -131,7 +120,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
np.testing.assert_allclose(
- a.accumulated_gradients.value[0],
+ a.accumulated_gradients.get()[0],
expected_accumulated_gradients,
atol=1e-9
)
@@ -163,22 +152,16 @@ def test_REINFORCESynapse2():
scalar_stddev=scalar_stddev
)
- evolve_process = (Process("evolve_proc") >> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="adapt")
-
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
+ evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- @Context.dynamicCommand
def clamp_inputs(x):
a.inputs.set(x)
- @Context.dynamicCommand
def clamp_rewards(x):
assert x.ndim == 1, "Rewards must be a 1D array"
a.rewards.set(x)
- @Context.dynamicCommand
def clamp_weights(x):
a.weights.set(x)
@@ -194,7 +177,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
std = scalar_stddev
sample = jax.random.normal(seed, mean.shape) * std + mean
sample = jnp.clip(sample, mu_out_min, mu_out_max)
- logp = gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
+ logp = _gaussian_logpdf(jax.lax.stop_gradient(sample), mean, std).sum(-1)
return (-logp * outputs).mean() * 1e-2
grad_fn = jax.value_and_grad(fn)
@@ -205,7 +188,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
expected_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)
initial_ngclearn_weights = jnp.concatenate([expected_weights_mu, expected_weights_logstd], axis=-1)[None]
expected_gradient_list = []
- ctx.reset()
+ reset_process.run()
# Loop through 3 steps
for step in range(10):
@@ -219,12 +202,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
clamp_weights(initial_ngclearn_weights)
clamp_rewards(outputs)
clamp_inputs(inputs)
- ctx.adapt(t=1., dt=dt)
- print(f"[ngclearn] objective: {a.objective.value}")
- print(f"[ngclearn] weights: {a.weights.value}")
- print(f"[ngclearn] dWeights: {a.dWeights.value}")
- print(f"[ngclearn] step_count: {a.step_count.value}")
- print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.value}")
+ evolve_process.run(t=1., dt=dt)
+ print(f"[ngclearn] objective: {a.objective.get()}")
+ print(f"[ngclearn] weights: {a.weights.get()}")
+ print(f"[ngclearn] dWeights: {a.dWeights.get()}")
+ print(f"[ngclearn] step_count: {a.step_count.get()}")
+ print(f"[ngclearn] accumulated_gradients: {a.accumulated_gradients.get()}")
# -------- Expectation ---------
print("--------------")
expected_objective, expected_grads = grad_fn(
@@ -241,12 +224,12 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
print(f"[Expectation] dWeights: {expected_grads}")
print(f"[Expectation] objective: {expected_objective}")
np.testing.assert_allclose(
- a.dWeights.value[0],
+ a.dWeights.get()[0],
expected_grads,
atol=1e-8
)
np.testing.assert_allclose(
- a.objective.value,
+ a.objective.get(),
expected_objective,
atol=1e-8
)
@@ -256,7 +239,7 @@ def fn(params: dict, inputs: jax.Array, outputs: jax.Array, seed: jax.Array):
decay_list = jnp.asarray([decay**i for i in range(len(expected_gradient_list))])[::-1]
expected_accumulated_gradients = jnp.mean(jnp.stack(expected_gradient_list, 0) * decay_list[:, None, None], axis=0)
np.testing.assert_allclose(
- a.accumulated_gradients.value[0],
+ a.accumulated_gradients.get()[0],
expected_accumulated_gradients,
atol=1e-9
)
diff --git a/tests/components/synapses/patched/test_hebbianPatchedSynapse.py b/tests/components/synapses/patched/test_hebbianPatchedSynapse.py
index d0997c82..4a33ae6f 100644
--- a/tests/components/synapses/patched/test_hebbianPatchedSynapse.py
+++ b/tests/components/synapses/patched/test_hebbianPatchedSynapse.py
@@ -1,19 +1,13 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
from ngclearn.components import HebbianPatchedSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
-
+from ngclearn import MethodProcess, Context
def test_hebbianPatchedSynapse():
np.random.seed(42)
@@ -31,58 +25,45 @@ def test_hebbianPatchedSynapse():
with Context(name) as ctx:
a = HebbianPatchedSynapse(
- name="a",
- shape=shape,
- n_sub_models=n_sub_models,
+ name="a",
+ shape=shape,
+ n_sub_models=n_sub_models,
stride_shape=stride_shape,
resist_scale=resist_scale,
batch_size=batch_size
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- evolve_process = (Process("evolve_proc") >> a.evolve)
- ctx.wrap_and_add_command(jit(evolve_process.pure), name="evolve")
-
- # Compile and add commands
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
- # evolve_cmd, evolve_args = ctx.compile_by_key(a, compile_key="evolve")
- # ctx.add_command(wrap_command(jit(evolve_cmd)), name="evolve")
-
- @Context.dynamicCommand
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
+ evolve_process = (MethodProcess("evolve_proc") >> a.evolve)
+
def clamp_inputs(x):
a.inputs.set(x)
- @Context.dynamicCommand
def clamp_pre(x):
a.pre.set(x)
- @Context.dynamicCommand
def clamp_post(x):
a.post.set(x)
- a.weights.set(jnp.ones((12, 12)) * 0.5)
+ a.weights.set(jnp.ones((12, 12)) * 0.5)
in_pre = jnp.ones((10, 12)) * 1.0
in_post = jnp.ones((10, 12)) * 0.75
- ctx.reset()
+ reset_process.run()
clamp_pre(in_pre)
clamp_post(in_post)
- ctx.run(t=1. * dt, dt=dt)
- ctx.evolve(t=1. * dt, dt=dt)
+ advance_process.run(t=1. * dt, dt=dt)
+ evolve_process.run(t=1. * dt, dt=dt)
- print(a.weights.value)
+ print(a.weights.get())
# Basic assertions to check learning dynamics
- assert a.weights.value.shape == (12, 12), ""
- assert a.weights.value[0, 0] == 0.5, ""
+ assert a.weights.get().shape == (12, 12), ""
+ assert a.weights.get()[0, 0] == 0.5, ""
+
+test_hebbianPatchedSynapse()
-# test_hebbianPatchedSynapse()
\ No newline at end of file
diff --git a/tests/components/synapses/patched/test_patchedSynapse.py b/tests/components/synapses/patched/test_patchedSynapse.py
index 8dd99d06..e9b96d80 100644
--- a/tests/components/synapses/patched/test_patchedSynapse.py
+++ b/tests/components/synapses/patched/test_patchedSynapse.py
@@ -1,18 +1,13 @@
# %%
from jax import numpy as jnp, random, jit
-from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
+from ngclearn.utils.distribution_generator import DistributionGenerator as dist
from ngclearn.components import PatchedSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-from ngcsimlib.utils.compartment import Get_Compartment_Batch
+from ngclearn import MethodProcess, Context
+
def test_patchedSynapse():
@@ -35,35 +30,29 @@ def test_patchedSynapse():
stride_shape=stride_shape,
resist_scale=resist_scale,
batch_size=batch_size,
- weight_init={"dist": "gaussian", "std": 0.1},
- bias_init={"dist": "constant", "value": 0.0}
+ weight_init=dist.gaussian(std=0.1), #{"dist": "gaussian", "std": 0.1},
+ bias_init=dist.constant(value=0.) #{"dist": "constant", "value": 0.0}
)
- advance_process = (Process("advance_proc") >> a.advance_state)
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc") >> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
-
- # Compile and add commands
- # reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- # ctx.add_command(wrap_command(jit(reset_cmd)), name="reset")
- # advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- # ctx.add_command(wrap_command(jit(advance_cmd)), name="run")
+ advance_process = (MethodProcess("advance_proc") >> a.advance_state)
+ reset_process = (MethodProcess("reset_proc") >> a.reset)
- @Context.dynamicCommand
def clamp_inputs(x):
a.inputs.set(x)
inputs_seq = jnp.asarray(np.random.randn(1, 12))
- weights = a.weights.value
- biases = a.biases.value
+ weights = a.weights.get()
+ biases = a.biases.get()
expected_outputs = (jnp.matmul(inputs_seq, weights) * resist_scale) + biases
outputs_outs = []
- ctx.reset()
- ctx.clamp_inputs(inputs_seq)
- ctx.run(t=0., dt=dt)
- outputs_outs.append(a.outputs.value)
+ reset_process.run()
+ clamp_inputs(inputs_seq)
+ advance_process.run(t=0., dt=dt)
+ outputs_outs.append(a.outputs.get())
outputs_outs = jnp.concatenate(outputs_outs, axis=1)
# Verify outputs match expected values
np.testing.assert_allclose(outputs_outs, expected_outputs, atol=1e-5)
+
+test_patchedSynapse()
+
diff --git a/tests/components/synapses/test_STPDenseSynapse.py b/tests/components/synapses/test_STPDenseSynapse.py
index 78ac2e12..32607959 100644
--- a/tests/components/synapses/test_STPDenseSynapse.py
+++ b/tests/components/synapses/test_STPDenseSynapse.py
@@ -2,15 +2,10 @@
from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import STPDenseSynapse
-from ngcsimlib.compilers import compile_command, wrap_command
-from numpy.testing import assert_array_equal
-from ngcsimlib.compilers.process import Process, transition
-from ngcsimlib.component import Component
-from ngcsimlib.compartment import Compartment
-from ngcsimlib.context import Context
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator
+from ngclearn.components.synapses.STPDenseSynapse import STPDenseSynapse
def test_STPDenseSynapse1():
name = "stp_ctx"
@@ -21,26 +16,15 @@ def test_STPDenseSynapse1():
# ---- build a simple Poisson cell system ----
with Context(name) as ctx:
a = STPDenseSynapse(
- name="a", shape=(1,1), resources_init=dist.constant(value=1.),key=subkeys[0]
+ name="a", shape=(1,1), resources_init=DistributionGenerator.constant(value=1.),key=subkeys[0]
)
- #"""
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
- #"""
- """
- reset_cmd, reset_args = ctx.compile_by_key(a, compile_key="reset")
- ctx.add_command(wrap_command(jit(ctx.reset)), name="reset")
- advance_cmd, advance_args = ctx.compile_by_key(a, compile_key="advance_state")
- ctx.add_command(wrap_command(jit(ctx.advance_state)), name="run")
- """
a.weights.set(jnp.ones((1, 1)))
in_pulse = jnp.ones((1, 1)) * 0.425
@@ -49,16 +33,19 @@ def test_STPDenseSynapse1():
outs = []
Wdyn = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for t in range(3):
a.inputs.set(in_pulse)
- ctx.run(t=t * dt, dt=dt)
- outs.append(a.outputs.value)
- Wdyn.append(a.Wdyn.value)
+ advance_process.run(t=t * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ outs.append(a.outputs.get())
+ Wdyn.append(a.Wdyn.get())
outs = jnp.concatenate(outs, axis=1)
Wdyn = jnp.concatenate(Wdyn, axis=1)
# print(outs)
+ # print(outs_truth)
+ # print("...")
# print(Wdyn)
+ # print(Wdyn_truth)
np.testing.assert_allclose(outs, outs_truth, atol=1e-8)
np.testing.assert_allclose(Wdyn, Wdyn_truth, atol=1e-8)
diff --git a/tests/components/synapses/test_exponentialSynapse.py b/tests/components/synapses/test_exponentialSynapse.py
index 83ad19ee..9217932f 100644
--- a/tests/components/synapses/test_exponentialSynapse.py
+++ b/tests/components/synapses/test_exponentialSynapse.py
@@ -1,11 +1,11 @@
from jax import numpy as jnp, random, jit
+from ngcsimlib.context import Context
import numpy as np
np.random.seed(42)
-from ngclearn.components import ExponentialSynapse
-from ngcsimlib.compilers.process import Process
-from ngcsimlib.context import Context
-import ngclearn.utils.weight_distribution as dist
+from ngclearn import Context, MethodProcess
+from ngclearn.utils.distribution_generator import DistributionGenerator
+from ngclearn.components.synapses.exponentialSynapse import ExponentialSynapse
def test_exponentialSynapse1():
name = "expsyn_ctx"
@@ -19,18 +19,15 @@ def test_exponentialSynapse1():
# ---- build a single exp-synapse system ----
with Context(name) as ctx:
a = ExponentialSynapse(
- name="a", shape=(1,1), tau_decay=tau_syn, g_syn_bar=2.4, syn_rest=E_rest, weight_init=dist.constant(value=1.),
- key=subkeys[0]
+ name="a", shape=(1,1), tau_decay=tau_syn, g_syn_bar=2.4, syn_rest=E_rest,
+ weight_init=DistributionGenerator.constant(value=1.), key=subkeys[0]
)
- advance_process = (Process("advance_proc")
+ advance_process = (MethodProcess("advance_proc")
>> a.advance_state)
- # ctx.wrap_and_add_command(advance_process.pure, name="run")
- ctx.wrap_and_add_command(jit(advance_process.pure), name="run")
- reset_process = (Process("reset_proc")
+ reset_process = (MethodProcess("reset_proc")
>> a.reset)
- ctx.wrap_and_add_command(jit(reset_process.pure), name="reset")
sp_train = jnp.array([1., 0., 1.], dtype=jnp.float32)
post_syn_neuron_volt = jnp.ones((1, 1)) * -65. ## post-syn neuron is at rest
@@ -38,15 +35,16 @@ def test_exponentialSynapse1():
outs_truth = jnp.array([[156., 78., 195.]])
outs = []
- ctx.reset()
+ reset_process.run() # ctx.reset()
for t in range(3):
in_pulse = jnp.expand_dims(sp_train[t], axis=0)
a.inputs.set(in_pulse)
a.v.set(post_syn_neuron_volt)
- ctx.run(t=t * dt, dt=dt)
- #print("g: ",a.g_syn.value)
- #print("i: ", a.i_syn.value)
- outs.append(a.outputs.value)
+ advance_process.run(t=t * 1., dt=dt) # ctx.run(t=ts * 1., dt=dt)
+ # print("in: ", a.inputs.get())
+ # print("g: ",a.g_syn.get())
+ # print("i: ", a.i_syn.get())
+ outs.append(a.outputs.get())
outs = jnp.concatenate(outs, axis=1)
#print(outs)