This repository has been archived by the owner on Jun 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 604
/
JointsDataset.py
executable file
·222 lines (176 loc) · 7.66 KB
/
JointsDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import logging
import random
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from utils.transforms import get_affine_transform
from utils.transforms import affine_transform
from utils.transforms import fliplr_joints
logger = logging.getLogger(__name__)
class JointsDataset(Dataset):
def __init__(self, cfg, root, image_set, is_train, transform=None):
self.num_joints = 0
self.pixel_std = 200
self.flip_pairs = []
self.parent_ids = []
self.is_train = is_train
self.root = root
self.image_set = image_set
self.output_path = cfg.OUTPUT_DIR
self.data_format = cfg.DATASET.DATA_FORMAT
self.scale_factor = cfg.DATASET.SCALE_FACTOR
self.rotation_factor = cfg.DATASET.ROT_FACTOR
self.flip = cfg.DATASET.FLIP
self.image_size = cfg.MODEL.IMAGE_SIZE
self.target_type = cfg.MODEL.EXTRA.TARGET_TYPE
self.heatmap_size = cfg.MODEL.EXTRA.HEATMAP_SIZE
self.sigma = cfg.MODEL.EXTRA.SIGMA
self.transform = transform
self.db = []
def _get_db(self):
raise NotImplementedError
def evaluate(self, cfg, preds, output_dir, *args, **kwargs):
raise NotImplementedError
def __len__(self,):
return len(self.db)
def __getitem__(self, idx):
db_rec = copy.deepcopy(self.db[idx])
image_file = db_rec['image']
filename = db_rec['filename'] if 'filename' in db_rec else ''
imgnum = db_rec['imgnum'] if 'imgnum' in db_rec else ''
if self.data_format == 'zip':
from utils import zipreader
data_numpy = zipreader.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
else:
data_numpy = cv2.imread(
image_file, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
if data_numpy is None:
logger.error('=> fail to read {}'.format(image_file))
raise ValueError('Fail to read {}'.format(image_file))
joints = db_rec['joints_3d']
joints_vis = db_rec['joints_3d_vis']
c = db_rec['center']
s = db_rec['scale']
score = db_rec['score'] if 'score' in db_rec else 1
r = 0
if self.is_train:
sf = self.scale_factor
rf = self.rotation_factor
s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
r = np.clip(np.random.randn()*rf, -rf*2, rf*2) \
if random.random() <= 0.6 else 0
if self.flip and random.random() <= 0.5:
data_numpy = data_numpy[:, ::-1, :]
joints, joints_vis = fliplr_joints(
joints, joints_vis, data_numpy.shape[1], self.flip_pairs)
c[0] = data_numpy.shape[1] - c[0] - 1
trans = get_affine_transform(c, s, r, self.image_size)
input = cv2.warpAffine(
data_numpy,
trans,
(int(self.image_size[0]), int(self.image_size[1])),
flags=cv2.INTER_LINEAR)
if self.transform:
input = self.transform(input)
for i in range(self.num_joints):
if joints_vis[i, 0] > 0.0:
joints[i, 0:2] = affine_transform(joints[i, 0:2], trans)
target, target_weight = self.generate_target(joints, joints_vis)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
meta = {
'image': image_file,
'filename': filename,
'imgnum': imgnum,
'joints': joints,
'joints_vis': joints_vis,
'center': c,
'scale': s,
'rotation': r,
'score': score
}
return input, target, target_weight, meta
def select_data(self, db):
db_selected = []
for rec in db:
num_vis = 0
joints_x = 0.0
joints_y = 0.0
for joint, joint_vis in zip(
rec['joints_3d'], rec['joints_3d_vis']):
if joint_vis[0] <= 0:
continue
num_vis += 1
joints_x += joint[0]
joints_y += joint[1]
if num_vis == 0:
continue
joints_x, joints_y = joints_x / num_vis, joints_y / num_vis
area = rec['scale'][0] * rec['scale'][1] * (self.pixel_std**2)
joints_center = np.array([joints_x, joints_y])
bbox_center = np.array(rec['center'])
diff_norm2 = np.linalg.norm((joints_center-bbox_center), 2)
ks = np.exp(-1.0*(diff_norm2**2) / ((0.2)**2*2.0*area))
metric = (0.2 / 16) * num_vis + 0.45 - 0.2 / 16
if ks > metric:
db_selected.append(rec)
logger.info('=> num db: {}'.format(len(db)))
logger.info('=> num selected db: {}'.format(len(db_selected)))
return db_selected
def generate_target(self, joints, joints_vis):
'''
:param joints: [num_joints, 3]
:param joints_vis: [num_joints, 3]
:return: target, target_weight(1: visible, 0: invisible)
'''
target_weight = np.ones((self.num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
assert self.target_type == 'gaussian', \
'Only support gaussian map now!'
if self.target_type == 'gaussian':
target = np.zeros((self.num_joints,
self.heatmap_size[1],
self.heatmap_size[0]),
dtype=np.float32)
tmp_size = self.sigma * 3
for joint_id in range(self.num_joints):
feat_stride = self.image_size / self.heatmap_size
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= self.heatmap_size[0] or ul[1] >= self.heatmap_size[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# # Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * self.sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], self.heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], self.heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], self.heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], self.heatmap_size[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight