In [1]:
import sys
sys.path.append('../../')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
import time
from collections import OrderedDict
from torch.utils.data import DataLoader
from wip.Code.distortion import distortion_transform

In [2]:
BATCH_SIZE = 128

USE_JIT = False

device = torch.device('cuda')

In [3]:
mnist = MNIST('../../', transform=transforms.ToTensor()) #distortion_transform([0,15], 3)
test = MNIST('../../', transform=transforms.ToTensor(), train=False)

In [4]:
data_loader = DataLoader(mnist, batch_size=BATCH_SIZE, drop_last=True, num_workers=0, shuffle=True)

test_loader = DataLoader(test, batch_size=1024, drop_last=False, num_workers=0)


In [5]:
from Code.Networks import Selector, DynNetwork, OuterWrapper, LSTMWrapper, ReLuWrapper, DummyNeuron, make_SequenceWrapper, ParallelNetwork
from Code.NewNeurons import SeqOnlySpike, CooldownNeuron, OutputNeuron, LIFNeuron, NoResetNeuron, AdaptiveNeuron

base_config = {
    'ALPHA': 0,
    'BETA': 0.9,
    'OFFSET': 2,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'ss'
}

heavyside = {
    **base_config,
    'BETA': 1,
    'OFFSET': 7
}

mem_lif = {
    **base_config,
    'BETA': 0.5
}

'''mem_loop = OrderedDict([
    ('input', 1),
    ('pre_mem', [['input', 'output'], LIFNeuron(128, mem_lif), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])'''

