In [None]:
from __future__ import print_function
import torchvision
import torchvision.transforms as transforms
import numpy as np
import os
import time
import sys
import glob
import math
import random
from IPython.display import clear_output

sys.path.insert(0, './SpikingNN')

from spiking_model import*
#from automotive_dataset.data_loader import Prophesee

import train as yolo_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print("Running on", device)

In [None]:
# Randomize data point indices for the data set batches
def random_batch(seed, n, offset, train_split, validation_split, test_split):
    indices = [i + offset for i in range(n)]
    random.seed(seed)
    random.shuffle(indices)

    train = np.array(indices[train_split.start - offset : train_split.stop - offset])
    valid = np.array(indices[validation_split.start - offset: validation_split.stop - offset])
    test = np.array(indices[test_split.start - offset : test_split.stop - offset])
    
    if n != len(np.unique(np.concatenate((train, valid, test)))): # Verify results
        raise Exception("Malformed data batches", len(train), len(valid), len(test))
    return train, valid, test

up = torch.nn.Upsample(scale_factor=2, mode='nearest')

def inference(inference_type, conf_thres=0.1, iou_thres=0.5):
    with torch.no_grad():
        split = valid_set if inference_type == "validation" else test_set
        print("Running", inference_type)
        yolo.test_init()
        
        running_loss = 0
        pred_count = 0
        start_time = time.time()
        for i in split:
            data, targets = load_data(i)

            if len(targets) == 0 or i == 118 or i == 202:
                print("\tSkipping", i)
                continue

            # - Forward -
            input_data = torch.zeros((1, 2, height, width), device=device)
            for j in range(data.size()[0]):
                input_data += data[j]

            input_data = (input_data / torch.max(input_data)) # Normalize to range 0-1

            corner_to_center(targets)
            yolo.test(input_data, targets, i, up(input_data), whwh, conf_thres=conf_thres, iou_thres=iou_thres)
            pred_count += 1

            print("\t[%.0f] Real time: %.1fs" % 
              (i+1, time.time() - start_time))
            
        map_score, losses = yolo.test_end()
        print(inference_type.capitalize(), "Time: %.0fs | MAP: %.4f" % 
              (time.time()-start_time, map_score), "| Loss: ", losses, "\n")
        return map_score

def load_data(i):
    return (torch.load(video_path + str(i) + ".pt"), torch.load(bb_path + str(i) + ".pt"))

def corner_to_center(targets):
    for l in range(len(targets)):
        # Move anchor from top left to center
        targets[l][2] += targets[l][4]/2
        targets[l][3] += targets[l][5]/2

In [None]:
# Dataset variables
video_path = "N-Caltech101/tensor_data/" #"/home/olfjoh-5/sdc1/event_bin/train/"
bb_path = "N-Caltech101/tensor_annotations/"
names = ["Airplane", "Motorbike"]  # Class names
num_epochs = 100
num_files = 1590
file_offset = 5
train_ratio = 0.8
validation_ratio = 0.2
test_ratio = 0
validate = True
test = False
width = 240
height = 176

# YOLO variables
yolo_load_file = './weights/original_best.pt'
yolo_save_path = './weights/original_'
yolo_cfg = "../cfg/pure-yolo-deeper.cfg"  # YOLO config file
whwh = torch.Tensor([width*2,height*2,width*2,height*2]).to(device)  # Output scaling for YOLO

load_yolo = True
save = False
accum_grad = 25

train_split = range(file_offset, int(file_offset + train_ratio * num_files))
validation_split = range(train_split.stop, int(train_split.stop + validation_ratio * num_files))
test_split = range(validation_split.stop, int(validation_split.stop + test_ratio * num_files))
train_set, valid_set, test_set = random_batch(78943522, num_files, file_offset, train_split, validation_split, test_split)

# Init YOLO
yolo = yolo_model.Train(yolo_cfg, 2, num_epochs, accum_grad, train_split.stop - train_split.start, load_file=yolo_load_file)

In [None]:
print("Pre-trained:", load_yolo, "\n")

step_counter = 1
best_map = 0
maps = []
inference("validation")
for epoch in range(num_epochs):
    print('Epoch [%d/%d]' % (epoch+1, num_epochs))
    running_loss = 0
    pred_count = 0
    start_time = time.time()
    
    for i in train_set:
        data, targets = load_data(i)
        
        if len(targets) == 0 or i == 118 or i == 202: # Malformed data points 
            print("\tSkipping", i)
            continue
        
        # - Forward -
        input_data = torch.zeros((1, 2, height, width), device=device)
        for j in range(data.size()[0]):
            input_data += data[j]
        
        input_data = (input_data / torch.max(input_data))  # Normalize to range 0-1
        
        step = (step_counter == accum_grad)
        if step:
            step_counter = 1
        else:
            step_counter += 1
            
        loss = yolo.predict(input_data, targets, step, i - file_offset, snn_grad=None)

        running_loss += loss.item()
        pred_count += 1
        
        print("\t[%.0fs] Real time: %.1fs | Loss: %.5f " % 
          (i, time.time() - start_time, loss.item()))

    print("GPU: %.3gGB | Training loss: %.6f | Time: %.0fs\n" % 
          (torch.cuda.memory_cached() / 1e9, running_loss/pred_count, time.time()-start_time))
    
    if validate and validation_ratio:
        map_score = inference("validation")
        maps.append(map_score)

    yolo.end_epoch()

    if save:
        yolo.save(yolo_save_path + "last.pt")
        if validate and validation_ratio:
            if map_score > best_map:
                yolo.save(yolo_save_path + "best.pt")
                best_map = map_score
            
if len(maps) > 0:
    print("All mAP scores", maps)

if test and test_ratio:
    inference("test")