Skip to content

Commit

Permalink
Merge pull request #231 from Huizerd/master
Browse files Browse the repository at this point in the history
Compute exponential decays only once
  • Loading branch information
djsaunde committed Apr 22, 2019
2 parents 020ba61 + ac68277 commit bd5f836
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 15 deletions.
16 changes: 11 additions & 5 deletions bindsnet/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Sequence[int]] = Non
self.refrac = refrac

self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def forward(self, x: torch.Tensor) -> None:
# language=rst
Expand Down Expand Up @@ -157,7 +157,14 @@ def reset_(self) -> None:
super().reset_()

self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class PassThroughNodes(nodes.Nodes):
Expand Down Expand Up @@ -209,7 +216,6 @@ class PermuteConnection(topology.AbstractConnection):

def __init__(self, source: nodes.Nodes, target: nodes.Nodes, dims: Sequence,
nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None:

# language=rst
"""
Constructor for ``PermuteConnection``.
Expand Down Expand Up @@ -411,7 +417,6 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs):
height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1
shape = (1, out_channels, int(width), int(height))


layer = node_type(shape=shape, reset=0, thresh=1, refrac=0, **kwargs)
connection = topology.Conv2dConnection(
source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride,
Expand Down Expand Up @@ -468,7 +473,8 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs):


def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None,
percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs) -> Network:
percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes,
**kwargs) -> Network:
# language=rst
"""
Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural
Expand Down
1 change: 1 addition & 0 deletions bindsnet/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def add_layer(self, layer: Nodes, name: str) -> None:
self.layers[name] = layer
layer.network = self
layer.dt = self.dt
layer._compute_decays()

def add_connection(self, connection: AbstractConnection, source: str, target: str) -> None:
# language=rst
Expand Down
100 changes: 90 additions & 10 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
if self.traces:
self.x = torch.zeros(self.shape) # Firing traces.
self.tc_trace = torch.tensor(tc_trace) # Time constant of spike trace decay.
self.trace_decay = None # Set in _compute_decays.

if self.sum_input:
self.summed = torch.zeros(self.shape) # Summed inputs.
Expand All @@ -64,7 +65,7 @@ def forward(self, x: torch.Tensor) -> None:
"""
if self.traces:
# Decay and set spike traces.
self.x *= torch.exp(-self.dt / self.tc_trace)
self.x *= self.trace_decay
self.x.masked_fill_(self.s, 1)

if self.sum_input:
Expand All @@ -88,6 +89,15 @@ def reset_(self) -> None:
if self.sum_input:
self.summed = torch.zeros(self.shape) # Summed inputs.

@abstractmethod
def _compute_decays(self) -> None:
# language=rst
"""
Abstract base class method for setting decays.
"""
if self.traces:
self.trace_decay = torch.exp(-self.dt / self.tc_trace) # Spike trace decay (per timestep).


class AbstractInput(ABC):
# language=rst
Expand Down Expand Up @@ -135,6 +145,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class RealInput(Nodes, AbstractInput):
"""
Expand Down Expand Up @@ -169,7 +186,7 @@ def forward(self, x: torch.Tensor) -> None:

if self.traces:
# Decay and set spike traces.
self.x *= torch.exp(-self.dt / self.tc_trace)
self.x *= self.trace_decay
self.x.masked_fill_(self.s != 0, 1)

if self.sum_input:
Expand All @@ -183,6 +200,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class McCullochPitts(Nodes):
# language=rst
Expand Down Expand Up @@ -229,6 +253,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class IFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -299,6 +330,13 @@ def reset_(self) -> None:
self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class LIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -335,6 +373,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand All @@ -348,7 +387,7 @@ def forward(self, x: torch.Tensor) -> None:
:param x: Inputs to the layer.
"""
# Decay voltages.
self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest)
self.v = self.rest + self.decay * (self.v - self.rest)

# Integrate inputs.
self.v += (self.refrac_count == 0).float() * x
Expand Down Expand Up @@ -378,6 +417,14 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).


class CurrentLIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -416,7 +463,9 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = None # Set in _compute_decays.
self.tc_i_decay = torch.tensor(tc_i_decay) # Time constant of synaptic input current decay.
self.i_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand All @@ -431,8 +480,8 @@ def forward(self, x: torch.Tensor) -> None:
:param x: Inputs to the layer.
"""
# Decay voltages and current.
self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest)
self.i *= torch.exp(-self.dt / self.tc_i_decay)
self.v = self.rest + self.decay * (self.v - self.rest)
self.i *= self.i_decay

# Decrement refractory counters.
self.refrac_count = (self.refrac_count > 0).float() * (self.refrac_count - self.dt)
Expand Down Expand Up @@ -464,6 +513,15 @@ def reset_(self) -> None:
self.i = torch.zeros(self.shape) # Synaptic input currents.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.i_decay = torch.exp(-self.dt / self.tc_i_decay) # Synaptic input current decay (per timestep).


class AdaptiveLIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -503,8 +561,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = None # Set in _compute_decays.
self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike.
self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay.
self.theta_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand All @@ -519,8 +579,8 @@ def forward(self, x: torch.Tensor) -> None:
:param x: Inputs to the layer.
"""
# Decay voltages and adaptive thresholds.
self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest)
self.theta *= torch.exp(-self.dt / self.tc_theta_decay)
self.v = self.rest + self.decay * (self.v - self.rest)
self.theta *= self.theta_decay

# Integrate inputs.
self.v += (self.refrac_count == 0).float() * x
Expand Down Expand Up @@ -551,6 +611,15 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).


class DiehlAndCookNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -592,8 +661,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = None # Set in _compute_decays.
self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike.
self.theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay.
self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay.
self.theta_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.
self.one_spike = one_spike # One spike per timestep.

Expand All @@ -609,8 +680,8 @@ def forward(self, x: torch.Tensor) -> None:
:param x: Inputs to the layer.
"""
# Decay voltages and adaptive thresholds.
self.v = self.rest + torch.exp(-self.dt / self.tc_decay) * (self.v - self.rest)
self.theta *= torch.exp(-self.dt / self.theta_decay)
self.v = self.rest + self.decay * (self.v - self.rest)
self.theta *= self.theta_decay

# Integrate inputs.
self.v += (self.refrac_count == 0).float() * x
Expand Down Expand Up @@ -648,6 +719,15 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).


class IzhikevichNodes(Nodes):
# language=rst
Expand Down

0 comments on commit bd5f836

Please sign in to comment.