In [1]:
import sys
sys.path.append('../')
from Code.envs.SequentialMNIST import SequentialMNIST
from Code.train import train, OptWrapper
import torch
import torch.nn as nn


In [2]:

spec = {
    'control_config': {
        'neuron_type': 'LIF',
        'n_neurons': 120,
        'TAU': 20,
        '1-beta': False,
        'SPIKE_FN': 'bellec'
    },
    'mem_config': {
        'neuron_type': 'Cooldown',
        'n_neurons': 100,
        'TAU': 700,
        '1-beta': False,
        'SPIKE_FN': 'bellec',
        'GAMMA': 0.27,
        'TAU_THR': 10000,
    },
    'experiment': 'SequentialMNIST',
    'lr': 0.001,
    'lr_decay': 0.9,
    'iterations': 36000,
    'batch_size': 128,
    'architecture': '1L',
    'device': 'cuda',
    'validation_frequency': 100
}

In [3]:
DEVICE = torch.device(spec['device'])


train_problem = SequentialMNIST(spec['iterations'], spec['batch_size'], DEVICE, '../')
val_problem = SequentialMNIST(-1, spec['batch_size'], DEVICE, '../', validate=True)


n_in, n_out, input_rate = train_problem.get_infos()


In [4]:
from Code.networks import OuterWrapper, SequenceWrapper, MeanModule, BaseNeuron, ParallelNetwork, LIFNeuron,\
    AdaptiveNeuron, CooldownNeuron, NoResetNeuron, SequentialNetwork


neuron_lookup = {
    'LIF': LIFNeuron,
    'Adaptive': AdaptiveNeuron,
    'Cooldown': CooldownNeuron,
    'NoReset': NoResetNeuron,
}

control_neuron = neuron_lookup[spec['control_config']['neuron_type']](spec['control_config']['n_neurons'], spec['control_config'])
mem_neuron = neuron_lookup[spec['mem_config']['neuron_type']](spec['mem_config']['n_neurons'], spec['mem_config'])
out_neuron_size = control_neuron.out_size + mem_neuron.out_size
out_neuron = BaseNeuron(out_neuron_size, None)

loop_2L = {
    'input': n_in,
    'control': [['input', 'mem'], control_neuron, nn.Linear],
    'mem': [['control'], mem_neuron, nn.Linear],
    'output': [['control', 'mem'], out_neuron, None],
}

loop_1L = {
    'input': n_in,
    'control': [['input', 'control', 'mem'], control_neuron, nn.Linear],
    'mem': [['input', 'control', 'mem'], mem_neuron, nn.Linear],
    'output': [['control', 'mem'], out_neuron, None],
}

loop = loop_1L if spec['architecture'] == '1L' else loop_2L
loop_model = SequenceWrapper(ParallelNetwork(loop))
out_neuron_size = loop_model.out_size

#TODO: this has to be ordered
outer = {
    'input': n_in,
    'loop': [['input'], loop_model, None],
    'mean': [['loop'], MeanModule(out_neuron_size, -56), None],
    'output': [['mean'], BaseNeuron(n_out, None), nn.Linear]
}

model = OuterWrapper(SequentialNetwork(outer))
model.to(DEVICE)

OuterWrapper(
  (model): SequentialNetwork(
    (layers): ModuleDict(
      (loop): SequenceWrapper(
        (model): ParallelNetwork(
          (layers): ModuleDict(
            (control_synapse): Linear(in_features=301, out_features=120, bias=True)
            (control): LIFNeuron()
            (mem_synapse): Linear(in_features=301, out_features=100, bias=True)
            (mem): NoResetNeuron()
            (output): BaseNeuron()
          )
        )
      )
      (mean): MeanModule()
      (output_synapse): Linear(in_features=220, out_features=10, bias=True)
      (output): BaseNeuron()
    )
  )
)

In [5]:
optimizer = OptWrapper(model.parameters(), spec['lr'], spec['lr_decay'], 2500)

In [6]:
train(train_problem, val_problem, optimizer, model, None, validate_every=spec['validation_frequency'])

Val Acc: 13.68% | Val Time: 101.2s | Time per it: 51.5s
It:   20 | Loss: 2.315 | Acc: 10.00%
It:   40 | Loss: 2.311 | Acc: 8.98%
It:   60 | Loss: 2.311 | Acc: 9.53%
It:   80 | Loss: 2.308 | Acc: 11.88%
It:  100 | Loss: 2.307 | Acc: 11.25%
Val Acc: 10.07% | Val Time: 99.7s | Time per it: 2.6s


KeyboardInterrupt: 