From bc4324171a75e0201617d6f3eb175466fa61a572 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Wed, 17 Jun 2020 10:57:54 +0200 Subject: [PATCH 01/11] faster LIF forward --- bindsnet/network/nodes.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index fba2cb70..87903119 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -509,16 +509,16 @@ def forward(self, x: torch.Tensor) -> None: self.v = self.decay * (self.v - self.rest) + self.rest # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x - + x.masked_fill_(self.refrac_count > 0, 0.0) # OPTIM 2 # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # OPTIM 1 + + self.v += x # interlaced # Check for spiking neurons. self.s = self.v >= self.thresh + # Refractoriness and voltage reset. self.refrac_count.masked_fill_(self.s, self.refrac) self.v.masked_fill_(self.s, self.reset) From a77b24755bbbafb5f9288c593991c1776474a8e0 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Wed, 17 Jun 2020 13:17:41 +0200 Subject: [PATCH 02/11] default reduction is sum. Faster code when no batch is used --- bindsnet/learning/learning.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index f602254b..9260870e 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -56,9 +56,13 @@ def __init__( # Parameter update reduction across minibatch dimension. if reduction is None: - reduction = torch.mean - - self.reduction = reduction + if self.source.batch_size == 1: + def NoBatch(a, dim):# get rid of the batch dim + return a[0] + reduction = NoBatch + else: + reduction = torch.sum + self.reduction = reduction # Weight decay. self.weight_decay = weight_decay From 05f2ebe2a66d2074e80a743f096341f76bb40609 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Wed, 17 Jun 2020 13:26:07 +0200 Subject: [PATCH 03/11] fixed indent bug --- bindsnet/learning/learning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 9260870e..1691fe93 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -62,7 +62,7 @@ def NoBatch(a, dim):# get rid of the batch dim reduction = NoBatch else: reduction = torch.sum - self.reduction = reduction + self.reduction = reduction # Weight decay. self.weight_decay = weight_decay From f24f94cc72b52fce78717a3509e028acd693e65b Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Sun, 21 Jun 2020 17:02:25 +0200 Subject: [PATCH 04/11] changed NoBatch function location --- bindsnet/learning/learning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 1691fe93..17f5efb8 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -13,6 +13,8 @@ ) from ..utils import im2col_indices +def NoBatch(a, dim):# get rid of the batch dim + return a[0] class LearningRule(ABC): # language=rst @@ -57,8 +59,6 @@ def __init__( # Parameter update reduction across minibatch dimension. if reduction is None: if self.source.batch_size == 1: - def NoBatch(a, dim):# get rid of the batch dim - return a[0] reduction = NoBatch else: reduction = torch.sum From 3e44698e24c33878c0bcfef892805b8c25f09c8c Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Mon, 22 Jun 2020 20:42:40 +0200 Subject: [PATCH 05/11] moved NoBatch inside class --- bindsnet/learning/learning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 17f5efb8..6c7439fc 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -13,15 +13,15 @@ ) from ..utils import im2col_indices -def NoBatch(a, dim):# get rid of the batch dim - return a[0] - class LearningRule(ABC): # language=rst """ Abstract base class for learning rules. """ + def NoBatch(a, dim):# get rid of the batch dim + return a[0] + def __init__( self, connection: AbstractConnection, From 25f7d94068c25ca6b7184973c0a15d2cbcc205aa Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Mon, 22 Jun 2020 22:13:48 +0200 Subject: [PATCH 06/11] squeeze! --- bindsnet/learning/learning.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 6c7439fc..b9009527 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -19,9 +19,6 @@ class LearningRule(ABC): Abstract base class for learning rules. """ - def NoBatch(a, dim):# get rid of the batch dim - return a[0] - def __init__( self, connection: AbstractConnection, @@ -59,10 +56,9 @@ def __init__( # Parameter update reduction across minibatch dimension. if reduction is None: if self.source.batch_size == 1: - reduction = NoBatch + self.reduction = torch.squeeze else: - reduction = torch.sum - self.reduction = reduction + self.reduction = torch.sum # Weight decay. self.weight_decay = weight_decay From 925f7805e946763adc1ce1ab784ac4bfc2ae94e4 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Mon, 22 Jun 2020 22:52:18 +0200 Subject: [PATCH 07/11] optim for D&C nodes --- bindsnet/network/nodes.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 87903119..ba0d10fb 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -965,12 +965,10 @@ def forward(self, x: torch.Tensor) -> None: self.theta *= self.theta_decay # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh + self.theta From b41bd6499821e4c97667d41c0a5f34786500a679 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Mon, 22 Jun 2020 22:54:13 +0200 Subject: [PATCH 08/11] optim for SRM0 nodes --- bindsnet/network/nodes.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index ba0d10fb..1ab8f2aa 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -1296,7 +1296,7 @@ def forward(self, x: torch.Tensor) -> None: self.v = self.decay * (self.v - self.rest) + self.rest # Integrate inputs. - self.v += (self.refrac_count == 0).float() * self.eps_0 * x + self.v += (self.refrac_count <= 0).float() * self.eps_0 * x # Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials. # Also known as 'escape noise', this simulates nearby neurons. @@ -1304,9 +1304,7 @@ def forward(self, x: torch.Tensor) -> None: self.s_prob = 1.0 - torch.exp(-self.rho * self.dt) # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons (spike when probability > some random number). self.s = torch.rand_like(self.s_prob) < self.s_prob From 2efc37639cf39024aa5c9dcfa59d2459544ac6eb Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Mon, 22 Jun 2020 22:57:39 +0200 Subject: [PATCH 09/11] optim for currentLIF & IF nodes --- bindsnet/network/nodes.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 1ab8f2aa..9bf1fc71 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -375,12 +375,10 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Integrate input voltages. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh @@ -653,13 +651,11 @@ def forward(self, x: torch.Tensor) -> None: self.i *= self.i_decay # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Integrate inputs. self.i += x - self.v += (self.refrac_count == 0).float() * self.i + self.v += (self.refrac_count <= 0).float() * self.i # Check for spiking neurons. self.s = self.v >= self.thresh @@ -808,12 +804,10 @@ def forward(self, x: torch.Tensor) -> None: self.theta *= self.theta_decay # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh + self.theta From d80d5dca7f8b7b2e0bbcde8b7c59a8dcfb3d4fe8 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Tue, 23 Jun 2020 17:43:44 +0200 Subject: [PATCH 10/11] default reduction corrected --- bindsnet/learning/learning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index b9009527..9c582f78 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -59,6 +59,8 @@ def __init__( self.reduction = torch.squeeze else: self.reduction = torch.sum + else: + self.reduction = reduction # Weight decay. self.weight_decay = weight_decay From 3af105b3582b4dafe8d42cefe2e071e5623c91b8 Mon Sep 17 00:00:00 2001 From: Simon Caby Date: Fri, 26 Jun 2020 21:10:15 +0200 Subject: [PATCH 11/11] last changes --- bindsnet/network/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index 9bf1fc71..0ffbf829 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -772,7 +772,7 @@ def __init__( "tc_decay", torch.tensor(tc_decay) ) # Time constant of neuron voltage decay. self.register_buffer( - "decay", torch.empty_like(self.tc_decay) + "decay", torch.empty_like(self.tc_decay, dtype=torch.float32) ) # Set in compute_decays. self.register_buffer( "theta_plus", torch.tensor(theta_plus)