In [1]:
import torch
from torch import nn
from mylib import custom

In [2]:
class BellecSpike(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * torch.max(torch.zeros([1], device=input.device), 1 - torch.abs(input)) * 0.3

class NoResetNeuron(custom.CustomModule):
    def __init__(self, params):
        super().__init__()
        self.beta = params['BETA']
        if params['1-beta'] == 'improved':
            self.factor = (1 - self.beta ** 2) ** (0.5)
        elif params['1-beta']:
            self.factor = (1-self.beta)
        else:
            self.factor = 1
        self.spike_fn = BellecSpike.apply
        self.target_var = 1
        self.est_rate = 0.5

    def enter_in_shape(self, in_shape):
        self.initial_mem = nn.Parameter(torch.zeros(in_shape), requires_grad=True)


    def get_initial_state(self, batch_size):
        return {
            'mem': self.initial_mem.expand([batch_size, self.in_size]),
        }

    def get_initial_output(self, batch_size):
        return self.spike_fn(self.initial_mem.expand([batch_size, self.in_size]) - 1)

    def forward(self, x, h):
        new_h = {}
        new_h['mem'] = self.beta * h['mem'] + self.factor * x
        spikes = self.spike_fn(new_h['mem'].refine_names(*x.names) - 1)
        #print(x.names, new_h['mem'].names, h['mem'].names, spikes.names)
        return spikes, new_h


class LIFNeuron(NoResetNeuron):
    def __init__(self, params):
        super().__init__(params)
        self.est_rate = 0.06

    def forward(self, x, h):
        out, new_h = super().forward(x, h)
        new_h['mem'] = new_h['mem'] - out#.detach()#TODO:remove
        return out, new_h

class DiscontinuousNeuron(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.spike_fn = BellecSpike.apply
        self.threshold = params['THRESHOLD']

    def forward(self, x):
        return self.spike_fn(x-self.threshold)

In [3]:
LIF = custom.register_recurrent(module_class=LIFNeuron, prepare_input='flatten', single_step=True, shape_change='none')
Linear = custom.register_non_recurrent(module_class=(lambda in_shape, out_size: nn.Linear(in_shape[0], out_size)),
                              prepare_input='flatten', shape_change=(lambda in_shape, out_size: (out_size,)))
Disc = custom.register_non_recurrent(module_class=DiscontinuousNeuron, prepare_input='flatten', shape_change='none')
Conv2d = custom.register_non_recurrent(module_class=nn.Conv2d, prepare_input='keep', shape_change='auto')
#do fs_conv with dim changes

In [4]:

lif_config = {
    'SPIKE_FN' : 'bellec',
    'BETA': 0.8,
    '1-beta': False,
}

net = LIF(lif_config)

In [5]:
net.make_model((100,))

OuterModule(
  (inner): RecurrentWrapper(
    (inner): LIFNeuron()
  )
)