In [1]:

import sys
#TODO: path
sys.path.append('../../')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
import time
from collections import OrderedDict
from torch.utils.data import DataLoader
import pickle

#TODO: change to 256
BATCH_SIZE = 128#256

USE_JIT = False

device = torch.device('cuda')

n_in = 81

mnist = MNIST('../../', transform=transforms.ToTensor(), download=True) #distortion_transform([0,15], 3)
test = MNIST('../../', transform=transforms.ToTensor(), train=False)


data_loader = DataLoader(mnist, batch_size=BATCH_SIZE, drop_last=True, num_workers=0, shuffle=True)

test_loader = DataLoader(test, batch_size=1024, drop_last=False, num_workers=0)

like_bellec = {
    'spkfn' : 'bellec',
    'spkconfig' : 0,
    'architecture': '1L',
    'beta': 0.95,
    'control_neuron': 'LIF',
    'mem_neuron' : 'Adaptive',
    'lr' : 1e-2,
    '1-beta': True,
    'decay_out': True
}

spec = like_bellec
#spec['decay_out'] = False
#TODO: remove
#spec['1-beta'] = False

from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper, ParallelNetwork
from Code.NewNeurons2 import SeqOnlySpike, CooldownNeuron, OutputNeuron, LIFNeuron, NoResetNeuron, AdaptiveNeuron

built_config = {
    'BETA': spec['beta'],
    'OFFSET': 2, # TODO: this?
    'SPIKE_FN': spec['spkfn'],
    '1-beta': spec['1-beta'],
    'ADAPDECAY': 0.9985,
    'ADAPSCALE': 180
}

n_control = 120
n_mem = 100

control_lookup = {
    'LIF': LIFNeuron,
    'Disc': SeqOnlySpike,
    'NoReset': NoResetNeuron
}

mem_lookup = {
    'Adaptive': AdaptiveNeuron,
    'Cooldown': CooldownNeuron,
    'NoReset': NoResetNeuron
}

control_neuron = control_lookup[spec['control_neuron']](n_control, built_config)
mem_neuron = mem_lookup[spec['mem_neuron']](n_mem, built_config)
out_neuron = OutputNeuron(n_control+n_mem, built_config) if spec['decay_out'] else DummyNeuron(n_control+n_mem, built_config)


