# Spike - Recurrent Neural Netork

-- more descriptions will be added --

## Imports and CoLab checking
In this section first we check if this code is running in the Google CoLab or in a Local Jupiter Notebook.
Afterwad, the proper libraries will be imported.

In [1]:
import sys, os
IN_COLAB = 'google.colab' in sys.modules
print 'Colab:', IN_COLAB
IN_NOTEBOOK = IN_COLAB or get_ipython().__class__.__name__ == 'ZMQInteractiveShell'

Colab: False


In [2]:
if IN_COLAB:
    !pip install -U -q PyDrive
    from pydrive.auth import GoogleAuth
    from pydrive.drive import GoogleDrive
    from google.colab import auth
    from oauth2client.client import GoogleCredentials
    # Authenticate and create the PyDrive client.
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    drive = GoogleDrive(gauth)
    
    from google.colab import drive
    drive.mount('/content/drive')
    DATASET_FOLDER_PATH = "/content/drive/My Drive/ColabCodes/data/"
else:
    DATASET_FOLDER_PATH = "/Users/aref/dvs-dataset/"
    
if not os.path.exists(DATASET_FOLDER_PATH):
    raise Exception('can not access data folder.')

In [3]:
# numpy essentials imports
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.gridspec import GridSpec
from sklearn.model_selection import train_test_split

# py-torch required imports
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

from torch.autograd import Variable

# misc imports
from time import sleep, time

## Reading Dataset
This section is for reading the cached dataset from the defined path.

In [4]:
# mapping
gesture_mapping = {
    0: 'no_gesture',
    1: 'hand_clapping',
    2: 'right_hand_wave',
    3: 'left_hand_wave',
    4: 'right_arm_clockwise',
    5: 'right_arm_counter_clockwise',
    6: 'left_arm_clockwise',
    7: 'left_arm_counter_clockwise',
    8: 'arm_roll',
    9: 'air_drums',
    10: 'air_guitar',
    11: 'other_gestures',
}


In [5]:
def data_loader(trail):
    dataset_path = DATASET_FOLDER_PATH + 'cleaned_cache_' + trail
    dataset_len = 98 if trail == 'train' else 24
    
    for counter in range(dataset_len):    
        x_data = np.load(file='%s/x_%s_%d.npy' % (dataset_path, trail, counter+1))
        y_data = np.load(file='%s/y_%s_%d.npy' % (dataset_path, trail, counter+1))
        z_data = np.load(file='%s/z_%s_%d.npy' % (dataset_path, trail, counter+1))
        
        yield x_data, y_data, z_data

In [6]:
def serialize_events(x_data, y_data, z_data):    
    scale_w = nb_image_weight / 128
    scale_h = nb_image_height / 128
    
    batch_x = []
    batch_y = []
    
    x_series_instance = []
    y_series_instance = []
    
    ev_times = z_data[:, 1]
    frames = 0
    max_time = np.max(ev_times)
    current_time = np.min(ev_times)
    while (current_time < max_time):
        ev_lb = ev_times > current_time
        ev_ub = ev_times < (current_time + retina_time_window_ms)

        event_x = x_data[ev_lb & ev_ub, :]
        event_y = y_data[ev_lb & ev_ub]

        retina = np.zeros([nb_image_height, nb_image_weight])
        retina[event_x[:,0] * scale_h, event_x[:,1] * scale_w] = 1

        frame_y = np.round(np.median(event_y))

        x_series_instance.append(retina.flatten())
        y_series_instance.append(frame_y)
        current_time += retina_time_window_ms
        
        frames += 1
        if len(x_series_instance) == nb_steps:
            batch_x.append(np.array(x_series_instance))
            batch_y.append(np.round(np.median(np.array(y_series_instance))))
            x_series_instance = []
            y_series_instance = []
            
            if len(batch_x) == batch_size:
                yield np.array(batch_x), np.array(batch_y)
                batch_x = []
                batch_y = []

In [7]:
%matplotlib inline

