In [1]:

import sys
#TODO: path
sys.path.append('../../')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
import time
from collections import OrderedDict
from torch.utils.data import DataLoader
import pickle

#TODO: change to 256
BATCH_SIZE = 128#256

USE_JIT = False

device = torch.device('cuda')


mnist = MNIST('../../', transform=transforms.ToTensor(), download=True) #distortion_transform([0,15], 3)
test = MNIST('../../', transform=transforms.ToTensor(), train=False)


data_loader = DataLoader(mnist, batch_size=BATCH_SIZE, drop_last=True, num_workers=0, shuffle=True)

test_loader = DataLoader(test, batch_size=1024, drop_last=False, num_workers=0)

like_bellec = {
    'spkfn' : 'bellec',
    'spkconfig' : 0,
    'architecture': '1L',
    'beta': 0.95,
    'control_neuron': 'LIF',
    'mem_neuron' : 'Adaptive',
    'lr' : 1e-2,
    '1-beta': True,
    'decay_out': True
}

spec = like_bellec
#spec['decay_out'] = False
#TODO: remove
#spec['1-beta'] = False

from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper, ParallelNetwork
from Code.NewNeurons2 import SeqOnlySpike, CooldownNeuron, OutputNeuron, LIFNeuron, NoResetNeuron, AdaptiveNeuron

built_config = {
    'BETA': spec['beta'],
    'OFFSET': 2, # TODO: this?
    'SPIKE_FN': spec['spkfn'],
    '1-beta': spec['1-beta'],
    'ADAPDECAY': 0.9985,
    'ADAPSCALE': 180
}

n_control = 120
n_mem = 100

control_lookup = {
    'LIF': LIFNeuron,
    'Disc': SeqOnlySpike,
    'NoReset': NoResetNeuron
}

mem_lookup = {
    'Adaptive': AdaptiveNeuron,
    'Cooldown': CooldownNeuron,
    'NoReset': NoResetNeuron
}

control_neuron = control_lookup[spec['control_neuron']](n_control, built_config)
mem_neuron = mem_lookup[spec['mem_neuron']](n_mem, built_config)
out_neuron = OutputNeuron(n_control+n_mem, built_config) if spec['decay_out'] else DummyNeuron(n_control+n_mem, built_config)


