In [1]:
import sys
sys.path.append('../')
from Code.envs.SequentialMNIST import SequentialMNIST
from Code.train import train, OptWrapper
from Code.everything5 import build_standard_loop
import torch
import torch.nn as nn
from PIL import Image
import numpy as np

In [2]:

spec = {
    'control_config': {
        'neuron_type': 'LIF',
        'n_neurons': 120,
        'TAU': 20,
        '1-beta': False,
        'SPIKE_FN': 'bellec'
    },
    'mem_config': {
        'neuron_type': 'NoReset',
        'n_neurons': 100,
        'TAU': 700,
        '1-beta': False,
        'SPIKE_FN': 'bellec',
        'GAMMA': 0.27,
        'TAU_THR': 10000,
    },
    'experiment': 'SequentialMNIST',
    'lr': 0.001,
    'lr_decay': 0.9,
    'iterations': 36000,
    'batch_size': 64,
    'architecture': '1L'
}

In [3]:
DEVICE = torch.device('cuda')


train_problem = SequentialMNIST(spec['iterations'], spec['batch_size'], DEVICE, '../')
val_problem = SequentialMNIST(-1, spec['batch_size'], DEVICE, '../', validate=True)


n_in, n_out, input_rate = train_problem.get_infos()


In [4]:
from Code.everything6 import OuterWrapper, DynNetwork, ParallelNetwork2, SequenceWrapper, MeanModule, BaseNeuron, ParallelNetwork

from Code.everything6 import LIFNeuron, SeqOnlySpike, AdaptiveNeuron, CooldownNeuron, NoResetNeuron
def build_standard_loop(spec, n_input, input_rate):

    neuron_lookup = {
        'LIF': LIFNeuron,
        'Disc': SeqOnlySpike,
        'Adaptive': AdaptiveNeuron,
        'Cooldown': CooldownNeuron,
        'NoReset': NoResetNeuron,
    }

    control_neuron = neuron_lookup[spec['control_config']['neuron_type']](spec['control_config']['n_neurons'], spec['control_config'])
    mem_neuron = neuron_lookup[spec['mem_config']['neuron_type']](spec['mem_config']['n_neurons'], spec['mem_config'])
    out_neuron_size = control_neuron.out_size + mem_neuron.out_size
    out_neuron = BaseNeuron(out_neuron_size, None)

    loop_2L = {
        'input': (n_input, input_rate),
        'control': [['input', 'mem'], control_neuron, nn.Linear],
        'mem': [['control'], mem_neuron, nn.Linear],
        'output': [['control', 'mem'], out_neuron, None],
    }

    loop_1L = {
        'input': (n_input, input_rate),
        'control': [['input', 'control', 'mem'], control_neuron, nn.Linear],
        'mem': [['input', 'control', 'mem'], mem_neuron, nn.Linear],
        'output': [['control', 'mem'], out_neuron, None],
    }

    return loop_1L if spec['architecture'] == '1L' else loop_2L

loop = build_standard_loop(spec, n_in, input_rate)
loop_model = SequenceWrapper(ParallelNetwork2(loop))
out_neuron_size = loop_model.out_size

#TODO: this has to be ordered
outer = {
    'input': n_in,
    'loop': [['input'], loop_model, None],
    'mean': [['loop'], MeanModule(out_neuron_size, -56), None],
    'output': [['mean'], BaseNeuron(n_out, None), nn.Linear]
}

model = OuterWrapper(DynNetwork(outer))
model.to(DEVICE)

{'control': [[('input', 1), ('control', 1), ('mem', 1)], LIFNeuron(), <class 'torch.nn.modules.linear.Linear'>], 'mem': [[('input', 1), ('control', 1), ('mem', 1)], NoResetNeuron(), <class 'torch.nn.modules.linear.Linear'>], 'output': [[('control', 1), ('mem', 1)], BaseNeuron(), None]}


OuterWrapper(
  (model): DynNetwork(
    (layers): ModuleDict(
      (loop): SequenceWrapper(
        (model): ParallelNetwork2(
          (layers): ModuleDict(
            (control): LIFNeuron()
            (control_synapse): Linear(in_features=301, out_features=120, bias=True)
            (mem): NoResetNeuron()
            (mem_synapse): Linear(in_features=301, out_features=100, bias=False)
            (output): BaseNeuron()
          )
        )
      )
      (mean): MeanModule()
      (output_synapse): Linear(in_features=220, out_features=10, bias=True)
      (output): BaseNeuron()
    )
  )
)

In [5]:
optimizer = OptWrapper(model.parameters(), spec['lr'], spec['lr_decay'], 2500)

In [6]:
train(train_problem, val_problem, optimizer, model, None)

Val Acc: 10.92% | Val Time: 97.8s | Time per it: 49.8s
It:   20 | Loss: 2.316 | Acc: 10.00%
It:   40 | Loss: 2.310 | Acc: 10.86%
It:   60 | Loss: 2.307 | Acc: 11.95%
It:   80 | Loss: 2.308 | Acc: 10.78%
It:  100 | Loss: 2.309 | Acc: 10.94%
Val Acc: 11.36% | Val Time: 104.2s | Time per it: 2.6s
It:  120 | Loss: 2.304 | Acc: 9.61%
It:  140 | Loss: 2.307 | Acc: 10.62%
It:  160 | Loss: 2.304 | Acc: 10.47%
It:  180 | Loss: 2.303 | Acc: 13.05%
It:  200 | Loss: 2.303 | Acc: 11.25%
Val Acc: 9.72% | Val Time: 103.0s | Time per it: 2.7s
It:  220 | Loss: 2.306 | Acc: 10.31%
It:  240 | Loss: 2.302 | Acc: 12.27%
It:  260 | Loss: 2.306 | Acc: 10.78%
It:  280 | Loss: 2.305 | Acc: 10.08%
It:  300 | Loss: 2.308 | Acc: 9.30%
Val Acc: 10.33% | Val Time: 97.5s | Time per it: 2.5s
It:  320 | Loss: 2.304 | Acc: 10.00%
It:  340 | Loss: 2.299 | Acc: 10.47%
It:  360 | Loss: 2.309 | Acc: 9.06%
It:  380 | Loss: 2.304 | Acc: 11.25%
It:  400 | Loss: 2.305 | Acc: 11.41%
Val Acc: 9.78% | Val Time: 97.4s | Time per i

KeyboardInterrupt: 

In [None]:
input = train_problem.make_inputs()#[:,:1]
out, _, log = OuterWrapper(model.model.layers.loop).to(DEVICE)(input, logging=True)
ar_length = input.shape[0]
array_list = [input, log['control'], log['mem']] #log['loop']['output']
array_list2 = []
for ar in array_list:
    array_list2.append(ar[:, 0])
    array_list2.append(torch.ones((ar_length, 1), device=input.device) * 0.5)
big_ar = (1-torch.cat(array_list2[:-1], dim=1).detach()).t() * 255
img = Image.fromarray(big_ar.cpu().numpy().astype(np.uint8), 'L')


In [None]:


img

In [None]:
log['control'].mean()


In [None]:
model.model