def visualize_dataset():
    ev_times = z_data[:, 1]
    frames = 0
    max_time = np.max(ev_times)
    current_time = np.min(ev_times)
    while (current_time < max_time):
        ev_lb = ev_times > current_time
        ev_ub = ev_times < (current_time + retina_time_window_ms)

        event_x = x_data[ev_lb & ev_ub, :]
        event_y = y_data[ev_lb & ev_ub]

        retina = np.zeros([128, 128])
        retina[event_x[:,1], event_x[:,0]] = 1

        frame_y = gesture_mapping[np.round(np.median(event_y))]

        plt.imshow(retina)
        plt.title(frame_y)
        plt.show()
        plt.pause(0.2)

        frames += 1
        if frames > 15:
            break

        current_time += retina_time_window_ms


# visualize_dataset()

## Network Architechture & Tools

In [8]:
dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    print('using cuda...')
    device = torch.device("cuda")     
else:
    print('using cpu...')
    device = torch.device("cpu")

using cpu...


In [27]:
nb_image_height = 64
nb_image_weight = 64
nb_inputs  = nb_image_height*nb_image_weight
nb_hidden  = 128
nb_outputs = 12

time_step = 1e-3
nb_steps  = 100

batch_size = 16

retina_time_window_ms = 45 * 1000
USE_RECURRENT_NEURONS = True

In [10]:
# tau_mem = 10e-3
# tau_syn = 5e-3

# alpha   = float(np.exp(-time_step/tau_syn))
# beta    = float(np.exp(-time_step/tau_mem))

In [10]:
class SpikingNeuronLayerRNN(nn.Module):
    def __init__(self, device, n_inputs, n_hidden, decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5):
        super(SpikingNeuronLayerRNN, self).__init__()
        self.device = device
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.decay_multiplier = decay_multiplier
        self.threshold = threshold
        self.penalty_threshold = penalty_threshold
        
        self.fc = nn.Linear(n_inputs, n_hidden)
        
        self.init_parameters()
        self.reset_state()
        self.to(self.device)
        
    def init_parameters(self):
        for param in self.parameters():
            if param.dim() >= 2:
                nn.init.xavier_uniform_(param)
        
    def reset_state(self):
        self.prev_inner = torch.zeros([self.n_hidden]).to(self.device)
        self.prev_outer = torch.zeros([self.n_hidden]).to(self.device)

    def forward(self, x):
        """
        Call the neuron at every time step.
        
        x: activated_neurons_below
        
        return: a tuple of (state, output) for each time step. Each item in the tuple
        are then themselves of shape (batch_size, n_hidden) and are PyTorch objects, such 
        that the whole returned would be of shape (2, batch_size, n_hidden) if casted.
        """
        if self.prev_inner.dim() == 1:
            # Adding batch_size dimension directly after doing a `self.reset_state()`:
            batch_size = x.shape[0]
            self.prev_inner = torch.stack(batch_size * [self.prev_inner])
            self.prev_outer = torch.stack(batch_size * [self.prev_outer])
        
        # 1. Weight matrix multiplies the input x
        input_excitation = self.fc(x)
        
        # 2. We add the result to a decayed version of the information we already had.
        inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
        
        # 3. We compute the activation of the neuron to find its output value, 
        #    but before the activation, there is also a negative bias that refrain thing from firing too much.
        outer_excitation = F.relu(inner_excitation - self.threshold)
        
        # 4. If the neuron fires, the activation of the neuron is subtracted to its inner state 
        #    (and with an extra penalty for increase refractory time), 
        #    because it discharges naturally so it shouldn't fire twice. 
        do_penalize_gate = (outer_excitation > 0).float()
        # TODO: remove following /2?
        inner_excitation = inner_excitation - (self.penalty_threshold/self.threshold * inner_excitation) * do_penalize_gate
        
        # 5. The outer excitation has a negative part after the positive part. 
        outer_excitation = outer_excitation #+ torch.abs(self.prev_outer) * self.decay_multiplier / 2.0
        
        # 6. Setting internal values before returning. 
        #    And the returning value is the one of the previous time step to delay 
        #    activation of 1 time step of "processing" time. For logits, we don't take activation.
        delayed_return_state = self.prev_inner
        delayed_return_output = self.prev_outer
        self.prev_inner = inner_excitation
        self.prev_outer = outer_excitation
        return delayed_return_state, delayed_return_output


