Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ngclearn/components/lava/neurons/LIFCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ class LIFCell(Component): ## Lava-compliant leaky integrate-and-fire cell
| where R is the membrane resistance and v_rest is the resting potential
| gamma_d is voltage decay -- 1 recovers LIF dynamics and 0 recovers IF dynamics

| --- Cell Input Compartments: (Takes wired-in signals) ---
| j_exc - excitatory electrical input
| j_inh - inhibitory electrical input
| --- Cell Output Compartments: (These signals are generated) ---
| v - membrane potential/voltage state
| s - emitted binary spikes/action potentials
| rfr - (relative) refractory variable state
| thr_theta - homeostatic/adaptive threshold increment state

Args:
name: the string name of this cell

Expand Down
10 changes: 10 additions & 0 deletions ngclearn/components/lava/synapses/hebbianSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ class HebbianSynapse(Component): ## Lava-compliant Hebbian synapse
adjustment rule. This is a Lava-compliant synaptic cable that adjusts
with a hard-coded form of (stochastic) gradient ascent.

| --- Synapse Input Compartments: (Takes wired-in signals) ---
| inputs - input (pre-synaptic) stimulus
| --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) ---
| pre - pre-synaptic signal to drive first term of Hebbian update
| post - post-synaptic signal to drive 2nd term of Hebbian update
| eta - global learning rate (unidimensional/scalar value)
| --- Synapse Output Compartments: (These signals are generated) ---
| outputs - transformed (post-synaptic) signal
| weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)

Args:
name: the string name of this cell

Expand Down
6 changes: 6 additions & 0 deletions ngclearn/components/lava/synapses/staticSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ class StaticSynapse(Component): ## Lava-compliant fixed/non-evolvable synapse
is in-built to this component. This a Lava-compliant version of the
static synapse component from the synapses sub-package of components.

| --- Synapse Input Compartments: (Takes wired-in signals) ---
| inputs - input (pre-synaptic) stimulus
| --- Synapse Output Compartments: (These signals are generated) ---
| outputs - transformed (post-synaptic) signal
| weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)

Args:
name: the string name of this cell

Expand Down
14 changes: 13 additions & 1 deletion ngclearn/components/lava/synapses/traceSTDPSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ class TraceSTDPSynapse(Component): ## Lava-compliant trace-STDP synapse
spike-timing-dependent plasticity (STDP). This is a Lava-compliant synaptic
cable that adjusts with a hard-coded form of (stochastic) gradient ascent.

| --- Synapse Input Compartments: (Takes wired-in signals) ---
| inputs - input (pre-synaptic) stimulus
| --- Synaptic Plasticity Input Compartments: (Takes in wired-in signals) ---
| pre - pre-synaptic spike(s) to drive STDP update
| x_pre - pre-synaptic trace value(s) to drive STDP update
| post - post-synaptic spike(s) to drive STDP update
| x_post - post-synaptic trace value(s) to drive STDP update
| eta - global learning rate (unidimensional/scalar value)
| --- Synapse Output Compartments: (These signals are generated) ---
| outputs - transformed (post-synaptic) signal
| weights - current value matrix of synaptic efficacies (this is post-update if eta > 0)

Args:
name: the string name of this cell

Expand All @@ -29,7 +41,7 @@ class TraceSTDPSynapse(Component): ## Lava-compliant trace-STDP synapse

Aminus: strength of long-term depression (LTD)

eta: global learning rate
eta: global learning rate (default: 1)

