In [1]:
import sys
sys.path.append('../../')
import torch
import torch.nn as nn
import torch.optim as optim
from Code.ANN import lstmPolicyPredictor, FullyConnected
from Code.envs.MountainCar import MultiMountainCar, LookupPolicy, PassiveEnv
from wip.Code.train import make_dataset_simple
from Code.SNN import DynNetwork, SequenceWrapper
import time
from collections import OrderedDict
from Code import Neurons

In [2]:
#TODO: output reset mechanism

config = {
    'ALPHA': 0.7,
    'BETA': 0.9, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

secondconfig = {
    'ALPHA': 0.7,
    'BETA': 0.5, #0.95
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

outconfig = {
    'ALPHA': 0,
    'BETA': 0,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}


architecture1 = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config]), #CooldownNeuron
    ('post_mem', [64, ['input', 'mem'], Neurons.LIFNeuron, config]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig]), #OutputNeuron
])

architecture = OrderedDict([
    ('input', [1]),
    ('pre_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('mem', [32, ['pre_mem'], Neurons.CooldownNeuron, config, nn.Linear]), #CooldownNeuron
    ('short_mem', [32, ['pre_mem'], Neurons.CooldownNeuron, secondconfig, nn.Linear]),
    ('post_mem', [64, ['input', 'mem', 'short_mem'], Neurons.LIFNeuron, config, nn.Linear]),
    ('output', [1, ['post_mem'], Neurons.OutputNeuron, outconfig, nn.Linear]), #OutputNeuron
])


BATCH_SIZE = 64
SIM_TIME = 1

In [3]:
device = torch.device('cpu')

env = MultiMountainCar(device)
MAX_ITER = 200
#model = lstmPolicyPredictor(1,32,64)

#model = FullyConnected([1, 128, 128, 1])

model_raw = DynNetwork(architecture, SIM_TIME)
model = SequenceWrapper(model_raw, BATCH_SIZE, device, False)

teacher = LookupPolicy(device)


In [4]:
import gym
testenv = gym.make('MountainCar-v0')

def validate(num_runs, render=False):
    sum = 0
    for i in range(num_runs):
        obs = testenv.reset()
        state = None
        for t in range(300):
            if render:
                testenv.render()
            output, state = model(torch.tensor([[[obs[0]]]], dtype=torch.float), state)
            action = 2 if output > 0 else 0
            obs, _, done, _ = testenv.step(action)
            if done:
                #print(t+1)
                sum += t + 1
                break
    print('Validation: ', sum/num_runs)




In [5]:
bce = nn.BCELoss(reduction='none') #reduction='sum'
optimizer = optim.Adam(model.parameters(), lr=1e-4)#0.00001

In [6]:
#torch.autograd.set_detect_anomaly(True)

In [10]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)#0.00006

for i in range(5000):
    model.zero_grad()
    observation = env.reset(BATCH_SIZE)
    state = None
    loss = 0
    div = 0
    for t in range(200):
        output, state = model(observation[:,:1].unsqueeze(0), state)
        target = teacher(observation)/2
        #print(observation[:,:1].unsqueeze(0).shape, output.shape, target.shape)
        action = (output.squeeze() > 0) * 2.0
        observation, _, done, _ = env.step(action)
        loss = loss + (bce(torch.sigmoid(output.squeeze()), target) * (~done).float()).sum()
        div = div + (~done).float().sum()
        #print(t, loss)
        if done.all():
            break
    loss = loss / div
    if i%10 == 0:
        print(loss.item(), t+1, i) #, ((outputs>0.5) != targets).sum()
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        validate(10)
    

0.6772387027740479 200 0
Validation:  200.0
0.6757628917694092 200 10
0.6727081537246704 200 20
0.6703775525093079 200 30
0.6707730889320374 200 40
0.6648341417312622 200 50
0.6750781536102295 200 60
0.6877434253692627 200 70
0.679787814617157 200 80
0.6696767210960388 200 90
0.6473906636238098 200 100
Validation:  200.0
0.650684118270874 200 110
0.6186470985412598 200 120
0.6612014770507812 200 130
0.6292560696601868 200 140
0.6390781998634338 200 150
0.6116876006126404 200 160
0.661375105381012 200 170
0.6208813190460205 200 180
0.6210055351257324 200 190
0.621647298336029 200 200
Validation:  200.0
0.6275950074195862 200 210
0.6094839572906494 200 220
0.6008831262588501 200 230
0.5922849774360657 200 240
0.5756155252456665 200 250
0.5645986199378967 200 260
0.5616487264633179 200 270
0.5956995487213135 200 280
0.5689993500709534 200 290
0.5781638622283936 200 300
Validation:  200.0
0.5647542476654053 200 310
0.5438669323921204 200 320
0.5398746728897095 200 330
0.5409409403800964 20