class InputDataToSpikingPerceptronLayer(nn.Module):
    
    def __init__(self, device):
        super(InputDataToSpikingPerceptronLayer, self).__init__()
        self.device = device
        
        self.reset_state()
        self.to(self.device)
        
    def reset_state(self):
        #     self.prev_state = torch.zeros([self.n_hidden]).to(self.device)
        pass
    
    def forward(self, x, is_2D=True):
        return x
#         x = x.view(x.size(0), -1)  # Flatten 2D image to 1D for FC
#         random_activation_perceptron = torch.rand(x.shape).to(self.device)
#         return random_activation_perceptron * x


class OutputDataToSpikingPerceptronLayer(nn.Module):
    
    def __init__(self, average_output=True):
        """
        average_output: might be needed if this is used within a regular neural net as a layer.
        Otherwise, sum may be numerically more stable for gradients with setting average_output=False.
        """
        super(OutputDataToSpikingPerceptronLayer, self).__init__()
        if average_output:
            self.reducer = lambda x, dim: x.sum(dim=dim)
        else:
            self.reducer = lambda x, dim: x.mean(dim=dim)
    
    def forward(self, x):
        if type(x) == list:
            x = torch.stack(x)
        return self.reducer(x, 0)


class SpikingNet(nn.Module):
    
    def __init__(self, device):
        super(SpikingNet, self).__init__()
        self.device = device
        self.n_time_steps = nb_steps
        
        self.input_conversion = InputDataToSpikingPerceptronLayer(device)
        
        
        self.layer1 = SpikingNeuronLayerRNN(
            device, n_inputs=nb_inputs, n_hidden=nb_hidden,
            decay_multiplier=0.9, threshold=1.0, penalty_threshold=1.5
        )
        
        self.layer2 = SpikingNeuronLayerRNN(
            device, n_inputs=nb_hidden, n_hidden=nb_outputs,
            decay_multiplier=0.9, threshold=1.0, penalty_threshold=1.5
        )
        
        self.output_conversion = OutputDataToSpikingPerceptronLayer(average_output=False)  # Sum on outputs.
        
        self.to(self.device)
    
    def forward_through_time(self, x):
        """
        This acts as a layer. Its input is non-time-related, and its output too.
        So the time iterations happens inside, and the returned layer is thus
        passed through global average pooling on the time axis before the return 
        such as to be able to mix this pipeline with regular backprop layers such
        as the input data and the output data.
        """
        self.input_conversion.reset_state()
        self.layer1.reset_state()
        self.layer2.reset_state()

        out = []
        
        all_layer1_states = []
        all_layer1_outputs = []
        all_layer2_states = []
        all_layer2_outputs = []
        for counter in range(self.n_time_steps):
            xi = self.input_conversion(x[counter, :])
            
            # For layer 1, we take the regular output.
            layer1_state, layer1_output = self.layer1(xi)
            
            # We take inner state of layer 2 because it's pre-activation and thus acts as out logits.
            layer2_state, layer2_output = self.layer2(layer1_output)
            
            all_layer1_states.append(layer1_state)
            all_layer1_outputs.append(layer1_output)
            all_layer2_states.append(layer2_state)
            all_layer2_outputs.append(layer2_output)
            out.append(layer2_state)
            
        out = self.output_conversion(out)
        return out, [[all_layer1_states, all_layer1_outputs], [all_layer2_states, all_layer2_outputs]]
    
    def forward(self, x):
        out, _ = self.forward_through_time(x)
        return F.log_softmax(out, dim=-1)

    def visualize_all_neurons(self, x):
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            "Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.")
        _, layers_state = self.forward_through_time(x)

        for i, (all_layer_states, all_layer_outputs) in enumerate(layers_state):
            layer_state  =  torch.stack(all_layer_states).data.cpu().numpy().squeeze().transpose()
            layer_output = torch.stack(all_layer_outputs).data.cpu().numpy().squeeze().transpose()
            
            self.plot_layer(layer_state, title="Inner state values of neurons for layer {}".format(i))
            self.plot_layer(layer_output, title="Output spikes (activation) values of neurons for layer {}".format(i))
    
    def visualize_neuron(self, x, layer_idx, neuron_idx):
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            "Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.")
        _, layers_state = self.forward_through_time(x)

        all_layer_states, all_layer_outputs = layers_state[layer_idx]
        layer_state  =  torch.stack(all_layer_states).data.cpu().numpy().squeeze().transpose()
        layer_output = torch.stack(all_layer_outputs).data.cpu().numpy().squeeze().transpose()

        self.plot_neuron(layer_state[neuron_idx], title="Inner state values neuron {} of layer {}".format(neuron_idx, layer_idx))
        self.plot_neuron(layer_output[neuron_idx], title="Output spikes (activation) values of neuron {} of layer {}".format(neuron_idx, layer_idx))

    def plot_layer(self, layer_values, title):
        """
        This function is derived from: 
            https://github.com/guillaume-chevalier/LSTM-Human-Activity-Recognition 
        Which was released under the MIT License. 
        """
        width = max(16, layer_values.shape[0] / 8)
        height = max(4, layer_values.shape[1] / 8)
        plt.figure(figsize=(width, height))
        plt.imshow(
            layer_values,
            interpolation="nearest",
            cmap=plt.cm.rainbow
        )
        plt.title(title)
        plt.colorbar()
        plt.xlabel("Time")
        plt.ylabel("Neurons of layer")
        plt.show()

    def plot_neuron(self, neuron_through_time, title):
        width = max(16, len(neuron_through_time) / 8)
        height = 4
        plt.figure(figsize=(width, height))
        plt.title(title)
        plt.plot(neuron_through_time)
        plt.xlabel("Time")
        plt.ylabel("Neuron's activation")
        plt.show()