w_decay: degree to which (L2) synaptic weight decay is applied to the
computed Hebbian adjustment (Default: 0); note that decay is not
Expand Down
5 changes: 5 additions & 0 deletions ngclearn/components/lava/traces/gatedTrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class GatedTrace(Component): ## gated/piecewise low-pass filter
"""
A gated/piecewise variable trace (filter).

| --- Cell Input Compartments: (Takes wired-in signals) ---
| inputs - input (takes wired-in external signals)
| --- Cell Output Compartments: (These signals are generated) ---
| trace - traced value signal

Args:
name: the string name of this operator

Expand Down
24 changes: 24 additions & 0 deletions ngclearn/components/other/expKernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ def reset(self, inputs, epsp, tf):
self.epsp.set(epsp)
self.tf.set(tf)

def help(self): ## component help function
properties = {
"cell type": "ExpKernel - maintains an exponential kernel over "
"incoming signal values (such as sequences of discrete pulses)"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values"},
"outputs_compartments":
{"epsp": "Excitatory postsynaptic potential/pulse emitted at time t",
"tr": "Value signal (rolling) time window"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"dt": "Integration time constant (kernel needs knowledge of `dt`)",
"nu": "Spike time interval for window",
"tau_w": "Spike window time constant"
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "epsp ~ Sum_{tf} exp(-(t - tf)/tau_w)",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
27 changes: 27 additions & 0 deletions ngclearn/components/other/varTrace.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ def reset(self, inputs, outputs, trace):
self.outputs.set(outputs)
self.trace.set(trace)

def help(self): ## component help function
properties = {
"cell type": "VarTrace - maintains a low pass filter over incoming signal "
"values (such as sequences of discrete pulses)"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values"},
"outputs_compartments":
{"trace": "Continuous low-pass filtered signal values, at time t",
"outputs": "Continuous low-pass filtered signal values, "
"at time t (same as `trace`)"},
}
hyperparams = {
"n_units": "Number of neuronal cells to model in this layer",
"tau_tr": "Trace/filter time constant",
"a_delta": "Increment to apply to trace (if not set to 0); "
"otherwise, traces clamp to 1 and then decay",
"decay_type": "Indicator of what type of decay dynamics to use "
"as filter is updated at time t"
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "tau_tr * dz/dt ~ -z + inputs",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
25 changes: 25 additions & 0 deletions ngclearn/components/synapses/denseSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,31 @@ def load(self, directory, **kwargs):
if "biases" in data.keys():
self.biases.set(data['biases'])

def help(self): ## component help function
properties = {
"cell type": "DenseSynapse - performs a synaptic transformation of inputs to produce "
"output signals (e.g., a scaled linear multivariate transformation)"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values",
"key": "JAX RNG key"},
"outputs_compartments":
{"outputs": "Output of synaptic transformation"},
}
hyperparams = {
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)"
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] + b",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
35 changes: 35 additions & 0 deletions ngclearn/components/synapses/hebbian/BCMSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@ def load(self, directory, **kwargs):
self.weights.set(data['weights'])
self.theta.set(data['theta'])

def help(self): ## component help function
properties = {
"cell type": "BCMSTDPSynapse - performs an adaptable synaptic transformation "
"of inputs to produce output signals; synapses are adjusted via "
"BCM theory"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values",
"key": "JAX RNG key",
"pre": "Pre-synaptic statistic for BCM (z_j)",
"post": "Post-synaptic statistic for BCM (z_i)"},
"outputs_compartments":
{"outputs": "Output of synaptic transformation",
"theta": "Synaptic threshold variable",
"dWeights": "Synaptic weight value adjustment matrix produced at time t"},
}
hyperparams = {
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
"tau_theta": "Time constant for synaptic threshold variable `theta`",
"tau_w": "Time constant for BCM synaptic adjustment",
"w_bound": "Soft synaptic bound applied to synapses post-update",
"w_decay": "Synaptic decay term"
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
"tau_w dW_{ij}/dt = z_j * (z_i - theta) - W_{ij} * w_decay;"
"tau_theta d(theta_{i})/dt = (-theta_{i} + (z_i)^2)",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
31 changes: 31 additions & 0 deletions ngclearn/components/synapses/hebbian/eventSTDPSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,37 @@ def reset(self, inputs, outputs, preSpike, postSpike, dWeights):
self.postSpike.set(postSpike)
self.dWeights.set(dWeights)

def help(self): ## component help function
properties = {
"cell type": "EventSTDPSynapse - performs an adaptable synaptic transformation "
"of inputs to produce output signals; synapses are adjusted with "
"event-based post-synaptic spike-timing-dependent plasticity"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values",
"key": "JAX RNG key",
"preSpike": "Pre-synaptic spike compartment value/term for STDP (s_j)",
"postSpike": "Post-synaptic spike compartment value/term for STDP (s_i)"},
"outputs_compartments":
{"outputs": "Output of synaptic transformation",
"dWeights": "Synaptic weight value adjustment matrix produced at time t"},
}
hyperparams = {
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
"lmbda": "Degree of synaptic disconnect",
"eta": "Global learning rate (multiplier beyond A_plus and A_minus)",
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
"dW_{ij}/dt = eta * [ (1 - W_{ij}(1 + lmbda)) * s_j - W_{ij} * (1 + lmbda) * s_j ]",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
37 changes: 37 additions & 0 deletions ngclearn/components/synapses/hebbian/expSTDPSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,43 @@ def reset(self, inputs, outputs, preSpike, postSpike, preTrace, postTrace, dWeig
self.postTrace.set(postTrace)
self.dWeights.set(dWeights)

def help(self): ## component help function
properties = {
"cell type": "ExpSTDPSynapse - performs an adaptable synaptic transformation "
"of inputs to produce output signals; synapses are adjusted with "
"exponential trace-based spike-timing-dependent plasticity"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values",
"key": "JAX RNG key",
"preSpike": "Pre-synaptic spike compartment value/term for STDP (s_j)",
"postSpike": "Post-synaptic spike compartment value/term for STDP (s_i)",
"preTrace": "Pre-synaptic trace value term for STDP (z_j)",
"postTrace": "Post-synaptic trace value term for STDP (z_i)"},
"outputs_compartments":
{"outputs": "Output of synaptic transformation",
"dWeights": "Synaptic weight value adjustment matrix produced at time t"},
}
hyperparams = {
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
"A_plus": "Strength of long-term potentiation (LTP)",
"A_minus": "Strength of long-term depression (LTD)",
"exp_beta": "Controls effect of exponential Hebbian shift / dependency (B)",
"eta": "Global learning rate (multiplier beyond A_plus and A_minus)",
"preTrace_target": "Pre-synaptic disconnecting/decay factor (x_tar)",
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] ;"
"dW_{ij}/dt = A_plus * [z_j * exp(-B w) - x_tar * exp(-B(w_max - w))] * s_i -"
"A_minus * s_j * [z_i * exp(-B w)]",
"hyperparameters": hyperparams}
return info

def __repr__(self):
comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))]
maxlen = max(len(c) for c in comps) + 5
Expand Down
37 changes: 37 additions & 0 deletions ngclearn/components/synapses/hebbian/hebbianSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,43 @@ def _reset(batch_size, shape, weight_init, bias_init):
jnp.zeros(shape[1]), # db
)

def help(self): ## component help function
properties = {
"cell type": "HebbianSynapse - performs an adaptable synaptic transformation "
"of inputs to produce output signals; synapses are adjusted via "
"two-term/factor Hebbian adjustment"
}
compartment_props = {
"input_compartments":
{"inputs": "Takes in external input signal values",
"key": "JAX RNG key",
"pre": "Pre-synaptic statistic for BCM (z_j)",
"post": "Post-synaptic statistic for BCM (z_i)"},
"outputs_compartments":
{"outputs": "Output of synaptic transformation",
"theta": "Synaptic threshold variable",
"dWeights": "Synaptic weight value adjustment matrix produced at time t"},
}
hyperparams = {
"shape": "Shape of synaptic weight value matrix; number inputs x number outputs",
"weight_init": "Initialization conditions for synaptic weight (W) values",
"bias_init": "Initialization conditions for bias/base-rate (b) values",
"resist_scale": "Resistance level scaling factor (applied to output of transformation)",
"p_conn": "Probability of a connection existing (otherwise, it is masked to zero)",
"is_nonnegative": "Should synapses be constrained to be non-negative post-updates?",
"sign_value": "Scalar `flipping` constant -- changes direction to Hebbian descent if < 0",
"pre_wght" : "Pre-synaptic weighting coefficient (q_pre)",
"post_wght" : "Post-synaptic weighting coefficient (q_post)",
"w_bound": "Soft synaptic bound applied to synapses post-update",
"w_decay": "Synaptic decay term"
}
info = {self.name: properties,
"compartments": compartment_props,
"dynamics": "outputs = [(W * Rscale) * inputs] + b ;"
"dW_{ij}/dt = eta * [(z_j * q_pre) * (z_i * q_post)] - W_{ij} * w_decay",
"hyperparameters": hyperparams}
return info

@resolver(_reset)
def reset(self, inputs, outputs, pre, post, dW, db):
self.inputs.set(inputs)
Expand Down
Loading