Skip to content

Commit

Permalink
Implemented method for setting decays
Browse files Browse the repository at this point in the history
  • Loading branch information
Huizerd committed Apr 18, 2019
1 parent d712b1a commit ac68277
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 13 deletions.
16 changes: 11 additions & 5 deletions bindsnet/conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Sequence[int]] = Non
self.refrac = refrac

self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def forward(self, x: torch.Tensor) -> None:
# language=rst
Expand Down Expand Up @@ -157,7 +157,14 @@ def reset_(self) -> None:
super().reset_()

self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class PassThroughNodes(nodes.Nodes):
Expand Down Expand Up @@ -209,7 +216,6 @@ class PermuteConnection(topology.AbstractConnection):

def __init__(self, source: nodes.Nodes, target: nodes.Nodes, dims: Sequence,
nu: Optional[Union[float, Sequence[float]]] = None, weight_decay: float = 0.0, **kwargs) -> None:

# language=rst
"""
Constructor for ``PermuteConnection``.
Expand Down Expand Up @@ -411,7 +417,6 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs):
height = (input_width - current.kernel_size[1] + 2 * current.padding[1]) / current.stride[1] + 1
shape = (1, out_channels, int(width), int(height))


layer = node_type(shape=shape, reset=0, thresh=1, refrac=0, **kwargs)
connection = topology.Conv2dConnection(
source=prev, target=layer, kernel_size=current.kernel_size, stride=current.stride,
Expand Down Expand Up @@ -468,7 +473,8 @@ def _ann_to_snn_helper(prev, current, node_type, **kwargs):


def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None,
percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs) -> Network:
percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes,
**kwargs) -> Network:
# language=rst
"""
Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural
Expand Down
1 change: 1 addition & 0 deletions bindsnet/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def add_layer(self, layer: Nodes, name: str) -> None:
self.layers[name] = layer
layer.network = self
layer.dt = self.dt
layer._compute_decays()

def add_connection(self, connection: AbstractConnection, source: str, target: str) -> None:
# language=rst
Expand Down
88 changes: 80 additions & 8 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
if self.traces:
self.x = torch.zeros(self.shape) # Firing traces.
self.tc_trace = torch.tensor(tc_trace) # Time constant of spike trace decay.
self.trace_decay = torch.exp(-self.dt / self.tc_trace) # Spike trace decay (per timestep).
self.trace_decay = None # Set in _compute_decays.

if self.sum_input:
self.summed = torch.zeros(self.shape) # Summed inputs.
Expand Down Expand Up @@ -89,6 +89,15 @@ def reset_(self) -> None:
if self.sum_input:
self.summed = torch.zeros(self.shape) # Summed inputs.

@abstractmethod
def _compute_decays(self) -> None:
# language=rst
"""
Abstract base class method for setting decays.
"""
if self.traces:
self.trace_decay = torch.exp(-self.dt / self.tc_trace) # Spike trace decay (per timestep).


class AbstractInput(ABC):
# language=rst
Expand Down Expand Up @@ -136,6 +145,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class RealInput(Nodes, AbstractInput):
"""
Expand Down Expand Up @@ -184,6 +200,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class McCullochPitts(Nodes):
# language=rst
Expand Down Expand Up @@ -230,6 +253,13 @@ def reset_(self) -> None:
"""
super().reset_()

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class IFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -300,6 +330,13 @@ def reset_(self) -> None:
self.v = self.reset * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()


class LIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -336,7 +373,7 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand Down Expand Up @@ -380,6 +417,14 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).


class CurrentLIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -418,9 +463,9 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.decay = None # Set in _compute_decays.
self.tc_i_decay = torch.tensor(tc_i_decay) # Time constant of synaptic input current decay.
self.i_decay = torch.exp(-self.dt / self.tc_i_decay) # Synaptic input current decay (per timestep).
self.i_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand Down Expand Up @@ -468,6 +513,15 @@ def reset_(self) -> None:
self.i = torch.zeros(self.shape) # Synaptic input currents.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.i_decay = torch.exp(-self.dt / self.tc_i_decay) # Synaptic input current decay (per timestep).


class AdaptiveLIFNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -507,10 +561,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.decay = None # Set in _compute_decays.
self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike.
self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay.
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).
self.theta_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand Down Expand Up @@ -557,6 +611,15 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).


class DiehlAndCookNodes(Nodes):
# language=rst
Expand Down Expand Up @@ -598,10 +661,10 @@ def __init__(self, n: Optional[int] = None, shape: Optional[Iterable[int]] = Non
self.thresh = torch.tensor(thresh) # Spike threshold voltage.
self.refrac = torch.tensor(refrac) # Post-spike refractory period.
self.tc_decay = torch.tensor(tc_decay) # Time constant of neuron voltage decay.
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.decay = None # Set in _compute_decays.
self.theta_plus = torch.tensor(theta_plus) # Constant threshold increase on spike.
self.tc_theta_decay = torch.tensor(tc_theta_decay) # Time constant of adaptive threshold decay.
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).
self.theta_decay = None # Set in _compute_decays.
self.lbound = lbound # Lower bound of voltage.
self.one_spike = one_spike # One spike per timestep.

Expand Down Expand Up @@ -656,6 +719,15 @@ def reset_(self) -> None:
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.

def _compute_decays(self) -> None:
# language=rst
"""
Sets the relevant decays.
"""
super()._compute_decays()
self.decay = torch.exp(-self.dt / self.tc_decay) # Neuron voltage decay (per timestep).
self.theta_decay = torch.exp(-self.dt / self.tc_theta_decay) # Adaptive threshold decay (per timestep).


class IzhikevichNodes(Nodes):
# language=rst
Expand Down

0 comments on commit ac68277

Please sign in to comment.