Skip to content

Commit

Permalink
Add support for caffe-style ResNet and more flexible pretrained model…
Browse files Browse the repository at this point in the history
… loading (#51)

* add support for resnetv1

* fix ordereddict manipulation

* add caffe-style resnet results in README

* update cfg files for caffe-style resnet
  • Loading branch information
bowenc0221 authored and leoxiaobin committed Nov 12, 2018
1 parent 7efb8ec commit 2d723e3
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 2 deletions.
9 changes: 9 additions & 0 deletions README.md
Expand Up @@ -31,9 +31,18 @@ This is an official pytorch implementation of [*Simple Baselines for Human Pose
| 256x192_pose_resnet_152_d256d256d256 | 0.720 | 0.893 | 0.798 | 0.687 | 0.789 | 0.778 | 0.934 | 0.846 | 0.736 | 0.839 |
| 384x288_pose_resnet_152_d256d256d256 | 0.743 | 0.896 | 0.811 | 0.705 | 0.816 | 0.797 | 0.937 | 0.858 | 0.751 | 0.863 |


#### Results on Caffe-style ResNet
| Arch | AP | Ap .5 | AP .75 | AP (M) | AP (L) | AR | AR .5 | AR .75 | AR (M) | AR (L) |
|---|---|---|---|---|---|---|---|---|---|---|
| 256x192_pose_resnet_50_caffe_d256d256d256 | 0.704 | 0.914 | 0.782 | 0.677 | 0.744 | 0.735 | 0.921 | 0.805 | 0.704 | 0.783 |
| 256x192_pose_resnet_101_caffe_d256d256d256 | 0.720 | 0.915 | 0.803 | 0.693 | 0.764 | 0.753 | 0.928 | 0.821 | 0.720 | 0.802 |


### Note:
- Flip test is used.
- Person detector has person AP of 56.4 on COCO val2017 dataset.
- Difference between PyTorch-style and Caffe-style ResNet is the position of stride=2 convolution

## Environment
The code is developed using python 3.6 on Ubuntu 16.04. NVIDIA GPUs are needed. The code is developed and tested using 4 NVIDIA P100 GPU cards. Other platforms or GPU cards are not fully tested.
Expand Down
77 changes: 77 additions & 0 deletions experiments/coco/resnet101/256x192_d256x3_adam_lr1e-3_caffe.yaml
@@ -0,0 +1,77 @@
GPUS: '0'
DATA_DIR: ''
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 4
PRINT_FREQ: 100

DATASET:
DATASET: 'coco'
ROOT: 'data/coco/'
TEST_SET: 'val2017'
TRAIN_SET: 'train2017'
FLIP: true
ROT_FACTOR: 40
SCALE_FACTOR: 0.3
MODEL:
NAME: 'pose_resnet'
PRETRAINED: 'models/pytorch/imagenet/resnet101-caffe.pth.tar'
STYLE: 'caffe'
IMAGE_SIZE:
- 192
- 256
NUM_JOINTS: 17
EXTRA:
TARGET_TYPE: 'gaussian'
HEATMAP_SIZE:
- 48
- 64
SIGMA: 2
FINAL_CONV_KERNEL: 1
DECONV_WITH_BIAS: false
NUM_DECONV_LAYERS: 3
NUM_DECONV_FILTERS:
- 256
- 256
- 256
NUM_DECONV_KERNELS:
- 4
- 4
- 4
NUM_LAYERS: 101
LOSS:
USE_TARGET_WEIGHT: true
TRAIN:
BATCH_SIZE: 32
SHUFFLE: true
BEGIN_EPOCH: 0
END_EPOCH: 140
RESUME: false
OPTIMIZER: 'adam'
LR: 0.001
LR_FACTOR: 0.1
LR_STEP:
- 90
- 120
WD: 0.0001
GAMMA1: 0.99
GAMMA2: 0.0
MOMENTUM: 0.9
NESTEROV: false
TEST:
BATCH_SIZE: 32
COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
BBOX_THRE: 1.0
FLIP_TEST: false
IMAGE_THRE: 0.0
IN_VIS_THRE: 0.2
MODEL_FILE: ''
NMS_THRE: 1.0
OKS_THRE: 0.9
USE_GT_BBOX: true
DEBUG:
DEBUG: true
SAVE_BATCH_IMAGES_GT: true
SAVE_BATCH_IMAGES_PRED: true
SAVE_HEATMAPS_GT: true
SAVE_HEATMAPS_PRED: true
77 changes: 77 additions & 0 deletions experiments/coco/resnet152/256x192_d256x3_adam_lr1e-3_caffe.yaml
@@ -0,0 +1,77 @@
GPUS: '0'
DATA_DIR: ''
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 4
PRINT_FREQ: 100

DATASET:
DATASET: 'coco'
ROOT: 'data/coco/'
TEST_SET: 'val2017'
TRAIN_SET: 'train2017'
FLIP: true
ROT_FACTOR: 40
SCALE_FACTOR: 0.3
MODEL:
NAME: 'pose_resnet'
PRETRAINED: 'models/pytorch/imagenet/resnet152-caffe.pth.tar'
STYLE: 'caffe'
IMAGE_SIZE:
- 192
- 256
NUM_JOINTS: 17
EXTRA:
TARGET_TYPE: 'gaussian'
HEATMAP_SIZE:
- 48
- 64
SIGMA: 2
FINAL_CONV_KERNEL: 1
DECONV_WITH_BIAS: false
NUM_DECONV_LAYERS: 3
NUM_DECONV_FILTERS:
- 256
- 256
- 256
NUM_DECONV_KERNELS:
- 4
- 4
- 4
NUM_LAYERS: 152
LOSS:
USE_TARGET_WEIGHT: true
TRAIN:
BATCH_SIZE: 32
SHUFFLE: true
BEGIN_EPOCH: 0
END_EPOCH: 140
RESUME: false
OPTIMIZER: 'adam'
LR: 0.001
LR_FACTOR: 0.1
LR_STEP:
- 90
- 120
WD: 0.0001
GAMMA1: 0.99
GAMMA2: 0.0
MOMENTUM: 0.9
NESTEROV: false
TEST:
BATCH_SIZE: 32
COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
BBOX_THRE: 1.0
FLIP_TEST: false
IMAGE_THRE: 0.0
IN_VIS_THRE: 0.2
MODEL_FILE: ''
NMS_THRE: 1.0
OKS_THRE: 0.9
USE_GT_BBOX: true
DEBUG:
DEBUG: true
SAVE_BATCH_IMAGES_GT: true
SAVE_BATCH_IMAGES_PRED: true
SAVE_HEATMAPS_GT: true
SAVE_HEATMAPS_PRED: true
77 changes: 77 additions & 0 deletions experiments/coco/resnet50/256x192_d256x3_adam_lr1e-3_caffe.yaml
@@ -0,0 +1,77 @@
GPUS: '0'
DATA_DIR: ''
OUTPUT_DIR: 'output'
LOG_DIR: 'log'
WORKERS: 4
PRINT_FREQ: 100

DATASET:
DATASET: 'coco'
ROOT: 'data/coco/'
TEST_SET: 'val2017'
TRAIN_SET: 'train2017'
FLIP: true
ROT_FACTOR: 40
SCALE_FACTOR: 0.3
MODEL:
NAME: 'pose_resnet'
PRETRAINED: 'models/pytorch/imagenet/resnet50-caffe.pth.tar'
STYLE: 'caffe'
IMAGE_SIZE:
- 192
- 256
NUM_JOINTS: 17
EXTRA:
TARGET_TYPE: 'gaussian'
HEATMAP_SIZE:
- 48
- 64
SIGMA: 2
FINAL_CONV_KERNEL: 1
DECONV_WITH_BIAS: false
NUM_DECONV_LAYERS: 3
NUM_DECONV_FILTERS:
- 256
- 256
- 256
NUM_DECONV_KERNELS:
- 4
- 4
- 4
NUM_LAYERS: 50
LOSS:
USE_TARGET_WEIGHT: true
TRAIN:
BATCH_SIZE: 32
SHUFFLE: true
BEGIN_EPOCH: 0
END_EPOCH: 140
RESUME: false
OPTIMIZER: 'adam'
LR: 0.001
LR_FACTOR: 0.1
LR_STEP:
- 90
- 120
WD: 0.0001
GAMMA1: 0.99
GAMMA2: 0.0
MOMENTUM: 0.9
NESTEROV: false
TEST:
BATCH_SIZE: 32
COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
BBOX_THRE: 1.0
FLIP_TEST: false
IMAGE_THRE: 0.0
IN_VIS_THRE: 0.2
MODEL_FILE: ''
NMS_THRE: 1.0
OKS_THRE: 0.9
USE_GT_BBOX: true
DEBUG:
DEBUG: true
SAVE_BATCH_IMAGES_GT: true
SAVE_BATCH_IMAGES_PRED: true
SAVE_HEATMAPS_GT: true
SAVE_HEATMAPS_PRED: true
2 changes: 2 additions & 0 deletions lib/core/config.py
Expand Up @@ -55,6 +55,8 @@
config.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]

config.MODEL.STYLE = 'pytorch'

config.LOSS = edict()
config.LOSS.USE_TARGET_WEIGHT = True

Expand Down
69 changes: 67 additions & 2 deletions lib/models/pose_resnet.py
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.nn as nn
from collections import OrderedDict


BN_MOMENTUM = 0.1
Expand Down Expand Up @@ -98,6 +99,48 @@ def forward(self, x):
return out


class Bottleneck_CAFFE(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck_CAFFE, self).__init__()
# add stride to conv1x1
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
residual = x

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
out = self.bn3(out)

if self.downsample is not None:
residual = self.downsample(x)

out += residual
out = self.relu(out)

return out


class PoseResNet(nn.Module):

def __init__(self, block, layers, cfg, **kwargs):
Expand Down Expand Up @@ -228,9 +271,27 @@ def init_weights(self, pretrained=''):
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)

