In [1]:
import os

import json
import torch
import numpy as np
import queue
import pprint
import random
import argparse
import importlib
import threading
import traceback

from tqdm import tqdm_notebook as tqdm
from config import system_configs
from nnet.py_factory import NetworkFactory
from torch.multiprocessing import Process, Queue, Pool
from db.datasets import datasets

In [2]:
print(torch.cuda.is_available())

True


In [3]:
torch.backends.cudnn.enabled   = True
torch.backends.cudnn.benchmark = True

In [4]:
def prefetch_data(db, queue, sample_data, data_aug):
    ind = 0
    print("start prefetching data...")
    np.random.seed(os.getpid())
    while True:
        try:
            data, ind = sample_data(db, ind, data_aug=data_aug)
            queue.put(data)
        except Exception as e:
            traceback.print_exc()
            raise e

def pin_memory(data_queue, pinned_data_queue, sema):
    while True:
        data = data_queue.get()

        data["xs"] = [x.pin_memory() for x in data["xs"]]
        data["ys"] = [y.pin_memory() for y in data["ys"]]

        pinned_data_queue.put(data)

        if sema.acquire(blocking=False):
            return

def init_parallel_jobs(dbs, queue, fn, data_aug):
    tasks = [Process(target=prefetch_data, args=(db, queue, fn, data_aug)) for db in dbs]
    for task in tasks:
        task.daemon = True
        task.start()
    return tasks

