# Understanding Training

In [68]:
import sys
sys.path.append('../')
import os
import torch
from network_parser import parse
from datasets import loadMNIST, loadCIFAR10, loadFashionMNIST, loadNMNIST_Spiking 
import logging
import cnns
# from utils import learningStats
# from utils import aboutCudaDevices
# from utils import EarlyStopping
import functions.loss_f as loss_f
import numpy as np
from datetime import datetime
import time
# import pycuda.driver as cuda
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils import clip_grad_value_
import global_v as glv

# from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
import argparse

# Anil adds
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as P
import yaml


max_accuracy = 0
min_loss = 1000

In [2]:
File = '../Networks/MNIST_CNN.yaml'
with open(File) as file:
    params = yaml.full_load(file)
    
params['Network']['data_path'] = '../' + params['Network']['data_path'] # add relative dir path

In [3]:
dtype = torch.float32

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("selected device: ", device)

glv.init(dtype, device, params['Network']['n_steps'], params['Network']['tau_s'] )
data_path = os.path.expanduser(params['Network']['data_path'])
train_loader, test_loader = loadMNIST.get_mnist(data_path, params['Network'])

net = cnns.Network(params['Network'], params['Layers'], list(train_loader.dataset[0][0].shape)).to(device)
error = loss_f.SpikeLoss(params['Network']).to(device)
optimizer = torch.optim.AdamW(net.get_parameters(), lr=params['Network']['lr'], betas=(0.9, 0.999))

best_acc = 0; best_epoch = 0

l_states = None
early_stopping =None

# for e in range(params['Network']['epochs']):
#     train(net, train_loader, optimizer, e, l_states, params['Network'], params['Layers'], error)
#     test(net, test_loader, e, l_states, params['Network'], params['Layers'], early_stopping)

selected device:  cuda
loading MNIST
Network Structure:
conv_1
[1, 28, 28]
[15, 24, 24]
[15, 1, 5, 5, 1]
-----------------------------------------
pooling_1
[15, 24, 24]
[15, 12, 12]
[1, 1, 2, 2, 1]
-----------------------------------------
conv_2
[15, 12, 12]
[40, 8, 8]
[40, 15, 5, 5, 1]
-----------------------------------------
pooling_2
[40, 8, 8]
[40, 4, 4]
[1, 1, 2, 2, 1]
-----------------------------------------
linear
FC_1
[40, 4, 4]
[300, 1, 1]
[300, 640]
-----------------------------------------
linear
output
[300, 1, 1]
[10, 1, 1]
[10, 300]
-----------------------------------------
-----------------------------------------




# Experiment without Training

In [4]:
n_steps = params['Network']['n_steps']
n_class = params['Network']['n_class']

In [5]:
net

Network(
  (my_parameters): ParameterList(
      (0): Parameter containing: [torch.cuda.FloatTensor of size 15x1x5x5x1 (GPU 0)]
      (1): Parameter containing: [torch.cuda.FloatTensor of size 40x15x5x5x1 (GPU 0)]
      (2): Parameter containing: [torch.cuda.FloatTensor of size 300x640 (GPU 0)]
      (3): Parameter containing: [torch.cuda.FloatTensor of size 10x300 (GPU 0)]
  )
)

In [6]:
x, label = next(iter(train_loader))
x.shape, label.shape
targets = torch.zeros((label.shape[0], n_class, 1, 1, n_steps), dtype=dtype).to(device) 
if len(x.shape) < 5:
    x = x.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
x.shape, targets.shape

(torch.Size([50, 1, 28, 28, 5]), torch.Size([50, 10, 1, 1, 5]))

In [7]:
y = net(x.to(device).type(dtype),0,False)

In [8]:
y.reshape(50,10,5)[0]

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], device='cuda:0', grad_fn=<SelectBackward>)

## Desired Spikes

In [9]:
if n_steps >= 10:
    desired_spikes = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1]).repeat(int(n_steps/10))