loop_2L = OrderedDict([
    ('input', n_in),
    ('control', [['input', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['control'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop_1L = OrderedDict([
    ('input', n_in),
    ('control', [['input', 'control', 'mem'], control_neuron, nn.Linear]),
    ('mem', [['input', 'control', 'mem'], mem_neuron, nn.Linear]),
    ('output', [['control', 'mem'], out_neuron, None]),
])

loop = loop_1L if spec['architecture'] == '1L' else loop_2L

loop_model = OuterWrapper(make_SequenceWrapper(ParallelNetwork(loop), USE_JIT), device, USE_JIT)

final_linear = nn.Linear(n_control+n_mem, 10).to(device)



params = list(loop_model.parameters())+list(final_linear.parameters())
lr = spec['lr']
optimizer = optim.Adam(params, lr=lr)
ce = nn.CrossEntropyLoss()

'''

#TODO: check correctness here

with torch.no_grad():
    for i in range(100):
        loop_model.pretrace.model.layers.mem_synapse.weight[i, i+201] = 0

    for i in range(120):
        loop_model.pretrace.model.layers.control_synapse.weight[i, i+81] = 0

'''


trigger_signal = torch.ones([783+56, 1, 1], device=device)
trigger_signal[:783] = 0
def encode_input(curr, last):
    out = torch.zeros([783+56, curr.shape[1], 2,40], device=curr.device)
    out[:783, :, 0, :] = ((torch.arange(40, device=curr.device) < 40 * last) & (torch.arange(40, device=curr.device) > 40 * curr)).float()
    out[:783, :, 1, :] = ((torch.arange(40, device=curr.device) > 40 * last) & (torch.arange(40, device=curr.device) < 40 * curr)).float()
    out = torch.cat((out.view([783+56, curr.shape[1], 80]), trigger_signal.expand([783+56, curr.shape[1], 1])), dim=-1)
    return out

stats = {
    'grad_norm': [],
    'loss': [],
    'acc': [],
    'batch_var': []
}

grad_norm_history = []
def record_norm():
    norms = []
    for p in params:
        norms.append(p.grad.norm().item())
    stats['grad_norm'].append(torch.tensor(norms).norm().item())


ITERATIONS = 36000

In [None]:
start = time.time()
i = 1
sumloss = 0
sumacc = 0
k = 0
while i < ITERATIONS:
    print('Epoch: ', k)
    k = k + 1
    for inp, target in data_loader:
        batchstart = time.time()
        x = inp.view(BATCH_SIZE, -1, 1).transpose(0,1).to(device)
        x = encode_input(x[1:], x[:-1])
        #print(x.shape)
        target = target.to(device)
        optimizer.zero_grad()
        outputs, _ = loop_model(x)
        meaned = outputs[-56:].mean(dim=0) #TODO: what is this value really in bellec?
        out_final = final_linear(meaned)
        loss = ce(out_final, target)

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            record_norm()
            stats['loss'].append(loss.item())
            acc = (torch.argmax(out_final, 1) == target).float().mean().item()
            stats['acc'].append(acc)
            batch_var = meaned.var(0).mean().item()
            stats['batch_var'].append(batch_var)
        #print(loss.item(), acc, batch_var, loop_model.pretrace.model.layers.control_synapse.weight.grad.norm().item()*20, target[0].item(), outputs.norm().item(), x.mean().item())

        sumloss += loss.item()
        sumacc += acc
        if i%20 == 0:
            print(loss.item(), sumloss/20, sumacc/20, time.time()-batchstart, batch_var) #torch.argmax(outputs[-1], 1).float().var()
            sumloss = 0
            sumacc = 0
        if i%2500 == 0:
            lr = lr * 0.8
            optimizer = optim.Adam(params, lr=lr)
            print('Learning Rate: ', lr)
        i += 1
    pickle.dump(stats, open('loc_stats', 'wb'))
    #model.save('../../models/adap_clip5_'+str(k))
    #post_model.save('../../models/post_big11_'+str(k))


print('Total time: ', time.time()-start)




#TODO: what about data augmentation?

Epoch:  0
2.3014233112335205 2.302599775791168 0.10078125 2.278226852416992 6.317734113636675e-10
2.2966341972351074 2.3008622407913206 0.128515625 2.2598235607147217 1.4453613630394102e-06
2.1057357788085938 2.2633551955223083 0.182421875 2.242460250854492 0.0001427228271495551
1.9980257749557495 2.101233237981796 0.176953125 2.267404556274414 0.0007529367576353252
1.9813475608825684 2.0273540318012238 0.201953125 2.2706615924835205 0.0017053470946848392
1.9449255466461182 1.9629006505012512 0.248046875 2.2592811584472656 0.00238977768458426
1.9568041563034058 1.9248777449131012 0.291796875 2.263963222503662 0.0016061929054558277
1.8096979856491089 1.8500484466552733 0.291796875 2.2819485664367676 0.004246073309332132
1.8092256784439087 1.8441750228404998 0.30234375 2.3208301067352295 0.005605306010693312
1.7024308443069458 1.7333423137664794 0.3609375 2.262558937072754 0.004948707763105631
1.409973382949829 1.6296229720115663 0.408203125 2.2701761722564697 0.007296879310160875
1.3958

In [None]:
for name, p in loop_model.named_parameters():
    print(name, p.shape)
for name, p in final_linear.named_parameters():
    print(name, p.shape)



In [2]:
import pickle
o_weights = pickle.load(open('../../weight_transplant_enc', 'rb'))


In [3]:
o1 = torch.tensor(o_weights['RecWeights/RecurrentWeight:0']).t()
o2 = torch.tensor(o_weights['InputWeights/InputWeight:0']).t()
o3 = torch.cat((o2,o1), dim=1)
with torch.no_grad():
    loop_model.pretrace.model.layers.control_synapse.bias *= 0
    loop_model.pretrace.model.layers.mem_synapse.bias *= 0
    loop_model.pretrace.model.layers.control_synapse.weight.data[:,:300] = o3[:120]
    loop_model.pretrace.model.layers.mem_synapse.weight.data[:,:300] = o3[120:]
    final_linear.bias *= 0
    final_linear.weight.data = torch.tensor(o_weights['out_weight:0']).t()
loop_model.to(device)
final_linear.to(device)

params = [ loop_model.pretrace.model.layers.control_synapse.weight, loop_model.pretrace.model.layers.mem_synapse.weight, final_linear.weight, final_linear.bias]
optimizer = optim.Adam(params, lr=lr)


In [None]:
with torch.no_grad():
    loop_model.pretrace.model.layers.control_synapse.bias *= 0
    loop_model.pretrace.model.layers.mem_synapse.bias *= 0
    loop_model.pretrace.model.layers.control_synapse.weight *= 20
    loop_model.pretrace.model.layers.mem_synapse.weight *= 20
    final_linear.bias *= 0
loop_model.to(device)
final_linear.to(device)

params = [ loop_model.pretrace.model.layers.control_synapse.weight, loop_model.pretrace.model.layers.mem_synapse.weight, final_linear.weight, final_linear.bias]
optimizer = optim.Adam(params, lr=lr)

In [6]:
samp_inp = torch.zeros([221], device=device)
samp_inp[0] = 1
torch.cat((loop_model.pretrace.model.layers.control_synapse(samp_inp),
           loop_model.pretrace.model.layers.mem_synapse(samp_inp)), dim=0).norm()

tensor(14.9409, device='cuda:0', grad_fn=<NormBackward0>)

In [8]:
loop_model.pretrace.model.layers.control_synapse(samp_inp)[0]/20

tensor(0.0882, device='cuda:0', grad_fn=<DivBackward0>)

In [4]:
loop_model(torch.ones([10,1,1], device=device))

tensor(0.0882, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.1720, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.2516, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.3272, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.3991, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.4673, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.5321, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.5937, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.6523, device='cuda:0', grad_fn=<SelectBackward>)
tensor(0.7078, device='cuda:0', grad_fn=<SelectBackward>)


(tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         ...,
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
        grad_fn=<CopySlices>),
 (((tensor([[ 0.7078,  0.1606,  0.3927,  0.8992,  0.7494, -0.3921,  0.3812, -0.0607,
             -0.0414,  0.1648,  0.0578,  0.5835,  0.3054,  0.0488,  0.1781,  0.1339,
              0.5995, -0.0823,  0.1256, -0.3427, -1.0244,  0.2623,  0.3469, -0.2978,
              0.9108, -0.5836,  0.0184, -0.0751,  0.6150,  0.5896,  0.0622,  0.1517,
             -0.3562, -0.7948, -0.1396,  0.0627,  0.4937,  0.4825, -0.1554, -0.1213,
             -0.4207, -0.5698, -0.6847,  0.7828, -0.2045, -0.1758, -0.5027,  0.3120,
             -0.6476, -0.0854, -0.3593,  0.1552, -0.2050, -0.4737, -0.0113,  0.1719,
              0.0267,  0.1214, -0.2545, -0.1456, -0.2698, -0.1443, -0.32

In [None]:
for n, p in o_weights.items():
    print(n, p.shape)

In [None]:
torch.tensor(o_weights['out_weight:0']).norm()

In [None]:
final_linear.weight.norm()

In [None]:
torch.tensor(o_weights['out_weight:0']).t().shape

In [None]:
final_linear.weight.data

In [None]:
torch.tensor(o_weights['RecWeights/RecurrentWeight:0']).norm()


In [None]:
nn.Linear(220,220).weight.norm()

In [None]:
o3.shape