In [None]:
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import time
import sys
import glob
import math
import random
from IPython.display import clear_output

sys.path.insert(0, './SpikingNN')

from spiking_model import*

import train as yolo_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print("Running on", device)

In [None]:
# Randomize data point indices for the data set batches
def random_batch(seed, n, offset, train_split, validation_split, test_split):
    indices = [i + offset for i in range(n)]
    random.seed(seed)
    random.shuffle(indices)

    train = np.array(indices[train_split.start - offset : train_split.stop - offset])
    valid = np.array(indices[validation_split.start - offset: validation_split.stop - offset])
    test = np.array(indices[test_split.start - offset : test_split.stop - offset])
    
    if n != len(np.unique(np.concatenate((train, valid, test)))): # Verify results
        raise Exception("Malformed data batches", len(train), len(valid), len(test))
    return train, valid, test

up = torch.nn.Upsample(scale_factor=2, mode='nearest')

def inference(inference_type, conf_thres=0.1, iou_thres=0.5):
    with torch.no_grad():
        split = valid_set if inference_type == "validation" else test_set
        print("Running", inference_type)
        yolo.test_init()
        
        running_loss = 0
        pred_count = 0
        start_time = time.time()
        for i in split:
            data, targets = load_data(i)
            if len(targets) == 0 or i == 118 or i == 202:
                print("\tSkipping", i)
                continue

            snn.reset_potentials()

            for j in range(data.size()[0]):
                input_data = data[j] / 5
                snn.feed(input_data)

            # Feed intermediate output to YOLO
            intermediate = snn.collect() / data.size()[0] # Normalize ouputs to 0-1
                
            corner_to_center(targets)
            yolo.test(intermediate, targets, i, data[-1], whwh, conf_thres=conf_thres, iou_thres=iou_thres)
            pred_count += 1

            print("\t[%.0f] Real time: %.1fs" % 
              (i, time.time() - start_time))
            
        map_score, losses = yolo.test_end()
        print(inference_type.capitalize(), "Time: %.0fs | mAP: %.4f\n" % 
              (time.time()-start_time, map_score), "| Loss: ", losses, "\n")
        return map_score
                
def corner_to_center(targets):
    for l in range(len(targets)):
        # Move anchor from top left to center
        targets[l][2] += targets[l][4]/2
        targets[l][3] += targets[l][5]/2
        
def load_data(i):
    return (torch.load(video_path + str(i) + ".pt"), torch.load(bb_path + str(i) + ".pt"))
    
def save_snn(snn, save_path):
    torch.save(snn.state_dict(), snn_save_path)
    print('Saved SNN', snn_save_path)

In [None]:
# Dataset variables
video_path = "N-Caltech101/tensor_data/"
bb_path = "N-Caltech101/tensor_annotations/"
names = ["Airplane", "Motorbike"]  # Class names
num_epochs = 100
num_files = 1590
file_offset = 5
train_ratio = 0.8
validation_ratio = 0.2
test_ratio = 0
validate = True
test = False
width = 240
height = 176

# YOLO variables
yolo_load_file = './weights/spiking_best.pt'
yolo_save_path = './weights/spiking_'
yolo_cfg = "../cfg/spiking-yolo.cfg"  # YOLO config file
whwh = torch.Tensor([398,269,345,223]).to(device)  # Output scaling for YOLO (approximated)

# SNN variables
snn_load_file = "./weights/spiking_best.t7"
snn_save_path = "./weights/spiking_"
snn_lr = 1e-5
accum_grad = 25  # Number to batch gradients up before stepping back
show_grad = False
clip_thresh = math.inf  # Gradient clipping threshold (summed grad)
use_scheduler = True
sched_min = 1e-7  # Minimum lr multiplier for the scheduler
init_gain = 2.25  # Random weight scaling for SNN 
decay = 0.75      # This ratio of membrane potential is removed each iteration

load_snn = True
load_yolo = True
save = False

