In [2]:
import numpy as np
import torch
from collections import namedtuple

import matplotlib.pyplot as plt
from jupyterthemes import jtplot
jtplot.style(theme='monokai', context='notebook', ticks=True, grid=False) ## for Dark plots
#jtplot.style(theme='grade3')  ## for Light plots

This notebook translates the Eprop algorithm from tensorflow to pytorch

In [None]:
def sum_of_sines_target(seq_len, n_sines=4, periods=[1000, 500, 333, 200], weights=None, phases=None, normalize=True):
    '''
    Generate a target signal as a weighted sum of sinusoids with random weights and phases.
    :param n_sines: number of sinusoids to combine
    :param periods: list of sinusoid periods
    :param weights: weight assigned the sinusoids
    :param phases: phases of the sinusoids
    :return: one dimensional vector of size seq_len contained the weighted sum of sinusoids
    '''
    if periods is None:
        periods = [np.random.uniform(low=100, high=1000) for i in range(n_sines)]
    assert n_sines == len(periods)
    sines = []
    weights = np.random.uniform(low=0.5, high=2, size=n_sines) if weights is None else weights
    phases = np.random.uniform(low=0., high=np.pi * 2, size=n_sines) if phases is None else phases
    for i in range(n_sines):
        sine = np.sin(np.linspace(0 + phases[i], np.pi * 2 * (seq_len // periods[i]) + phases[i], seq_len))
        sines.append(sine * weights[i])

    output = sum(sines)
    if normalize:
        output = output - output[0]
        scale = max(np.abs(np.min(output)), np.abs(np.max(output)))
        output = output / np.maximum(scale, 1e-6)
    return output

In [4]:
def pseudo_derivative(v_scaled, dampening_factor):
    '''
    Define the pseudo derivative used to derive through spikes.
    :param v_scaled: scaled version of the voltage being 0 at threshold and -1 at rest
    :param dampening_factor: parameter that stabilizes learning
    :return:
    '''
    return torch.maximum(1 - torch.abs(v_scaled), 0) * dampening_factor

class sgt_heaviside(torch.autograd.Function):
    dampening_factor = 20.0
    @staticmethod
    def forward(ctx, input, thr):
        ctx.save_for_backward(input,thr)
        out = torch.zeros_like(input)
        out[input > thr] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        input,thr = ctx.saved_tensors
        dE_dz = grad_output.clone()
        #grad = grad_input/(torch.abs(sgt_heaviside.scale*(input - thr))+1.0)**2
        dE_dz_scaled = torch.maximum(1 - torch.abs(input), 0) * dampening_factor
        return grad, None

In [11]:
LightLIFStateTuple = namedtuple('LightLIFStateTuple', ('v', 'z'))
class LightLIF(torch.nn.Module):
    def __init__(self, n_in, n_rec, tau=20., thr=0.615, dt=1., dtype=torch.float32, dampening_factor=0.3,
                 stop_z_gradients=False):
        super(LightLIF, self).__init__()
        '''
        A tensorflow RNN cell model to simulate Learky Integrate and Fire (LIF) neurons.

        WARNING: This model might not be compatible with tensorflow framework extensions because the input and recurrent
        weights are defined with tf.Variable at creation of the cell instead of using variable scopes.

        :param n_in: number of input neurons
        :param n_rec: number of recurrenet neurons
        :param tau: membrane time constant
        :param thr: threshold voltage
        :param dt: time step
        :param dtype: data type
        :param dampening_factor: parameter to stabilize learning
        :param stop_z_gradients: if true, some gradients are stopped to get an equivalence between eprop and bptt
        '''

        self.dampening_factor = dampening_factor
        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype
        self.stop_z_gradients = stop_z_gradients

        self._num_units = self.n_rec

        self.tau = tau
        self._decay = torch.exp(-dt / self.tau)
        self.thr = thr

        #with tf.variable_scope('InputWeights'):
        self.w_in_var = torch.nn.Parameter(torch.randn(n_in, n_rec, dtype=dtype) / torch.sqrt(n_in))
        self.w_in_val = torch.nn.identity(self.w_in_var)

        #with tf.variable_scope('RecWeights'):
        self.w_rec_var = torch.nn.Parameter(torch.randn(n_rec, n_rec, dtype=dtype) / torch.sqrt(n_rec))
        self.recurrent_disconnect_mask = torch.diag(torch.ones(n_rec)).bool()
        self.w_rec_val = torch.where(self.recurrent_disconnect_mask, torch.zeros_like(self.w_rec_var),
                                  self.w_rec_var)  # Disconnect autotapse

    def state_size(self):
        return LightLIFStateTuple(v=self.n_rec, z=self.n_rec)

    def output_size(self):
        return [self.n_rec, self.n_rec]

    def zero_state(self, batch_size, dtype, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        v0 = torch.zeros(size=(batch_size, n_rec), dtype=dtype)
        z0 = torch.zeros(size=(batch_size, n_rec), dtype=dtype)

        return LightLIFStateTuple(v=v0, z=z0)

    def forward(self, inputs, state, scope=None, dtype=torch.float32):
        # state in tensorflow comes from the RNN cell module. We don't have that in pytorch,
        # so that will be replaced by a simple tuple
        thr = self.thr
        z = state.z
        v = state.v
        decay = self._decay

        if self.stop_z_gradients:
            z = z.requires_grad(False)

        # update the voltage
        i_t = torch.matmul(inputs, self.w_in_val) + torch.matmul(z, self.w_rec_val)
        I_reset = z * self.thr * self.dt
        new_v = decay * v + (1 - decay) * i_t - I_reset

        # Spike generation
        v_scaled = (new_v - thr) / thr
        new_z = SpikeFunction(v_scaled, self.dampening_factor)
        new_z = new_z * 1 / self.dt
        new_state = LightLIFStateTuple(v=new_v, z=new_z)
        return [new_z, new_v], new_state

In [None]:
LightALIFStateTuple = namedtuple('LightALIFState', (
    'z',
    'v',
    'b'))

class LightALIF(LightLIF):
    def __init__(self, n_in, n_rec, tau=20., thr=0.03, dt=1., dtype=tf.float32, dampening_factor=0.3,
                 tau_adaptation=200., beta=1.6, stop_z_gradients=False):

        super(LightALIF, self).__init__()
        self.tau_adaptation = tau_adaptation
        self.beta = beta
        self.decay_b = torch.exp(-dt / tau_adaptation)

    def state_size(self):
        return LightALIFStateTuple(v=self.n_rec, z=self.n_rec, b=self.n_rec)

    def output_size(self):
        return [self.n_rec, self.n_rec, self.n_rec]

    def zero_state(self, batch_size, dtype):
        v0 = torch.zeros(size=(batch_size, self.n_rec), dtype=dtype)
        z0 = torch.zeros(size=(batch_size, self.n_rec), dtype=dtype)
        b0 = torch.zeros(size=(batch_size, self.n_rec), dtype=dtype)
        return LightALIFStateTuple(v=v0, z=z0, b=b0)


    def forward(self, inputs, state, scope=None, dtype=tf.float32):
        z = state.z
        v = state.v
        b = state.b
        decay = self._decay

        # the eligibility traces of b see the spike of the own neuron
        new_b = self.decay_b * b + (1. - self.decay_b) * z
        thr = self.thr + new_b * self.beta
        if self.stop_z_gradients:
            z.requires_grad(False)

        # update the voltage
        i_t = torch.matmul(inputs, self.w_in_val) + torch.matmul(z, self.w_rec_val)
        I_reset = z * self.thr * self.dt
        new_v = decay * v + (1 - decay) * i_t - I_reset

        # Spike generation
        v_scaled = (new_v - thr) / thr
        new_z = SpikeFunction(v_scaled, self.dampening_factor)
        new_z = new_z * 1 / self.dt

        new_state = LightALIFStateTuple(v=new_v,z=new_z, b=new_b)
        return [new_z, new_v, new_b], new_state

In [None]:
EligALIFStateTuple = namedtuple('EligALIFStateTuple', ('s', 'z', 'z_local', 'r'))

class EligALIF(torch.nn.Module):
    def __init__(self, n_in, n_rec, tau=20., thr=0.03, dt=1., dtype=torch.float32, dampening_factor=0.3,
                 tau_adaptation=200., beta=1.6,
                 stop_z_gradients=False, n_refractory=1):
        super(EligALIF, self).__init__()

        if tau_adaptation is None: raise ValueError("alpha parameter for adaptive bias must be set")
        if beta is None: raise ValueError("beta parameter for adaptive bias must be set")

        self.n_refractory = n_refractory
        self.tau_adaptation = tau_adaptation
        self.beta = beta
        self.decay_b = torch.exp(-dt / tau_adaptation)

        if np.isscalar(tau): tau = torch.ones(n_rec, dtype=dtype) * torch.mean(tau)
        if np.isscalar(thr): thr = torch.ones(n_rec, dtype=dtype) * torch.mean(thr)

        tau = tau.type(dtype=dtype)
        dt = dt.type(dtype=dtype)

        self.dampening_factor = dampening_factor
        self.stop_z_gradients = stop_z_gradients
        self.dt = dt
        self.n_in = n_in
        self.n_rec = n_rec
        self.data_type = dtype

        self._num_units = self.n_rec

        self.tau = tau
        self._decay = torch.exp(-dt / tau)
        self.thr = thr

        #with tf.variable_scope('InputWeights'):
        self.w_in_var = torch.nn.Parameter(torch.randn(n_in, n_rec) / torch.sqrt(n_in), dtype=dtype)
        self.w_in_val = torch.nn.identity(self.w_in_var)

        #with tf.variable_scope('RecWeights'):
        self.w_rec_var = torch.nn.Parameter(torch.randn(n_rec, n_rec) / torch.sqrt(n_rec), dtype=dtype)
        self.recurrent_disconnect_mask = torch.diag(torch.ones(n_rec)).bool()
        self.w_rec_val = torch.where(self.recurrent_disconnect_mask, torch.zeros_like(self.w_rec_var),
                                  self.w_rec_var)  # Disconnect self-connection

        self.variable_list = [self.w_in_var, self.w_rec_var]
        self.built = True

    def state_size(self):
        return EligALIFStateTuple(s=[self.n_rec, 2], 
                                  z=self.n_rec, r=self.n_rec, z_local=self.n_rec)

    def output_size(self):
        return [self.n_rec, [self.n_rec, 2]]

    def zero_state(self, batch_size, dtype, n_rec=None):
        if n_rec is None: n_rec = self.n_rec

        s0 = torch.zeros(size=(batch_size, n_rec, 2), dtype=dtype)
        z0 = torch.zeros(size=(batch_size, n_rec), dtype=dtype)
        z_local0 = torch.zeros(size=(batch_size, n_rec), dtype=dtype)
        r0 = torch.zeros(size=(batch_size, n_rec), dtype=dtype)

        return EligALIFStateTuple(s=s0, z=z0, r=r0, z_local=z_local0)

    def compute_z(self, v, b):
        adaptive_thr = self.thr + b * self.beta
        v_scaled = (v - adaptive_thr) / self.thr
        z = SpikeFunction(v_scaled, self.dampening_factor)
        z = z * 1 / self.dt
        return z

    def compute_v_relative_to_threshold_values(self,hidden_states):
        v = hidden_states[..., 0]
        b = hidden_states[..., 1]

        adaptive_thr = self.thr + b * self.beta
        v_scaled = (v - adaptive_thr) / self.thr
        return v_scaled

    def forward(self, inputs, state, scope=None, dtype=torch.float32, stop_gradient=None):

        decay = self._decay
        z = state.z
        z_local = state.z_local
        s = state.s
        r = state.r
        v, b = s[..., 0], s[..., 1]

        # This stop_gradient allows computing e-prop with auto-diff.
        #
        # needed for correct auto-diff computation of gradient for threshold adaptation
        # stop_gradient: forward pass unchanged, gradient is blocked in the backward pass
        use_stop_gradient = stop_gradient if stop_gradient is not None else self.stop_z_gradients
        if use_stop_gradient:
            z.requires_grad(False)

        new_b = self.decay_b * b + z_local # threshold update does not have to depend on the stopped-gradient-z, it's local

        i_t = tf.matmul(inputs, self.w_in_val) + tf.matmul(z, self.w_rec_val) # gradients are blocked in spike transmission
        I_reset = z * self.thr * self.dt
        new_v = decay * v + i_t - I_reset

        # Spike generation
        is_refractory = r > 0
        zeros_like_spikes = torch.zeros_like(z)
        #################
        # look at tf.where
        #################
        new_z = torch.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, new_b))
        new_z_local = torch.where(is_refractory, zeros_like_spikes, self.compute_z(new_v, new_b))
        new_r = r + self.n_refractory * new_z - 1
        new_r = torch.clamp(new_r, 0., float(self.n_refractory))
        new_r.requires_grad(False)
        new_s = torch.stack((new_v, new_b), dim=-1)

        new_state = EligALIFStateTuple(s=new_s, z=new_z, r=new_r, z_local=new_z_local)
        return [new_z, new_s], new_state

    
    def compute_eligibility_traces(self, v_scaled, z_pre, z_post, is_rec):

        n_neurons = tf.shape(z_post)[2]
        rho = self.decay_b
        beta = self.beta
        alpha = self._decay
        n_ref = self.n_refractory

        # everything should be time major
        z_pre = z_pre.permute(1, 0, 2)
        v_scaled = v_scaled.permute(1, 0, 2)
        z_post = z_post.permute(1, 0, 2)

        psi_no_ref = self.dampening_factor / self.thr * torch.maximum(0., 1. - torch.abs(v_scaled))

        update_refractory = lambda refractory_count, z_post:\
            torch.where(z_post > 0,torch.ones_like(refractory_count) * (n_ref - 1),torch.maximum(0, refractory_count - 1))

        refractory_count_init = torch.zeros_like(z_post[0], dtype=torch.int32)
        #refractory_count = tf.scan(update_refractory, z_post[:-1], initializer=refractory_count_init)
        refractory_count = []
        refractory_count = update_refractory( refractory_count_init, z_post[:-1])
        for i in range( 1,z_post.size(0)-1 ):
            refractory_count.append( update_refractory( refractory_count_init, z_post[:-1]) )
        refractory_count = torch.stack(refractory_count, dim=0)
        refractory_count = torch.cat([[refractory_count_init], refractory_count], dim=0)

        is_refractory = refractory_count > 0
        psi = torch.where(is_refractory, tf.zeros_like(psi_no_ref), psi_no_ref)

        update_epsilon_v = lambda epsilon_v, z_pre: alpha[None, None, :] * epsilon_v + z_pre[:, :, None]
        epsilon_v_zero = torch.ones((1, 1, n_neurons)) * z_pre[0][:, :, None]
        #epsilon_v = tf.scan(update_epsilon_v, z_pre[1:], initializer=epsilon_v_zero, )
        epsilon_v = []
        epsilon_past = update_epsilon_v( epsilon_past, z_pre[1] )
        epsilon_v.append( epsilon_past )
        for i in range( 1,z_pre[1:].size() ):
            epsilon_past = update_epsilon_v( epsilon_past, z_pre[i] )
            epsilon_v.append( epsilon_past )
        epsilon_v = torch.stack( epsilon_v, dim=0 )
        epsilon_v = tf.concat([[epsilon_v_zero], epsilon_v], dim=0)

        update_epsilon_a = lambda epsilon_a, elems:\
                (rho - beta * elems['psi'][:, None, :]) * epsilon_a + elems['psi'][:, None, :] * elems['epsi']

        ####################################################################################
        epsilon_a_zero = tf.zeros_like(epsilon_v[0])
        epsilon_a = tf.scan(fn=update_epsilon_a,
                            elems={'psi': psi[:-1], 'epsi': epsilon_v[:-1], 'previous_epsi':shift_by_one_time_step(epsilon_v[:-1])},
                            initializer=epsilon_a_zero)
        ####################################################################################

        epsilon_a = tf.concat([[epsilon_a_zero], epsilon_a], axis=0)

        e_trace = psi[:, :, None, :] * (epsilon_v - beta * epsilon_a)

        # everything should be time major
        e_trace = e_trace.permute( 1, 0, 2, 3 )
        epsilon_v = epsilon_v.permute( 1, 0, 2, 3 )
        epsilon_a = epsilon_a.permute( 1, 0, 2, 3 )
        psi = psi.permute( 1, 0, 2 )

        if is_rec:
            identity_diag = torch.eye(n_neurons)[None, None, :, :]
            e_trace -= identity_diag * e_trace
            epsilon_v -= identity_diag * epsilon_v
            epsilon_a -= identity_diag * epsilon_a

        return e_trace, epsilon_v, epsilon_a, psi

    def compute_loss_gradient(self, learning_signal, z_pre, z_post, v_post, b_post,
                              decay_out=None,zero_on_diagonal=None):
        thr_post = self.thr + self.beta * b_post
        v_scaled = (v_post - thr_post) / self.thr

        e_trace, epsilon_v, epsilon_a, _ = self.compute_eligibility_traces(v_scaled, z_pre, z_post, zero_on_diagonal)

        if decay_out is not None:
            e_trace_time_major = e_trace.permute(1, 0, 2, 3)
            filtered_e_zero = torch.zeros_like(e_trace_time_major[0])
            ####################################################################################
            filtering = lambda filtered_e, e: filtered_e * decay_out + e * (1 - decay_out)
            filtered_e = tf.scan(filtering, e_trace_time_major, initializer=filtered_e_zero)
            ####################################################################################
            filtered_e = (filtered_e.permute( 1, 0, 2, 3 )
            e_trace = filtered_e

        gradient = torch.einsum('btj,btij->ij', learning_signal, e_trace)
        return gradient, e_trace, epsilon_v, epsilon_a