Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hand keypoints and train with Freihand dataset #534

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions alphapose/datasets/Freihand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# -----------------------------------------------------

import copy
import os
import pickle as pk

import numpy as np

from abc import abstractmethod, abstractproperty

import scipy.misc
import torch.utils.data as data
from pycocotools.coco import COCO

from alphapose.models.builder import DATASET
from alphapose.utils.bbox import bbox_clip_xyxy, bbox_xywh_to_xyxy

from alphapose.utils.presets import SimpleTransform


class CustomDataset(data.Dataset):
"""Custom dataset.
Annotation file must be in `coco` format.

Parameters
----------
train: bool, default is True
If true, will set as training mode.
dpg: bool, default is False
If true, will activate `dpg` for data augmentation.
skip_empty: bool, default is False
Whether skip entire image if no valid label is found.
cfg: dict, dataset configuration.
"""

CLASSES = ['hand']
EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
num_joints = 21

def __init__(self,
train=True,
dpg=False,
skip_empty=True,
lazy_import=False,
**cfg):

self._cfg = cfg
self._preset_cfg = cfg['PRESET']
self._root = cfg['ROOT']
self._img_prefix = cfg['IMG_PREFIX']
self._ann_file = os.path.join(self._root, cfg['ANN'])

self._lazy_import = lazy_import
self._skip_empty = skip_empty
self._train = train
self._dpg = dpg

if 'AUG' in cfg.keys():
self._scale_factor = cfg['AUG']['SCALE_FACTOR']
self._rot = cfg['AUG']['ROT_FACTOR']
self.num_joints_half_body = cfg['AUG']['NUM_JOINTS_HALF_BODY']
self.prob_half_body = cfg['AUG']['PROB_HALF_BODY']
else:
self._scale_factor = 0
self._rot = 0
self.num_joints_half_body = -1
self.prob_half_body = -1

self._input_size = self._preset_cfg['IMAGE_SIZE']
self._output_size = self._preset_cfg['HEATMAP_SIZE']

self._sigma = self._preset_cfg['SIGMA']

self._check_centers = False

self.num_class = len(self.CLASSES)

self.upper_body_ids = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
self.lower_body_ids = (11, 12, 13, 14, 15, 16)

if self._preset_cfg['TYPE'] == 'simple':
self.transformation = SimpleTransform(
self, scale_factor=self._scale_factor,
input_size=self._input_size,
output_size=self._output_size,
rot=self._rot, sigma=self._sigma,
train=self._train, add_dpg=self._dpg)
else:
raise NotImplementedError

self._items, self._labels = self._lazy_load_json()

def __getitem__(self, idx):
# get image id
img_path = self._items[idx]
img_id = int(os.path.splitext(os.path.basename(img_path))[0])

# load ground truth, including bbox, keypoints, image size
label = copy.deepcopy(self._labels[idx])
img = scipy.misc.imread(img_path, mode='RGB')

# transform ground truth into training label and apply data augmentation
img, label, label_mask, bbox = self.transformation(img, label)
return img, label, label_mask, img_id, bbox

def __len__(self):
return len(self._items)

def _lazy_load_ann_file(self):
if os.path.exists(self._ann_file + '.pkl') and self._lazy_import:
print('Lazy load json...')
with open(self._ann_file + '.pkl', 'rb') as fid:
return pk.load(fid)
else:
_database = COCO(self._ann_file)
if os.access(self._ann_file + '.pkl', os.W_OK):
with open(self._ann_file + '.pkl', 'wb') as fid:
pk.dump(_database, fid, pk.HIGHEST_PROTOCOL)
return _database

def _lazy_load_json(self):
if os.path.exists(self._ann_file + '_annot_keypoint.pkl') and self._lazy_import:
print('Lazy load annot...')
with open(self._ann_file + '_annot_keypoint.pkl', 'rb') as fid:
items, labels = pk.load(fid)
else:
items, labels = self._load_jsons()
if os.access(self._ann_file + '_annot_keypoint.pkl', os.W_OK):
with open(self._ann_file + '_annot_keypoint.pkl', 'wb') as fid:
pk.dump((items, labels), fid, pk.HIGHEST_PROTOCOL)

return items, labels

@abstractmethod
def _load_jsons(self):
pass

@abstractproperty
def CLASSES(self):
return None

@abstractproperty
def num_joints(self):
return None

