In [1]:
import torch
from torch import nn
from mylib import custom
import mylib as ml

In [2]:
class BellecSpike(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return grad_output * torch.max(torch.zeros([1], device=input.device), 1 - torch.abs(input)) * 0.3

class NoResetNeuron(custom.CustomModule):
    def __init__(self, params):
        super().__init__()
        self.beta = params['BETA']
        if params['1-beta'] == 'improved':
            self.factor = (1 - self.beta ** 2) ** (0.5)
        elif params['1-beta']:
            self.factor = (1-self.beta)
        else:
            self.factor = 1
        self.spike_fn = BellecSpike.apply
        self.target_var = 1
        self.est_rate = 0.5

    def enter_in_shape(self, in_shape):
        self.initial_mem = nn.Parameter(torch.zeros(in_shape), requires_grad=True)
        self.in_size = in_shape[0]


    def get_initial_state(self, batch_size):
        return {
            'mem': self.initial_mem.expand([batch_size, self.in_size]),
        }

    def get_initial_output(self, batch_size):
        return self.spike_fn(self.initial_mem.expand([batch_size, self.in_size]) - 1)

    def forward(self, x, h):
        new_h = {}
        new_h['mem'] = self.beta * h['mem'] + self.factor * x
        spikes = self.spike_fn(new_h['mem'] - 1)
        return spikes, new_h


class LIFNeuron(NoResetNeuron):
    def __init__(self, params):
        super().__init__(params)
        self.est_rate = 0.06

    def forward(self, x, h):
        out, new_h = super().forward(x, h)
        new_h['mem'] = new_h['mem'] - out#.detach()#TODO:remove
        return out, new_h

class DiscontinuousNeuron(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.spike_fn = BellecSpike.apply
        self.threshold = params['THRESHOLD']

    def forward(self, x):
        return self.spike_fn(x-self.threshold)

class ConvFSModule(custom.CustomModule): # inherit from conv?
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.args = args
        self.kwargs = kwargs
        self.inner = None

    def enter_in_shape(self, in_shape):
        self.inner = nn.Conv3d(in_channels=in_shape[0], *self.args, **self.kwargs)

    def forward(self, x, _):
        x = x.permute(1, 2, 0, 3, 4)
        x = self.inner(x)
        return x.permute(2, 0, 1, 3, 4), ()


In [3]:
LIF = custom.register_recurrent(module_class=LIFNeuron, prepare_input='flatten', single_step=True, shape_change='none')
Linear = custom.register_non_recurrent(module_class=(lambda in_shape, out_size: nn.Linear(in_shape[0], out_size)),
                              prepare_input='flatten', shape_change=(lambda in_shape, out_size: (out_size,)))
Disc = custom.register_non_recurrent(module_class=DiscontinuousNeuron, prepare_input='flatten', shape_change='none')
Conv2d = custom.register_non_recurrent(module_class=(lambda in_shape, *args, **kwargs: nn.Conv2d(in_shape[0], *args, **kwargs)), prepare_input='keep', shape_change='auto')
ConvFS = custom.register_recurrent(module_class=ConvFSModule, prepare_input='keep', single_step=False, unroll_full_state=False, shape_change='auto')
#do fs_conv with dim changes

In [4]:
seq = 8
batch = 32
inp_shape = (1,92,76)
example = torch.zeros((seq,batch)+inp_shape)
FRAME_STACK = 4
N_OUT = 3

lif_config = {
    'SPIKE_FN' : 'bellec',
    'BETA': 0.5,#0.8,
    '1-beta': False,
}

disc_config = {
    'SPIKE_FN' : 'bellec',
    'THRESHOLD' : 1
}

'''conv = ConvPath([
    Conv3({'out_channels': 32, 'kernel_size': 8, 'stride': 4, 'frame_stack': FRAME_STACK}, CONV_NEURON),
    Conv2({'out_channels': 64, 'kernel_size': 4, 'stride': 2}, CONV_NEURON),
    Conv2({'out_channels': 64, 'kernel_size': 3, 'stride': 1}, CONV_NEURON)])'''
# TODO: neurons
conv_stack = ml.Sequential(ConvFS(out_channels=32, kernel_size=(FRAME_STACK, 8, 8), stride=(1, 4, 4)), Disc(disc_config),
                     Conv2d(out_channels=64, kernel_size=4, stride=2), Disc(disc_config),
                     Conv2d(out_channels=64, kernel_size=3, stride=1), Disc(disc_config))

ll = ml.Network()
ll.output = ml.Placeholder()
ll.output = ll.input.stack(ll.output).apply(Linear(512), LIF(lif_config))#ml.Layer([ll.input, ll.output], [Linear(512), LIF(lif_config)])

overall = ml.Sequential(conv_stack, ll, Linear(N_OUT))

In [5]:
model = overall.make_model(inp_shape)




In [6]:
from factories_old import Network, Linear, Seq, ExecPath, Conv3, Conv2, ConvPath, Disc, LIF, ReLU, LSTM, PotOut


In [7]:

CONV_NEURON = Disc(disc_config) # Seq(LIF(lif_config)) # ReLU()
ll_rsnn = Seq(Network(ExecPath(['input', 'output'], [Linear(512), LIF(lif_config)], 'output')))
ll_snn = Network(ExecPath(['input'], [Linear(512), Seq(LIF(lif_config))], 'output'))
ll_lstm = LSTM(512)
ll_ffann = Network(ExecPath(['input'], [Linear(512), ReLU()], 'output'))
LAST_LAYER = ll_rsnn
conv = ConvPath([
    Conv3({'out_channels': 32, 'kernel_size': 8, 'stride': 4, 'frame_stack': FRAME_STACK}, CONV_NEURON),
    Conv2({'out_channels': 64, 'kernel_size': 4, 'stride': 2}, CONV_NEURON),
    Conv2({'out_channels': 64, 'kernel_size': 3, 'stride': 1}, CONV_NEURON)])
new_model_fac = Network(conv, ExecPath(['conv'], [LAST_LAYER, Linear(N_OUT)], 'output'))
make_model = lambda: new_model_fac(inp_shape)

In [8]:
old_model = make_model()

In [9]:
for name, p in old_model.named_parameters():
    print(name, p.shape)

paths.conv.0.conv.conv.weight torch.Size([32, 1, 4, 8, 8])
paths.conv.0.conv.conv.bias torch.Size([32])
paths.conv.1.conv.conv.weight torch.Size([64, 32, 4, 4])
paths.conv.1.conv.conv.bias torch.Size([64])
paths.conv.2.conv.conv.weight torch.Size([64, 64, 3, 3])
paths.conv.2.conv.conv.bias torch.Size([64])
paths.output.0.model.paths.output.0.linear.weight torch.Size([512, 3584])
paths.output.0.model.paths.output.0.linear.bias torch.Size([512])
paths.output.0.model.paths.output.1.initial_mem torch.Size([512])
paths.output.1.linear.weight torch.Size([3, 512])
paths.output.1.linear.bias torch.Size([3])


In [10]:
for name, p in model.named_parameters():
    print(name, p.shape)

inner.mlist.0.mlist.0.inner.inner.weight torch.Size([32, 1, 4, 8, 8])
inner.mlist.0.mlist.0.inner.inner.bias torch.Size([32])
inner.mlist.0.mlist.2.inner.weight torch.Size([64, 32, 4, 4])
inner.mlist.0.mlist.2.inner.bias torch.Size([64])
inner.mlist.0.mlist.4.inner.weight torch.Size([64, 64, 3, 3])
inner.mlist.0.mlist.4.inner.bias torch.Size([64])
inner.mlist.1.layers.c0.layers.output.mlist.0.inner.weight torch.Size([512, 3584])
inner.mlist.1.layers.c0.layers.output.mlist.0.inner.bias torch.Size([512])
inner.mlist.1.layers.c0.layers.output.mlist.1.inner.initial_mem torch.Size([512])
inner.mlist.2.inner.weight torch.Size([3, 512])
inner.mlist.2.inner.bias torch.Size([3])


In [11]:
with torch.no_grad():
    for i in range(len(list(model.named_parameters()))):
        par1 = list(model.named_parameters())[i][1]
        par1.data = list(old_model.named_parameters())[i][1].data.view(par1.shape)

In [12]:
out, h = model(example)

out_old, _ = old_model(example.refine_names('time', 'batch', 'C', 'H', 'W'), old_model.get_initial_state(batch))

  return super(Tensor, self).refine_names(names)


In [13]:
torch.isclose(out_old.rename(None), out).all()

tensor(True)