In [12]:
# def compute_classification_accuracy(x_data, y_data):
#     """ Computes classification accuracy on supplied data in batches. """
#     accs = []
#     max_count = len(x_data) / batch_size
#     for counter in range(max_count):
#         i = counter*batch_size
#         x_local = torch.from_numpy(x_data[i:i+batch_size, :, :]).type(dtype)
#         y_local = torch.from_numpy(y_data[i:i+batch_size])
#         output, _ = run_snn(x_local)
#         m, _ = torch.max(output,1) # max over time
#         _, am = torch.max(m,1)      # argmax over output units
#         tmp = np.mean((y_local.type(torch.long)==am).detach().cpu().numpy()) # compare to labels
#         accs.append(tmp)
#     return np.mean(accs)


In [13]:
# def print_accuracy_values():
#     errors = []
#     for x_data, y_data, z_data in data_loader('train'):
#         errors.append(compute_classification_accuracy(x_data, y_data))
#     print("Train accuracy: %.3f" % (np.mean(np.array(errors))))

#     errors = []
#     for x_data, y_data, z_data in data_loader('test'):
#         errors.append(compute_classification_accuracy(x_data, y_data))
#     print("Test accuracy: %.3f" % (np.mean(np.array(errors))))

In [11]:
from time import time
from datetime import datetime

_last_time_length = None
_iteration_left = None

def expector_timer(func):
    def wrapper(*args, **kwargs):
        global _last_time_length
        t_start = time()
        
        if _last_time_length is not None:
            expectation = t_start
            if _iteration_left is not None:
                expectation +=  _last_time_length * _iteration_left
            else:
                expectation +=  _last_time_length
            datestr = datetime.fromtimestamp(expectation).strftime("%H:%M:%S")
            print '[expecting to finish at %s]' % datestr
        res = func(*args, **kwargs)
        length = time() - t_start
        print '[operation took %ds]' % length
        print '[operation finished at %s]' % datetime.fromtimestamp(time()).strftime("%H:%M:%S")
        
        if _last_time_length is None:
            _last_time_length = length
        else:
            _last_time_length = .9 * _last_time_length + .1 * length
        return res
    return wrapper

In [15]:
# %matplotlib inline
# @expector_timer
# def train_once():
#     global loss_hist
#     loss_hist = train(lr=2e-4, nb_epochs=5, loss_hist=loss_hist)

