In [1]:
import sys
sys.path.append('../')
from torch.distributions.bernoulli import Bernoulli
from torch.distributions.geometric import Geometric
from Code.envs.storerecall import make_batch
import torch
import torch.nn as nn
import torch.optim as optim
import time
from collections import OrderedDict


In [2]:
BATCH_SIZE = 128

USE_JIT = False

#TODO: test device
device = torch.device('cpu')


In [3]:
spec2 = {'beta': 0.9,
   'lr': 0.001,
   'lr_decay': 0.8,
   '1-beta': False,
   'ported_weights': True,
   'NoBias': True,
   'iterations': 5000,
   'batch_size': 128,
   'mem_beta': 1,
   'spkfn': 'ss',
   'decay_out': False,
   'architecture': '1L',
   'control_neuron': 'LIF',
   'mem_neuron': 'Cooldown'}

spec = {'beta': 0.9,
   'lr': 0.01,
   'lr_decay': 0.8,
   '1-beta': False,
   'ported_weights': True,
   'NoBias': True,
   'iterations': 5000,
   'batch_size': 128,
   'mem_beta': 0.95,
   'spkfn': 'bellec',
   'decay_out': False,
   'architecture': '2L',
   'control_neuron': 'LIF',
   'mem_neuron': 'Adaptive'}

spec['iterations'] = 100

In [4]:
from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper, ParallelNetwork, MeanModule
from Code.NewNeurons2 import SeqOnlySpike, CooldownNeuron, OutputNeuron, LIFNeuron, NoResetNeuron, AdaptiveNeuron


built_config = {
    'BETA': spec['beta'],
    'OFFSET': 7, # TODO: was 3 for config24
    'SPIKE_FN': spec['spkfn'],
    '1-beta': spec['1-beta'],
    'ADAPDECAY': 0.9985,
    'ADAPSCALE': 180
}

mem_config = {
    **built_config,
    'BETA': spec['mem_beta']
}

n_input = 3
n_control = 10
n_mem = 10

control_lookup = {
    'LIF': LIFNeuron,
    'Disc': SeqOnlySpike,
    'NoReset': NoResetNeuron
}

mem_lookup = {
    'Adaptive': AdaptiveNeuron,
    'Cooldown': CooldownNeuron,
    'NoReset': NoResetNeuron
}

control_neuron = control_lookup[spec['control_neuron']](n_control, built_config)
mem_neuron = mem_lookup[spec['mem_neuron']](n_mem, mem_config)
out_neuron = OutputNeuron(n_control+n_mem, built_config) if spec['decay_out'] else DummyNeuron(n_control+n_mem, built_config)