pretrained_state_dict = torch.load(pretrained)
# pretrained_state_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
self.load_state_dict(pretrained_state_dict, strict=False)
# self.load_state_dict(pretrained_state_dict, strict=False)
checkpoint = torch.load(pretrained)
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict_old = checkpoint['state_dict']
state_dict = OrderedDict()
# delete 'module.' because it is saved from DataParallel module
for key in state_dict_old.keys():
if key.startswith('module.'):
# state_dict[key[7:]] = state_dict[key]
# state_dict.pop(key)
state_dict[key[7:]] = state_dict_old[key]
else:
state_dict[key] = state_dict_old[key]
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(pretrained))
self.load_state_dict(state_dict, strict=False)
else:
logger.error('=> imagenet pretrained model dose not exist')
logger.error('=> please download it first')
Expand All @@ -246,9 +307,13 @@ def init_weights(self, pretrained=''):

def get_pose_net(cfg, is_train, **kwargs):
num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
style = cfg.MODEL.STYLE

block_class, layers = resnet_spec[num_layers]

if style == 'caffe':
block_class = Bottleneck_CAFFE

model = PoseResNet(block_class, layers, cfg, **kwargs)

if is_train and cfg.MODEL.INIT_WEIGHTS:
Expand Down

0 comments on commit 2d723e3

Please sign in to comment.