else:
    desired_spikes = torch.tensor([0, 1, 1, 1, 1]).repeat(int(n_steps/5))
desired_spikes = desired_spikes.view(1, 1, 1, 1, n_steps).to(device)
print('>> before',desired_spikes)
desired_spikes = loss_f.psp(desired_spikes, params['Network']).view(1, 1, 1, n_steps)
print('>> after',desired_spikes)

>> before tensor([[[[[0, 1, 1, 1, 1]]]]], device='cuda:0')
>> after tensor([[[[0.0000, 0.3333, 0.5556, 0.7037, 0.8025]]]], device='cuda:0')


# Training

In [10]:
# global max_accuracy
# global min_loss
# Datasets
data_path = os.path.expanduser(params['Network']['data_path'])
trainloader, testloader = loadMNIST.get_mnist(data_path, params['Network'])
# Network Config
network_config = params['Network']
n_steps        = network_config['n_steps']
n_class        = network_config['n_class']
batch_size     = network_config['batch_size']
# Training Functions
err  = loss_f.SpikeLoss(network_config).to(device)
opti = torch.optim.AdamW(net.get_parameters(), lr=network_config['lr'], betas=(0.9, 0.999))

loading MNIST


## Train for one epoch

In [73]:
train_loss = 0
correct = 0
total = 0
epoch = 0

In [75]:
for batch_idx, (inputs, labels) in enumerate(trainloader):
    start_time = time.time()
    targets = torch.zeros((labels.shape[0], n_class, 1, 1, n_steps), dtype=dtype).to(device) 
    
    # begin offline
    # this is the case for each item in the for loop, can be done offline, functionally, to speed up!
    if len(inputs.shape) < 5: 
        inputs = inputs.unsqueeze_(-1).repeat(1, 1, 1, 1, n_steps)
    labels = labels.to(device)
    inputs = inputs.to(device)
    inputs.type(dtype)
    # end offline 
    
    outputs = net.forward(inputs, epoch, True)
    
    # begin function
    # seems systematic enough, can be done in a function
    if network_config['loss'] == "count":
        # set target signal
        desired_count = network_config['desired_count']
        undesired_count = network_config['undesired_count']
        targets = torch.ones((outputs.shape[0], outputs.shape[1], 1, 1), dtype=dtype).to(device) * undesired_count
        for i in range(len(labels)):
            targets[i, labels[i], ...] = desired_count
        loss = err.spike_count(outputs, targets, network_config, layers_config[list(layers_config.keys())[-1]])
    elif network_config['loss'] == "kernel":
        targets.zero_()
        for i in range(len(labels)):
            targets[i, labels[i], ...] = desired_spikes
        loss = err.spike_kernel(outputs, targets, network_config)
    elif network_config['loss'] == "softmax":
        # set target signal
        loss = err.spike_soft_max(outputs, labels)
    else:
        raise Exception('Unrecognized loss function.')
    # end function

    opti.zero_grad()
    loss.backward()
    clip_grad_norm_(net.get_parameters(), 1) # what's this for?
    opti.step()
    net.weight_clipper()
    
    # begin argmax 
    # PyTorch has argmax function, re-write this and clean up squeezes!
    spike_counts = torch.sum(outputs, dim=4).squeeze_(-1).squeeze_(-1).detach().cpu().numpy()
    predicted = np.argmax(spike_counts, axis=1)
    # end argmax
    
    # It is not efficient to return to item at each epoch, do we need that?
    # Don't think when using W&B at least.
    train_loss += torch.sum(loss).item()
    labels = labels.cpu().numpy()
    total += len(labels)
    correct += (predicted == labels).sum().item()

total_accuracy = correct / total
total_loss = train_loss / total
end_time = time.time() - start_time
print('>> result for one epoch: {:.3}, time it takes {:.2}s'.format(total_accuracy, end_time))

>> result for one epoch: 0.99, time it takes 0.059s


### IMPRESSIVE RESULTS
> result for one epoch: 0.987, time it takes 0.071s