In [5]:
def train(training_dbs, validation_db, start_iter=0):
    learning_rate    = system_configs.learning_rate
    max_iteration    = system_configs.max_iter
    pretrained_model = system_configs.pretrain
    snapshot         = system_configs.snapshot
    val_iter         = system_configs.val_iter
    display          = system_configs.display
    decay_rate       = system_configs.decay_rate
    stepsize         = system_configs.stepsize

    # getting the size of each database
    training_size   = len(training_dbs[0].db_inds)
    validation_size = len(validation_db.db_inds)

    # queues storing data for training
    training_queue   = Queue(system_configs.prefetch_size)
    validation_queue = Queue(5)

    # queues storing pinned data for training
    pinned_training_queue   = queue.Queue(system_configs.prefetch_size)
    pinned_validation_queue = queue.Queue(5)

    # load data sampling function
    data_file   = "sample.{}".format(training_dbs[0].data)
    sample_data = importlib.import_module(data_file).sample_data

    # allocating resources for parallel reading
    training_tasks   = init_parallel_jobs(training_dbs, training_queue, sample_data, True)
    if val_iter:
        validation_tasks = init_parallel_jobs([validation_db], validation_queue, sample_data, False)

    training_pin_semaphore   = threading.Semaphore()
    validation_pin_semaphore = threading.Semaphore()
    training_pin_semaphore.acquire()
    validation_pin_semaphore.acquire()

    training_pin_args   = (training_queue, pinned_training_queue, training_pin_semaphore)
    training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args)
    training_pin_thread.daemon = True
    training_pin_thread.start()

    validation_pin_args   = (validation_queue, pinned_validation_queue, validation_pin_semaphore)
    validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args)
    validation_pin_thread.daemon = True
    validation_pin_thread.start()

    print("building model...")
    nnet = NetworkFactory(training_dbs[0])

    if pretrained_model is not None:
        if not os.path.exists(pretrained_model):
            raise ValueError("pretrained model does not exist")
        print("loading from pretrained model")
        nnet.load_pretrained_params(pretrained_model)

    if start_iter:
        learning_rate /= (decay_rate ** (start_iter // stepsize))

        nnet.load_params(start_iter)
        nnet.set_lr(learning_rate)
        print("training starts from iteration {} with learning_rate {}".format(start_iter + 1, learning_rate))
    else:
        nnet.set_lr(learning_rate)

    print("training start...")
    nnet.cuda()
    nnet.train_mode()
    for iteration in tqdm(range(start_iter + 1, max_iteration + 1)):
        training = pinned_training_queue.get(block=True)
        training_loss, focal_loss, pull_loss, push_loss, regr_loss = nnet.train(**training)
        #training_loss, focal_loss, pull_loss, push_loss, regr_loss, cls_loss = nnet.train(**training)

        if display and iteration % display == 0:
            print("training loss at iteration {}: {}".format(iteration, training_loss.item()))
            print("focal loss at iteration {}:    {}".format(iteration, focal_loss.item()))
            print("pull loss at iteration {}:     {}".format(iteration, pull_loss.item())) 
            print("push loss at iteration {}:     {}".format(iteration, push_loss.item()))
            print("regr loss at iteration {}:     {}".format(iteration, regr_loss.item()))
            #print("cls loss at iteration {}:      {}\n".format(iteration, cls_loss.item()))

        del training_loss, focal_loss, pull_loss, push_loss, regr_loss#, cls_loss

        if val_iter and validation_db.db_inds.size and iteration % val_iter == 0:
            nnet.eval_mode()
            validation = pinned_validation_queue.get(block=True)
            validation_loss = nnet.validate(**validation)
            print("validation loss at iteration {}: {}".format(iteration, validation_loss.item()))
            nnet.train_mode()

        if iteration % snapshot == 0:
            nnet.save_params(iteration)

        if iteration % stepsize == 0:
            learning_rate /= decay_rate
            nnet.set_lr(learning_rate)

    # sending signal to kill the thread
    training_pin_semaphore.release()
    validation_pin_semaphore.release()

    # terminating data fetching processes
    for training_task in training_tasks:
        training_task.terminate()
    for validation_task in validation_tasks:
        validation_task.terminate()

In [6]:
# def parse_args():
#     parser = argparse.ArgumentParser(description="Train CenterNet")
#     parser.add_argument("cfg_file", help="config file", type=str)
#     parser.add_argument("--iter", dest="start_iter",
#                         help="train at iteration i",
#                         default=0, type=int)
#     parser.add_argument("--threads", dest="threads", default=4, type=int)

#     #args = parser.parse_args()
#     args, unparsed = parser.parse_known_args()
#     return args
# args = parse_args()

args = {'cfg_file': 'CenterNet-52',
        'start_iter': 0,
        'threads': 1}

cfg_file = os.path.join(system_configs.config_dir, args['cfg_file'] + ".json")
with open(cfg_file, "r") as f:
    configs = json.load(f)

configs["system"]["snapshot_name"] = args['cfg_file']
system_configs.update_config(configs["system"])

train_split = system_configs.train_split
val_split   = system_configs.val_split

print("loading all datasets...")
dataset = system_configs.dataset
# threads = max(torch.cuda.device_count() * 2, 4)
threads = args['threads']
print("using {} threads".format(threads))
training_dbs  = [datasets[dataset](configs["db"], train_split) for _ in range(threads)]
validation_db = datasets[dataset](configs["db"], val_split)

print("system config...")
pprint.pprint(system_configs.full)

print("db config...")
pprint.pprint(training_dbs[0].configs)

print("len of db: {}".format(len(training_dbs[0].db_inds)))

loading all datasets...
using 1 threads
loading from cache file: cache/coco_trainval2014.pkl
loading annotations into memory...
Done (t=7.85s)
creating index...
index created!
loading from cache file: cache/coco_minival2014.pkl
loading annotations into memory...
Done (t=0.87s)
creating index...
index created!
system config...
{'batch_size': 4,
 'cache_dir': 'cache',
 'chunk_sizes': [6, 6, 6, 6, 6, 6, 6, 6],
 'config_dir': 'config',
 'data_dir': './data',
 'data_rng': RandomState(MT19937) at 0x7F7B8BB62A98,
 'dataset': 'MSCOCO',
 'decay_rate': 10,
 'display': 5,
 'learning_rate': 0.00025,
 'max_iter': 480000,
 'nnet_rng': RandomState(MT19937) at 0x7F7B8BB62CA8,
 'opt_algo': 'adam',
 'prefetch_size': 6,
 'pretrain': None,
 'result_dir': 'results',
 'sampling_function': 'kp_detection',
 'snapshot': 5000,
 'snapshot_name': 'CenterNet-52',
 'stepsize': 450000,
 'test_split': 'testdev',
 'train_split': 'trainval',
 'val_iter': 500,
 'val_split': 'minival',
 'weight_decay': False,
 'weight_de

In [7]:
train(training_dbs, validation_db, args['start_iter'])

start prefetching data...
shuffling indices...
start prefetching data...
shuffling indices...
building model...
module_file: models.CenterNet-52
total parameters: 104844152
setting learning rate to: 0.00025
training start...


HBox(children=(IntProgress(value=0, max=480000), HTML(value='')))

training loss at iteration 5: 420.2887878417969
focal loss at iteration 5:    419.49114990234375
pull loss at iteration 5:     0.016489481553435326
push loss at iteration 5:     0.33171120285987854
regr loss at iteration 5:     0.44944384694099426
training loss at iteration 10: 162.9038543701172
focal loss at iteration 10:    162.19921875
pull loss at iteration 10:     0.02546476200222969
push loss at iteration 10:     0.332808256149292
regr loss at iteration 10:     0.346366822719574
training loss at iteration 15: 149.5183563232422
focal loss at iteration 15:    148.78408813476562
pull loss at iteration 15:     0.08973975479602814
push loss at iteration 15:     0.3529437482357025
regr loss at iteration 15:     0.29157406091690063
training loss at iteration 20: 21.524978637695312
focal loss at iteration 20:    20.76079559326172
pull loss at iteration 20:     0.1051899716258049
push loss at iteration 20:     0.3286711871623993
regr loss at iteration 20:     0.3303218483924866
training l

Process Process-2:
Process Process-1:
Traceback (most recent call last):
  File "/home/teng/miniconda3/envs/py36/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/teng/miniconda3/envs/py36/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "<ipython-input-4-d0866c996085>", line 8, in prefetch_data
    queue.put(data)
  File "/home/teng/miniconda3/envs/py36/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/teng/miniconda3/envs/py36/lib/python3.6/multiprocessing/queues.py", line 82, in put
    if not self._sem.acquire(block, timeout):
  File "/home/teng/miniconda3/envs/py36/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
KeyboardInterrupt
  File "<ipython-input-4-d0866c996085>", line 8, in prefetch_data
    queue.put(data)
  File "/home/teng/minic

KeyboardInterrupt: 