In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code.envs.GPEnv import PassiveEnv
import time
from collections import OrderedDict

In [2]:

BATCH_SIZE = 511#512
SIM_TIME = 1
MAX_ITER = 50

device = torch.device('cuda')

env = PassiveEnv(BATCH_SIZE, MAX_ITER, device)

#torch.backends.cudnn.enabled = False


In [3]:
from Code.Networks import Selector, DynNetwork, SequenceWrapper, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, SequenceTracer
from Code.NewNeurons import SeqOnlySpike, CooldownNeuron, OutputNeuron

base_config = {
    'ALPHA': 0,
    'BETA': 0,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    'ALPHA': 0,
    'BETA': 1, #0.95
    'RESET_ZERO': False,
    'SPIKE_FN': 'ss'
}

mem_loop = OrderedDict([
    ('input', 2),
    ('pre_mem', [['input', 'output'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

architecture = OrderedDict([
    ('input', 3),
    ('obs', [['input'], Selector(0, 2), None]),
    ('probe', [['input'], Selector(2, 1), None]),
    ('mem_loop', [['obs'], SequenceTracer(SequenceWrapper(DynNetwork(mem_loop))), None]),
    ('post_mem', [['probe', 'mem_loop'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])

architecturelstm = OrderedDict([
    ('input', 3),
    ('obs', [['input'], Selector(0, 2), None]),
    ('probe', [['input'], Selector(2, 1), None]),
    ('lstm', [['obs'], LSTMWrapper(2, 128), None]),
    ('post_mem', [['probe', 'lstm'], ReLuWrapper(128), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(2), nn.Linear]),
])

#TODO: fix output


In [4]:
#144, 150, 137, 150

model = OuterWrapper(DynNetwork(architecture), device, BATCH_SIZE, True)

#model = (OuterWrapper(DynNetwork(architecturelstm), device, True))

In [5]:
mse = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=1e-3)#0.000011e-6
#optimizer = optim.Adam(model.model.layers['output_linear'].parameters(), lr=1e-4)#0.000011e-6

In [6]:
start = time.time()

for i in range(30000):
    model.zero_grad()
    inputs, targets = env.getBatch()
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs)
    processed = torch.empty_like(outputs)
    processed[:, :, 1] = outputs[:, :, 1]
    processed[:, :, 0] = torch.sigmoid(outputs[:, :, 0])
    loss = mse(processed, targets)
    #loss = mse(outputs[..., 1], targets[..., 1])
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        print(loss.item(), (loss/targets.view(-1).var()).item(), i)

print('Total time: ', time.time()-start)




0.45530739426612854 0.8238160610198975 0
0.24472668766975403 0.4488821029663086 100
0.25527527928352356 0.45863816142082214 200
0.25438210368156433 0.458394855260849 300
0.26301300525665283 0.4604100286960602 400
0.2676936686038971 0.4658801555633545 500
0.24671602249145508 0.4420275092124939 600
0.2546716034412384 0.4663085639476776 700
0.24688367545604706 0.4592960476875305 800
0.25606802105903625 0.44151419401168823 900
0.2626591920852661 0.45510581135749817 1000
0.25889548659324646 0.452882319688797 1100
0.24958185851573944 0.4435475468635559 1200
0.2506442666053772 0.45471885800361633 1300
0.24214377999305725 0.4501224160194397 1400
0.25714731216430664 0.4649292230606079 1500
0.2522107660770416 0.43640196323394775 1600
0.24609218537807465 0.44134917855262756 1700
0.25715550780296326 0.45595037937164307 1800
0.25728780031204224 0.4579957127571106 1900
0.24441227316856384 0.4433169662952423 2000
0.25768542289733887 0.459758996963501 2100
0.2416718602180481 0.4621039628982544 2200
0.

KeyboardInterrupt: 

In [7]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets = env.getBatch()
outputs, _ = model(inputs)
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 1].cpu(), label='Mean_Target')
plt.scatter(inputs[:, 0, 2].cpu(), outputs[:, 0, 1].detach().cpu(), label='Mean')
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 0].cpu(), label='Var_Target')
plt.scatter(inputs[:, 0, 2].cpu(), torch.sigmoid(outputs[:, 0, 0].cpu()).detach(), label='Var')
plt.legend()


Using matplotlib backend: TkAgg


<matplotlib.legend.Legend at 0x7fad51adaf98>

In [8]:
#model.save('../models/rsnn_gppred3')

In [None]:
#model = (OuterWrapper(torch.load('../models/jittest'), device, True))



In [None]:

229.6/286.2

In [None]:
inputs.shape

In [None]:
model.model

In [None]:
sqw = SequenceWrapper(DynNetwork(mem_loop))
sqw.model = model.model.layers.mem_loop.model
scripted_sqw = torch.jit.script(sqw)

In [None]:
type(((3,)))

In [None]:
import torch.jit

import typing

In [None]:
typing

In [None]:
type((2,3))

In [None]:
dir(typing)

In [None]:
typing.get_origin(int)

In [None]:
typing.get_type_hints(np.zeros_like)

In [None]:
import numpy as np

In [None]:
def test(a:int):
    pass

In [None]:
test.__annotations__

In [None]:
import typing_inspect

In [None]:
typing_inspect.get_generic_type((2,(4,5)))
typing_inspect.

In [None]:
from typing import Tuple, TupleMeta
def getType(obj):
    if type(obj) is tuple:
        return TupleMeta.__getitem__(Tuple, tuple(getType(x) for x in obj))
    return type(obj)

In [None]:
def getTypeName(obj):
    if type(obj) is tuple:
        return 'Tuple['+', '.join([getTypeName(x) for x in obj])+']'
    return 'Tensor'

In [None]:
getTypeName((1,(2,torch.Tensor([3]))))

In [None]:
'# type: (torch.Tensor, '+typing.TupleMeta.__repr__()+') -> Tensor'

In [None]:
TupleMeta.__repr__(Tuple)

In [None]:
Tuple.__name__

In [None]:
type(int)

In [None]:
type(torch.tensor([1.0]))

In [None]:
Tuple[x + 1 for x in [1,2]]

In [None]:
a = Tuple[Tuple[int], int]

In [None]:
dir(a)

In [None]:
a = Tuple[int]

In [None]:
a.__args__ = (int)

In [None]:
a._subs_tree()

In [None]:
b = Tuple[int]


In [None]:
b

In [None]:
typing.get_type_hints(model.model.forward)

In [None]:
dir(model.model)

In [None]:
model.model.forward_magic_method()


In [None]:
model.model._methods

In [None]:

type(Tuple)#%%

import typing

In [None]:
Tuple.__getitem__(int)

In [None]:
typing.TupleMeta.__getitem__(Tuple, [int])

In [None]:
model.model.__annotations__

In [None]:
dir(torch.jit.frontend.build_def)

In [None]:
og_build_def = torch.jit.frontend.build_def


In [None]:
def my_build_def(ctx, py_def, type_line, self_name=None):
    print(type_line)
    #print(dir(py_def))
    #print(py_def.type_comment)
    #print(py_def.args.args[2].annotation)
    return og_build_def(ctx, py_def, type_line, self_name)

In [None]:
torch.jit.frontend.build_def = my_build_def