#     plt.figure(figsize=(3.3,2),dpi=150)
#     plt.plot(loss_hist)
#     plt.xlabel("Epoch")
#     plt.ylabel("Loss")
#     plt.show()
#     sns.despine()
#     print_accuracy_values()
    
#     np.save(file="_last_result", arr=[w1, w2])
#     print('[check point saved.]')

In [28]:
@expector_timer
def train(model, device, x_local, y_local, optimizer, epoch, logging_interval=100):
    # This method is derived from: 
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    
    print '>>>>>>>>>>>>>>', x_local.shape
    print '>>>>>>>>>>>>>>', y_local.shape
    model.train()
    for d, t in zip(x_local, y_local):
        d, t = d.to(device), t.to(device)
        optimizer.zero_grad()
        output = model(d)
        print '%%%%%%%%%', output.shape
        target = torch.ones_like(output) * t
        print output.shape, '~>', target.shape
        loss = F.nll_loss(output, target.type(torch.long))
        loss.backward()
        optimizer.step()
    
#     model.train()
#     for batch_idx, (data, target) in enumerate(zip(x_local, y_local)):
#         print '~~~~~~~~~~~~~~~~~~~~>', target
#         data, target = data.to(device), target.to(device)
#         optimizer.zero_grad()
#         output = model(data)
#         print output.shape, '~>', target.shape
#         loss = F.nll_loss(output, target)
#         loss.backward()
#         optimizer.step()
        
#         if batch_idx % logging_interval == 0:
#             pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
#             correct = pred.eq(target.view_as(pred)).float().mean().item()
#             print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.2f}%'.format(
#                 epoch, batch_idx * len(data), len([1]),
#                 100. * batch_idx / x_local.size(0), loss.item(),
#                 100. * correct))

def test(model, device, x_local, y_local):
    # This method is derived from: 
    # https://github.com/pytorch/examples/blob/master/mnist/main.py
    # Was licensed BSD-3-clause
    
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in zip(x_local, y_local):
            data, target = data.to(device), target.to(device)
            output = model(data)
            # Note: with `reduce=True`, I'm not sure what would happen with a final batch size 
            # that would be smaller than regular previous batch sizes. For now it works.
            test_loss += F.nll_loss(output, target, reduce=True).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_set_loader.dataset)
    print("")
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss, 
        correct, len([1]),
        100. * correct / len([1])))
    print("")
    
    
    
    
spiking_model = SpikingNet(device)
train_many_epochs(spiking_model, device)

. >>>>>>>>>>>>>> torch.Size([2, 10, 4096])
>>>>>>>>>>>>>> torch.Size([2])
%%%%%%%%% torch.Size([4096, 12])
torch.Size([4096, 12]) ~> torch.Size([4096, 12])
**** 

TypeError: max() missing 1 required positional arguments: "dim"

## Spike Simulation
Here is the code for simulating spikes as well as codes for training them.

In [24]:
# if os.path.exists('_last_result') and ('w1' not in globals() or 'w2' not in globals()):
#     w1, w2 = np.load(file="_last_result")
#     print 'continueing last progress'
# else:
#     weight_scale = 7*(1.0-beta) # this should give us some spikes to begin with

#     w1 = torch.empty(((nb_inputs+1) if USE_RECURRENT_NEURONS else nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
#     torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

#     w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
#     torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

#     loss_hist = []
#     print('init')

init


In [25]:
# def run_snn(inputs):    
#     if USE_RECURRENT_NEURONS:
#         h1 = torch.einsum("abc,cd->abd", (inputs, w1[:-1]))
#     else:
#         h1 = torch.einsum("abc,cd->abd", (inputs, w1))
#     hr = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
#     syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
#     mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)

#     mem_rec = [mem]
#     spk_rec = [mem]

#     # Compute hidden layer activity
#     for t in range(nb_steps):
#         mthr = mem-1.0
#         out = spike_fn(mthr)
#         rst = torch.zeros_like(mem)
#         c   = (mthr > 0)
#         rst[c] = torch.ones_like(mem)[c]
       
