In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
#from Code import Neurons
#from Code.ANN import lstmPolicyPredictor, FullyConnected
from Code.envs.MountainCar import LookupPolicy, PassiveEnv
#from Code.SNN import DynNetwork, SequenceWrapper
import time
from collections import OrderedDict

In [2]:

BATCH_SIZE = 64#512
SIM_TIME = 1
MAX_ITER = 50
USE_JIT = True

device = torch.device('cpu')

env = PassiveEnv(device)

#torch.backends.cudnn.enabled = False


In [3]:
from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper
from Code.NewNeurons import SeqOnlySpike, CooldownNeuron

base_config = {
    'ALPHA': 0,
    'BETA': 0,
    'OFFSET': 2,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    **base_config,
    'BETA': 1,
}

mem_loop = OrderedDict([
    ('input', 1),
    ('pre_mem', [['input', 'output'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

architecture = OrderedDict([
    ('input', 1),
    ('mem_loop', [['input'], make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT), None]),
    ('post_mem', [['input', 'mem_loop'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(1), nn.Linear]),
])

architecturelstm = OrderedDict([
    ('input', 3),
    ('obs', [['input'], Selector(0, 2), None]),
    ('probe', [['input'], Selector(2, 1), None]),
    ('lstm', [['obs'], LSTMWrapper(2, 128), None]),
    ('post_mem', [['probe', 'lstm'], ReLuWrapper(128), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])

#TODO: fix output


In [4]:
#144, 150, 137, 150

model = OuterWrapper(DynNetwork(architecture), device, USE_JIT)

#model = (OuterWrapper(DynNetwork(architecturelstm), device, True))




In [5]:
teacher = LookupPolicy(device)
mse = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=1e-4)#0.000011e-6

In [6]:
start = time.time()

for i in range(3000):
    model.zero_grad()
    inputs, targets, mask = env.getBatch(BATCH_SIZE)
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs/0.52)
    loss = (mse(outputs.squeeze(dim=2)*0.0234, targets) * mask).sum() / mask.sum()
    loss.backward()
    optimizer.step()
    if i%10 == 0:
        print(loss.item(), (loss/targets.view(-1).var()).item(), i)

print('Total time: ', time.time()-start)



0.0009567285305820405 1.7927567958831787 0
0.0006994449067860842 1.3951340913772583 10
0.0005513299838639796 1.0600453615188599 20
0.0005037413793615997 0.9186327457427979 30
0.0004624322464223951 0.9252512454986572 40
0.0004421067424118519 0.8451938033103943 50
0.0004104664549231529 0.7728644609451294 60
0.0003841915458906442 0.7240529656410217 70
0.0003398799162823707 0.626100480556488 80
0.0003006798797287047 0.5537849068641663 90
0.00029048335272818804 0.5383895039558411 100
0.0002307336835656315 0.4425481855869293 110
0.0002537095278967172 0.4836163818836212 120
0.00022239568352233618 0.40035516023635864 130
0.00020677111751865596 0.385418176651001 140
0.00018125712813343853 0.35809046030044556 150
0.0001942008238984272 0.37526851892471313 160
0.0001934255415108055 0.37782618403434753 170
0.0001830184191931039 0.34564605355262756 180
0.00019408411753829569 0.3763570785522461 190
0.00017673123511485755 0.32548022270202637 200
0.0001735884288791567 0.3365263342857361 210
0.000169477

In [7]:
#torch.save(model, '../models/rsnn_passive')

In [10]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets, mask = env.getBatch(1)
outputs, _ = model(inputs/0.52)
plt.close()
plt.plot(inputs[:, 0, 0], targets)
plt.plot(inputs[:, 0, 0], outputs.squeeze().detach()*0.0234)

Using matplotlib backend: TkAgg


[<matplotlib.lines.Line2D at 0x7fc3e2609e10>]

In [9]:
#TODO: output reset mechanism

config = {
    'ALPHA': 0.7,
    'BETA': 0.9, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

secondconfig = {
    'ALPHA': 0.7,
    'BETA': 0.5, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

outconfig = {
    'ALPHA': 0,
    'BETA': 0,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

'''
architecture1 = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config]), #CooldownNeuron
    ('post_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig]), #OutputNeuron
])

architecture = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron
    ('short_mem', [32, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]),
    ('post_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])


BATCH_SIZE = 64
SIM_TIME = 1


device = torch.device('cpu')

env = PassiveEnv(device)

#model = lstmPolicyPredictor(1,32,64)

#model = FullyConnected([1, 128, 128, 1])

model_raw = DynNetwork(architecture, SIM_TIME)
model = SequenceWrapper(model_raw, BATCH_SIZE, device, False)
'''


"\narchitecture1 = OrderedDict([\n    ('input', [1]),\n    ('pre_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),\n    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config]), #CooldownNeuron\n    ('post_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),\n    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig]), #OutputNeuron\n])\n\narchitecture = OrderedDict([\n    ('input', [1]),\n    ('pre_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),\n    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron\n    ('short_mem', [32, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]),\n    ('post_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),\n    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron\n])\n\n\nBATCH_SIZE = 64\nSIM_TIME = 1\n\n\ndevice = torch.device('cpu')\n\nenv = PassiveEnv(device)\n\n#model = lstmPolicyPr