# Pose Estimation Training

In [1]:
import argparse
from collections import OrderedDict
import importlib
import os
import time

import gluoncv as gcv
from gluoncv.model_zoo import get_model
from mxboard import SummaryWriter    
import mxnet as mx
from mxnet import gluon, autograd, nd
from mxnet.gluon import nn, loss
import numpy as np

import multi_pose.models
import multi_pose.datasets.coco_data
importlib.reload(multi_pose.models)
importlib.reload(multi_pose.datasets.coco_data)

from multi_pose.models import build_model
from multi_pose.datasets.coco_data import get_loader
from multi_pose.utils import AverageMeter

### Configuration

In [2]:
data_dir = 'data/dataset/COCO/images'
mask_dir = 'data/dataset/COCO/mask'
logdir = 'logs'
json_path = 'data/dataset/COCO/COCO.json'
model_path = 'model_checkpoints/'
lr = 0.001                    
momentum = 0.9
epochs_ft = 50
epochs_pre = 50
wd = 0.0                    
nesterov = False
optim = 'sgd'
gpuIDs = [0]
batch_size = 16
print_freq = 20
load_model = ''
log_key = 'notebook_tests'
model_trunk='resnet18_v1b'
dtype='float32'

ctx = [mx.gpu(e) for e in gpuIDs] if gpuIDs[0] != -1 else [mx.cpu()]
ctx = ctx[0] # single GPU for now

params_transform = dict()
params_transform['mode'] = 5
# === aug_scale ===
params_transform['scale_min'] = 0.8
params_transform['scale_max'] = 1.2
params_transform['scale_prob'] = 1
params_transform['target_dist'] = 0.6
# === aug_rotate ===
params_transform['max_rotate_degree'] = 20

# ===
params_transform['center_perterb_max'] = 20

# === aug_flip ===
params_transform['flip_prob'] = 0.5

params_transform['np'] = 56
params_transform['sigma'] = 7.0
params_transform['limb_width'] = 1.

Helper classes and functions

In [3]:
def build_names():
    names = []
    for j in range(1, 7):
        for k in range(1, 3):
            names.append('loss_stage%d_L%d' % (j, k))
    return names

Training and evaluation loops

## Data Loading

In [4]:
downsample = 8 if model_trunk == 'mobilenet' or model_trunk == 'vgg19' else 4

In [5]:
print("Loading dataset...")
# load data
train_data = get_loader(json_path, data_dir, mask_dir, 384, downsample, batch_size, params_transform = params_transform, shuffle=True, training=True, num_workers=8)
print('train dataset len: {}'.format(len(train_data._dataset)))

# validation data
valid_data = get_loader(json_path, data_dir, mask_dir, 384, downsample, training=False, batch_size=batch_size, params_transform = params_transform, shuffle=False, num_workers=8)
print('val dataset len: {}'.format(len(valid_data._dataset)))

Loading dataset...
train dataset len: 121522
val dataset len: 4873


## Network

Creating the network

In [6]:
model_trunk='resnet101_v1b'

In [7]:
model = build_model(trunk=model_trunk, pretrained_ctx=ctx, is_train=True, num_joints=19)

In [None]:
if load_model != '':
    model.load_parameters(os.path.join(model_path, load_model), ctx=ctx)

In [9]:
model.hybridize(static_shape=True, static_alloc=True)

Training first with backbone fixed

### Loss Function

