In [1]:
import sys
sys.path.append('../')
from Code.envs.SuccessiveLookups import SuccessiveLookups
from Code.train import train, OptWrapper
import torch
import torch.nn as nn

In [2]:

spec = {
    'control_config': {
        'neuron_type': 'LIF',
        'n_neurons': 100,
        'TAU': 5,
        '1-beta': False, # feature used in bellec et al. (should be deactivated for better results)
        'SPIKE_FN': 'bellec' #surrogate gradient function, either 'bellec' (better) or 'superspike'
    },
    'mem_config': {
        'neuron_type': 'Cooldown', #Supported types: 'Adaptive', 'NoReset', 'Cooldown', 'LIF'
        'n_neurons': 10,
        'TAU': 25,
        '1-beta': False,
        'SPIKE_FN': 'bellec',
        'GAMMA': 0.27,
        'TAU_THR': 100,
    },
    'exp_config': {
        'n_sequence': 30,
        'val_sequence': 100,
        'round_length': 50
    },
    'lr': 0.001,
    'lr_decay': 1,
    'iterations': 5000,
    'batch_size': 64,
    'architecture': '1L',
    'device': 'cuda',
    'validation_frequency': 100
}

In [3]:
DEVICE = torch.device(spec['device'])


train_problem = SuccessiveLookups(spec['iterations'], spec['batch_size'], spec['exp_config']['n_sequence'],
                               spec['exp_config']['round_length'], DEVICE)
val_problem = SuccessiveLookups(1, spec['batch_size'], spec['exp_config']['val_sequence'],
                               spec['exp_config']['round_length'], DEVICE)


n_in, n_out, input_rate = train_problem.get_infos()




In [4]:
from Code.networks import OuterWrapper, SequentialNetwork, ParallelNetwork, SequenceWrapper, BaseNeuron, build_standard_loop

loop = build_standard_loop(spec, n_in)
out_neuron_size = spec['control_config']['n_neurons'] + spec['mem_config']['n_neurons']

outer = {
    'input': n_in,
    'loop': [['input'], SequenceWrapper(ParallelNetwork(loop)), None],
    'output': [['loop'], BaseNeuron(n_out, None), nn.Linear]
}

model = OuterWrapper(SequentialNetwork(outer))
model.to(DEVICE)

OuterWrapper(
  (model): SequentialNetwork(
    (layers): ModuleDict(
      (loop): SequenceWrapper(
        (model): ParallelNetwork(
          (layers): ModuleDict(
            (control_synapse): Linear(in_features=168, out_features=100, bias=True)
            (control): LIFNeuron()
            (mem_synapse): Linear(in_features=168, out_features=10, bias=True)
            (mem): CooldownNeuron()
            (output): BaseNeuron()
          )
        )
      )
      (output_synapse): Linear(in_features=110, out_features=8, bias=True)
      (output): BaseNeuron()
    )
  )
)

In [5]:
optimizer = OptWrapper(model.parameters(), spec['lr'], spec['lr_decay'], 2500)

In [None]:
train(train_problem, val_problem, optimizer, model, None, validate_every=spec['validation_frequency'])

Val Acc: 12.28% | Val Time: 4.0s | Time per it: 3.6s
It:   20 | Loss: 2.082 | Acc: 13.03%
It:   40 | Loss: 2.080 | Acc: 12.49%
It:   60 | Loss: 2.078 | Acc: 12.08%
