Skip to content

Commit

Permalink
Merge pull request #509 from mahbodnr/tensor-wmin-and-wmax
Browse files Browse the repository at this point in the history
Tensor support for wmin/wmax
  • Loading branch information
Hananel-Hazan committed Aug 18, 2021
2 parents 9e16cad + 03610a6 commit 15bb926
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 54 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,16 @@ examples/saved_checkpoints
# PyCharm project folder.
.idea/

# VS Code workspace
.vscode/

# macOS
.DS_Store

figures/

# Analyzer log default directory.
logs/

# PyTorch tensorboard log default directory.
runs/
9 changes: 5 additions & 4 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def update(self) -> None:

# Bound weights.
if (
self.connection.wmin != -np.inf or self.connection.wmax != np.inf
(self.connection.wmin != -np.inf).any()
or (self.connection.wmax != np.inf).any()
) and not isinstance(self, NoOp):
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)

Expand Down Expand Up @@ -282,9 +283,9 @@ def __init__(
)

assert self.source.traces, "Pre-synaptic nodes must record spike traces."
assert (
connection.wmin != -np.inf and connection.wmax != np.inf
), "Connection must define finite wmin and wmax."
assert (connection.wmin != -np.inf).any() and (
connection.wmax != np.inf
).any(), "Connection must define finite wmin and wmax."