In [6]:
mydict = {}

def train_batch(inputs, targets, mask):
    model.zero_grad()
    outputs, _ = model(inputs)
    loss = (bce(torch.sigmoid(outputs), targets) * mask).sum() / mask.sum()
    loss.backward()
    optimizer.step()
    return loss.item()


def train_dataset(num_batches, num_epochs):
    obs, target, mask = make_dataset_simple(num_batches, BATCH_SIZE, MAX_ITER, model, teacher, device, env)
    for e in range(num_epochs):
        idc = torch.randperm(obs.shape[1], device=device)
        for i in range(num_batches):
            base = i*BATCH_SIZE
            batch_obs = obs[:, idc[base:base + BATCH_SIZE]]
            batch_targets = target[:, idc[base:base + BATCH_SIZE]]
            batch_mask = mask[:, idc[base:base + BATCH_SIZE]]
            #print(batch_obs.shape, batch_targets.shape, batch_mask.shape)
            loss = train_batch(batch_obs, batch_targets, batch_mask)
        for p in model.parameters():
            if torch.isnan(p).any():
                raise Exception('Corrupted Model')
        print(loss)
            #if i%10 == 0:
                #print(loss.item(), (loss/targets.view(-1).var()).item(), i)



In [7]:
for i in range(20):
    print('Bigstep: ', i)
    train_dataset(100, 10)
    validate(10)

Bigstep:  0
0.664725124835968
0.5898089408874512
0.49153774976730347
0.6977851390838623
0.45482107996940613
0.4534326195716858
0.44856181740760803
0.6130235195159912
0.428638756275177
0.4103873074054718
Validation:  109.3
Bigstep:  1
0.25005558133125305
0.20389588177204132
0.18793132901191711
0.1697760373353958
0.16045846045017242
0.14852465689182281
0.14243480563163757
0.13182225823402405
0.13465696573257446
0.14020313322544098
Validation:  103.3
Bigstep:  2
0.16003140807151794
0.1574556976556778
0.15382830798625946
0.13881447911262512
0.1414133608341217
0.14536212384700775
0.14187590777873993
0.128886878490448
0.14120197296142578
0.14318718016147614
Validation:  103.3
Bigstep:  3
0.15053102374076843
0.1493670642375946
0.13940149545669556
0.14449100196361542
0.14355769753456116
0.14049240946769714
0.13741865754127502


KeyboardInterrupt: 

In [11]:
validate(1, render=True)

Validation:  105.0


In [None]:
obs = testenv.reset()
state = None
for t in range(300):
    output, state = model(torch.tensor([[[obs[0]]]], dtype=torch.float), state)
    action = 2 if output > 0 else 0
    obs, _, done, _ = testenv.step(action)
    print(output)


In [8]:
for p in model.parameters():
                    if torch.isnan(p).any():
                        raise Exception('Corrupted Model')

In [None]:
teacher(observation)/2



In [None]:
observation

In [13]:
#torch.save(model, '../../models/rsnn_mountaincar')




In [None]:
model.input_linear.bias


In [28]:

mydict

{'out': tensor([[[-0.0653],
          [ 0.0964],
          [ 0.1616],
          ...,
          [ 0.0449],
          [ 0.1112],
          [ 0.0767]],
 
         [[-0.0532],
          [-0.1147],
          [-0.1384],
          ...,
          [-0.0960],
          [-0.1202],
          [-0.1077]],
 
         [[-0.0761],
          [-0.3994],
          [-0.5322],
          ...,
          [-0.2986],
          [-0.4298],
          [-0.3611]],
 
         ...,
 
         [[-0.9187],
          [-0.9111],
          [-0.9117],
          ...,
          [-0.9049],
          [-0.9030],
          [-0.9077]],
 
         [[-0.9186],
          [-0.9110],
          [-0.9113],
          ...,
          [-0.9055],
          [-0.9033],
          [-0.9079]],
 
         [[-0.9185],
          [-0.9108],
          [-0.9109],
          ...,
          [-0.9060],
          [-0.9035],
          [-0.9081]]], grad_fn=<AddBackward0>), 'tar': tensor([[[ 1.0000],
          [ 0.0000],
          [ 0.0000],
          ...,
     

In [32]:
bce(torch.sigmoid(mydict['out']), mydict['tar']).shape

torch.Size([200, 64, 1])

In [37]:
mydict['tar'][110]


tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])