In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code.envs.GPEnv import PassiveEnv
import time
from collections import OrderedDict

In [2]:

BATCH_SIZE = 511#512
SIM_TIME = 1
MAX_ITER = 30
USE_JIT = True

device = torch.device('cuda')

env = PassiveEnv(BATCH_SIZE, MAX_ITER, device, dims=2)

#torch.backends.cudnn.enabled = False


In [3]:
from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper
from Code.NewNeurons import SeqOnlySpike, CooldownNeuron

base_config = {
    'ALPHA': 0,
    'BETA': 0,
    'OFFSET': 2,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    **base_config,
    'BETA': 1,
}

mem_loop = OrderedDict([
    ('input', 3),
    ('pre_mem', [['input', 'output'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

architecture = OrderedDict([
    ('input', 5),
    ('obs', [['input'], Selector(0, 3), None]),
    ('probe', [['input'], Selector(3, 2), None]),
    ('mem_loop', [['obs'], make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT), None]),
    ('post_mem', [['probe', 'mem_loop'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])


architecturelstm = OrderedDict([
    ('input', 5),
    ('obs', [['input'], Selector(0, 3), None]),
    ('probe', [['input'], Selector(3, 2), None]),
    ('lstm', [['obs'], LSTMWrapper(3, 128), None]),
    ('post_mem', [['probe', 'lstm'], ReLuWrapper(128), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])

#TODO: fix output


In [4]:
#144, 150, 137, 150

#model = OuterWrapper(DynNetwork(architecture), device, USE_JIT)

model = (OuterWrapper(DynNetwork(architecturelstm), device, USE_JIT))

In [5]:
mse = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=1e-3)#0.000011e-6
#optimizer = optim.Adam(model.model.layers['output_linear'].parameters(), lr=1e-4)#0.000011e-6

In [6]:
start = time.time()

for i in range(30000):
    model.zero_grad()
    inputs, targets = env.getBatch()
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs)
    processed = torch.empty_like(outputs)
    processed[:, :, 1] = outputs[:, :, 1]
    processed[:, :, 0] = torch.sigmoid(outputs[:, :, 0])
    loss = mse(processed, targets)
    #loss = mse(outputs[..., 1], targets[..., 1])
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        print(loss.item(), i)

print('Total time: ', time.time()-start)




0.21926622092723846 0
0.1784408688545227 100
0.17320823669433594 200
0.17467765510082245 300
0.1751784384250641 400
0.17785213887691498 500
0.17801135778427124 600
0.17599299550056458 700
0.17127470672130585 800
0.1682662069797516 900
0.16456447541713715 1000
0.16388700902462006 1100
0.16164179146289825 1200
0.1549331694841385 1300
0.1544068306684494 1400
0.15725761651992798 1500
0.1535523235797882 1600
0.15380309522151947 1700
0.14938625693321228 1800
0.1514005959033966 1900
0.15436838567256927 2000
0.15108436346054077 2100
0.1547524780035019 2200
0.15933620929718018 2300
0.1535252332687378 2400
0.14758893847465515 2500
0.150563582777977 2600
0.15436677634716034 2700
0.15482568740844727 2800
0.1470799744129181 2900
0.15140101313591003 3000
0.1524466574192047 3100
0.14984077215194702 3200
0.15640103816986084 3300
0.1531815230846405 3400
0.15134358406066895 3500
0.15324170887470245 3600
0.149271160364151 3700
0.148248553276062 3800
0.15539143979549408 3900
0.14655809104442596 4000
0.147

In [7]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets = env.getBatch()
outputs, _ = model(inputs)
plt.scatter(inputs[:, 0, 3].cpu(), targets[:, 0, 1].cpu(), label='Mean_Target')
plt.scatter(inputs[:, 0, 3].cpu(), outputs[:, 0, 1].detach().cpu(), label='Mean')
plt.scatter(inputs[:, 0, 3].cpu(), targets[:, 0, 0].cpu(), label='Var_Target')
plt.scatter(inputs[:, 0, 3].cpu(), torch.sigmoid(outputs[:, 0, 0].cpu()).detach(), label='Var')
plt.legend()


Using matplotlib backend: TkAgg


<matplotlib.legend.Legend at 0x7f678805ac88>

In [11]:
#model.save('../models/lstm_gppred2d1')



In [9]:
inputs[:, 0, 3]

tensor([0.5200, 0.6739, 0.9393, 0.4634, 0.1120, 0.7135, 0.4439, 0.5267, 0.8205,
        0.4589, 0.1338, 0.4874, 0.2105, 0.2922, 0.0298, 0.0973, 0.8926, 0.3645,
        0.7103, 0.0149, 0.8659, 0.6343, 0.3866, 0.3968, 0.8706, 0.2963, 0.0755,
        0.7173, 0.4407, 0.7983], device='cuda:0')

In [10]:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
#ax.scatter(inputs[:, 0, 0].cpu(), inputs[:, 0, 1].cpu(), inputs[:, 0, 2].cpu(), label='Mean_Target')
ax.scatter(inputs[:, 0, 3].cpu(), inputs[:, 0, 4].cpu(), targets[:, 0, 1].cpu(), label='Mean_Target')
ax.scatter(inputs[:, 0, 3].cpu(), inputs[:, 0, 4].cpu(), outputs[:, 0, 1].detach().cpu(), label='Mean')
ax.legend()

<matplotlib.legend.Legend at 0x7f67801e0470>