#         if USE_RECURRENT_NEURONS:
#             rs =  torch.ones_like(mem)
#             for i in range(batch_size):
#                 rs[i,:] = w1[-1, :]
#             rs[mthr <= 0] = 0.0
#         else:
#             rs = 0.0
        
#         new_syn = alpha*syn +h1[:,t] +rs
#         new_mem = beta*mem +syn -rst

#         mem = new_mem
#         syn = new_syn

#         mem_rec.append(mem)
#         spk_rec.append(out)

#     mem_rec = torch.stack(mem_rec,dim=1)
#     spk_rec = torch.stack(spk_rec,dim=1)

#     # Readout layer
#     h2= torch.einsum("abc,cd->abd", (spk_rec, w2))
#     flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
#     out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
#     out_rec = [out]
#     for t in range(nb_steps):
# #         rs = torch.ones_like(out)
# #         for i in range(batch_size):
# #             rs[i,:] = w2[-1, :]
# #         rs[out <= 0] = 0
        
#         new_flt = alpha*flt +h2[:,t] #+rs
#         new_out = beta*out +flt

#         flt = new_flt
#         out = new_out

#         out_rec.append(out)

#     out_rec = torch.stack(out_rec,dim=1)
#     other_recs = [mem_rec, spk_rec]
#     return out_rec, other_recs

In [26]:
# def train(lr=2e-3, nb_epochs=10, loss_hist=None):
#     params = [w1,w2]
#     optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9,0.999))

#     log_softmax_fn = nn.LogSoftmax(dim=1)
#     loss_fn = nn.NLLLoss()
    
#     if loss_hist is None:
#         loss_hist = []
#     for e in range(nb_epochs):
#         print_counter = 0
#         local_loss = []
#         for data_packet in data_loader('train'):
#             if print_counter % 4 == 0:
#                 print '.',
#             print_counter += 1
            
#             for x_data, y_data in serialize_events(*data_packet):
#                 x_local = torch.from_numpy(x_data).type(dtype)
#                 y_local = torch.from_numpy(y_data.astype(np.long))
                                
#                 output,_ = run_snn(x_local)
#                 m,_=torch.max(output,1)
#                 log_p_y = log_softmax_fn(m)
#                 loss_val = loss_fn(log_p_y, y_local)

#                 optimizer.zero_grad()
#                 loss_val.backward()
#                 optimizer.step()
#                 local_loss.append(loss_val.item())
#         mean_loss = np.mean(local_loss)
#         print("Epoch %i: loss=%.5f"%(e+1,mean_loss))
#         loss_hist.append(mean_loss)
        
#     return loss_hist


In [None]:
# max_epoch_counter = 1
# for i in range(max_epoch_counter):
#     _iteration_left = max_epoch_counter - i
#     train_once()

. . . . . . . .

In [14]:
def train_many_epochs(model, device): 
    def _train(epoch, lr):
        optimizer = optim.SGD(model.parameters(), lr, momentum=0.5)
        print_counter = 0
        for data_packet in data_loader('train'):
            if print_counter % 4 == 0:
                print '.',
            print_counter += 1
            
            for x_data, y_data in serialize_events(*data_packet):
                x_local = torch.from_numpy(x_data).type(dtype).to(device)
                y_local = torch.from_numpy(y_data.astype(np.long)).to(device)

                if x_local.shape[0] != batch_size:
                    continue
            
                train(model, device, x_local, y_local, optimizer, epoch, logging_interval=10)
#         test(model, device, test_set_loader)

    _iteration_left = 6
    _train(1, 0.1)
    _iteration_left = 5
    _train(2, 0.05)
    _iteration_left = 3
    _train(3, 0.01)
    _iteration_left = 0

In [21]:
# TODO check 'IN_NOTEBOOK' to either plot the output or not

spiking_model = SpikingNet(device)
train_many_epochs(spiking_model, device)

. >>>>>>>>>>>>>> torch.Size([16, 100, 4096])
>>>>>>>>>>>>>> torch.Size([16])
%%%%%%%%% torch.Size([4096, 12])


RuntimeError: _thnn_nll_loss_forward not supported on CPUType for Long