In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code.envs.GPEnv import PassiveEnv
import time
from collections import OrderedDict

In [2]:

BATCH_SIZE = 511#512
SIM_TIME = 1
MAX_ITER = 50
USE_JIT = True

device = torch.device('cuda')

env = PassiveEnv(BATCH_SIZE, MAX_ITER, device)

#torch.backends.cudnn.enabled = False


In [3]:
from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper
from Code.NewNeurons import SeqOnlySpike, CooldownNeuron

base_config = {
    'ALPHA': 0,
    'BETA': 0,
    'OFFSET': 2,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    **base_config,
    'BETA': 1,
}

mem_loop = OrderedDict([
    ('input', 2),
    ('pre_mem', [['input', 'output'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

architecture = OrderedDict([
    ('input', 3),
    ('obs', [['input'], Selector(0, 2), None]),
    ('probe', [['input'], Selector(2, 1), None]),
    ('mem_loop', [['obs'], make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT), None]),
    ('post_mem', [['probe', 'mem_loop'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])


architecturelstm = OrderedDict([
    ('input', 3),
    ('obs', [['input'], Selector(0, 2), None]),
    ('probe', [['input'], Selector(2, 1), None]),
    ('lstm', [['obs'], LSTMWrapper(2, 128), None]),
    ('post_mem', [['probe', 'lstm'], ReLuWrapper(128), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])

#TODO: fix output


In [4]:
#144, 150, 137, 150

model = OuterWrapper(DynNetwork(architecture), device, USE_JIT)

#model = (OuterWrapper(DynNetwork(architecturelstm), device, True))

In [5]:
mse = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=1e-3)#0.000011e-6
#optimizer = optim.Adam(model.model.layers['output_linear'].parameters(), lr=1e-4)#0.000011e-6

In [12]:
start = time.time()

for i in range(5000):
    model.zero_grad()
    inputs, targets = env.getBatch()
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs)
    processed = torch.empty_like(outputs)
    processed[:, :, 1] = outputs[:, :, 1]
    processed[:, :, 0] = torch.sigmoid(outputs[:, :, 0])
    loss = mse(processed, targets)
    #loss = mse(outputs[..., 1], targets[..., 1])
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        print(loss.item(), i)

print('Total time: ', time.time()-start)




0.06238705292344093 0
0.06155238673090935 100
0.062144119292497635 200
0.06209178641438484 300
0.0633515864610672 400
0.062264688313007355 500
0.061702221632003784 600
0.059004347771406174 700
0.06729067862033844 800
0.06163351610302925 900
0.0598258338868618 1000
0.057840652763843536 1100
0.060749683529138565 1200
0.057178471237421036 1300
0.05777277424931526 1400
0.05582823604345322 1500
0.057751890271902084 1600
0.0569889210164547 1700
0.05970709025859833 1800
0.056200895458459854 1900
0.05615023151040077 2000
0.05575781315565109 2100
0.05669473484158516 2200
0.05375907942652702 2300
0.05325467139482498 2400
0.05234714224934578 2500
0.059537969529628754 2600
0.054924216121435165 2700
0.05745537579059601 2800
0.05450347438454628 2900
0.05202874913811684 3000
0.05298013240098953 3100
0.05402892455458641 3200
0.05610252171754837 3300
0.05217156559228897 3400
0.05297120660543442 3500
0.05259671062231064 3600
0.054004088044166565 3700
0.052552755922079086 3800
0.05084218084812164 3900
0.

In [7]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets = env.getBatch()
outputs, _ = model(inputs)
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 1].cpu(), label='Mean_Target')
plt.scatter(inputs[:, 0, 2].cpu(), outputs[:, 0, 1].detach().cpu(), label='Mean')
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 0].cpu(), label='Var_Target')
plt.scatter(inputs[:, 0, 2].cpu(), torch.sigmoid(outputs[:, 0, 0].cpu()).detach(), label='Var')
plt.legend()


Using matplotlib backend: TkAgg


<matplotlib.legend.Legend at 0x7f58c8ec9cf8>

In [13]:
#model.save('../models/rsnn_gppred4')