loop_2L = OrderedDict([
    ('input', n_input),
    ('control', [['input', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['control'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop_1L = OrderedDict([
    ('input', n_input),
    ('control', [['input', 'control', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['input', 'control', 'mem'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop = loop_1L if spec['architecture'] == '1L' else loop_2L

outer = OrderedDict([
    ('input', n_input),
    ('loop', [['input'], make_SequenceWrapper(ParallelNetwork(loop, bias=(not spec['NoBias'])), USE_JIT), None]),
    ('output', [['loop'], DummyNeuron(1, None), nn.Linear]),
])

model = OuterWrapper(DynNetwork(outer), device, USE_JIT)


#loop_model = OuterWrapper(make_SequenceWrapper(ParallelNetwork(loop), USE_JIT), device, USE_JIT)

#final_linear = nn.Linear(n_control+n_mem, 10).to(device)
'''
if spec['ported_weights']:
    o_weights = pickle.load(open('weight_transplant_enc', 'rb'))

    o1 = torch.tensor(o_weights['RecWeights/RecurrentWeight:0']).t()
    o2 = torch.tensor(o_weights['InputWeights/InputWeight:0']).t()
    o3 = torch.cat((o2, o1), dim=1)
    with torch.no_grad():
        model.pretrace.layers.loop.model.layers.control_synapse.weight.data[:,:300] = o3[:120] if spec['architecture'] == '1L' else o3[:120, :181]
        model.pretrace.layers.loop.model.layers.mem_synapse.weight.data[:,:300] = o3[120:] if spec['architecture'] == '1L' else o3[120:, 180:]
        model.pretrace.layers.output_synapse.weight.data = torch.tensor(o_weights['out_weight:0']).t()
'''
params = list(model.parameters())

model.to(device)


OuterWrapper(
  (pretrace): DynNetwork(
    (layers): ModuleDict(
      (loop): SequenceWrapper(
        (model): ParallelNetwork(
          (layers): ModuleDict(
            (control_synapse): Linear(in_features=13, out_features=10, bias=False)
            (control): LIFNeuron()
            (mem_synapse): Linear(in_features=10, out_features=10, bias=False)
            (mem): AdaptiveNeuron()
            (output): DummyNeuron()
          )
        )
      )
      (output_synapse): Linear(in_features=20, out_features=1, bias=True)
      (output): DummyNeuron()
    )
  )
  (model): DynNetwork(
    (layers): ModuleDict(
      (loop): SequenceWrapper(
        (model): ParallelNetwork(
          (layers): ModuleDict(
            (control_synapse): Linear(in_features=13, out_features=10, bias=False)
            (control): LIFNeuron()
            (mem_synapse): Linear(in_features=10, out_features=10, bias=False)
            (mem): AdaptiveNeuron()
            (output): DummyNeuron()
         

In [5]:
lr = spec['lr']
optimizer = optim.Adam(params, lr=lr)
bce = nn.BCEWithLogitsLoss(reduction='none')

ITERATIONS = spec['iterations']#36000

In [6]:
stats = {
    'grad_norm': [],
    'loss': [],
    'acc': [],
    'batch_var': [],
    'val': []
}

grad_norm_history = []
def record_norm():
    norms = []
    for p in params:
        norms.append(p.grad.norm().item())
    stats['grad_norm'].append(torch.tensor(norms).norm().item())


In [7]:
store_dist = (lambda : Geometric(torch.tensor([0.2], device=device)).sample().int().item()+1)
recall_dist = (lambda : Geometric(torch.tensor([0.2], device=device)).sample().int().item()+1)
SEQ_LEN = 13
CHAR_DUR = 200

In [8]:
start = time.time()
i = 1
sumloss = 0
sumacc = 0

while i < ITERATIONS:
    batchstart = time.time()
    optimizer.zero_grad()
    data = make_batch(BATCH_SIZE, SEQ_LEN, store_dist, recall_dist, device)
    data = data.repeat_interleave(CHAR_DUR, 0)
    input = data[:, :, :3]
    target = data[:, :, 3]
    recall = data[:, :, 0]
    #TODO: repeat data over

    output, _ = model(input)
    output = output.squeeze()
    assert output.shape == target.shape, f'shapes on loss {output.shape}, {target.shape}'
    #TODO: mask with recall
    loss = (bce(output, target)*recall).mean() #correct shape?
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        record_norm()
        stats['loss'].append(loss.item())
        acc = (((output > 0.5).float() == target).float()*recall).mean().item()
        #stats['acc'].append(acc)
        batch_var = 3 #out_final.var(0).mean().item()
        #stats['batch_var'].append(batch_var)

        print(loss.item(), acc)


    sumloss += loss.item()
    sumacc += acc
    if i%20 == 0:
        print(loss.item(), sumloss/20, sumacc/20, time.time()-batchstart, batch_var) #torch.argmax(outputs[-1], 1).float().var()
        sumloss = 0
        sumacc = 0
    if i%2500 == 0:
        lr = lr * spec['lr_decay']
        optimizer = optim.Adam(params, lr=lr)
        print('Learning Rate: ', lr)
    i += 1
    #config['stats'] = stats
    #config['progress'] = i
    #with open('configs/' + run_id + '.json', 'w') as config_file:
    #    json.dump(config, config_file, indent=2)
    #model.save('models/'+run_id)


print('Total time: ', time.time()-start)


0.07343670725822449 0.0518689900636673
0.07853973656892776 0.0625
0.0686025395989418 0.04627403989434242
0.06591444462537766 0.04447115212678909
0.06657587736845016 0.043870192021131516
0.06644821912050247 0.04747596010565758
0.07201509922742844 0.05288461595773697
0.06714820116758347 0.04747596010565758
0.0709623396396637 0.05528846010565758
0.06872430443763733 0.048076923936605453
0.06852830201387405 0.043870192021131516
0.06576674431562424 0.045673076063394547
0.07093815505504608 0.043870192021131516
0.07038183510303497 0.04747596010565758
0.07215932011604309 0.053485576063394547
0.07257229834794998 0.0625
0.06927400082349777 0.05228365212678909
0.06896336376667023 0.053485576063394547
0.0663054883480072 0.048076923936605453
0.0711168646812439 0.051682692021131516
0.0711168646812439 0.06971869207918643 0.05012950673699379 17.891602277755737 3
0.07085995376110077 0.05288461595773697
0.06787104904651642 0.05048076808452606
0.07125521451234818 0.05288461595773697
0.07160301506519318 0.

In [38]:
test_data = make_batch(1, 13, store_dist, recall_dist, device)

In [39]:
a = test_data[:, 0 ,0].nonzero()

In [40]:
b = test_data[:a, 0, 1].nonzero()[-1]

TypeError: only integer tensors of a single element can be converted to an index

In [41]:
test_data[b, 0, 2] == test_data[a, 0, 3]


tensor([[True],
        [True]])