self.wmin = connection.wmin
self.wmax = connection.wmax
Expand Down
117 changes: 70 additions & 47 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -40,8 +40,10 @@ def __init__(
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
:param float wmin: The minimum value on the connection weights.
:param float wmax: The maximum value on the connection weights.
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param Union[float, torch.Tensor] wmax: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization.
"""
super().__init__()
Expand All @@ -59,8 +61,16 @@ def __init__(
from ..learning import NoOp

self.update_rule = kwargs.get("update_rule", NoOp)
self.wmin = kwargs.get("wmin", -np.inf)
self.wmax = kwargs.get("wmax", np.inf)

# Float32 necessary for comparisons with +/-inf
self.wmin = Parameter(
torch.as_tensor(kwargs.get("wmin", -np.inf), dtype=torch.float32),
requires_grad=False,
)
self.wmax = Parameter(
torch.as_tensor(kwargs.get("wmax", np.inf), dtype=torch.float32),
requires_grad=False,
)
self.norm = kwargs.get("norm", None)
self.decay = kwargs.get("decay", None)

Expand All @@ -72,7 +82,7 @@ def __init__(
nu=nu,
reduction=reduction,
weight_decay=weight_decay,
**kwargs
**kwargs,
)

@abstractmethod
Expand Down Expand Up @@ -127,7 +137,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -146,20 +156,22 @@ def __init__(
some rule.
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor b: Target population bias.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param Union[float, torch.Tensor] wmax: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)

w = kwargs.get("w", None)
if w is None:
if self.wmin == -np.inf or self.wmax == np.inf:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp(torch.rand(source.n, target.n), self.wmin, self.wmax)
else:
w = self.wmin + torch.rand(source.n, target.n) * (self.wmax - self.wmin)
else:
if self.wmin != -np.inf or self.wmax != np.inf:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(torch.as_tensor(w), self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
Expand Down Expand Up @@ -260,7 +272,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -283,8 +295,10 @@ def __init__(
some rule.
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor b: Target population bias.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
Expand Down Expand Up @@ -326,8 +340,9 @@ def __init__(
), error

w = kwargs.get("w", None)
inf = torch.tensor(np.inf)
if w is None:
if self.wmin == -np.inf or self.wmax == np.inf:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(
torch.rand(self.out_channels, self.in_channels, *self.kernel_size),
self.wmin,
Expand All @@ -339,7 +354,7 @@ def __init__(
)
w += self.wmin
else:
if self.wmin != -np.inf or self.wmax != np.inf:
if (self.wmin == -inf).any() or (self.wmax == inf).any():
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
Expand Down Expand Up @@ -410,7 +425,7 @@ def __init__(
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand Down Expand Up @@ -500,7 +515,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand Down Expand Up @@ -528,8 +543,10 @@ def __init__(
some rule.
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor b: Target population bias.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization constant.
:param Tuple[int, int] input_shape: Shape of input population if it's not
``[sqrt, sqrt]``.
Expand Down Expand Up @@ -561,9 +578,10 @@ def __init__(
conv_prod = int(np.prod(conv_size))
kernel_prod = int(np.prod(kernel_size))

assert (
target.n == n_filters * conv_prod
), "Target layer size must be n_filters * (kernel_size ** 2)."
assert target.n == n_filters * conv_prod, (
f"Total neurons in target layer must be {n_filters * conv_prod}. "
f"Got {target.n}."
)

locations = torch.zeros(
kernel_size[0], kernel_size[1], conv_size[0], conv_size[1]
Expand All @@ -584,20 +602,21 @@ def __init__(
w = kwargs.get("w", None)

if w is None:
# Calculate unbounded weights
w = torch.zeros(source.n, target.n)
for f in range(n_filters):
for c in range(conv_prod):
for k in range(kernel_prod):
if self.wmin == -np.inf or self.wmax == np.inf:
w[self.locations[k, c], f * conv_prod + c] = np.clip(
np.random.rand(), self.wmin, self.wmax
)
else:
w[
self.locations[k, c], f * conv_prod + c
] = self.wmin + np.random.rand() * (self.wmax - self.wmin)
w[self.locations[k, c], f * conv_prod + c] = np.random.rand()

# Bind weights to given range
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)
else:
w = self.wmin + w * (self.wmax - self.wmin)

else:
if self.wmin != -np.inf or self.wmax != np.inf:
if (self.wmin != -np.inf).any() or (self.wmax != np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
Expand Down Expand Up @@ -671,7 +690,7 @@ def __init__(
target: Nodes,
nu: Optional[Union[float, Sequence[float]]] = None,
weight_decay: float = 0.0,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -683,21 +702,23 @@ def __init__(
Keyword arguments:
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
:param torch.Tensor w: Strengths of synapses.
:param float wmin: Minimum allowed value on the connection weights.
:param float wmax: Maximum allowed value on the connection weights.
:param Union[float, torch.Tensor] w: Strengths of synapses. Can be single value or tensor of size ``target``
:param Union[float, torch.Tensor] wmin: Minimum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param Union[float, torch.Tensor] wmax: Maximum allowed value(s) on the connection weights. Single value, or
tensor of same size as w
:param float norm: Total weight per target neuron normalization constant.
"""
super().__init__(source, target, nu, weight_decay, **kwargs)

w = kwargs.get("w", None)
if w is None:
if self.wmin == -np.inf or self.wmax == np.inf:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp((torch.randn(1)[0] + 1) / 10, self.wmin, self.wmax)
else:
w = self.wmin + ((torch.randn(1)[0] + 1) / 10) * (self.wmax - self.wmin)
else:
if self.wmin != -np.inf or self.wmax != np.inf:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
w = torch.clamp(w, self.wmin, self.wmax)

self.w = Parameter(w, requires_grad=False)
Expand Down Expand Up @@ -752,7 +773,7 @@ def __init__(
nu: Optional[Union[float, Sequence[float]]] = None,
reduction: Optional[callable] = None,
weight_decay: float = None,
**kwargs
**kwargs,
) -> None:
# language=rst
"""
Expand All @@ -767,7 +788,7 @@ def __init__(
Keyword arguments:
:param torch.Tensor w: Strengths of synapses.
:param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
:param float sparsity: Fraction of sparse connections to use.
:param LearningRule update_rule: Modifies connection parameters according to
some rule.
Expand All @@ -791,16 +812,17 @@ def __init__(
i = torch.bernoulli(
1 - self.sparsity * torch.ones(*source.shape, *target.shape)
)
if self.wmin == -np.inf or self.wmax == np.inf:
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
v = torch.clamp(
torch.rand(*source.shape, *target.shape)[i.bool()],
torch.rand(*source.shape, *target.shape),
self.wmin,
self.wmax,
)
)[i.bool()]
else:
v = self.wmin + torch.rand(*source.shape, *target.shape)[i.bool()] * (
self.wmax - self.wmin
)
v = (
self.wmin
+ torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin)
)[i.bool()]
w = torch.sparse.FloatTensor(i.nonzero().t(), v)
elif w is not None and self.sparsity is None:
assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
Expand All @@ -818,7 +840,8 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
:return: Incoming spikes multiplied by synaptic weights (with or without
decaying spike activation).
"""
return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)
return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1)
# return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)

def update(self, **kwargs) -> None:
# language=rst
Expand Down

0 comments on commit 15bb926

Please sign in to comment.