In [10]:
def get_loss(saved_for_loss, heat_temp, heat_weight,
               vec_temp, vec_weight):

    names = build_names()
    saved_for_log = OrderedDict()
    total_loss = 0
    loss_fn = gluon.loss.L2Loss()
    for j in range(len(saved_for_loss)//2):
        pred1 = saved_for_loss[2 * j] * vec_weight 
        gt1 = vec_temp * vec_weight
        pred2 = saved_for_loss[2 * j + 1] * heat_weight 
        gt2 = heat_weight * heat_temp

        # Compute losses
        loss1 = loss_fn(pred1, gt1)
        loss2 = loss_fn(pred2, gt2)

        total_loss = total_loss + loss1
        total_loss = total_loss + loss2
        saved_for_log[names[2 * j]] = loss1.mean().asscalar()
        saved_for_log[names[2 * j + 1]] = loss2.mean().asscalar()

    saved_for_log['max_ht'] = saved_for_loss[-1][:, 0:-1, :, :].asnumpy().max()
    saved_for_log['min_ht'] = saved_for_loss[-1][:, 0:-1, :, :].asnumpy().min()
    saved_for_log['max_paf'] = saved_for_loss[-2].asnumpy().max()
    saved_for_log['min_paf'] = saved_for_loss[-2].asnumpy().min()

    return total_loss, saved_for_log

### Training Utils

In [11]:
def run_epoch(iterator, model, epoch, is_train=True, trainer_trunk=None, trainer_pose=None, dtype='float32'):
    batch_time = AverageMeter()
    losses = AverageMeter()
    model.cast(dtype)
    meter_dict = {}
    for name in build_names():
        meter_dict[name] = AverageMeter()
    meter_dict['max_ht'] = AverageMeter()
    meter_dict['min_ht'] = AverageMeter()    
    meter_dict['max_paf'] = AverageMeter()    
    meter_dict['min_paf'] = AverageMeter()
    
    end = time.time()
    
    for i, (img, heatmap_target, heat_mask, paf_target, paf_mask) in enumerate(iterator):
        img = img.as_in_context(ctx).astype(dtype, copy=False)
        heatmap_target = heatmap_target.as_in_context(ctx).astype(dtype, copy=False)
        heat_mask = heat_mask.as_in_context(ctx).astype(dtype, copy=False)
        paf_target = paf_target.as_in_context(ctx).astype(dtype, copy=False)
        paf_mask = paf_mask.as_in_context(ctx).astype(dtype, copy=False)
                
        with autograd.record(is_train):
            # compute output
            out = model(img)
            if type(out[0]) == tuple: # vgg19 or mobilenet
                total_loss, saved_for_log = get_loss(out[1], heatmap_target, heat_mask,
                       paf_target, paf_mask)
            else: # resnet
                total_loss, saved_for_log = get_loss(list(out), heatmap_target, heat_mask,
                       paf_target, paf_mask)
        
        for name,_ in saved_for_log.items():
            meter_dict[name].update(saved_for_log[name], img.shape[0])
        losses.update(total_loss.astype('float32').mean().asscalar(), img.shape[0])

        if is_train:
            total_loss.backward()
            if trainer_trunk is not None:
                trainer_trunk.step(img.shape[0])
            trainer_pose.step(img.shape[0])

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % print_freq == 0 and is_train:
            print('Epoch: [{0}][{1}/{2}]\t'.format(epoch, i, len(iterator)))
            print('Data time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format( batch_time=batch_time))
            print('Loss {loss.val:.4f} ({loss.avg:.4f})'.format(loss=losses))
            writer.add_scalar('data/max_ht', {log_key:meter_dict['max_ht'].avg}, i+epoch*len(iterator))
            writer.add_scalar('data/max_paf', {log_key:meter_dict['max_paf'].avg}, i+epoch*len(iterator))
            writer.add_scalar('data/loss', {log_key:losses.avg}, i+epoch*len(iterator)),
            for name in saved_for_log:
                print('{name}: {loss.val:.4f} ({loss.avg:.4f})\t'.format(name=name, loss=meter_dict[name]))
            writer.flush()
            print()
    return losses.avg
    

In [12]:
log_key=model_trunk+"refactor"
writer = SummaryWriter(logdir=logdir)   

In [None]:
# Fix the  pre-trained weights for now
if model_trunk == 'vgg19':
    trainer_trunk = gluon.Trainer(model.model0.collect_params('.*CPM.*'), 'sgd', {'learning_rate':lr, 'momentum': momentum, 'wd':wd})
    trainer_pose = gluon.Trainer(model.collect_params('block.*'), 'sgd', {'learning_rate':lr, 'momentum': momentum, 'wd':wd}) 
elif model_trunk =='mobilenet':
    trainer_pose = gluon.Trainer(model.collect_params('block.*'), 'sgd', {'learning_rate':lr, 'momentum': momentum, 'wd':wd}) 
    trainer_trunk = None
elif 'resnet' in model_trunk:
    trainer_trunk = gluon.Trainer(model.collect_params('.*resnet.*'), 'adam', {'learning_rate':0.001})
    trainer_pose = gluon.Trainer(model.collect_params('.*final.*'), 'adam', {'learning_rate':0.001}) 

                                                                                          
    
for epoch in range(epochs_pre):
    # train for one epoch
    train_loss = run_epoch(train_data, model, epoch, is_train=True, trainer_trunk=trainer_trunk, trainer_pose=trainer_pose)
    model.save_parameters(os.path.join(model_path, log_key+'_'+model_trunk+'_pose2_'+str(epoch)+'.params'))
    # evaluate on validation set
    val_loss = run_epoch(valid_data, model, epoch, is_train=False)  
                  
    writer.add_scalar('epoch/train_loss', {log_key: train_loss}, epoch)
    writer.add_scalar('epoch/val_loss', {log_key: val_loss}, epoch)       

Epoch: [0][0/7596]	
Data time 7.833 (7.833)	
Loss 0.0256 (0.0256)
loss_stage1_L1: 0.0009 (0.0009)	
loss_stage1_L2: 0.0247 (0.0247)	
max_ht: 0.1141 (0.1141)	
min_ht: -0.1305 (-0.1305)	
max_paf: 0.1436 (0.1436)	
min_paf: -0.1307 (-0.1307)	

Epoch: [0][20/7596]	
Data time 0.316 (0.686)	
Loss 0.0043 (0.0079)
loss_stage1_L1: 0.0009 (0.0015)	
loss_stage1_L2: 0.0034 (0.0064)	
max_ht: 0.0638 (0.2051)	
min_ht: -0.0634 (-0.1967)	
max_paf: 0.0685 (0.2161)	
min_paf: -0.0731 (-0.2152)	

Epoch: [0][40/7596]	
Data time 0.335 (0.514)	
Loss 0.0024 (0.0054)
loss_stage1_L1: 0.0009 (0.0012)	
loss_stage1_L2: 0.0015 (0.0042)	
max_ht: 0.0373 (0.1497)	
min_ht: -0.0350 (-0.1361)	
max_paf: 0.0403 (0.1611)	
min_paf: -0.0381 (-0.1545)	

Epoch: [0][60/7596]	
Data time 0.327 (0.454)	
Loss 0.0020 (0.0044)
loss_stage1_L1: 0.0007 (0.0011)	
loss_stage1_L2: 0.0012 (0.0033)	
max_ht: 0.0366 (0.1126)	
min_ht: -0.0062 (-0.0993)	
max_paf: 0.0498 (0.1226)	
min_paf: -0.0240 (-0.1159)	

Epoch: [0][80/7596]	
Data time 0.327 (0.4

Fine-tuning the model

In [None]:
optim ='adam'
lr = 0.001
wd = 0.000000

In [None]:
if model_trunk == 'vgg19':
    trainer_trunk = gluon.Trainer(model.model0.collect_params('.*vgg19_.*'), 'adam', {'learning_rate':lr*0.1, 'wd':wd})
    trainer_pose = gluon.Trainer(model.collect_params('block.*'), 'adam', {'learning_rate':lr*0.1, 'wd':wd}) 
elif model_trunk =='mobilenet':
    trainer_trunk =  gluon.Trainer(model.model0.collect_params('.*mobilenet.*'), 'adam', {'learning_rate':lr*0.1, 'wd':wd})
    trainer_pose = gluon.Trainer(model.collect_params('block.*'), 'adam', {'learning_rate':lr*0.1,  'wd':wd}) 
elif 'resnet' in model_trunk:
    trainer_trunk = gluon.Trainer(model.collect_params('.*resnet.*'), 'adam', {'learning_rate':lr*0.1})
    trainer_pose = gluon.Trainer(model.collect_params('.*final.*'), 'adam', {'learning_rate':lr*0.1})   

log_key += '_ft'        

for epoch in range(epochs_pre, epochs_pre+epochs_ft):
    # train for one epoch
    train_loss = run_epoch(train_data, model, epoch, is_train=True, trainer_trunk=trainer_trunk, trainer_pose=trainer_pose, dtype=dtype)
    model.save_parameters(os.path.join(model_path, log_key+'_'+model_trunk+'_pose_ft_'+str(epoch)+'_'+dtype+'.params'))
    # evaluate on validation set
    val_loss = run_epoch(valid_data, model, epoch, is_train=False)  
                                 
    writer.add_scalar('epoch_ft/train_loss', {log_key: train_loss}, epoch)
    writer.add_scalar('epoch_ft/val_loss', {log_key: val_loss}, epoch)                                                                

In [None]:
writer.close()    

### Exporting the model

In [22]:
ctx = mx.gpu(1)

Exporting the symbol for the heatmap branch

In [21]:
model(mx.sym.var('data'))[0][1].save('model_checkpoints/export_mobilenet-heatmap-symbol.json')

Exporting the parameters for everything

In [15]:
model.export('model_checkpoints/export_mobilenet', 0)

Loading the parameters for everything with `ignore_extra=True` and then exporting again to have a slimmer faster network

In [25]:
model_heatmap = gluon.nn.SymbolBlock.imports(
    symbol_file='model_checkpoints/export_mobilenet-heatmap-symbol.json',
    input_names=['data'],
    ctx=ctx)
model_heatmap.load_parameters('model_checkpoints/export_mobilenet-0000.params', ctx=ctx, ignore_extra=True)
model_heatmap.hybridize()
model_heatmap.export('model_checkpoints/export_mobilenet_heatmap', 0)

Loading again the small network from scratch and testing runtime

In [67]:
ctx = mx.gpu(2)
model_heatmap = gluon.nn.SymbolBlock.imports(
    symbol_file='model_checkpoints/export_mobilenet_heatmap-symbol.json',
    param_file='model_checkpoints/export_mobilenet_heatmap-0000.params',
    input_names=['data'],
    ctx=ctx)

In [68]:
model_heatmap.hybridize(static_shape=True, static_alloc=True)
out = model_heatmap(mx.nd.ones((1,3,368,368), ctx=ctx))
out.wait_to_read()

In [69]:
%%time
model_heatmap(mx.nd.ones((1,3,368,368), ctx=ctx)).wait_to_read()

CPU times: user 16 ms, sys: 16 ms, total: 32 ms
Wall time: 30.3 ms


Float16 inference

In [70]:
for key, item in model_heatmap.collect_params().items():
    if not ('gamma' in key or 'beta' in key or 'running_mean' in key or 'running_var' in key):
        item.cast('float16')