In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code import Neurons
from Code.ANN import lstmPolicyPredictor, FullyConnected
from Code.envs.MountainCar import LookupPolicy, PassiveEnv
from Code.SNN import DynNetwork, SequenceWrapper
import time
from collections import OrderedDict

In [2]:
#TODO: output reset mechanism

config = {
    'ALPHA': 0.7,
    'BETA': 0.9, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

secondconfig = {
    'ALPHA': 0.7,
    'BETA': 0.5, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

outconfig = {
    'ALPHA': 0,
    'BETA': 0,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}


architecture1 = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config]), #CooldownNeuron
    ('post_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig]), #OutputNeuron
])

architecture = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron
    ('short_mem', [32, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]),
    ('post_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])


BATCH_SIZE = 64
SIM_TIME = 1

In [3]:
device = torch.device('cpu')

env = PassiveEnv(device)

#model = lstmPolicyPredictor(1,32,64)

#model = FullyConnected([1, 128, 128, 1])

model_raw = DynNetwork(architecture, SIM_TIME)
model = SequenceWrapper(model_raw, BATCH_SIZE, device, False)

teacher = LookupPolicy(device)




In [4]:
mse = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=1e-4)#0.000011e-6

In [5]:
start = time.time()

for i in range(3000):
    model.zero_grad()
    inputs, targets, mask = env.getBatch(BATCH_SIZE)
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs/0.52)
    loss = (mse(outputs.squeeze(dim=2)*0.0234, targets) * mask).sum() / mask.sum()
    loss.backward()
    optimizer.step()
    if i%10 == 0:
        print(loss.item(), (loss/targets.view(-1).var()).item(), i)

print('Total time: ', time.time()-start)



0.0006075938581489027 1.1876280307769775 0
0.000551170960534364 1.0760575532913208 10
0.000534778693690896 0.9697951674461365 20
0.0005007761646993458 0.9358938932418823 30
0.00045337475603446364 0.8922518491744995 40
0.0004190044419374317 0.8155099749565125 50
0.00041013999725691974 0.8276923298835754 60
0.0004053821903653443 0.7553707957267761 70
0.0003979110042564571 0.7828456163406372 80
0.00039412942714989185 0.7704535722732544 90
0.00038543003029190004 0.7238419055938721 100
0.0003818403056357056 0.7543662190437317 110
0.00036740131326951087 0.7336852550506592 120
0.0003707163850776851 0.7311034798622131 130
0.0003630129504017532 0.7385413646697998 140
0.0003764972498174757 0.7359243035316467 150
0.0003733926569111645 0.7392640113830566 160
0.0003569390100892633 0.6853536367416382 170
0.0003670275618787855 0.7248878479003906 180
0.00035694229882210493 0.6760374903678894 190
0.000368733424693346 0.6836230754852295 200
0.00036111805820837617 0.6809335350990295 210
0.000371309404727

KeyboardInterrupt: 

In [None]:
#torch.save(model, '../models/rsnn_passive')

In [None]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets, mask = env.getBatch(1)
outputs, _ = model(inputs/0.52)
plt.close()
plt.plot(inputs[:, 0, 0], targets)
plt.plot(inputs[:, 0, 0], outputs.squeeze().detach()*0.0234)