diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index f602254b..9c582f78 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -13,7 +13,6 @@ ) from ..utils import im2col_indices - class LearningRule(ABC): # language=rst """ @@ -56,9 +55,12 @@ def __init__( # Parameter update reduction across minibatch dimension. if reduction is None: - reduction = torch.mean - - self.reduction = reduction + if self.source.batch_size == 1: + self.reduction = torch.squeeze + else: + self.reduction = torch.sum + else: + self.reduction = reduction # Weight decay. self.weight_decay = weight_decay diff --git a/bindsnet/network/nodes.py b/bindsnet/network/nodes.py index fba2cb70..0ffbf829 100644 --- a/bindsnet/network/nodes.py +++ b/bindsnet/network/nodes.py @@ -375,12 +375,10 @@ def forward(self, x: torch.Tensor) -> None: :param x: Inputs to the layer. """ # Integrate input voltages. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh @@ -509,16 +507,16 @@ def forward(self, x: torch.Tensor) -> None: self.v = self.decay * (self.v - self.rest) + self.rest # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x - + x.masked_fill_(self.refrac_count > 0, 0.0) # OPTIM 2 # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # OPTIM 1 + + self.v += x # interlaced # Check for spiking neurons. self.s = self.v >= self.thresh + # Refractoriness and voltage reset. self.refrac_count.masked_fill_(self.s, self.refrac) self.v.masked_fill_(self.s, self.reset) @@ -653,13 +651,11 @@ def forward(self, x: torch.Tensor) -> None: self.i *= self.i_decay # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Integrate inputs. self.i += x - self.v += (self.refrac_count == 0).float() * self.i + self.v += (self.refrac_count <= 0).float() * self.i # Check for spiking neurons. self.s = self.v >= self.thresh @@ -776,7 +772,7 @@ def __init__( "tc_decay", torch.tensor(tc_decay) ) # Time constant of neuron voltage decay. self.register_buffer( - "decay", torch.empty_like(self.tc_decay) + "decay", torch.empty_like(self.tc_decay, dtype=torch.float32) ) # Set in compute_decays. self.register_buffer( "theta_plus", torch.tensor(theta_plus) @@ -808,12 +804,10 @@ def forward(self, x: torch.Tensor) -> None: self.theta *= self.theta_decay # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh + self.theta @@ -965,12 +959,10 @@ def forward(self, x: torch.Tensor) -> None: self.theta *= self.theta_decay # Integrate inputs. - self.v += (self.refrac_count == 0).float() * x + self.v += (self.refrac_count <= 0).float() * x # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons. self.s = self.v >= self.thresh + self.theta @@ -1298,7 +1290,7 @@ def forward(self, x: torch.Tensor) -> None: self.v = self.decay * (self.v - self.rest) + self.rest # Integrate inputs. - self.v += (self.refrac_count == 0).float() * self.eps_0 * x + self.v += (self.refrac_count <= 0).float() * self.eps_0 * x # Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials. # Also known as 'escape noise', this simulates nearby neurons. @@ -1306,9 +1298,7 @@ def forward(self, x: torch.Tensor) -> None: self.s_prob = 1.0 - torch.exp(-self.rho * self.dt) # Decrement refractory counters. - self.refrac_count = (self.refrac_count > 0).float() * ( - self.refrac_count - self.dt - ) + self.refrac_count -= self.dt # Check for spiking neurons (spike when probability > some random number). self.s = torch.rand_like(self.s_prob) < self.s_prob