train_split = range(file_offset, int(file_offset + train_ratio * num_files))
validation_split = range(train_split.stop, int(train_split.stop + validation_ratio * num_files))
test_split = range(validation_split.stop, int(validation_split.stop + test_ratio * num_files))
train_set, valid_set, test_set = random_batch(78943522, num_files, file_offset, train_split, validation_split, test_split)

# Init SNN
snn = SCNN(device, decay=decay, init_gain=init_gain, input_size=(width, height, 2))
if load_snn:
    snn.load_state_dict(torch.load(snn_load_file))
    print("Loaded SNN", snn_load_file)

print(snn)
snn.to(device)
optimizer = torch.optim.Adam(snn.parameters(), lr=snn_lr)
#optimizer = torch.optim.SGD(snn.parameters(), lr=snn_lr, momentum=0.937, nesterov=True)

if use_scheduler:
    # Cosine decay (https://arxiv.org/pdf/1812.01187.pdf). Is multiplied with lr for the specific epoch
    lf = lambda x: (((1 + math.cos(x * math.pi / num_epochs)) / 2) ** 1.0) * (1 - sched_min) + sched_min
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    scheduler.last_epoch = 0 # starting_epoch - 1

# Init YOLO
yolo = yolo_model.Train(yolo_cfg, snn.c, num_epochs, accum_grad, train_split.stop - train_split.start, load_file=yolo_load_file)

In [None]:
print("Decay:", decay)
print("SNN lr:", snn_lr)
print("Scheduled lr:", sched_min)
print("Decay:", decay)
print("Init gain:", init_gain)
print("Scheduler:", use_scheduler)
print("Pre-trained:", load_snn, load_yolo, "\n")

inference("validation")

step_counter = 1
best_map = 0
maps = []
for epoch in range(num_epochs):
    print('Epoch [%d/%d]' % (epoch+1, num_epochs))
    running_loss = 0
    pred_count = 0
    start_time = time.time()
    
    for i in train_set:
        data, targets = load_data(i)
        
        if len(targets) == 0 or i == 118 or i == 202: # Malformed data points 
            print("\tSkip", i)
            continue
        
        snn.reset_potentials()

        # - Forward -
        for j in range(data.size()[0]):
            input_data = data[j] / 5
            snn.feed(input_data)
            
        # Feed intermediate output to YOLO
        intermediate = snn.collect() / data.size()[0] # Normalize ouputs to 0-1

        step = (step_counter == accum_grad)
        if step:
            step_counter = 1
        else:
            step_counter += 1
                
        loss = yolo.predict(intermediate, targets, step, i - file_offset, 
                            snn_grad=(snn.named_parameters() if show_grad and step else None))

        running_loss += loss.item()
        pred_count += 1
        
        # - Backward -
        if step:
            grad_size = torch.nn.utils.clip_grad_norm_(snn.parameters(), clip_thresh) # Avoid exploding gradient
            #if grad_size > clip_thresh: 
            #    print("Gradient of size", grad_size, "clipped")
            optimizer.step()
            snn.clear()
            snn.zero_grad()
            optimizer.zero_grad()

        print("\t[%.0f] Real time: %.1fs | Loss: %.5f | Spike-rate: %.0f" % 
          (i, time.time() - start_time, loss.item(), torch.sum(intermediate).item()), 
              "| Grad: %.4f" % grad_size if step else "")

    if use_scheduler:
        scheduler.step() 

    print("GPU: %.3gGB | Training loss: %.6f | Time: %.0fs\n" % 
          (torch.cuda.memory_cached() / 1e9, running_loss/pred_count, time.time()-start_time))
    
    if validate and validation_ratio:
        map_score = inference("validation")
        maps.append(map_score)
        
    yolo.end_epoch()

    if save:
        save_snn(snn, snn_save_path + "last.t7")
        yolo.save(yolo_save_path + "last.pt")
        if validate and validation_ratio:
            if map_score > best_map:
                save_snn(snn, snn_save_path + "best.t7")
                yolo.save(yolo_save_path + "best.pt")
                best_map = map_score


if len(maps) > 0:
    print("All mAP scores", maps)
        
if test and test_ratio:
    inference("test")