@abstractproperty
def joint_pairs(self):
"""Joint pairs which defines the pairs of joint to be swapped
when the image is flipped horizontally."""
return None


@DATASET.register_module
class Freihand(CustomDataset):
""" Freihand dataset.

Parameters
----------
train: bool, default is True
If true, will set as training mode.
skip_empty: bool, default is False
Whether skip entire image if no valid label is found. Use `False` if this dataset is
for validation to avoid COCO metric error.
dpg: bool, default is False
If true, will activate `dpg` for data augmentation.
"""
CLASSES = ['hand']
EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
num_joints = 21

@property
def joint_pairs(self):
"""Joint pairs which defines the pairs of joint to be swapped
when the image is flipped horizontally."""
return [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16],[17, 18],[19, 20]]

def _load_jsons(self):
"""Load all image paths and labels from JSON annotation files into buffer."""
items = []
labels = []

_freihand = self._lazy_load_ann_file()
classes = [c['name'] for c in _freihand.loadCats(_freihand.getCatIds())]
assert classes == self.CLASSES, "Incompatible category names with Freihand. "

self.json_id_to_contiguous = {
v: k for k, v in enumerate(_freihand.getCatIds())}

# iterate through the annotations
image_ids = sorted(_freihand.getImgIds())
for entry in _freihand.loadImgs(image_ids):
dirname, filename = entry['Freihand_url'].split('/')[-2:]
abs_path = os.path.join(self._root, dirname, filename)
if not os.path.exists(abs_path):
raise IOError('Image: {} not exists.'.format(abs_path))
label = self._check_load_keypoints(_freihand, entry)
if not label:
continue

# num of items are relative to person, not image
for obj in label:
items.append(abs_path)
labels.append(obj)

return items, labels

def _check_load_keypoints(self, coco, entry):
"""Check and load ground-truth keypoints"""
ann_ids = coco.getAnnIds(imgIds=entry['id'], iscrowd=False)
objs = coco.loadAnns(ann_ids)
# check valid bboxes
valid_objs = []
width = entry['width']
height = entry['height']

for obj in objs:
contiguous_cid = self.json_id_to_contiguous[obj['category_id']]
if contiguous_cid >= self.num_class:
# not class of interest
continue
if max(obj['keypoints']) == 0:
continue
# convert from (x, y, w, h) to (xmin, ymin, xmax, ymax) and clip bound
xmin, ymin, xmax, ymax = bbox_clip_xyxy(bbox_xywh_to_xyxy(obj['bbox']), width, height)
# require non-zero box area
if obj['area'] <= 0 or xmax <= xmin or ymax <= ymin:
continue
if obj['num_keypoints'] == 0:
continue
# joints 3d: (num_joints, 3, 2); 3 is for x, y, z; 2 is for position, visibility
joints_3d = np.zeros((self.num_joints, 3, 2), dtype=np.float32)
for i in range(self.num_joints):
joints_3d[i, 0, 0] = obj['keypoints'][i * 3 + 0]
joints_3d[i, 1, 0] = obj['keypoints'][i * 3 + 1]
# joints_3d[i, 2, 0] = 0
visible = min(1, obj['keypoints'][i * 3 + 2])
joints_3d[i, :2, 1] = visible
# joints_3d[i, 2, 1] = 0

if np.sum(joints_3d[:, 0, 1]) < 1:
# no visible keypoint
continue

if self._check_centers and self._train:
bbox_center, bbox_area = self._get_box_center_area((xmin, ymin, xmax, ymax))
kp_center, num_vis = self._get_keypoints_center_count(joints_3d)
ks = np.exp(-2 * np.sum(np.square(bbox_center - kp_center)) / bbox_area)
if (num_vis / 80.0 + 47 / 80.0) > ks:
continue

valid_objs.append({
'bbox': (xmin, ymin, xmax, ymax),
'width': width,
'height': height,
'joints_3d': joints_3d
})

if not valid_objs:
if not self._skip_empty:
# dummy invalid labels if no valid objects are found
valid_objs.append({
'bbox': np.array([-1, -1, 0, 0]),
'width': width,
'height': height,
'joints_3d': np.zeros((self.num_joints, 2, 2), dtype=np.float32)
})
return valid_objs

def _get_box_center_area(self, bbox):
"""Get bbox center"""
c = np.array([(bbox[0] + bbox[2]) / 2.0, (bbox[1] + bbox[3]) / 2.0])
area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
return c, area

