diff --git a/docs/images/tutorials/lava/raster_O.png b/docs/images/tutorials/lava/raster_O.png new file mode 100644 index 00000000..9b102d01 Binary files /dev/null and b/docs/images/tutorials/lava/raster_O.png differ diff --git a/docs/images/tutorials/lava/raster_T.png b/docs/images/tutorials/lava/raster_T.png new file mode 100644 index 00000000..dd45e4dc Binary files /dev/null and b/docs/images/tutorials/lava/raster_T.png differ diff --git a/docs/images/tutorials/lava/raster_X.png b/docs/images/tutorials/lava/raster_X.png new file mode 100644 index 00000000..e4fd9e1d Binary files /dev/null and b/docs/images/tutorials/lava/raster_X.png differ diff --git a/docs/tutorials/foundations/monitors.md b/docs/tutorials/foundations/monitors.md new file mode 100644 index 00000000..bc57849d --- /dev/null +++ b/docs/tutorials/foundations/monitors.md @@ -0,0 +1,72 @@ +# NGC Monitors + +Ngc-monitors are a way of storing a rolling window of compartment values +automatically. Their intended purpose is not to be used inside of a model but +just as an auxiliary way to view the internal state of the module even when it is +compiled. A monitor will track the last `n` values it has observed within the +compartment with the oldest value being at `index=0` and the newest being at +`index=n-1`. + +## Building a Monitor + +Monitors are constructed exactly like regular components are for general models. +Simply import the monitor `from ngclearn.components import Monitor`. Now, +inside of your model, build it like a regular component. + +```python +with Context("model") as model: + M = Monitor("M", default_window_length=100) +``` + +## Watching compartments + +There are then two key ways of watching compartments, the first way looks similar +to the wiring paradigm found in connecting standard ngclearn components +together. The primary difference is that connecting compartments to the monitor +does not require a compartment, they are wired directly into the `Monitor` +following the pattern below: + +```python + M << z0.s +``` + +This will wire the spike output of `z0` into the monitor with a view window +length of the `default_window_length`. In the event that you want a view window +that is not the default viewing length, you can use the `watch()` method +instead as in below: + +```python + M.watch(z0.s, customWindowLength) +``` + +There is no limit to the number of compartments that a monitor can watch or the +length of the window that it can store. However, as it is constantly shifting +values, tracking large matrices, such as those containing synapse values +over many timesteps, may get expensive. + +For the monitor to run during your `advance_state` and `reset` calls, make sure +to add to the monitor to the list of components to compile. Currently, +monitors do not work with non-compiled methods +(This is a planned feature for future developments of ngc-learn). + +## Extracting Values + +To look at the currently stored window of any compartment being tracked, there +are two methods available to you. The first method requires that you have +access to the compartment that the monitor is watching. To read out the +monitors values, you can call: + +```python +M.view(z0.s) +``` + +In the event that you do not have access to the compartment, all of the stored +values can be found via the path using the following: + +```python +M.get_store("path/to/compartment") +``` + +The stored windows are kept in a tree of dictionaries, where each node is a part +of the path and the leaves are compartment objects holding the +tracked value windows. \ No newline at end of file diff --git a/docs/tutorials/lava/hebbian_learning.md b/docs/tutorials/lava/hebbian_learning.md index ed576427..c45a01b6 100644 --- a/docs/tutorials/lava/hebbian_learning.md +++ b/docs/tutorials/lava/hebbian_learning.md @@ -1,17 +1,19 @@ # Training a Spiking Network On-chip -In this tutorial we will build generate a simple dataset (consisting of binary +In this tutorial we will build generate a simple dataset (consisting of binary patterns of X's, O's, and T's) and train a model in the Loihi simulator -using Hebbian learning in the form of trace-based spike-timing-dependent +using Hebbian learning in the form of trace-based spike-timing-dependent plasticity. ## Setting up ngc-learn -The first step of this project consist of setting up the configuration file for -ngc-learn. Create a folder in your project directory root called `json_files` and +The first step of this project consist of setting up the configuration file for +ngc-learn. Create a folder in your project directory root called `json_files` +and then create a `config.json` configuration inside of that folder. -Now for this project we will not be loading anything dynamically, so we can simply add: +Now for this project we will not be loading anything dynamically, so we can +simply add: ```json { @@ -21,12 +23,12 @@ Now for this project we will not be loading anything dynamically, so we can simp } ``` -The above configuration will skip the dynamic loading of modules, which is +The above configuration will skip the dynamic loading of modules, which is important for Lava-based model transference and simulation. -Next, in order to run code with the Loihi simulator, the base version of numpy -needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts -to by default). To change all of ngc-learn over to using the base version of +Next, in order to run code with the Loihi simulator, the base version of numpy +needs to be used instead of JAX's wrapped numpy (which is what ngc-learn resorts +to by default). To change all of ngc-learn over to using the base version of numpy, simply add the following to your configuration: ```json @@ -35,21 +37,22 @@ numpy, simply add the following to your configuration: } ``` -Now your project is configured for ngc-lava and Lava usage and we can move on +Now your project is configured for ngc-lava and Lava usage and we can move on to data generation. ## Generating Data -For this project we will be using three different patterns to train a simple -biophysical spiking neural network; the data will simply consist of binary -image patterns of either an `X`, `O`, and a `T`. To create the file needed to -generate these patterns, create a Python script named `data_generator.py` in +For this project we will be using three different patterns to train a simple +biophysical spiking neural network; the data will simply consist of binary +image patterns of either an `X`, `O`, and a `T`. To create the file needed to +generate these patterns, create a Python script named `data_generator.py` in your project root. Next, we will import `numpy` and `random` and define the following three generator methods: ```python from ngclearn import numpy as np + def make_X(size): X = np.zeros((size, size)) for i in range(0, size): @@ -57,6 +60,7 @@ def make_X(size): X[i, size - 1 - i] = np.random.uniform(0.75, 1) return X + def make_O(size): O = np.zeros((size, size)) for i in range(0, (size // 2) - 1): @@ -66,11 +70,13 @@ def make_O(size): O[(size // 2) + i, size - 2 - i] = np.random.uniform(0.75, 1) return O + def make_T(size): T = np.zeros((size, size)) T[1, 1:size - 1] = np.random.uniform(0.75, 1, (1, size - 2)) for i in range(2, size - 1): - T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1, (1, 2)) + T[i, (size // 2) - 1: (size // 2) + 1] = np.random.uniform(0.75, 1, + (1, 2)) return T ``` @@ -78,56 +84,63 @@ Each of these methods will create a pattern of the desired size and shape. ## Building the Model -Found below is all of the imports that will be needed to run the model we desire +Found below is all of the imports that will be needed to run the model we desire in Lava: ```python from ngclava import LavaContext from ngclearn import numpy as np -from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse +from ngclearn.components.lava import LIFCell, GatedTrace, TraceSTDPSynapse, StaticSynapse, Monitor +import ngclearn.utils.viz as viz_utils import ngclearn.utils.weight_distribution as dist -from ngclearn.utils.viz.synapse_plot import visualize from data_generator import make_X, make_O, make_T ``` -To start off building this model, we will define all of the hyperparameters +To start off building this model, we will define all of the hyperparameters needed to create the necessary model components: ```python -#Training Params +# Training Params epochs = 35 view_length = 200 rest_length = 1000 -#Model Params +# Model Params n_in = 64 # Input layer size n_hid = 25 # Hidden layer size dt = 1. # ms # integration time constant -np.random.seed(42) ## seed the internal numpy calls +np.random.seed(42) ## seed the internal numpy calls ``` -After this we will create the lava context, the components, as well as the wiring: +After this we will create the lava context, the components, as well as the +wiring: ```python with LavaContext("Model") as model: - z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt, - tau_m=50., v_decay=0., tau_theta=500., refract_T=0.) ## IF cell - z1e = LIFCell("z1e", n_units=n_hid, thr_theta_init=dist.uniform(amin=-2, amax=2.), - dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell - z1i = LIFCell("z1i", n_units=n_hid, thr_theta_init=dist.uniform(amin=-2, amax=2.), - dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45., - theta_plus=0.) ## inhibitory LIF cell + z0 = LIFCell("z0", n_units=n_in, thr_theta_init=dist.constant(0.), dt=dt, + tau_m=50., v_decay=0., tau_theta=500., + refract_T=0.) ## IF cell + z1e = LIFCell("z1e", n_units=n_hid, + thr_theta_init=dist.uniform(amin=-2, amax=2.), + dt=dt, tau_m=100., tau_theta=500.) ## excitatory LIF cell + z1i = LIFCell("z1i", n_units=n_hid, + thr_theta_init=dist.uniform(amin=-2, amax=2.), + dt=dt, tau_m=100., thr=-40., v_rest=-60., v_reset=-45., + theta_plus=0.) ## inhibitory LIF cell tr0 = GatedTrace("tr0", n_units=n_in, dt=dt, tau_tr=20.) tr1 = GatedTrace("tr1", n_units=n_hid, dt=dt, tau_tr=20.) - W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3), - shape=(n_in, n_hid), dt=dt, Aplus=0.011, Aminus=0.0011, + W1 = TraceSTDPSynapse("W1", weight_init=dist.uniform(amin=0, amax=0.3), + shape=(n_in, n_hid), dt=dt, Aplus=0.011, + Aminus=0.0011, preTrace_target=0.055) - W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.), - shape=(n_hid, n_hid),dt=dt) - W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5), + W1ie = StaticSynapse("W1ie", weight_init=dist.hollow(120.), shape=(n_hid, n_hid), dt=dt) + W1ei = StaticSynapse("W1ei", weight_init=dist.eye(22.5), + shape=(n_hid, n_hid), dt=dt) + + M = Monitor("M", default_window_length=view_length) ## wire z0 to z1e via W1 and z1i to z1e via W1ie W1.inputs << z0.s @@ -149,11 +162,15 @@ with LavaContext("Model") as model: W1.pre << z0.s W1.x_post << tr1.trace W1.post << z1e.s + + # set up monitoring of z1e's spike output + M << z1e.s ``` -After the components have been set up, we have to "lag out" the synapses that -will cause recurrent (locking) problems when running on the Loihi2. This will -cause each of these synapses to run one time-step behind and fixes many recurrency +After the components have been set up, we have to "lag out" the synapses that +will cause recurrent (locking) problems when running on the Loihi2. This will +cause each of these synapses to run one time-step behind and fixes many +recurrency issues (as described [here](lava_context.md)). ```python @@ -162,41 +179,45 @@ issues (as described [here](lava_context.md)). model.set_lag('W1ei') ``` -Now that the model is all set up, we have to tell the Lava compiler to actually +Now that the model is all set up, we have to tell the Lava compiler to actually build all the Lava objects with the following: ```python model.rebuild_lava() ``` -This line will stop the automatic build of components when leaving this -with-block and provides access to all of the Lava components inside of this +This line will stop the automatic build of components when leaving this +with-block and provides access to all of the Lava components inside of this with-block. -Next, we set up two methods, a `clamp` method to set the input data and +Next, we set up two methods, a `clamp` method to set the input data and `viz` to visualize all of the different receptive fields of our model: ```python lz0, lW1 = model.get_lava_components('z0', 'W1') + @model.dynamicCommand def clamp(x): model.pause() lz0.j_exc.set(x) + @model.dynamicCommand def viz(): - visualize([lW1.weights.get()], [(8, 8)], "lava_fields") + viz_utils.synapse_plot.visualize([lW1.weights.get()], [(8, 8)], "lava_fields") ``` -Now that everything is set up to build the runtime and start training the model +## Running The Model + +Now that everything is set up to build the runtime and start training the model inside of the Loihi simulator. To set up the runtime we call the following: ```python -model.set_up_runtime("z0", rest_image=np.zeros((1, 64))) + model.set_up_runtime("z0", rest_image=np.zeros((1, 64))) ``` -This will set up a runtime with `z0` as the root node and also uses a resting +This will set up a runtime with `z0` as the root node and also uses a resting image of all zeros to allow the system to return to its resting state. Now the training loop will be as follows: @@ -217,13 +238,75 @@ with model.runtime: model.view(T, view_length) model.rest(rest_length) - print("\nDone Training") + +``` + +## Evaluating the On-Chip Trained Model + +The code above will work to train the model on a Loihi neuromorphic chip, but, +currently, we do not have a way of viewing how effective the model learned +really is. To set up this evaluation, we can call +the `viz` method defined above to view the receptive fields that our spiking +model has acquired: + +```python model.viz() - model.save_to_json(".", model_name="trained") ``` -Running this should produce a set of receptive fields that look like the +Running this should produce a set of receptive fields that look like the following:
+ +While viewing the receptive fields qualitatively tells us that our spiking +model has trained, we may also want to view the +[raster plots](ngclearn.utils.viz.raster) -- visual depictions of the +underlying spike patterns acquired in the hidden layer of our model -- for each +of our three image patterns (as they are fed into our trained model). To do +this, we will make use of the monitor we defined above in the following manner: + +```python + ## Turning off learning + lW1.eta.set(np.array([0])) + + model.view(np.reshape(make_T(8), (1, 64)), view_length) + model.write_to_ngc() + spikes = M.view(z1e.s) + viz_utils.raster.create_raster_plot(spikes, tag="T", plot_fname="raster_T") + model.rest(rest_length) + print("Done T") + + model.view(np.reshape(make_X(8), (1, 64)), view_length) + model.write_to_ngc() + spikes = M.view(z1e.s) + viz_utils.raster.create_raster_plot(spikes, tag="X", plot_fname="raster_X") + model.rest(rest_length) + print("Done X") + + model.view(np.reshape(make_O(8), (1, 64)), view_length) + model.write_to_ngc() + spikes = M.view(z1e.s) + viz_utils.raster.create_raster_plot(spikes, tag="O", plot_fname="raster_O") + model.rest(rest_length) + print("Done O") +``` + +The above should result in raster plots where the spikes correspond to the +receptive fields of each trained letter pattern. Specifically, you should see +that the top left field is `N0` and the bottom right is `N24`. Your raster plots +should look like the ones below: + +
+
+
+ +Finally to save the model to disk, you can call the following: + +```python + model.save_to_json(".", model_name="trained") +``` + +which will save your on-chip trained Loihi model to disk for later use. + + diff --git a/docs/tutorials/lava/monitors.md b/docs/tutorials/lava/monitors.md new file mode 100644 index 00000000..b5d82243 --- /dev/null +++ b/docs/tutorials/lava/monitors.md @@ -0,0 +1,17 @@ +# Monitors + +While lava does have its own version of monitors, ngclearn offers an +in-built version for convenience. It is +recommended that you use the ngclearn monitors as they have expanded +functionality and are designed to interact with the Lava components well. For an +overview of/details on how monitors work please see +[this](../foundations/monitors.md). The +only difference is that Lava has its own monitor found +in `ngclearn.components.lava`. + +## Sharp Edges and Bits + +- Due to the fact that a Lava component of the monitor must be built, it has to + be defined inside the `LavaContext`. +- To view the values found within the monitor via the `view()` and `get_path()` + methods, `model.write_to_ngc()` must be called to refresh the values. \ No newline at end of file diff --git a/ngclearn/components/__init__.py b/ngclearn/components/__init__.py index aabfbdee..d1bcc40e 100644 --- a/ngclearn/components/__init__.py +++ b/ngclearn/components/__init__.py @@ -26,3 +26,5 @@ from .synapses.hebbian.expSTDPSynapse import ExpSTDPSynapse from .synapses.hebbian.eventSTDPSynapse import EventSTDPSynapse from .synapses.hebbian.BCMSynapse import BCMSynapse + +from .monitor import Monitor diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py new file mode 100644 index 00000000..79b2c535 --- /dev/null +++ b/ngclearn/components/base_monitor.py @@ -0,0 +1,216 @@ +import json + +from ngclearn import Component, Compartment +from ngclearn import numpy as np +from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path +from ngcsimlib.logger import warn, critical + + +class Base_Monitor(Component): + """ + An abstract base for monitors for both ngclearn and ngclava. Compartments wired directly into this component will + have their value tracked during `advance_state` loops automatically. + + Note the monitor only works for compiled methods currently + + + Using default window length: + myMonitor << myComponent.myCompartment + + Using custom window length: + myMonitor.watch(myComponent.myCompartment, customWindowLength) + + To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All + paths are the same as their local path variable. + + Using a compartment: + myMonitor.view(myComponent.myCompartment) + + Using a path: + myMonitor.get_store(myComponent.myCompartment.path).value + + There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers + for ngclearn. + + Args: + name: The name of the component. + + default_window_length: The default window length. + """ + + _singleton = None # Only one Monitor + + @staticmethod + def build_advance(compartments): + """ + A method to build the method to advance the stored values. + + Args: + compartments: A list of compartments to store values + + Returns: The method to advance the stored values. + + """ + critical( + "build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or " + "ngclearn.components.lava (If using lava)") + + + @staticmethod + def build_reset(compartments): + """ + A method to build the method to reset the stored values. + Args: + compartments: A list of compartments to reset + + Returns: The method to reset the stored values. + """ + @staticmethod + def _reset(**kwargs): + return_vals = [] + for _, comp in compartments: + current_store = kwargs[comp + "*store"] + return_vals.append(np.zeros(current_store.shape)) + return return_vals if len(compartments) > 1 else return_vals[0] + + return _reset + + def __init__(self, name, default_window_length=100, **kwargs): + if Base_Monitor._singleton is not None: + critical("Only one monitor can be built") + else: + Base_Monitor._singleton = True + super().__init__(name, **kwargs) + self.store = {} + self.compartments = [] + self.default_window_length = default_window_length + + def __lshift__(self, other): + if isinstance(other, Compartment): + self.watch(other, self.default_window_length) + else: + warn("Only Compartments can be monitored not", type(other)) + + def watch(self, compartment, window_length): + """ + Sets the monitor to watch a specific compartment, for a specified window length. + + Args: + compartment: the compartment object to monitor + + window_length: the window length + """ + cs, end = self._add_path(compartment.path) + + shape = compartment.value.shape + new_comp = Compartment(np.zeros(shape)) + new_comp_store = Compartment(np.zeros((window_length, *shape))) + + comp_key = "*".join(compartment.path.split("/")) + store_comp_key = comp_key + "*store" + + new_comp._setup(self, comp_key) + new_comp_store._setup(self, store_comp_key) + + new_comp << compartment + + cs[end] = new_comp_store + setattr(self, comp_key, new_comp) + setattr(self, store_comp_key, new_comp_store) + self.compartments.append(new_comp.path) + + self._update_resolver() + + def _update_resolver(self): + output_compartments = [] + compartments = [] + for comp in self.compartments: + output_compartments.append(comp.split("/")[-1] + "*store") + compartments.append((0, comp.split("/")[-1])) + + args = [] + parameters = [] + + add_component_resolver(self.__class__.__name__, "advance_state", + (self.build_advance(compartments), output_compartments)) + add_resolver_meta(self.__class__.__name__, "advance_state", + (args, parameters, compartments + [(0, o) for o in output_compartments], False)) + + add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments)) + add_resolver_meta(self.__class__.__name__, "reset", + (args, parameters, [(0, o) for o in output_compartments], False)) + + def _add_path(self, path): + _path = path.split("/")[1:] + end = _path.pop(-1) + + current_store = self.store + for p in _path: + if p not in current_store.keys(): + current_store[p] = {} + current_store = current_store[p] + + return current_store, end + + def view(self, compartment): + """ + Gets the value associated with the specified compartment + + Args: + compartment: The compartment to extract the stored value of + + Returns: The stored value, None if not monitoring that compartment + + """ + _path = compartment.path.split("/")[1:] + store = self.get_store(_path) + return store.value if store is not None else store + + def get_store(self, path): + current_store = self.store + for p in path: + if p not in current_store.keys(): + return None + current_store = current_store[p] + return current_store + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".json" + _dict = {"sources": {}, "stores": {}} + for key in self.compartments: + n = key.split("/")[-1] + _dict["sources"][key] = self.__dict__[n].value.shape + _dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape + + with open(file_name, "w") as f: + json.dump(_dict, f) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".json" + with open(file_name, "r") as f: + vals = json.load(f) + + for comp_path, shape in vals["stores"].items(): + + compartment_path = comp_path.split("/")[-1] + new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1]) + + cs, end = self._add_path(new_path) + + new_comp = Compartment(np.zeros(shape)) + new_comp._setup(self, compartment_path) + + cs[end] = new_comp + setattr(self, compartment_path, new_comp) + + + + for comp_path, shape in vals['sources'].items(): + compartment_path = comp_path.split("/")[-1] + new_comp = Compartment(np.zeros(shape)) + new_comp._setup(self, compartment_path) + + setattr(self, compartment_path, new_comp) + self.compartments.append(new_comp.path) + + self._update_resolver() diff --git a/ngclearn/components/lava/__init__.py b/ngclearn/components/lava/__init__.py index d70f8f67..962f843a 100644 --- a/ngclearn/components/lava/__init__.py +++ b/ngclearn/components/lava/__init__.py @@ -6,3 +6,6 @@ from .synapses.hebbianSynapse import HebbianSynapse ## Lava-compliant encoders/traces from .traces.gatedTrace import GatedTrace + +#monitor +from .monitor import Monitor \ No newline at end of file diff --git a/ngclearn/components/lava/monitor.py b/ngclearn/components/lava/monitor.py new file mode 100644 index 00000000..6fcb8d9a --- /dev/null +++ b/ngclearn/components/lava/monitor.py @@ -0,0 +1,22 @@ +from ngclearn.components.base_monitor import Base_Monitor + + +class Monitor(Base_Monitor): + """ + A numpy implementation of `Base_Monitor`. Designed to be used with all lava compatible ngclearn components + """ + + @staticmethod + def build_advance(compartments): + @staticmethod + def _advance(**kwargs): + return_vals = [] + for _, comp in compartments: + new_val = kwargs[comp] + current_store = kwargs[comp + "*store"] + current_store[:-1] = current_store[1:] + current_store[-1] = new_val + return_vals.append(current_store) + return return_vals if len(compartments) > 1 else return_vals[0] + + return _advance diff --git a/ngclearn/components/lava/neurons/LIFCell.py b/ngclearn/components/lava/neurons/LIFCell.py index a32d0292..0d7175fc 100644 --- a/ngclearn/components/lava/neurons/LIFCell.py +++ b/ngclearn/components/lava/neurons/LIFCell.py @@ -108,7 +108,7 @@ def __init__(self, name, n_units, dt, tau_m, thr_theta_init=None, resist_m=1., self._init(thr_theta0) def _init(self, thr_theta0): - self.thr_theta = Compartment(thr_theta0) + self.thr_theta.set(thr_theta0) @staticmethod def _advance_state(dt, tau_m, R_m, v_rest, v_reset, v_decay, refract_T, thr, tau_theta, diff --git a/ngclearn/components/lava/synapses/hebbianSynapse.py b/ngclearn/components/lava/synapses/hebbianSynapse.py index 946e48ea..65a80ae3 100644 --- a/ngclearn/components/lava/synapses/hebbianSynapse.py +++ b/ngclearn/components/lava/synapses/hebbianSynapse.py @@ -96,11 +96,11 @@ def _init(self, weights): preVals = jnp.zeros((self.batch_size, self.rows)) postVals = jnp.zeros((self.batch_size, self.cols)) ## Compartments - self.inputs = Compartment(preVals) - self.outputs = Compartment(postVals) - self.pre = Compartment(preVals) - self.post = Compartment(postVals) - self.weights = Compartment(weights) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre.set(preVals) + self.post.set(postVals) + self.weights.set(weights) @staticmethod def _advance_state(dt, Rscale, w_bounds, w_decay, inputs, weights, diff --git a/ngclearn/components/lava/synapses/staticSynapse.py b/ngclearn/components/lava/synapses/staticSynapse.py index e025bb0a..1df58ab2 100755 --- a/ngclearn/components/lava/synapses/staticSynapse.py +++ b/ngclearn/components/lava/synapses/staticSynapse.py @@ -12,7 +12,7 @@ class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse | --- Synapse Input Compartments: (Takes wired-in signals) --- | inputs - input (pre-synaptic) stimulus - | --- Synapse Output Compartments: (These signals are generated) --- + | --- Synapse Output Compartments: .set()ese signals are generated) --- | outputs - transformed (post-synaptic) signal | weights - current value matrix of synaptic efficacies (this is post-update if eta > 0) @@ -75,9 +75,9 @@ def _init(self, weights): preVals = jnp.zeros((self.batch_size, self.rows)) postVals = jnp.zeros((self.batch_size, self.cols)) ## Compartments - self.inputs = Compartment(preVals) - self.outputs = Compartment(postVals) - self.weights = Compartment(weights) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.weights.set(weights) @staticmethod def _advance_state(dt, Rscale, inputs, weights): diff --git a/ngclearn/components/lava/synapses/traceSTDPSynapse.py b/ngclearn/components/lava/synapses/traceSTDPSynapse.py index e0d322ea..16103be8 100755 --- a/ngclearn/components/lava/synapses/traceSTDPSynapse.py +++ b/ngclearn/components/lava/synapses/traceSTDPSynapse.py @@ -83,7 +83,7 @@ def __init__(self, name, dt, resist_scale=1., weight_init=None, shape=None, ## Component size setup self.batch_size = 1 - self.eta = Compartment(jnp.ones((1,1)) * eta) + self.eta = Compartment(jnp.ones((1, 1)) * eta) self.inputs = Compartment(None) self.outputs = Compartment(None) @@ -112,13 +112,13 @@ def _init(self, weights): preVals = jnp.zeros((self.batch_size, self.rows)) postVals = jnp.zeros((self.batch_size, self.cols)) ## Compartments - self.inputs = Compartment(preVals) - self.outputs = Compartment(postVals) - self.pre = Compartment(preVals) ## pre-synaptic spike - self.x_pre = Compartment(preVals) ## pre-synaptic trace - self.post = Compartment(postVals) ## post-synaptic spike - self.x_post = Compartment(postVals) ## post-synaptic trace - self.weights = Compartment(weights) + self.inputs.set(preVals) + self.outputs.set(postVals) + self.pre.set(preVals) ## pre-synaptic spike + self.x_pre.set(preVals) ## pre-synaptic trace + self.post.set(postVals) ## post-synaptic spike + self.x_post.set(postVals) ## post-synaptic trace + self.weights.set(weights) @staticmethod def _advance_state(dt, Rscale, Aplus, Aminus, w_bounds, w_decay, x_tar, @@ -155,7 +155,7 @@ def _reset(batch_size, rows, cols, eta0): postVals, # post preVals, # x_pre postVals, # x_post - jnp.ones((1,1)) * eta0 + jnp.ones((1, 1)) * eta0 ) @resolver(_reset) diff --git a/ngclearn/components/monitor.py b/ngclearn/components/monitor.py new file mode 100644 index 00000000..595ab69a --- /dev/null +++ b/ngclearn/components/monitor.py @@ -0,0 +1,19 @@ +from ngclearn.components.base_monitor import Base_Monitor + +class Monitor(Base_Monitor): + """ + A jax implementation of `Base_Monitor`. Designed to be used with all non-lava ngclearn components + """ + @staticmethod + def build_advance(compartments): + @staticmethod + def _advance(**kwargs): + return_vals = [] + for _, comp in compartments: + new_val = kwargs[comp] + current_store = kwargs[comp + "*store"] + current_store = current_store.at[:-1].set(current_store[1:]) + current_store = current_store.at[-1].set(new_val) + return_vals.append(current_store) + return return_vals if len(compartments) > 1 else return_vals[0] + return _advance diff --git a/ngclearn/utils/__init__.py b/ngclearn/utils/__init__.py index c2dd6534..f897f5a7 100755 --- a/ngclearn/utils/__init__.py +++ b/ngclearn/utils/__init__.py @@ -1,5 +1,6 @@ from .model_utils import tensorstats ## forward imports from core ngc-learn utility sub-packages +from . import viz from . import io_utils from . import metric_utils from . import model_utils