diff --git a/bindsnet/conversion/__init__.py b/bindsnet/conversion/__init__.py index cb450aae..d8468ae8 100644 --- a/bindsnet/conversion/__init__.py +++ b/bindsnet/conversion/__init__.py @@ -124,7 +124,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Sequence[int]] = Non self.refrac = refrac self.v = self.reset * torch.ones(self.shape) # Neuron voltages. - self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + self.refrac_count = torch.zeros(self.shape) # Refractory period counters. def forward(self, x: torch.Tensor) -> None: # language=rst @@ -157,7 +157,14 @@ def reset_(self) -> None: super().reset_() self.v = self.reset * torch.ones(self.shape) # Neuron voltages. - self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() class PassThroughNodes(nodes.Nodes): @@ -209,7 +216,6 @@ class PermuteConnection(topology.AbstractConnection): def __init__(self, source: nodes.Nodes, target: nodes.Nodes, dims: Sequence, nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None: - # language=rst """ Constructor for ``PermuteConnection``. @@ -411,7 +417,6 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs): height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1 shape = (1, out_channels, int(width), int(height)) - layer = node_type(shape=shape, reset=0, thresh=1, refrac=0, **kwargs) connection = topology.Conv2dConnection( source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride, @@ -468,7 +473,8 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs): def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None, - percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs) -> Network: + percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, + **kwargs) -> Network: # language=rst """ Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural diff --git a/bindsnet/network/__init__.py b/bindsnet/network/__init__.py index 894b8fa4..9e23dc54 100644 --- a/bindsnet/network/__init__.py +++ b/bindsnet/network/__init__.py @@ -107,6 +107,7 @@ def add_layer(self, layer: Nodes, name: str) -> None: self.layers[name] = layer layer.network = self layer.dt = self.dt + layer._compute_decays() def add_connection(self, connection: AbstractConnection, source: str, target: str) -> None: # language=rst diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 70679a48..04fee6a1 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -47,6 +47,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non if self.traces: self.x = torch.zeros(self.shape) # Firing traces. self.tc_trace = torch.tensor(tc_trace) # Time constant of spike trace decay. + self.trace_decay = None # Set in _compute_decays. if self.sum_input: self.summed = torch.zeros(self.shape) # Summed inputs. @@ -64,7 +65,7 @@ def forward(self, x: torch.Tensor) -> None: """ if self.traces: # Decay and set spike traces. - self.x *= torch.exp(-self.dt / self.tc_trace) + self.x *= self.trace_decay self.x.masked_fill_(self.s, 1) if self.sum_input: @@ -88,6 +89,15 @@ def reset_(self) -> None: if self.sum_input: self.summed = torch.zeros(self.shape) # Summed inputs. + @abstractmethod + def _compute_decays(self) -> None: + # language=rst + """ + Abstract base class method for setting decays. + """ + if self.traces: + self.trace_decay = torch.exp(-self.dt / self.tc_trace) # Spike trace decay (per timestep). + class AbstractInput(ABC): # language=rst @@ -135,6 +145,13 @@ def reset_(self) -> None: """ super().reset_() + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + class RealInput(Nodes, AbstractInput): """ @@ -169,7 +186,7 @@ def forward(self, x: torch.Tensor) -> None: if self.traces: # Decay and set spike traces. - self.x *= torch.exp(-self.dt / self.tc_trace) + self.x *= self.trace_decay self.x.masked_fill_(self.s != 0, 1) if self.sum_input: @@ -183,6 +200,13 @@ def reset_(self) -> None: """ super().reset_() + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + class McCullochPitts(Nodes): # language=rst @@ -229,6 +253,13 @@ def reset_(self) -> None: """ super().reset_() + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + class IFNodes(Nodes): # language=rst @@ -299,6 +330,13 @@ def reset_(self) -> None: self.v = self.reset * torch.ones(self.shape) # Neuron voltages. self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + class LIFNodes(Nodes): # language=rst @@ -335,6 +373,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non self.thresh = torch.tensor(thresh) # Spike threshold voltage. self.refrac = torch.tensor(refrac) # Post-spike refractory period. self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay. + self.decay = None # Set in _compute_decays. self.lbound = lbound # Lower bound of voltage. self.v = self.rest * torch.ones(self.shape) # Neuron voltages. @@ -348,7 +387,7 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Decay voltages. - self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest) + self.v = self.rest + self.decay * (self.v - self.rest) # Integrate inputs. self.v += (self.refrac_count == 0).float() * x @@ -378,6 +417,14 @@ def reset_(self) -> None: self.v = self.rest * torch.ones(self.shape) # Neuron voltages. self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep). + class CurrentLIFNodes(Nodes): # language=rst @@ -416,7 +463,9 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non self.thresh = torch.tensor(thresh) # Spike threshold voltage. self.refrac = torch.tensor(refrac) # Post-spike refractory period. self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay. + self.decay = None # Set in _compute_decays. self.tc_i_decay = torch.tensor(tc_i_decay) # Time constant of synaptic input current decay. + self.i_decay = None # Set in _compute_decays. self.lbound = lbound # Lower bound of voltage. self.v = self.rest * torch.ones(self.shape) # Neuron voltages. @@ -431,8 +480,8 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Decay voltages and current. - self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest) - self.i *= torch.exp(-self.dt / self.tc_i_decay) + self.v = self.rest + self.decay * (self.v - self.rest) + self.i *= self.i_decay # Decrement refractory counters. self.refrac_count = (self.refrac_count > 0).float() * (self.refrac_count - self.dt) @@ -464,6 +513,15 @@ def reset_(self) -> None: self.i = torch.zeros(self.shape) # Synaptic input currents. self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep). + self.i_decay = torch.exp(-self.dt / self.tc_i_decay) # Synaptic input current decay (per timestep). + class AdaptiveLIFNodes(Nodes): # language=rst @@ -503,8 +561,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non self.thresh = torch.tensor(thresh) # Spike threshold voltage. self.refrac = torch.tensor(refrac) # Post-spike refractory period. self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay. + self.decay = None # Set in _compute_decays. self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike. self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay. + self.theta_decay = None # Set in _compute_decays. self.lbound = lbound # Lower bound of voltage. self.v = self.rest * torch.ones(self.shape) # Neuron voltages. @@ -519,8 +579,8 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Decay voltages and adaptive thresholds. - self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest) - self.theta *= torch.exp(-self.dt / self.tc_theta_decay) + self.v = self.rest + self.decay * (self.v - self.rest) + self.theta *= self.theta_decay # Integrate inputs. self.v += (self.refrac_count == 0).float() * x @@ -551,6 +611,15 @@ def reset_(self) -> None: self.v = self.rest * torch.ones(self.shape) # Neuron voltages. self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep). + self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep). + class DiehlAndCookNodes(Nodes): # language=rst @@ -592,8 +661,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non self.thresh = torch.tensor(thresh) # Spike threshold voltage. self.refrac = torch.tensor(refrac) # Post-spike refractory period. self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay. + self.decay = None # Set in _compute_decays. self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike. - self.theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay. + self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay. + self.theta_decay = None # Set in _compute_decays. self.lbound = lbound # Lower bound of voltage. self.one_spike = one_spike # One spike per timestep. @@ -609,8 +680,8 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Decay voltages and adaptive thresholds. - self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest) - self.theta *= torch.exp(-self.dt / self.theta_decay) + self.v = self.rest + self.decay * (self.v - self.rest) + self.theta *= self.theta_decay # Integrate inputs. self.v += (self.refrac_count == 0).float() * x @@ -648,6 +719,15 @@ def reset_(self) -> None: self.v = self.rest * torch.ones(self.shape) # Neuron voltages. self.refrac_count = torch.zeros(self.shape) # Refractory period counters. + def _compute_decays(self) -> None: + # language=rst + """ + Sets the relevant decays. + """ + super()._compute_decays() + self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep). + self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep). + class IzhikevichNodes(Nodes): # language=rst