diff --git a/docs/images/tutorials/lava/raster_O.png b/docs/images/tutorials/lava/raster_O.png
new file mode 100644
index 00000000..9b102d01
Binary files /dev/null and b/docs/images/tutorials/lava/raster_O.png differ
diff --git a/docs/images/tutorials/lava/raster_T.png b/docs/images/tutorials/lava/raster_T.png
new file mode 100644
index 00000000..dd45e4dc
Binary files /dev/null and b/docs/images/tutorials/lava/raster_T.png differ
diff --git a/docs/images/tutorials/lava/raster_X.png b/docs/images/tutorials/lava/raster_X.png
new file mode 100644
index 00000000..e4fd9e1d
Binary files /dev/null and b/docs/images/tutorials/lava/raster_X.png differ
diff --git a/docs/tutorials/foundations/monitors.md b/docs/tutorials/foundations/monitors.md
new file mode 100644
index 00000000..bc57849d
--- /dev/null
+++ b/docs/tutorials/foundations/monitors.md
@@ -0,0 +1,72 @@
+# NGC Monitors
+
+Ngc-monitors are a way of storing a rolling window of compartment values
+automatically. Their intended purpose is not to be used inside of a model but
+just as an auxiliary way to view the internal state of the module even when it is
+compiled. A monitor will track the last `n` values it has observed within the
+compartment with the oldest value being at `index=0` and the newest being at
+`index=n-1`.
+
+## Building a Monitor
+
+Monitors are constructed exactly like regular components are for general models.
+Simply import the monitor `from ngclearn.components import Monitor`. Now,
+inside of your model, build it like a regular component.
+
+```python
+with Context("model") as model:
+ M = Monitor("M", default_window_length=100)
+```
+
+## Watching compartments
+
+There are then two key ways of watching compartments, the first way looks similar
+to the wiring paradigm found in connecting standard ngclearn components
+together. The primary difference is that connecting compartments to the monitor
+does not require a compartment, they are wired directly into the `Monitor`
+following the pattern below:
+
+```python
+ M << z0.s
+```
+
+This will wire the spike output of `z0` into the monitor with a view window
+length of the `default_window_length`. In the event that you want a view window
+that is not the default viewing length, you can use the `watch()` method
+instead as in below:
+
+```python
+ M.watch(z0.s, customWindowLength)
+```
+
+There is no limit to the number of compartments that a monitor can watch or the
+length of the window that it can store. However, as it is constantly shifting
+values, tracking large matrices, such as those containing synapse values
+over many timesteps, may get expensive.
+
+For the monitor to run during your `advance_state` and `reset` calls, make sure
+to add to the monitor to the list of components to compile. Currently,
+monitors do not work with non-compiled methods
+(This is a planned feature for future developments of ngc-learn).
+
+## Extracting Values
+
+To look at the currently stored window of any compartment being tracked, there
+are two methods available to you. The first method requires that you have
+access to the compartment that the monitor is watching. To read out the
+monitors values, you can call:
+
+```python
+M.view(z0.s)
+```
+
+In the event that you do not have access to the compartment, all of the stored
+values can be found via the path using the following:
+
+```python
+M.get_store("path/to/compartment")
+```
+
+The stored windows are kept in a tree of dictionaries, where each node is a part
+of the path and the leaves are compartment objects holding the
+tracked value windows.
\ No newline at end of file
diff --git a/docs/tutorials/lava/hebbian_learning.md b/docs/tutorials/lava/hebbian_learning.md
index ed576427..c45a01b6 100644
--- a/docs/tutorials/lava/hebbian_learning.md
+++ b/docs/tutorials/lava/hebbian_learning.md
@@ -1,17 +1,19 @@
# Training a Spiking Network On-chip
-In this tutorial we will build generate a simple dataset (consisting of binary
+In this tutorial we will build generate a simple dataset (consisting of binary
patterns of X's, O's, and T's) and train a model in the Loihi simulator
-using Hebbian learning in the form of trace-based spike-timing-dependent
+using Hebbian learning in the form of trace-based spike-timing-dependent
plasticity.
## Setting up ngc-learn
-The first step of this project consist of setting up the configuration file for
-ngc-learn. Create a folder in your project directory root called `json_files` and
+The first step of this project consist of setting up the configuration file for
+ngc-learn. Create a folder in your project directory root called `json_files`
+and
then create a `config.json` configuration inside of that folder.
-Now for this project we will not be loading anything dynamically, so we can simply add:
+Now for this project we will not be loading anything dynamically, so we can
+simply add:
```json
{
@@ -21,12 +23,12 @@ Now for this project we will not be loading anything dynamically, so we can simp
}
```
-The above configuration will skip the dynamic loading of modules, which is
+The above configuration will skip the dynamic loading of modules, which is
important for Lava-based model transference and simulation.
-Next, in order to run code with the Loihi simulator, the base version of numpy
-needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts
-to by default). To change all of ngc-learn over to using the base version of
+Next, in order to run code with the Loihi simulator, the base version of numpy
+needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts
+to by default). To change all of ngc-learn over to using the base version of
numpy, simply add the following to your configuration:
```json
@@ -35,21 +37,22 @@ numpy, simply add the following to your configuration:
}
```
-Now your project is configured for ngc-lava and Lava usage and we can move on
+Now your project is configured for ngc-lava and Lava usage and we can move on
to data generation.
## Generating Data
-For this project we will be using three different patterns to train a simple
-biophysical spiking neural network; the data will simply consist of binary
-image patterns of either an `X`, `O`, and a `T`. To create the file needed to
-generate these patterns, create a Python script named `data_generator.py` in
+For this project we will be using three different patterns to train a simple
+biophysical spiking neural network; the data will simply consist of binary
+image patterns of either an `X`, `O`, and a `T`. To create the file needed to
+generate these patterns, create a Python script named `data_generator.py` in
your project root. Next, we will import `numpy` and `random` and
define the following three generator methods:
```python
from ngclearn import numpy as np
+
def make_X(size):
X = np.zeros((size, size))
for i in range(0, size):
@@ -57,6 +60,7 @@ def make_X(size):
X[i, size - 1 - i] = np.random.uniform(0.75, 1)
return X
+
def make_O(size):
O = np.zeros((size, size))
for i in range(0, (size // 2) - 1):
@@ -66,11 +70,13 @@ def make_O(size):
O[(size // 2) + i, size - 2 - i] = np.random.uniform(0.75, 1)
return O
+
def make_T(size):
T = np.zeros((size, size))
T[1, 1:size - 1] = np.random.uniform(0.75, 1, (1, size - 2))
for i in range(2, size - 1):
- T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1, (1, 2))
+ T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1,
+ (1, 2))
return T
```
@@ -78,56 +84,63 @@ Each of these methods will create a pattern of the desired size and shape.
## Building the Model
-Found below is all of the imports that will be needed to run the model we desire
+Found below is all of the imports that will be needed to run the model we desire
in Lava:
```python
from ngclava import LavaContext
from ngclearn import numpy as np
-from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse
+from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse, Monitor
+import ngclearn.utils.viz as viz_utils
import ngclearn.utils.weight_distribution as dist
-from ngclearn.utils.viz.synapse_plot import visualize
from data_generator import make_X, make_O, make_T
```
-To start off building this model, we will define all of the hyperparameters
+To start off building this model, we will define all of the hyperparameters
needed to create the necessary model components:
```python
-#Training Params
+# Training Params
epochs = 35
view_length = 200
rest_length = 1000
-#Model Params
+# Model Params
n_in = 64 # Input layer size
n_hid = 25 # Hidden layer size
dt = 1. # ms # integration time constant
-np.random.seed(42) ## seed the internal numpy calls
+np.random.seed(42) ## seed the internal numpy calls
```
-After this we will create the lava context, the components, as well as the wiring:
+After this we will create the lava context, the components, as well as the
+wiring:
```python
with LavaContext("Model") as model:
- z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt,
- tau_m=50., v_decay=0., tau_theta=500., refract_T=0.) ## IF cell
- z1e = LIFCell("z1e", n_units=n_hid, thr_theta_init=dist.uniform(amin=-2, amax=2.),
- dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell
- z1i = LIFCell("z1i", n_units=n_hid, thr_theta_init=dist.uniform(amin=-2, amax=2.),
- dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45.,
- theta_plus=0.) ## inhibitory LIF cell
+ z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt,
+ tau_m=50., v_decay=0., tau_theta=500.,
+ refract_T=0.) ## IF cell
+ z1e = LIFCell("z1e", n_units=n_hid,
+ thr_theta_init=dist.uniform(amin=-2, amax=2.),
+ dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell
+ z1i = LIFCell("z1i", n_units=n_hid,
+ thr_theta_init=dist.uniform(amin=-2, amax=2.),
+ dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45.,
+ theta_plus=0.) ## inhibitory LIF cell
tr0 = GatedTrace("tr0", n_units=n_in, dt=dt, tau_tr=20.)
tr1 = GatedTrace("tr1", n_units=n_hid, dt=dt, tau_tr=20.)
- W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3),
- shape=(n_in, n_hid), dt=dt, Aplus=0.011, Aminus=0.0011,
+ W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3),
+ shape=(n_in, n_hid), dt=dt, Aplus=0.011,
+ Aminus=0.0011,
preTrace_target=0.055)
- W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.),
- shape=(n_hid, n_hid),dt=dt)
- W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5),
+ W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.),
shape=(n_hid, n_hid), dt=dt)
+ W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5),
+ shape=(n_hid, n_hid), dt=dt)
+
+ M = Monitor("M", default_window_length=view_length)
## wire z0 to z1e via W1 and z1i to z1e via W1ie
W1.inputs << z0.s
@@ -149,11 +162,15 @@ with LavaContext("Model") as model:
W1.pre << z0.s
W1.x_post << tr1.trace
W1.post << z1e.s
+
+ # set up monitoring of z1e's spike output
+ M << z1e.s
```
-After the components have been set up, we have to "lag out" the synapses that
-will cause recurrent (locking) problems when running on the Loihi2. This will
-cause each of these synapses to run one time-step behind and fixes many recurrency
+After the components have been set up, we have to "lag out" the synapses that
+will cause recurrent (locking) problems when running on the Loihi2. This will
+cause each of these synapses to run one time-step behind and fixes many
+recurrency
issues (as described [here](lava_context.md)).
```python
@@ -162,41 +179,45 @@ issues (as described [here](lava_context.md)).
model.set_lag('W1ei')
```
-Now that the model is all set up, we have to tell the Lava compiler to actually
+Now that the model is all set up, we have to tell the Lava compiler to actually
build all the Lava objects with the following:
```python
model.rebuild_lava()
```
-This line will stop the automatic build of components when leaving this
-with-block and provides access to all of the Lava components inside of this
+This line will stop the automatic build of components when leaving this
+with-block and provides access to all of the Lava components inside of this
with-block.
-Next, we set up two methods, a `clamp` method to set the input data and
+Next, we set up two methods, a `clamp` method to set the input data and
`viz` to visualize all of the different receptive fields of our model:
```python
lz0, lW1 = model.get_lava_components('z0', 'W1')
+
@model.dynamicCommand
def clamp(x):
model.pause()
lz0.j_exc.set(x)
+
@model.dynamicCommand
def viz():
- visualize([lW1.weights.get()], [(8, 8)], "lava_fields")
+ viz_utils.synapse_plot.visualize([lW1.weights.get()], [(8, 8)], "lava_fields")
```
-Now that everything is set up to build the runtime and start training the model
+## Running The Model
+
+Now that everything is set up to build the runtime and start training the model
inside of the Loihi simulator. To set up the runtime we call the following:
```python
-model.set_up_runtime("z0", rest_image=np.zeros((1, 64)))
+ model.set_up_runtime("z0", rest_image=np.zeros((1, 64)))
```
-This will set up a runtime with `z0` as the root node and also uses a resting
+This will set up a runtime with `z0` as the root node and also uses a resting
image of all zeros to allow the system to return to its resting state.
Now the training loop will be as follows:
@@ -217,13 +238,75 @@ with model.runtime:
model.view(T, view_length)
model.rest(rest_length)
-
print("\nDone Training")
+
+```
+
+## Evaluating the On-Chip Trained Model
+
+The code above will work to train the model on a Loihi neuromorphic chip, but,
+currently, we do not have a way of viewing how effective the model learned
+really is. To set up this evaluation, we can call
+the `viz` method defined above to view the receptive fields that our spiking
+model has acquired:
+
+```python
model.viz()
- model.save_to_json(".", model_name="trained")
```
-Running this should produce a set of receptive fields that look like the
+Running this should produce a set of receptive fields that look like the
following:
+
+While viewing the receptive fields qualitatively tells us that our spiking
+model has trained, we may also want to view the
+[raster plots](ngclearn.utils.viz.raster) -- visual depictions of the
+underlying spike patterns acquired in the hidden layer of our model -- for each
+of our three image patterns (as they are fed into our trained model). To do
+this, we will make use of the monitor we defined above in the following manner:
+
+```python
+ ## Turning off learning
+ lW1.eta.set(np.array([0]))
+
+ model.view(np.reshape(make_T(8), (1, 64)), view_length)
+ model.write_to_ngc()
+ spikes = M.view(z1e.s)
+ viz_utils.raster.create_raster_plot(spikes, tag="T", plot_fname="raster_T")
+ model.rest(rest_length)
+ print("Done T")
+
+ model.view(np.reshape(make_X(8), (1, 64)), view_length)
+ model.write_to_ngc()
+ spikes = M.view(z1e.s)
+ viz_utils.raster.create_raster_plot(spikes, tag="X", plot_fname="raster_X")
+ model.rest(rest_length)
+ print("Done X")
+
+ model.view(np.reshape(make_O(8), (1, 64)), view_length)
+ model.write_to_ngc()
+ spikes = M.view(z1e.s)
+ viz_utils.raster.create_raster_plot(spikes, tag="O", plot_fname="raster_O")
+ model.rest(rest_length)
+ print("Done O")
+```
+
+The above should result in raster plots where the spikes correspond to the
+receptive fields of each trained letter pattern. Specifically, you should see
+that the top left field is `N0` and the bottom right is `N24`. Your raster plots
+should look like the ones below:
+
+
+
+
+
+Finally to save the model to disk, you can call the following:
+
+```python
+ model.save_to_json(".", model_name="trained")
+```
+
+which will save your on-chip trained Loihi model to disk for later use.
+
+
diff --git a/docs/tutorials/lava/monitors.md b/docs/tutorials/lava/monitors.md
new file mode 100644
index 00000000..b5d82243
--- /dev/null
+++ b/docs/tutorials/lava/monitors.md
@@ -0,0 +1,17 @@
+# Monitors
+
+While lava does have its own version of monitors, ngclearn offers an
+in-built version for convenience. It is
+recommended that you use the ngclearn monitors as they have expanded
+functionality and are designed to interact with the Lava components well. For an
+overview of/details on how monitors work please see
+[this](../foundations/monitors.md). The
+only difference is that Lava has its own monitor found
+in `ngclearn.components.lava`.
+
+## Sharp Edges and Bits
+
+- Due to the fact that a Lava component of the monitor must be built, it has to
+ be defined inside the `LavaContext`.
+- To view the values found within the monitor via the `view()` and `get_path()`
+ methods, `model.write_to_ngc()` must be called to refresh the values.
\ No newline at end of file
diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py
index aabfbdee..d1bcc40e 100644
--- a/ngclearn/components/__init__.py
+++ b/ngclearn/components/__init__.py
@@ -26,3 +26,5 @@
from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse
from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse
from .synapses.hebbian.BCMSynapse import BCMSynapse
+
+from .monitor import Monitor
diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py
new file mode 100644
index 00000000..79b2c535
--- /dev/null
+++ b/ngclearn/components/base_monitor.py
@@ -0,0 +1,216 @@
+import json
+
+from ngclearn import Component, Compartment
+from ngclearn import numpy as np
+from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path
+from ngcsimlib.logger import warn, critical
+
+
+class Base_Monitor(Component):
+ """
+ An abstract base for monitors for both ngclearn and ngclava. Compartments wired directly into this component will
+ have their value tracked during `advance_state` loops automatically.
+
+ Note the monitor only works for compiled methods currently
+
+
+ Using default window length:
+ myMonitor << myComponent.myCompartment
+
+ Using custom window length:
+ myMonitor.watch(myComponent.myCompartment, customWindowLength)
+
+ To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All
+ paths are the same as their local path variable.
+
+ Using a compartment:
+ myMonitor.view(myComponent.myCompartment)
+
+ Using a path:
+ myMonitor.get_store(myComponent.myCompartment.path).value
+
+ There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers
+ for ngclearn.
+
+ Args:
+ name: The name of the component.
+
+ default_window_length: The default window length.
+ """
+
+ _singleton = None # Only one Monitor
+
+ @staticmethod
+ def build_advance(compartments):
+ """
+ A method to build the method to advance the stored values.
+
+ Args:
+ compartments: A list of compartments to store values
+
+ Returns: The method to advance the stored values.
+
+ """
+ critical(
+ "build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or "
+ "ngclearn.components.lava (If using lava)")
+
+
+ @staticmethod
+ def build_reset(compartments):
+ """
+ A method to build the method to reset the stored values.
+ Args:
+ compartments: A list of compartments to reset
+
+ Returns: The method to reset the stored values.
+ """
+ @staticmethod
+ def _reset(**kwargs):
+ return_vals = []
+ for _, comp in compartments:
+ current_store = kwargs[comp + "*store"]
+ return_vals.append(np.zeros(current_store.shape))
+ return return_vals if len(compartments) > 1 else return_vals[0]
+
+ return _reset
+
+ def __init__(self, name, default_window_length=100, **kwargs):
+ if Base_Monitor._singleton is not None:
+ critical("Only one monitor can be built")
+ else:
+ Base_Monitor._singleton = True
+ super().__init__(name, **kwargs)
+ self.store = {}
+ self.compartments = []
+ self.default_window_length = default_window_length
+
+ def __lshift__(self, other):
+ if isinstance(other, Compartment):
+ self.watch(other, self.default_window_length)
+ else:
+ warn("Only Compartments can be monitored not", type(other))
+
+ def watch(self, compartment, window_length):
+ """
+ Sets the monitor to watch a specific compartment, for a specified window length.
+
+ Args:
+ compartment: the compartment object to monitor
+
+ window_length: the window length
+ """
+ cs, end = self._add_path(compartment.path)
+
+ shape = compartment.value.shape
+ new_comp = Compartment(np.zeros(shape))
+ new_comp_store = Compartment(np.zeros((window_length, *shape)))
+
+ comp_key = "*".join(compartment.path.split("/"))
+ store_comp_key = comp_key + "*store"
+
+ new_comp._setup(self, comp_key)
+ new_comp_store._setup(self, store_comp_key)
+
+ new_comp << compartment
+
+ cs[end] = new_comp_store
+ setattr(self, comp_key, new_comp)
+ setattr(self, store_comp_key, new_comp_store)
+ self.compartments.append(new_comp.path)
+
+ self._update_resolver()
+
+ def _update_resolver(self):
+ output_compartments = []
+ compartments = []
+ for comp in self.compartments:
+ output_compartments.append(comp.split("/")[-1] + "*store")
+ compartments.append((0, comp.split("/")[-1]))
+
+ args = []
+ parameters = []
+
+ add_component_resolver(self.__class__.__name__, "advance_state",
+ (self.build_advance(compartments), output_compartments))
+ add_resolver_meta(self.__class__.__name__, "advance_state",
+ (args, parameters, compartments + [(0, o) for o in output_compartments], False))
+
+ add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments))
+ add_resolver_meta(self.__class__.__name__, "reset",
+ (args, parameters, [(0, o) for o in output_compartments], False))
+
+ def _add_path(self, path):
+ _path = path.split("/")[1:]
+ end = _path.pop(-1)
+
+ current_store = self.store
+ for p in _path:
+ if p not in current_store.keys():
+ current_store[p] = {}
+ current_store = current_store[p]
+
+ return current_store, end
+
+ def view(self, compartment):
+ """
+ Gets the value associated with the specified compartment
+
+ Args:
+ compartment: The compartment to extract the stored value of
+
+ Returns: The stored value, None if not monitoring that compartment
+
+ """
+ _path = compartment.path.split("/")[1:]
+ store = self.get_store(_path)
+ return store.value if store is not None else store
+
+ def get_store(self, path):
+ current_store = self.store
+ for p in path:
+ if p not in current_store.keys():
+ return None
+ current_store = current_store[p]
+ return current_store
+
+ def save(self, directory, **kwargs):
+ file_name = directory + "/" + self.name + ".json"
+ _dict = {"sources": {}, "stores": {}}
+ for key in self.compartments:
+ n = key.split("/")[-1]
+ _dict["sources"][key] = self.__dict__[n].value.shape
+ _dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape
+
+ with open(file_name, "w") as f:
+ json.dump(_dict, f)
+
+ def load(self, directory, **kwargs):
+ file_name = directory + "/" + self.name + ".json"
+ with open(file_name, "r") as f:
+ vals = json.load(f)
+
+ for comp_path, shape in vals["stores"].items():
+
+ compartment_path = comp_path.split("/")[-1]
+ new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1])
+
+ cs, end = self._add_path(new_path)
+
+ new_comp = Compartment(np.zeros(shape))
+ new_comp._setup(self, compartment_path)
+
+ cs[end] = new_comp
+ setattr(self, compartment_path, new_comp)
+
+
+
+ for comp_path, shape in vals['sources'].items():
+ compartment_path = comp_path.split("/")[-1]
+ new_comp = Compartment(np.zeros(shape))
+ new_comp._setup(self, compartment_path)
+
+ setattr(self, compartment_path, new_comp)
+ self.compartments.append(new_comp.path)
+
+ self._update_resolver()
diff --git a/ngclearn/components/lava/__init__.py b/ngclearn/components/lava/__init__.py
index d70f8f67..962f843a 100644
--- a/ngclearn/components/lava/__init__.py
+++ b/ngclearn/components/lava/__init__.py
@@ -6,3 +6,6 @@
from .synapses.hebbianSynapse import HebbianSynapse
## Lava-compliant encoders/traces
from .traces.gatedTrace import GatedTrace
+
+#monitor
+from .monitor import Monitor
\ No newline at end of file
diff --git a/ngclearn/components/lava/monitor.py b/ngclearn/components/lava/monitor.py
new file mode 100644
index 00000000..6fcb8d9a
--- /dev/null
+++ b/ngclearn/components/lava/monitor.py
@@ -0,0 +1,22 @@
+from ngclearn.components.base_monitor import Base_Monitor
+
+
+class Monitor(Base_Monitor):
+ """
+ A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components
+ """
+
+ @staticmethod
+ def build_advance(compartments):
+ @staticmethod
+ def _advance(**kwargs):
+ return_vals = []
+ for _, comp in compartments:
+ new_val = kwargs[comp]
+ current_store = kwargs[comp + "*store"]
+ current_store[:-1] = current_store[1:]
+ current_store[-1] = new_val
+ return_vals.append(current_store)
+ return return_vals if len(compartments) > 1 else return_vals[0]
+
+ return _advance
diff --git a/ngclearn/components/lava/neurons/LIFCell.py b/ngclearn/components/lava/neurons/LIFCell.py
index a32d0292..0d7175fc 100644
--- a/ngclearn/components/lava/neurons/LIFCell.py
+++ b/ngclearn/components/lava/neurons/LIFCell.py
@@ -108,7 +108,7 @@ def __init__(self, name, n_units, dt, tau_m, thr_theta_init=None, resist_m=1.,
self._init(thr_theta0)
def _init(self, thr_theta0):
- self.thr_theta = Compartment(thr_theta0)
+ self.thr_theta.set(thr_theta0)
@staticmethod
def _advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta,
diff --git a/ngclearn/components/lava/synapses/hebbianSynapse.py b/ngclearn/components/lava/synapses/hebbianSynapse.py
index 946e48ea..65a80ae3 100644
--- a/ngclearn/components/lava/synapses/hebbianSynapse.py
+++ b/ngclearn/components/lava/synapses/hebbianSynapse.py
@@ -96,11 +96,11 @@ def _init(self, weights):
preVals = jnp.zeros((self.batch_size, self.rows))
postVals = jnp.zeros((self.batch_size, self.cols))
## Compartments
- self.inputs = Compartment(preVals)
- self.outputs = Compartment(postVals)
- self.pre = Compartment(preVals)
- self.post = Compartment(postVals)
- self.weights = Compartment(weights)
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre.set(preVals)
+ self.post.set(postVals)
+ self.weights.set(weights)
@staticmethod
def _advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights,
diff --git a/ngclearn/components/lava/synapses/staticSynapse.py b/ngclearn/components/lava/synapses/staticSynapse.py
index e025bb0a..1df58ab2 100755
--- a/ngclearn/components/lava/synapses/staticSynapse.py
+++ b/ngclearn/components/lava/synapses/staticSynapse.py
@@ -12,7 +12,7 @@ class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse
| --- Synapse Input Compartments: (Takes wired-in signals) ---
| inputs - input (pre-synaptic) stimulus
- | --- Synapse Output Compartments: (These signals are generated) ---
+ | --- Synapse Output Compartments: .set()ese signals are generated) ---
| outputs - transformed (post-synaptic) signal
| weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)
@@ -75,9 +75,9 @@ def _init(self, weights):
preVals = jnp.zeros((self.batch_size, self.rows))
postVals = jnp.zeros((self.batch_size, self.cols))
## Compartments
- self.inputs = Compartment(preVals)
- self.outputs = Compartment(postVals)
- self.weights = Compartment(weights)
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.weights.set(weights)
@staticmethod
def _advance_state(dt, Rscale, inputs, weights):
diff --git a/ngclearn/components/lava/synapses/traceSTDPSynapse.py b/ngclearn/components/lava/synapses/traceSTDPSynapse.py
index e0d322ea..16103be8 100755
--- a/ngclearn/components/lava/synapses/traceSTDPSynapse.py
+++ b/ngclearn/components/lava/synapses/traceSTDPSynapse.py
@@ -83,7 +83,7 @@ def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None,
## Component size setup
self.batch_size = 1
- self.eta = Compartment(jnp.ones((1,1)) * eta)
+ self.eta = Compartment(jnp.ones((1, 1)) * eta)
self.inputs = Compartment(None)
self.outputs = Compartment(None)
@@ -112,13 +112,13 @@ def _init(self, weights):
preVals = jnp.zeros((self.batch_size, self.rows))
postVals = jnp.zeros((self.batch_size, self.cols))
## Compartments
- self.inputs = Compartment(preVals)
- self.outputs = Compartment(postVals)
- self.pre = Compartment(preVals) ## pre-synaptic spike
- self.x_pre = Compartment(preVals) ## pre-synaptic trace
- self.post = Compartment(postVals) ## post-synaptic spike
- self.x_post = Compartment(postVals) ## post-synaptic trace
- self.weights = Compartment(weights)
+ self.inputs.set(preVals)
+ self.outputs.set(postVals)
+ self.pre.set(preVals) ## pre-synaptic spike
+ self.x_pre.set(preVals) ## pre-synaptic trace
+ self.post.set(postVals) ## post-synaptic spike
+ self.x_post.set(postVals) ## post-synaptic trace
+ self.weights.set(weights)
@staticmethod
def _advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar,
@@ -155,7 +155,7 @@ def _reset(batch_size, rows, cols, eta0):
postVals, # post
preVals, # x_pre
postVals, # x_post
- jnp.ones((1,1)) * eta0
+ jnp.ones((1, 1)) * eta0
)
@resolver(_reset)
diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py
new file mode 100644
index 00000000..595ab69a
--- /dev/null
+++ b/ngclearn/components/monitor.py
@@ -0,0 +1,19 @@
+from ngclearn.components.base_monitor import Base_Monitor
+
+class Monitor(Base_Monitor):
+ """
+ A jax implementation of `Base_Monitor`. Designed to be used with all non-lava ngclearn components
+ """
+ @staticmethod
+ def build_advance(compartments):
+ @staticmethod
+ def _advance(**kwargs):
+ return_vals = []
+ for _, comp in compartments:
+ new_val = kwargs[comp]
+ current_store = kwargs[comp + "*store"]
+ current_store = current_store.at[:-1].set(current_store[1:])
+ current_store = current_store.at[-1].set(new_val)
+ return_vals.append(current_store)
+ return return_vals if len(compartments) > 1 else return_vals[0]
+ return _advance
diff --git a/ngclearn/utils/__init__.py b/ngclearn/utils/__init__.py
index c2dd6534..f897f5a7 100755
--- a/ngclearn/utils/__init__.py
+++ b/ngclearn/utils/__init__.py
@@ -1,5 +1,6 @@
from .model_utils import tensorstats
## forward imports from core ngc-learn utility sub-packages
+from . import viz
from . import io_utils
from . import metric_utils
from . import model_utils