Skip to content

Commit

Permalink
Merge pull request #501 from mahbodnr/master
Browse files Browse the repository at this point in the history
Minor changes
  • Loading branch information
Hananel-Hazan committed Jul 19, 2021
2 parents c9f3e9d + 4dddce6 commit 91ffafb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
41 changes: 24 additions & 17 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC
from typing import Union, Optional, Sequence
import warnings

import torch
import numpy as np
Expand All @@ -26,7 +27,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -49,13 +50,19 @@ def __init__(
# Learning rate(s).
if nu is None:
nu = [0.0, 0.0]
elif isinstance(nu, float) or isinstance(nu, int):
elif isinstance(nu, (float, int)):
nu = [nu, nu]

self.nu = torch.zeros(2, dtype=torch.float)
self.nu[0] = nu[0]
self.nu[1] = nu[1]

if (self.nu == torch.zeros(2)).all() and not isinstance(self, NoOp):
warnings.warn(
f"nu is set to [0., 0.] for {type(self).__name__} learning rule. "
+ "It will disable the learning process."
)

# Parameter update reduction across minibatch dimension.
if reduction is None:
if self.source.batch_size == 1:
Expand Down Expand Up @@ -96,7 +103,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -113,7 +120,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

def update(self, **kwargs) -> None:
Expand All @@ -137,7 +144,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -155,7 +162,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert (
Expand Down Expand Up @@ -253,7 +260,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -271,7 +278,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert self.source.traces, "Pre-synaptic nodes must record spike traces."
Expand Down Expand Up @@ -391,7 +398,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -409,7 +416,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

assert (
Expand Down Expand Up @@ -496,7 +503,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -520,7 +527,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

if isinstance(connection, (Connection, LocalConnection)):
Expand Down Expand Up @@ -690,7 +697,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -715,7 +722,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

if isinstance(connection, (Connection, LocalConnection)):
Expand Down Expand Up @@ -788,7 +795,7 @@ def _connection_update(self, **kwargs) -> None:
self.p_minus += a_minus * target_s

# Calculate point eligibility value.
self.eligibility = torch.ger(self.p_plus, target_s) + torch.ger(
self.eligibility = torch.outer(self.p_plus, target_s) + torch.outer(
source_s, self.p_minus
)

Expand Down Expand Up @@ -894,7 +901,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -919,7 +926,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

# Trace is needed for computing epsilon.
Expand Down
2 changes: 1 addition & 1 deletion bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def __init__(
)
self.add_connection(input_output_conn, source="X", target="Y")

# add internal inhibetory connections
# add internal inhibitory connections
w = torch.ones(self.n_neurons, self.n_neurons) - torch.diag(
torch.ones(self.n_neurons)
)
Expand Down
8 changes: 7 additions & 1 deletion bindsnet/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]:
self.batch_size,
target.res_window_size,
*target.shape,
device=target.s.device
device=target.s.device,
)
else:
inputs[c[1]] = torch.zeros(
Expand Down Expand Up @@ -275,6 +275,7 @@ def run(
learning.
:param Dict[Tuple[str], torch.Tensor] masks: Mapping of connection names to
boolean masks determining which weights to clamp to zero.
:param Bool progress_bar: Show a progress bar while running the network.
**Example:**
Expand Down Expand Up @@ -306,6 +307,11 @@ def run(
plt.title('Input spiking')
plt.show()
"""
# Check input type
assert type(inputs) == dict, (
"'inputs' must be a dict of names of layers "
+ f"(str) and relevant input tensors. Got {type(inputs).__name__} instead."
)
# Parse keyword arguments.
clamps = kwargs.get("clamp", {})
unclamps = kwargs.get("unclamp", {})
Expand Down
11 changes: 5 additions & 6 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def __init__(
self.register_buffer(
"tc_trace", torch.tensor(tc_trace)
) # Time constant of spike trace decay.
if self.traces_additive:
self.register_buffer(
"trace_scale", torch.tensor(trace_scale)
) # Scaling factor for spike trace.
self.register_buffer(
"trace_scale", torch.tensor(trace_scale)
) # Scaling factor for spike trace.
self.register_buffer(
"trace_decay", torch.empty_like(self.tc_trace)
) # Set in compute_decays.
Expand Down Expand Up @@ -101,7 +100,7 @@ def forward(self, x: torch.Tensor) -> None:
if self.traces_additive:
self.x += self.trace_scale * self.s.float()
else:
self.x.masked_fill_(self.s.bool(), 1)
self.x.masked_fill_(self.s.bool(), self.trace_scale)

if self.sum_input:
# Add current input to running sum.
Expand Down Expand Up @@ -1538,7 +1537,7 @@ def ExponentialKernel(self, dt):

def RectangularKernel(self, dt):
t = torch.arange(0, self.res_window_size, dt)
kernelVec = 1 / (selftau * 2)
kernelVec = 1 / (self.tau * 2)
return torch.flip(kernelVec, [0])

def TriangularKernel(self, dt):
Expand Down

0 comments on commit 91ffafb

Please sign in to comment.