def _get_keypoints_center_count(self, keypoints):
"""Get geometric center of all keypoints"""
keypoint_x = np.sum(keypoints[:, 0, 0] * (keypoints[:, 0, 1] > 0))
keypoint_y = np.sum(keypoints[:, 1, 0] * (keypoints[:, 1, 1] > 0))
num = float(np.sum(keypoints[:, 0, 1]))
return np.array([keypoint_x / num, keypoint_y / num]), num
107 changes: 107 additions & 0 deletions alphapose/datasets/Freihand_det.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# Written by Jiefeng Li (jeff.lee.sjtu@gmail.com)
# -----------------------------------------------------

"""Freihand Hand Detection Box dataset."""
import json
import os

import scipy.misc
import torch
import torch.utils.data as data
from tqdm import tqdm

from alphapose.utils.presets import SimpleTransform
from detector.apis import get_detector
from alphapose.models.builder import DATASET


@DATASET.register_module
class Freihand_det(data.Dataset):
""" COCO human detection box dataset.

"""
EVAL_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

def __init__(self,
det_file=None,
opt=None,
**cfg):

self._cfg = cfg
self._opt = opt
self._preset_cfg = cfg['PRESET']
self._root = cfg['ROOT']
self._img_prefix = cfg['IMG_PREFIX']
if not det_file:
det_file = cfg['DET_FILE']
self._ann_file = os.path.join(self._root, cfg['ANN'])

if os.path.exists(det_file):
print("Detection results exist, will use it")
else:
print("Will create detection results to {}".format(det_file))
self.write_coco_json(det_file)

assert os.path.exists(det_file), "Error: no detection results found"
with open(det_file, 'r') as fid:
self._det_json = json.load(fid)

self._input_size = self._preset_cfg['IMAGE_SIZE']
self._output_size = self._preset_cfg['HEATMAP_SIZE']

self._sigma = self._preset_cfg['SIGMA']

if self._preset_cfg['TYPE'] == 'simple':
self.transformation = SimpleTransform(
self, scale_factor=0,
input_size=self._input_size,
output_size=self._output_size,
rot=0, sigma=self._sigma,
train=False, add_dpg=False)

def __getitem__(self, index):
det_res = self._det_json[index]
if not isinstance(det_res['image_id'], int):
img_id, _ = os.path.splitext(os.path.basename(det_res['image_id']))
img_id = int(img_id)
else:
img_id = det_res['image_id']
img_path = './data/Freihand/val/%08d.jpg' % img_id

# Load image
image = scipy.misc.imread(img_path, mode='RGB')

imght, imgwidth = image.shape[1], image.shape[2]
x1, y1, w, h = det_res['bbox']
bbox = [x1, y1, x1 + w, y1 + h]
inp, bbox = self.transformation.test_transform(image, bbox)
return inp, torch.Tensor(bbox), torch.Tensor([det_res['bbox']]), torch.Tensor([det_res['image_id']]), torch.Tensor([det_res['score']]), torch.Tensor([imght]), torch.Tensor([imgwidth])

def __len__(self):
return len(self._det_json)

def write_coco_json(self, det_file):
from pycocotools.coco import COCO
import pathlib

_coco = COCO(self._ann_file)
image_ids = sorted(_coco.getImgIds())
det_model = get_detector(self._opt)
dets = []
for entry in tqdm(_coco.loadImgs(image_ids)):
abs_path = os.path.join(
self._root, self._img_prefix, entry['file_name'])
det = det_model.detect_one_img(abs_path)
if det:
dets += det
pathlib.Path(os.path.split(det_file)[0]).mkdir(parents=True, exist_ok=True)
json.dump(dets, open(det_file, 'w'))

@property
def joint_pairs(self):
"""Joint pairs which defines the pairs of joint to be swapped
when the image is flipped horizontally."""
return [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16], [17, 18], [19, 20]]
2 changes: 1 addition & 1 deletion alphapose/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from .mscoco import Mscoco
from .mpii import Mpii

__all__ = ['CustomDataset', 'Mscoco', 'Mscoco_det', 'Mpii', 'ConcatDataset']
__all__ = ['CustomDataset', 'Mscoco', 'Mscoco_det', 'Mpii', 'ConcatDataset', 'Freihand']
Loading