In [1]:
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code import macroNeurons as Neurons
from Code.lstm import lstmPolicyPredictor, FullyConnected
from Code.envs.MountainCar import MultiMountainCar, LookupPolicy, PassiveEnv
from Code.SNN import RSNN, FeedForwardSNN, magicRSNN, AdaptiveFF, newRSNN
import time

In [2]:
config = {
    'ALPHA': 0.7,
    'BETA': 0.95,
    'RESET_ZERO': False,
    'THRESH_ADD': 1,
    'THRESH_DECAY': 1,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss',
    'SIM_TIME': 10
}
mconfig = {'CooldownNeuron': config,
           'LIFNeuron': config,
           'OutputNeuron': config}
Neurons.set_config(mconfig)



class CooldownNeuron(nn.Module):

    def __init__(self, params, size):
        super(CooldownNeuron, self).__init__()
        self.spike_fn = SuperSpike.apply
        self.beta = params['BETA']
        self.config = params
        self.spike_fn = SuperSpike.apply
        self.elu = torch.nn.ELU()
        self.initial_mem = nn.Parameter(torch.zeros([size]), requires_grad=True)
        self.sgn = torch.ones([size], requires_grad=False)
        self.sgn[(size // 2):] *= (-(1))
        self.size = size

    def get_initial_state(self, batch_size):
        return {'mem':self.initial_mem.expand([batch_size, self.size])}

    def get_initial_output(self, batch_size):
        return (self.sgn < 0).float().expand([batch_size, self.size])

    def forward(self, x, h):
        if (not h):
            h = self.get_initial_state(x.shape[0])
        new_h = {}
        new_h['mem'] = (((self.beta * h['mem']) + self.elu((x - 2))) + 1)
        spikes = self.spike_fn((self.sgn * (new_h['mem'] - 1)))


In [3]:
env = PassiveEnv()

In [4]:
BATCH_SIZE = 64

In [5]:
#model = lstmPolicyPredictor(2,8,16)
#model = RSNN(config, 1, 32, 32, 1, Neurons.LIFNeuron, Neurons.AdaptiveNeuron, Neurons.OutputNeuron)
#model = magicRSNN(config, 1, 32, 32, 1, Neurons.LIFNeuron, Neurons.MagicNeuron, Neurons.OutputNeuron)
#model = FeedForwardSNN(config, [1, 128, 128, 1], Neurons.LIFNeuron, Neurons.OutputNeuron)
#model = FullyConnected([1, 128, 128, 1])
#model = AdaptiveFF(config, 1, 64, 32, 64, 1, Neurons.LIFNeuron, Neurons.AdaptiveNeuron, Neurons.OutputNeuron)
#model = AdaptiveFF(config, 1, 64, 32, 64, 1, Neurons.LIFNeuron, Neurons.FlipFlopNeuron, Neurons.OutputNeuron)
#model = RSNN(config, 1, 32, 32, 1, Neurons.LIFNeuron, Neurons.FlipFlopNeuron, Neurons.OutputNeuron)
#model = newRSNN(config, 1, 32, 32, 32, 1, Neurons.LIFNeuron, Neurons.FlipFlopNeuron, Neurons.OutputNeuron)
model = newRSNN(config, 1, 64, 32, 64, 128, 1, Neurons.LIFNeuron, Neurons.CooldownNeuron, Neurons.OutputNeuron)
#model = newRSNN(config, 1, 64, 32, 64, 128, 1, Neurons.CooldownNeuron, Neurons.CooldownNeuron, Neurons.OutputNeuron)


teacher = LookupPolicy()


#TODO: test superspike instead of Bellec

In [6]:
#inputs, targets, mask = env.getBatch(BATCH_SIZE)
#model = torch.jit.trace(model, inputs)

In [7]:
spikes = torch.ones((13, 1, config['SIM_TIME']))
def logger(h, t, i):
    if t%10 == 0:
        spikes[t//10, 0, i] = h['spikes'][0][:32].sum()
        #spikes[t//10, 1, i] = h['spikes'][0][32:].sum()

        #print(t, i)
        #print(h['spikes'][0].sum())
    

In [8]:
#torch.autograd.set_detect_anomaly(True)

In [9]:
#bce = nn.BCELoss(reduction='none') #reduction='sum'
mse = nn.MSELoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=1e-4)#0.000011e-6

In [10]:
start = time.time()
#inputs, targets, mask = env.getBatch(BATCH_SIZE)
#model(inputs/0.4, h=None, logger=logger)
#print(spikes.squeeze())
for i in range(20):
    model.zero_grad()
    inputs, targets, mask = env.getBatch(BATCH_SIZE)
    if i%100 == 0:
        # torch.autograd.set_detect_anomaly(False)
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
    outputs, _ = model(inputs/0.4)
    #print(outputs.shape, targets.shape)
    loss = (mse(outputs.squeeze(dim=2)/50, targets) * mask).sum() / mask.sum()
    loss.backward()
    optimizer.step()
    #if i%100 == 0:
    #    torch.autograd.set_detect_anomaly(False)
    if i%10 == 0:
        print(loss.item(), (loss/targets.view(-1).var()).item(), i) #, ((outputs>0.5) != targets).sum()
    #if i%50 == 0:
    #    model(inputs/0.4, h=None, logger=logger)
    #    print(spikes.squeeze())
print(time.time()-start)

0.033898670226335526 61.55609130859375 0
0.022605909034609795 42.211151123046875 10
44.07919239997864


In [10]:
model.adaptive_linear.weight

Parameter containing:
tensor([[ 0.0886,  0.0034,  0.0122,  ..., -0.1079, -0.0454,  0.1268],
        [ 0.0730, -0.1046, -0.0994,  ...,  0.0973,  0.0527, -0.0051],
        [ 0.1256, -0.0232,  0.0665,  ..., -0.0530, -0.1733, -0.0602],
        ...,
        [ 0.0011, -0.0089, -0.0370,  ..., -0.0695,  0.0889,  0.1279],
        [ 0.0916,  0.1379, -0.0358,  ...,  0.0395, -0.0350, -0.0625],
        [ 0.1300, -0.0407, -0.0327,  ...,  0.0911,  0.0870,  0.1360]],
       requires_grad=True)

In [11]:
#torch.save(model, '../models/lstm_passive')
#0.0002

In [12]:
#torch.save(model, '../models/withalpha_006')
#0.001

In [10]:
from matplotlib import pyplot as plt
#model = torch.load('../models/snn_passive3')
%matplotlib



Using matplotlib backend: TkAgg


In [11]:
def doplot():
    inputs, targets, mask = env.getBatch(1)
    outputs, _ = model(inputs/0.4)
    plt.close()
    plt.plot(inputs[:, 0, 0], targets)
    plt.plot(inputs[:, 0, 0], outputs.squeeze().detach()/50)

In [12]:
doplot()

In [16]:

for p in model.parameters():
    print(p)

Parameter containing:
tensor([-0.0204, -0.0907,  0.0328,  0.0208, -0.0162, -0.0155, -0.0051,  0.0004,
        -0.0103,  0.0376,  0.0286,  0.0266,  0.0085,  0.0174, -0.0050,  0.1202,
         0.0269, -0.0510, -0.0024,  0.0244, -0.0004, -0.0123,  0.0150,  0.0328,
        -0.0040, -0.0520, -0.0198,  0.0324, -0.0129, -0.0187,  0.0372, -0.0176,
        -0.0086, -0.0481,  0.0040,  0.0074, -0.0169,  0.0119, -0.0310,  0.0245,
        -0.0128, -0.0212, -0.0158,  0.0223,  0.0099,  0.0035,  0.0097, -0.0083,
         0.0372, -0.0129, -0.0219, -0.0241,  0.0306,  0.0258,  0.0064,  0.0308,
         0.0259, -0.0114,  0.0161,  0.0028,  0.0459,  0.0226,  0.0388, -0.0146],
       requires_grad=True)
Parameter containing:
tensor([-0.0074, -0.0094, -0.0251, -0.0063, -0.0061, -0.0195,  0.0115, -0.0146,
         0.0162, -0.0154, -0.0255, -0.0009,  0.0177,  0.0074, -0.0161, -0.0006,
        -0.0131, -0.0003, -0.0208, -0.0075, -0.0098, -0.0122, -0.0198,  0.0032,
        -0.0114, -0.0136, -0.0323, -0.0263, -0.0

In [17]:
model.parameters().__next__()

Parameter containing:
tensor([-0.0204, -0.0907,  0.0328,  0.0208, -0.0162, -0.0155, -0.0051,  0.0004,
        -0.0103,  0.0376,  0.0286,  0.0266,  0.0085,  0.0174, -0.0050,  0.1202,
         0.0269, -0.0510, -0.0024,  0.0244, -0.0004, -0.0123,  0.0150,  0.0328,
        -0.0040, -0.0520, -0.0198,  0.0324, -0.0129, -0.0187,  0.0372, -0.0176,
        -0.0086, -0.0481,  0.0040,  0.0074, -0.0169,  0.0119, -0.0310,  0.0245,
        -0.0128, -0.0212, -0.0158,  0.0223,  0.0099,  0.0035,  0.0097, -0.0083,
         0.0372, -0.0129, -0.0219, -0.0241,  0.0306,  0.0258,  0.0064,  0.0308,
         0.0259, -0.0114,  0.0161,  0.0028,  0.0459,  0.0226,  0.0388, -0.0146],
       requires_grad=True)

In [9]:
#model(inputs/0.4)[0].squeeze()
model.zero_grad()
inputs, targets, mask = env.getBatch(1)
outputs = model(inputs/0.4)
loss = (mse(outputs.squeeze(dim=2)/30, targets) * mask).sum() / mask.sum()
loss.backward()

In [25]:
teacher(torch.cat((torch.zeros([20, 1]), torch.linspace(-0.07, 0, 20).unsqueeze(1)), dim=1))

tensor([0., 0., 0., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

In [28]:
torch.cat((torch.zeros([20, 1]), torch.linspace(-0.07, 0, 20).unsqueeze(1)), dim=1)

tensor([[ 0.0000, -0.0700],
        [ 0.0000, -0.0663],
        [ 0.0000, -0.0626],
        [ 0.0000, -0.0589],
        [ 0.0000, -0.0553],
        [ 0.0000, -0.0516],
        [ 0.0000, -0.0479],
        [ 0.0000, -0.0442],
        [ 0.0000, -0.0405],
        [ 0.0000, -0.0368],
        [ 0.0000, -0.0332],
        [ 0.0000, -0.0295],
        [ 0.0000, -0.0258],
        [ 0.0000, -0.0221],
        [ 0.0000, -0.0184],
        [ 0.0000, -0.0147],
        [ 0.0000, -0.0111],
        [ 0.0000, -0.0074],
        [ 0.0000, -0.0037],
        [ 0.0000,  0.0000]])

In [35]:
teacher(torch.tensor([0, -0.052]))

tensor(2.)

In [8]:

model

FullyConnected(
  (layers): ModuleList(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): Linear(in_features=128, out_features=128, bias=True)
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [15]:

print(model.code)

def forward(self,
    input: Tensor) -> Tensor:
  _0 = getattr(self.layers, "2")
  _1 = getattr(self.layers, "1")
  _2 = getattr(self.layers, "0")
  out = torch.zeros([1], dtype=6, layout=0, device=torch.device("cpu"), pin_memory=False)
  input0 = torch.relu((_2).forward(input, ))
  input1 = torch.relu((_1).forward(input0, ))
  out0 = torch.add(out, (_0).forward(input1, ), alpha=1)
  input2 = torch.relu((_2).forward1(input, ))
  input3 = torch.relu((_1).forward1(input2, ))
  out1 = torch.add(out0, (_0).forward1(input3, ), alpha=1)
  input4 = torch.relu((_2).forward2(input, ))
  input5 = torch.relu((_1).forward2(input4, ))
  out2 = torch.add(out1, (_0).forward2(input5, ), alpha=1)
  input6 = torch.relu((_2).forward3(input, ))
  input7 = torch.relu((_1).forward3(input6, ))
  out3 = torch.add(out2, (_0).forward3(input7, ), alpha=1)
  input8 = torch.relu((_2).forward4(input, ))
  input9 = torch.relu((_1).forward4(input8, ))
  out4 = torch.add(out3, (_0).forward4(input9, ), alpha=1)
  input

In [13]:

torch.set_printoptions(profile='default')