loop_2L = OrderedDict([
    ('input', 81),
    ('control', [['input', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['control'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop_1L = OrderedDict([
    ('input', 81),
    ('control', [['input', 'control', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['input', 'control', 'mem'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop = loop_1L if spec['architecture'] == '1L' else loop_2L

loop_model = OuterWrapper(make_SequenceWrapper(ParallelNetwork(loop), USE_JIT), device, USE_JIT)

final_linear = nn.Linear(n_control+n_mem, 10).to(device)



params = list(loop_model.parameters())+list(final_linear.parameters())
lr = spec['lr']
optimizer = optim.Adam(params, lr=lr)
ce = nn.CrossEntropyLoss()

'''

#TODO: check correctness here

with torch.no_grad():
    for i in range(100):
        loop_model.pretrace.model.layers.mem_synapse.weight[i, i+201] = 0

    for i in range(120):
        loop_model.pretrace.model.layers.control_synapse.weight[i, i+81] = 0

'''


trigger_signal = torch.ones([783+56, 1, 1], device=device)
trigger_signal[:783] = 0
def encode_input(curr, last):
    out = torch.zeros([783+56, curr.shape[1], 2,40], device=curr.device)
    out[:783, :, 0, :] = ((torch.arange(40, device=curr.device) < 40 * last) & (torch.arange(40, device=curr.device) > 40 * curr)).float()
    out[:783, :, 1, :] = ((torch.arange(40, device=curr.device) > 40 * last) & (torch.arange(40, device=curr.device) < 40 * curr)).float()
    out = torch.cat((out.view([783+56, curr.shape[1], 80]), trigger_signal.expand([783+56, curr.shape[1], 1])), dim=-1)
    return out

stats = {
    'grad_norm': [],
    'loss': [],
    'acc': [],
    'batch_var': []
}

grad_norm_history = []
def record_norm():
    norms = []
    for p in params:
        norms.append(p.grad.norm().item())
    stats['grad_norm'].append(torch.tensor(norms).norm().item())


ITERATIONS = 36000

In [4]:
start = time.time()
i = 1
sumloss = 0
sumacc = 0
k = 0
while i < ITERATIONS:
    print('Epoch: ', k)
    k = k + 1
    for inp, target in data_loader:
        batchstart = time.time()
        x = inp.view(BATCH_SIZE, -1, 1).transpose(0,1).to(device)
        x = encode_input(x[1:], x[:-1])
        #print(x.shape)
        target = target.to(device)
        optimizer.zero_grad()
        outputs, _ = loop_model(x)
        meaned = outputs[-56:].mean(dim=0) #TODO: what is this value really in bellec?
        out_final = final_linear(meaned)
        loss = ce(out_final, target)

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            record_norm()
            stats['loss'].append(loss.item())
            acc = (torch.argmax(out_final, 1) == target).float().mean().item()
            stats['acc'].append(acc)
            batch_var = meaned.var(0).mean().item()
            stats['batch_var'].append(batch_var)
        print(loss.item(), acc, batch_var, loop_model.pretrace.model.layers.control_synapse.weight.grad.norm().item()*20, target[0].item(), outputs.norm().item())

        sumloss += loss.item()
        sumacc += acc
        if i%20 == 0:
            print(loss.item(), sumloss/20, sumacc/20, time.time()-batchstart, batch_var) #torch.argmax(outputs[-1], 1).float().var()
            sumloss = 0
            sumacc = 0
        if i%2500 == 0:
            lr = lr * 0.8
            optimizer = optim.Adam(params, lr=lr)
            print('Learning Rate: ', lr)
        i += 1
    pickle.dump(stats, open('loc_stats', 'wb'))
    #model.save('../../models/adap_clip5_'+str(k))
    #post_model.save('../../models/post_big11_'+str(k))


print('Total time: ', time.time()-start)




#TODO: what about data augmentation?

Epoch:  0
2.3025848865509033 0.125 0.0 0.006161743076518178 5 0.0
2.3026351928710938 0.0625 0.0 0.006130308611318469 2 0.0
2.3023581504821777 0.125 0.0 0.005754795856773853 9 0.0
2.3022964000701904 0.109375 0.0 0.006037699640728533 6 0.0
2.305745840072632 0.0859375 0.0 0.005270394613035023 2 0.0
2.3032639026641846 0.09375 0.0 0.004994799382984638 7 0.0
2.3048617839813232 0.09375 0.0 0.0045761046931147575 6 0.0
2.3012917041778564 0.0703125 0.0 0.007414498832076788 6 0.0
2.301555633544922 0.1171875 6.604082831604193e-19 0.004308718489482999 5 0.6755813360214233
2.2982327938079834 0.125 3.1570808051867457e-18 0.005354820750653744 5 1.7809325456619263
2.3051164150238037 0.125 3.581160009784616e-15 0.005532428622245789 6 4.745429515838623
2.300459623336792 0.1171875 1.9023432494254342e-15 0.006046640337444842 0 6.503476619720459
2.304121971130371 0.0859375 8.262989359582307e-14 0.004799041780643165 1 8.712943077087402
2.3019790649414062 0.09375 2.434931770101123e-11 0.008298408356495202 0 1

KeyboardInterrupt: 

In [2]:
for name, p in loop_model.named_parameters():
    print(name, p.shape)
for name, p in final_linear.named_parameters():
    print(name, p.shape)



pretrace.model.layers.control_synapse.weight torch.Size([120, 301])
pretrace.model.layers.control_synapse.bias torch.Size([120])
pretrace.model.layers.control.initial_mem torch.Size([120])
pretrace.model.layers.mem_synapse.weight torch.Size([100, 301])
pretrace.model.layers.mem_synapse.bias torch.Size([100])
pretrace.model.layers.mem.initial_mem torch.Size([100])
pretrace.model.layers.output.initial_mem torch.Size([220])
weight torch.Size([10, 220])
bias torch.Size([10])


In [2]:
import pickle
o_weights = pickle.load(open('../../weight_transplant', 'rb'))


In [3]:
o1 = torch.tensor(o_weights['RecWeights/RecurrentWeight:0']).t()
o2 = torch.tensor(o_weights['InputWeights/InputWeight:0']).t()
o3 = torch.cat((o1,o2), dim=1)
with torch.no_grad():
    loop_model.pretrace.model.layers.control_synapse.bias *= 0
    loop_model.pretrace.model.layers.mem_synapse.bias *= 0
    loop_model.pretrace.model.layers.control_synapse.weight[:, :300] = o3[:120]
    loop_model.pretrace.model.layers.mem_synapse.weight[:, :300] = o3[120:]
    final_linear.bias *= 0
    final_linear.weight.data = torch.tensor(o_weights['out_weight:0']).t()
loop_model.to(device)
final_linear.to(device)

params = [ loop_model.pretrace.model.layers.control_synapse.weight, loop_model.pretrace.model.layers.mem_synapse.weight, final_linear.weight, final_linear.bias]
optimizer = optim.Adam(params, lr=lr)


In [2]:
with torch.no_grad():
    loop_model.pretrace.model.layers.control_synapse.bias *= 0
    loop_model.pretrace.model.layers.mem_synapse.bias *= 0
    loop_model.pretrace.model.layers.control_synapse.weight *= 20
    loop_model.pretrace.model.layers.mem_synapse.weight *= 20
    final_linear.bias *= 0
loop_model.to(device)
final_linear.to(device)

params = [ loop_model.pretrace.model.layers.control_synapse.weight, loop_model.pretrace.model.layers.mem_synapse.weight, final_linear.weight, final_linear.bias]
optimizer = optim.Adam(params, lr=lr)

In [None]:
for n, p in o_weights.items():
    print(n, p.shape)

In [18]:
torch.tensor(o_weights['out_weight:0']).norm()

tensor(4.3832)

In [24]:
final_linear.weight.norm()

tensor(4.3832, grad_fn=<NormBackward0>)

In [20]:
torch.tensor(o_weights['out_weight:0']).t().shape

torch.Size([10, 220])

In [22]:
final_linear.weight.data

tensor([[-0.0619, -0.0065,  0.0658,  ...,  0.0344,  0.0014, -0.0114],
        [-0.0217, -0.0400,  0.0581,  ...,  0.0532,  0.0516, -0.0228],
        [-0.0063,  0.0566, -0.0608,  ..., -0.0267,  0.0659,  0.0009],
        ...,
        [-0.0648, -0.0411,  0.0619,  ...,  0.0228, -0.0491, -0.0315],
        [-0.0311, -0.0110, -0.0269,  ...,  0.0393, -0.0209, -0.0272],
        [-0.0143,  0.0655, -0.0398,  ...,  0.0148,  0.0179,  0.0602]],
       device='cuda:0')

In [29]:
torch.tensor(o_weights['RecWeights/RecurrentWeight:0']).norm()


tensor(14.8264)

In [31]:
nn.Linear(220,220).weight.norm()

tensor(8.5887, grad_fn=<NormBackward0>)

In [36]:
o3.shape

torch.Size([220, 300])