Skip to content

Commit

Permalink
Merge pull request #77 from Hananel-Hazan/hananel
Browse files Browse the repository at this point in the history
Fix bugs and code improvements in Izhikevich neuron model
  • Loading branch information
Hananel-Hazan committed Jun 17, 2018
2 parents 8bf1b9a + 8023ce4 commit 57a1ee5
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions bindsnet/network/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,14 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
self.s = (self.v >= self.thresh)

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)

# Integrate input and decay voltages.
self.v += inpts
self.v += (self.refrac_count == 0).float() * inpts

super().step(inpts, dt)

Expand Down Expand Up @@ -228,7 +228,7 @@ def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
Expand All @@ -249,7 +249,7 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
self.s = (self.v >= self.thresh)

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(self, n=None, shape=None, traces=False, thresh=-52.0, rest=-65.0,
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.
self.i_decay = i_decay # Rate of decay of synaptic input current.

self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
Expand Down Expand Up @@ -372,7 +372,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.
self.theta_plus = theta_plus # Constant threshold increase on spike.
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.

Expand All @@ -398,7 +398,7 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)
self.s = (self.v >= self.thresh + self.theta)

# Refractoriness, voltage reset, and adaptive thresholds.
self.refrac_count.masked_fill_(self.s, self.refrac)
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(self, n=None, shape=None, traces=False, rest=-65.0, reset=-65.0, th
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.
self.theta_plus = theta_plus # Constant threshold increase on spike.
self.theta_decay = theta_decay # Rate of decay of adaptive thresholds.

Expand All @@ -473,7 +473,7 @@ def step(self, inpts, dt):
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh + self.theta) & (self.refrac_count == 0)
self.s = (self.v >= self.thresh + self.theta)

# Refractoriness, voltage reset, and adaptive thresholds.
self.refrac_count.masked_fill_(self.s, self.refrac)
Expand Down Expand Up @@ -529,7 +529,7 @@ def __init__(self, n=None, shape=None, traces=False, excitatory=True, rest=-65.0
self.reset = reset # Post-spike reset voltage.
self.thresh = thresh # Spike threshold voltage.
self.refrac = refrac # Post-spike refractory period.
self.decay = decay # Rate of decay of neuron voltage.
self.decay = decay # Rate of decay of neuron voltage.

if excitatory:
self.r = torch.rand(n)
Expand All @@ -540,13 +540,12 @@ def __init__(self, n=None, shape=None, traces=False, excitatory=True, rest=-65.0
else:
self.r = torch.rand(n)
self.a = 0.02 + 0.08 * self.r
self.b = 0.25 - 0.05 * torch.ones(n)
self.c = -65.0 * (self.re ** 2)
self.b = 0.25 - 0.05 * self.r
self.c = -65.0 * torch.ones(n)
self.d = 2 * torch.ones(n)

self.v = self.rest * torch.ones(n) # Neuron voltages.
self.u = self.b * self.v # Neuron recovery.
self.refrac_count = torch.zeros(n) # Refractory period counters.

def step(self, inpts, dt):
'''
Expand All @@ -557,15 +556,14 @@ def step(self, inpts, dt):
| :code:`inpts` (:code:`torch.Tensor`): Inputs to the layer.
| :code:`dt` (:code:`float`): Simulation time step.
'''
# Decrement refrac counters.
self.refrac_count[self.refrac_count != 0] -= dt

# Check for spiking neurons.
self.s = (self.v >= self.thresh) & (self.refrac_count == 0)
self.s = (self.v >= self.thresh)

# Refractoriness and voltage reset.
self.refrac_count.masked_fill_(self.s, self.refrac)
self.v.masked_fill_(self.s, self.reset)
self.v = torch.where(self.s, self.c, self.v)
self.u = torch.where(self.s, self.u + self.d, self.u)


# Apply v and u updates.
self.v += dt * (0.04 * (self.v ** 2) + 5 * self.v + 140 - self.u + inpts)
Expand All @@ -579,5 +577,4 @@ def _reset(self):
'''
super()._reset()
self.v = self.rest * torch.ones(self.shape) # Neuron voltages.
self.u = self.b * self.v # Neuron recovery.
self.refrac_count = torch.zeros(self.shape) # Refractory period counters.
self.u = self.b * self.v # Neuron recovery.

0 comments on commit 57a1ee5

Please sign in to comment.