mem_loop = OrderedDict([
    ('input', 1),
    ('pre_mem', [['input', 'output', 'shortterm'], NoResetNeuron(128, base_config), nn.Linear]),
    ('shortterm', [['pre_mem'], CooldownNeuron(64, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

post_mem = OrderedDict([
    ('input', 128),
    ('pre_mem', [['input'], make_SequenceWrapper(LIFNeuron(128, base_config), USE_JIT), nn.Linear]),
    ('output', [['pre_mem'], DummyNeuron(10), nn.Linear]),
])

architecture = OrderedDict([
    ('input', 1),
    ('mem_loop', [['input'], make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT), None]),
    ('post_mem', [['mem_loop'], SeqOnlySpike(128, base_config), nn.Linear]),
    ('output', [['post_mem'], OutputNeuron(10, heavyside), nn.Linear]),
])

architecturelstm = OrderedDict([
    ('input', 1),
    ('lstm', [['input'], LSTMWrapper(1, 128), None]),
    ('post_mem', [['lstm'], ReLuWrapper(128), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(10), nn.Linear]),
])


adap_config = {
    'ALPHA': 0,
    'BETA': 0.95,
    'OFFSET': 2,
    'ADAPDECAY': 0.998,
    'RESET_ZERO': False,
    'DECODING': 'potential',
    'SPIKE_FN': 'bellec'
}

adap_arch = OrderedDict([
    ('input', 81),
    ('bundled', [['input', 'adaptive', 'regular'], DummyNeuron(81+100+120), None]),
    ('adaptive', [['bundled'], AdaptiveNeuron(100, adap_config), nn.Linear]),
    ('regular', [['bundled'], LIFNeuron(120, adap_config), nn.Linear]),
    ('output', [['bundled'], DummyNeuron(10), nn.Linear]),
])

adap_arch2 = OrderedDict([
    ('input', 81),
    ('adaptive', [['input', 'adaptive', 'regular'], AdaptiveNeuron(100, adap_config), nn.Linear]),
    ('regular', [['input', 'adaptive', 'regular'], LIFNeuron(120, adap_config), nn.Linear]),
    ('output', [['adaptive', 'regular'], DummyNeuron(10), nn.Linear]),
])

adap_arch3 = OrderedDict([
    ('input', 81),
    ('regular', [['input', 'regular'], LIFNeuron(120, adap_config), nn.Linear]),
    ('output', [['regular'], DummyNeuron(10), nn.Linear]),
])

adap_arch4 = OrderedDict([
    ('input', 81),
    ('adap', [['input', 'adap'], NoResetNeuron(120, base_config), nn.Linear]),
    ('output', [['adap'], DummyNeuron(10), nn.Linear]),
])

mem_loop2 = OrderedDict([
    ('input', 81),
    ('pre_mem', [['input', 'output'], NoResetNeuron(128, base_config), nn.Linear]),
    ('output', [['pre_mem'], CooldownNeuron(128, heavyside), nn.Linear]),
])

mem_loop3 = OrderedDict([
    ('input', 81),
    ('pre_mem', [['input', 'output'], LIFNeuron(128, adap_config), nn.Linear]),
    ('output', [['pre_mem'], AdaptiveNeuron(128, adap_config), nn.Linear]),
])

cd_full = OrderedDict([
    ('input', 81),
    ('mem_loop', [['input'], make_SequenceWrapper(DynNetwork(mem_loop2), USE_JIT), None]),
    ('post_mem', [['mem_loop'], make_SequenceWrapper(LIFNeuron(128, base_config), USE_JIT), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(10), nn.Linear]),
])

ada_full = OrderedDict([
    ('input', 81),
    ('mem_loop', [['input'], make_SequenceWrapper(DynNetwork(mem_loop3), USE_JIT), None]),
    ('post_mem', [['mem_loop'], make_SequenceWrapper(LIFNeuron(128, adap_config), USE_JIT), nn.Linear]),
    ('output', [['post_mem'], DummyNeuron(10), nn.Linear]),
])

#TODO: fix output


In [6]:
#mem_model = OuterWrapper(torch.load('../../models/mem_big10_5'), device, USE_JIT)
#post_model = OuterWrapper(torch.load('../../models/post_big10_5'), device, USE_JIT)
#mem_model = OuterWrapper(n_mem, device, USE_JIT)
#mem_model = OuterWrapper(make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT), device, USE_JIT)
#post_model = OuterWrapper(DynNetwork(post_mem), device, USE_JIT)

#144, 150, 137, 150
#model = OuterWrapper(torch.load('../../models/snn4_3'), device, USE_JIT)

#model = OuterWrapper(DynNetwork(architecture), device, USE_JIT)

#model = OuterWrapper(DynNetwork(architecturelstm), device, USE_JIT)

#model = OuterWrapper(DynNetwork(cd_full), device, USE_JIT)
#model = OuterWrapper(make_SequenceWrapper(ParallelNetwork(adap_arch2), USE_JIT), device, USE_JIT)
#model = OuterWrapper(make_SequenceWrapper(DynNetwork(adap_arch), USE_JIT), device, USE_JIT)
#model = OuterWrapper(DynNetwork(ada_full), device, USE_JIT)
model = OuterWrapper(make_SequenceWrapper(DynNetwork(adap_arch4), USE_JIT), device, USE_JIT)

#with torch.no_grad():
#    model.model.layers.lstm.lstm.bias_hh_l0[:256] += 3



In [7]:
'''
with torch.no_grad():
    for i in range(100):
        model.pretrace.model.layers.adaptive_synapse.weight[i, i+81] = 0

    for i in range(120):
        model.pretrace.model.layers.regular_synapse.weight[i, i+181] = 0

'''

'\nwith torch.no_grad():\n    for i in range(100):\n        model.pretrace.model.layers.adaptive_synapse.weight[i, i+81] = 0\n\n    for i in range(120):\n        model.pretrace.model.layers.regular_synapse.weight[i, i+181] = 0\n\n'

In [8]:
trigger_signal = torch.ones([783+56, 1, 1], device=device)
trigger_signal[:783] = 0
def encode_input(curr, last):
    out = torch.zeros([783+56, curr.shape[1], 2,40], device=curr.device)
    out[:783, :, 0, :] = ((torch.arange(40, device=curr.device) < 40 * last) & (torch.arange(40, device=curr.device) > 40 * curr)).float()
    out[:783, :, 1, :] = ((torch.arange(40, device=curr.device) > 40 * last) & (torch.arange(40, device=curr.device) < 40 * curr)).float()
    out = torch.cat((out.view([783+56, curr.shape[1], 80]), trigger_signal.expand([783+56, curr.shape[1], 1])), dim=-1)
    return out

In [9]:
#params = list(mem_model.parameters())+list(post_model.parameters())
params = model.parameters()
ce = nn.CrossEntropyLoss()
optimizer = optim.Adam(params, lr=1e-3)#0.000011e-6
#optimizer = optim.SGD(params, lr=1e-5)
#optimizer = optim.Adam(mem_model.parameters(), lr=1e-3)#0.000011e-6

In [10]:
gradient_history = {}

for name, p in model.named_parameters():
    gradient_history[name] = {'iter': [], 'value':[], 'avg':1}

In [11]:
def manage_gradients(iter):
    do_gradient = True
    for name, p in model.named_parameters():
        v = p.grad.norm().item()
        if v > 100 * gradient_history[name]['avg']:
            #print(name, v)
            do_gradient = False
        else:
            gradient_history[name]['avg'] = 0.9 * gradient_history[name]['avg'] + 0.1 * v
    return do_gradient

In [12]:
def monitor_gradients(iter):
    for name, p in model.named_parameters():
        v = p.grad.norm().item()
        if iter == 0 or v > 2* gradient_history[name]['value'][-1] or v < 0.5 * gradient_history[name]['value'][-1]:
            print(name, v)
            gradient_history[name]['value'].append(v)
            gradient_history[name]['iter'].append(iter)
        if iter - gradient_history[name]['iter'][-1] >= 20:
            gradient_history[name]['value'].append(v)
            gradient_history[name]['iter'].append(iter)



In [13]:
start = time.time()
i = 0
sumloss = 0
sumacc = 0
for k in range(150):
    print('Epoch: ', k)
    for inp, target in data_loader:
        batchstart = time.time()
        x = inp.view(BATCH_SIZE, -1, 1).transpose(0,1).to(device)
        x = encode_input(x[1:], x[:-1])
        #print(x.shape)
        target = target.to(device)
        optimizer.zero_grad()
        outputs, _ = model(x)
        #mem, _ = mem_model(x)
        #outputs, _ = post_model(mem[-1].expand(56, BATCH_SIZE, 256))
        outputs = outputs[-56:]
        loss = ce(outputs.mean(dim=0), target)

        loss.backward()
        #monitor_gradients(i)
        if manage_gradients(i):
            optimizer.step()
            print('.', end='')
        else:
            print('|', end='')
        sumloss += loss.item()
        sumacc += (torch.argmax(outputs.mean(dim=0), 1) == target).float().mean().item()
        if i%20 == 0:
            print(loss.item(), sumloss/20, sumacc/20, time.time()-batchstart, outputs.var(1).mean().item()) #torch.argmax(outputs[-1], 1).float().var()
            sumloss = 0
            sumacc = 0
        i += 1
    model.save('../../models/adap_clip5_'+str(k))
    #post_model.save('../../models/post_big11_'+str(k))


print('Total time: ', time.time()-start)


Epoch:  0
.2.3569815158843994 0.11784907579421997 0.004296875 0.8987224102020264 0.02374180033802986
....................2.2997829914093018 2.311808669567108 0.095703125 1.0932927131652832 0.005496765952557325
....................2.2981131076812744 2.303507947921753 0.119921875 0.9066684246063232 0.0015271632000803947
....................2.2867183685302734 2.2997230529785155 0.123828125 0.9010732173919678 0.0006056968704797328
....................2.303955554962158 2.3040708661079408 0.104296875 0.904564380645752 0.0038315022829920053
....................2.283956289291382 2.282125735282898 0.12890625 0.8752143383026123 0.0007966957637108862
....................2.2900655269622803 2.2937514901161196 0.13046875 0.937474250793457 0.0008277239976450801
....................2.297236204147339 2.3033241271972655 0.116015625 0.861459493637085 0.002906516194343567
....................2.2945237159729004 2.3237852454185486 0.113671875 0.9113798141479492 0.0005186368362046778
....................2.28

KeyboardInterrupt: 

In [None]:
#model.save('../../models/seq_mnist_rsnn1')

In [None]:
confusion = torch.zeros([10,10])
i = 0
with torch.no_grad():
    acc = 0
    for inp, target in test_loader:
        x = inp.view(inp.shape[0], -1, 1).transpose(0,1).to(device)
        target = target.to(device)
        mem, _ = mem_model(x)
        outputs, _ = post_model(mem[-1].expand(56, x.shape[1], 256))
        choice = torch.argmax(outputs.mean(dim=0), 1)
        acc += (choice == target).float().mean()
        i += 1
        for k in range(len(target)):
            confusion[choice[k], target[k]] += 1
    print(acc/i)
print(confusion)

In [None]:
max = confusion.max().item()
from PIL import Image
img = Image.new('L',(10,10),color=128)
for i in range(10):
    for k in range(10):
        img.putpixel((i, k), int(confusion[i,k]/max*255))

In [None]:
img.resize((500, 500))

In [None]:
testi = MNIST('../../', train=False)

In [None]:
show = []
schoice = []
starget = []
for img, target in testi:
    x = transforms.ToTensor()(img).view(-1, 1, 1).to(device)
    mem, _ = mem_model(x)
    outputs, _ = post_model(mem[-1].expand(56, 1, 256))
    choice = torch.argmax(outputs.mean(dim=0), 1).item()
    if choice != target:
        show.append(img)
        schoice.append(choice)
        starget.append(target)
        if len(show) == 10:
            break


In [None]:
show[7].resize((500,500))


In [None]:
print(schoice)
print(starget)

In [None]:
mem_model.model.model.layers.shortterm_synapse.named_parameters()

In [None]:
for name, p in n_mem.named_parameters():
    print(name, p.shape)

In [None]:
for name, p in mem_model2.named_parameters():
    print(name, p.shape)

In [None]:
mem_model2 = torch.load('../../models/mem_nores3_76')
n_mem = make_SequenceWrapper(DynNetwork(mem_loop), USE_JIT)
with torch.no_grad():
    n_mem.model.layers.output_synapse.weight = mem_model2.model.layers.output_synapse.weight
    n_mem.model.layers.output_synapse.bias = mem_model2.model.layers.output_synapse.bias
    n_mem.model.layers.output.initial_mem = mem_model2.model.layers.output.initial_mem
    n_mem.model.layers.pre_mem_synapse.bias = mem_model2.model.layers.pre_mem_synapse.bias
    n_mem.model.layers.pre_mem.initial_mem = mem_model2.model.layers.pre_mem.initial_mem
    n_mem.model.layers.pre_mem_synapse.weight[:, :129] = mem_model2.model.layers.pre_mem_synapse.weight

In [None]:
data_loader.__iter__().__next__()[1]

In [None]:

28*28

In [None]:

inp, target = data_loader.__iter__().__next__()
x = inp.view(BATCH_SIZE, -1, 1).transpose(0,1).to(device)
x = encode_input(x[1:], x[:-1]).cpu()
inpimg = transforms.ToPILImage()(x[:,4,:])

In [None]:
inpimg

In [None]:
target[4]

In [None]:
for name, p in model.named_parameters():
    print(name, p.shape)



In [None]:
import pickle
trans, img = pickle.load(open('../some_img', 'rb'))


In [None]:
timg = torch.tensor(img).view(-1, 1, 1).to(device)

In [None]:
mytrans = torch.cat((torch.zeros((1, 81)), encode_input(timg[1:], timg[:-1]).squeeze().cpu()), dim=0)
their_trans = torch.cat((torch.zeros((840, 1)), torch.tensor(trans).squeeze()), dim=1)

In [None]:
mytrans.shape

In [None]:
their_trans.shape

In [None]:
from PIL import Image
pimg = Image.new('RGB',(840,81),color=128)

In [None]:
for i in range(81):
    for k in range(840):
        pimg.putpixel((k, i), (int(mytrans[k, i])*255, int(their_trans[k, i*2%81])*255, 0))


In [None]:
pimg.save('input_comparison', 'png')

In [None]:
pimg