In [1]:
import sys
sys.path.append('../../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code.ANN import lstmPolicyPredictor, FullyConnected
from Code.envs.GPEnv import PassiveEnv
from wip.Code.train import make_dataset_simple
from Code.SNN import DynNetwork, SequenceWrapper
from Code.ANN import LSTMWrapper, ReLuWrapper
import time
from collections import OrderedDict
from Code import Neurons

In [2]:
#TODO: output reset mechanism

config = {
    'ALPHA': 0.5, #0.7
    'BETA': 0.9, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

secondconfig = {
    'ALPHA': 0.7,
    'BETA': 0.5, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    'ALPHA': 0,
    'BETA': 1, #0.95
    'RESET_ZERO': False,
    'SPIKE_FN': 'ss'
}

outconfig = {
    'ALPHA': 0,
    'BETA': 0,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

class Selector(nn.Module):
    def __init__(self, params, size):
        super().__init__()
        self.start = params
        self.end = params + size

    def forward(self, x, h):
        return x[:, self.start:self.end], ()

    def get_initial_state(self, batch_size):
        return ()

architecture1 = OrderedDict([
    ('input', [3]),
    ('obs', [2, ['input'], Selector, 0, None]),
    ('probe', [1, ['input'], Selector, 2, None]),
    ('pre_mem', [128, ['obs', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('mem', [64, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron 32
    ('short_mem', [64, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]), #32
    ('post_mem', [128, ['probe', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('output', [2, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])

architecture0 = OrderedDict([
    ('input', [3]),
    ('obs', [2, ['input'], Selector, 0, None]),
    ('probe', [1, ['input'], Selector, 2, None]),
    ('pre_mem', [128, ['obs', 'mem'], Neurons.LIFNeuron, outconfig, nn.Linear]),
    ('mem', [128, ['pre_mem'], Neurons.CooldownNeuron, heavyside, nn.Linear]), #CooldownNeuron 32
    ('post_mem', [128, ['probe', 'mem'], Neurons.LIFNeuron, outconfig, nn.Linear]),
    ('output', [2, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])

architecture2 = OrderedDict([
    ('input', [3]),
    ('obs', [2, ['input'], Selector, 0, None]),
    ('probe', [1, ['input'], Selector, 2, None]),
    ('pre_mem', [64, ['obs', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('mem', [128, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron 32
    ('short_mem', [128, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]), #32
    ('post_mem', [64, ['probe', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('secondproc', [64, ['post_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('output', [2, ['secondproc'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])

architecturelstm0 = OrderedDict([
    ('input', [3]),
    ('obs', [2, ['input'], Selector, 0, None]),
    ('probe', [1, ['input'], Selector, 2, None]),
    ('lstm', [128, ['obs'], LSTMWrapper, None, nn.Linear]),
    ('post_mem', [128, ['probe', 'lstm'], ReLuWrapper, None, nn.Linear]),
    ('output', [2, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])


BATCH_SIZE = 1024
SIM_TIME = 1
MAX_ITER = 50




In [3]:
device = torch.device('cuda')

env = PassiveEnv(BATCH_SIZE, MAX_ITER, device)
#model = lstmPolicyPredictor(1,32,64)

#model = FullyConnected([1, 128, 128, 1])

model_raw = DynNetwork(architecturelstm0, SIM_TIME)
model = SequenceWrapper(model_raw, BATCH_SIZE, device, False)



In [4]:
mse = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters(), lr=5e-5)#0.000011e-6
#optimizer = optim.Adam(model.model.layers['output_linear'].parameters(), lr=1e-4)#0.000011e-6

In [5]:
start = time.time()

for i in range(1000):
    model.zero_grad()
    inputs, targets = env.getBatch()
    if i%100 == 0:
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs)
    processed = torch.empty_like(outputs)
    processed[:, :, 1] = outputs[:, :, 1]
    processed[:, :, 0] = torch.sigmoid(outputs[:, :, 0])
    loss = mse(processed, targets)
    #loss = mse(outputs[..., 1], targets[..., 1])
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        print(loss.item(), (loss/targets.view(-1).var()).item(), i)

print('Total time: ', time.time()-start)




0.44402673840522766 0.8076158165931702 0
0.30475252866744995 0.5456461906433105 100
0.26499995589256287 0.47420331835746765 200
0.25821569561958313 0.4648166596889496 300
0.26339057087898254 0.4727363884449005 400
0.2609795033931732 0.4643869996070862 500
0.2565402686595917 0.46144288778305054 600
0.26218268275260925 0.4628727436065674 700
0.2490558922290802 0.4608836770057678 800
0.25311094522476196 0.45142123103141785 900
Total time:  266.6013057231903


In [28]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib


inputs, targets = env.getBatch()
outputs, _ = model(inputs)
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 1].cpu(), label='Mean_Target')
plt.scatter(inputs[:, 0, 2].cpu(), outputs[:, 0, 1].detach().cpu(), label='Mean')
plt.scatter(inputs[:, 0, 2].cpu(), targets[:, 0, 0].cpu(), label='Var_Target')
plt.scatter(inputs[:, 0, 2].cpu(), torch.sigmoid(outputs[:, 0, 0].cpu()).detach(), label='Var')
plt.legend()


Using matplotlib backend: TkAgg


<matplotlib.legend.Legend at 0x7fbeccdc5e10>

In [25]:
torch.save(model, '../../models/rsnn_gppred2')


  "type " + obj.__name__ + ". It won't be checked "


In [7]:
obs, targets = env.getBatch()

In [8]:
obs[:, 0]


tensor([[ 0.0725,  1.3078,  0.7469],
        [ 0.7469,  0.1004,  0.1551],
        [ 0.1551,  0.7648,  0.5999],
        [ 0.5999, -0.3369,  0.3660],
        [ 0.3660,  0.6587,  0.2637],
        [ 0.2637,  0.3779,  0.2201],
        [ 0.2201,  0.3814,  0.1194],
        [ 0.1194,  1.0581,  0.8709],
        [ 0.8709,  0.5644,  0.1411],
        [ 0.1411,  0.8821,  0.6016],
        [ 0.6016, -0.3385,  0.0749],
        [ 0.0749,  1.3016,  0.7182],
        [ 0.7182, -0.0277,  0.5491],
        [ 0.5491, -0.1718,  0.2822],
        [ 0.2822,  0.4266,  0.7817],
        [ 0.7817,  0.2554,  0.1489],
        [ 0.1489,  0.8160,  0.9466],
        [ 0.9466,  0.6245,  0.5037],
        [ 0.5037,  0.1100,  0.6078],
        [ 0.6078, -0.3423,  0.3834],
        [ 0.3834,  0.6589,  0.1964],
        [ 0.1964,  0.4742,  0.7783],
        [ 0.7783,  0.2406,  0.3649],
        [ 0.3649,  0.6579,  0.9045],
        [ 0.9045,  0.6172,  0.5368],
        [ 0.5368, -0.1034,  0.7161],
        [ 0.7161, -0.0366,  0.4711],
 

In [9]:
targets[:, 0]


tensor([[ 1.7684e-20,  1.7391e-10],
        [ 5.0578e-01,  9.3009e-01],
        [ 1.1508e-01,  3.4039e-02],
        [ 2.4592e-02, -4.2329e-02],
        [ 6.7668e-01,  4.2292e-01],
        [ 9.8833e-01,  3.8013e-01],
        [ 9.9864e-01,  1.0591e+00],
        [ 2.3511e-01,  1.0906e-01],
        [ 9.9999e-01,  8.8214e-01],
        [ 9.9982e-01, -3.3843e-01],
        [ 9.9999e-01,  1.3017e+00],
        [ 9.8745e-01, -4.8751e-02],
        [ 9.7186e-01, -1.9440e-01],
        [ 9.9988e-01,  4.2651e-01],
        [ 9.9765e-01,  2.5635e-01],
        [ 9.9999e-01,  8.1605e-01],
        [ 8.5181e-01,  5.0005e-01],
        [ 9.9471e-01,  1.1237e-01],
        [ 9.9999e-01, -3.4233e-01],
        [ 9.9987e-01,  6.5898e-01],
        [ 9.9999e-01,  4.7421e-01],
        [ 9.9999e-01,  2.4056e-01],
        [ 9.9999e-01,  6.5790e-01],
        [ 9.9935e-01,  6.1698e-01],
        [ 9.9999e-01, -1.0340e-01],
        [ 9.9999e-01, -3.6590e-02],
        [ 9.9994e-01,  3.2506e-01],
        [ 9.9712e-01,  5.774

In [12]:
#model2 = torch.load('../../models/rsnn_gppred1')


In [15]:
#model2.cpu()

SequenceWrapper(
  (pretrace): DynNetwork(
    (layers): ModuleDict(
      (obs): Selector()
      (probe): Selector()
      (pre_mem_linear): Linear(in_features=130, out_features=128, bias=True)
      (pre_mem): LIFNeuron()
      (mem_linear): Linear(in_features=128, out_features=128, bias=True)
      (mem): CooldownNeuron(
        (elu): ELU(alpha=1.0)
      )
      (post_mem_linear): Linear(in_features=129, out_features=128, bias=True)
      (post_mem): LIFNeuron()
      (output_linear): Linear(in_features=128, out_features=2, bias=True)
      (output): OutputNeuron()
    )
  )
  (model): DynNetwork(
    (layers): ModuleDict(
      (obs): Selector()
      (probe): Selector()
      (pre_mem_linear): Linear(in_features=130, out_features=128, bias=True)
      (pre_mem): LIFNeuron()
      (mem_linear): Linear(in_features=128, out_features=128, bias=True)
      (mem): CooldownNeuron(
        (elu): ELU(alpha=1.0)
      )
      (post_mem_linear): Linear(in_features=129, out_features=128, 

In [19]:
for p in model.model.layers['output_linear'].parameters():
    print(p.shape)

torch.Size([2, 128])
torch.Size([2])


In [18]:
model.model.layers['output_linear'].parameters()

<generator object Module.parameters at 0x7fbeccb995c8>