Skip to content

Commit

Permalink
Cleaned up loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexEMG committed Nov 19, 2019
1 parent b225322 commit 4cc161e
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 71 deletions.
1 change: 1 addition & 0 deletions deeplabcut/pose_cfg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mirror: False

#Data loaders, i.e. with additional data augmentation options (as of 2.0.9+):
dataset_type: default
batch_size: 1
#default with be with no extra dataloaders. Other options: 'tensorpack, deterministic'
#types of datasets, see factory: deeplabcut/pose_estimation_tensorflow/dataset/factory.py
#For deterministic, see https://github.com/AlexEMG/DeepLabCut/pull/324
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ class PoseDataset:
def __init__(self, cfg):
self.cfg = cfg
self.data = self.load_dataset()
self.batch_size = cfg.get('batch_size',1)
self.num_images = len(self.data)
self.batch_size = cfg.batch_size
self.max_input_sizesquare=cfg.get('max_input_size', 1500)**2
self.min_input_sizesquare=cfg.get('min_input_size', 64)**2
self.locref_scale = 1.0 / cfg.locref_stdev
self.stride = cfg.stride
self.half_stride = cfg.stride / 2
self.scale = cfg.global_scale
self.scale_jitter_lo=cfg.get('scale_jitter_lo',.75)
self.scale_jitter_up=cfg.get('scale_jitter_up',1.25)
print("Batch Size is %d" % self.batch_size)

def load_dataset(self):
Expand Down Expand Up @@ -75,15 +83,10 @@ def load_dataset(self):
else:
print("Loading pickle data with float coordinates!")
file_name = cfg.dataset.split(".")[0] + ".pickle"
# Load Matlab file dataset annotation
#mlab = sio.loadmat(file_name)
#mlab = sio.loadmat(os.path.join(self.cfg.project_path,file_name))
with open(os.path.join(self.cfg.project_path,file_name), 'rb') as f:
# Pickle the 'data' dictionary using the highest protocol available.
pickledata=pickle.load(f)

self.raw_data = pickledata
#mlab = mlab['dataset']
num_images = len(pickledata) #mlab.shape[1]
data = []
has_gt = True
Expand Down Expand Up @@ -168,7 +171,7 @@ def get_batch(self):
scale = self.get_scale()
size = self.data[idx].im_size
target_size = np.ceil(size[1:3]*scale).astype(int)
if self.is_valid_size(target_size):
if self.is_valid_size(target_size[1] * target_size[0]):
break

stride = self.cfg.stride
Expand Down Expand Up @@ -260,24 +263,18 @@ def get_scale(self):
scale *= scale_jitter
return scale

def is_valid_size(self, target_size):
im_width = target_size[1]
im_height = target_size[0]
if hasattr(self.cfg, 'min_input_size'):
min_input_size = self.cfg.min_input_size
if im_height < min_input_size or im_width < min_input_size:
return False
if hasattr(self.cfg, 'max_input_size'):
max_input_size = self.cfg.max_input_size
if im_width * im_height > max_input_size * max_input_size:
return False
def is_valid_size(self, target_size_product):
if target_size_product > self.max_input_sizesquare:
return False

if target_size_product < self.min_input_sizesquare:
return False

return True

def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
stride = self.cfg.stride
#dist_thresh = float(self.cfg.pos_dist_thresh * scale)
num_joints = self.cfg.num_joints
half_stride = stride / 2
scmap = np.zeros(cat([size, arr([num_joints])]))
locref_size = cat([size, arr([num_joints * 2])])
locref_mask = np.zeros(locref_size)
Expand All @@ -286,20 +283,19 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
width = size[1]
height = size[0]
dist_thresh = float((width+height)/6)
locref_scale = 1.0 / self.cfg.locref_stdev
dist_thresh_sq = dist_thresh ** 2

std = dist_thresh/4
# Grid of coordinates
grid = np.mgrid[:height, :width].transpose((1,2,0))
grid = grid*stride + half_stride
grid = grid*self.stride + self.half_stride
for person_id in range(len(coords)):
for k, j_id in enumerate(joint_id[person_id]):
joint_pt = coords[person_id][k, :]
j_x = np.asscalar(joint_pt[0])
j_x_sm = round((j_x - half_stride) / stride)
j_x_sm = round((j_x - self.half_stride) / self.stride)
j_y = np.asscalar(joint_pt[1])
j_y_sm = round((j_y - half_stride) / stride)
j_y_sm = round((j_y - self.half_stride) / self.stride)
map_j = grid.copy()
# Distance between the joint point and each coordinate
dist = np.linalg.norm(grid - (j_y, j_x), axis=2)**2
Expand All @@ -309,14 +305,13 @@ def gaussian_scmap(self, joint_id, coords, data_item, size, scale):
locref_mask[dist<=dist_thresh_sq,j_id * 2 + 1]=1
dx = j_x - grid.copy()[:, :, 1 ]
dy = j_y - grid.copy()[:, :, 0 ]
locref_map[..., j_id * 2 + 0] = dx * locref_scale
locref_map[..., j_id * 2 + 1] = dy * locref_scale
locref_map[..., j_id * 2 + 0] = dx * self.locref_scale
locref_map[..., j_id * 2 + 1] = dy * self.locref_scale
weights = self.compute_scmap_weights(scmap.shape, joint_id, data_item)
return scmap, weights, locref_map, locref_mask

def compute_scmap_weights(self, scmap_shape, joint_id, data_item):
cfg = self.cfg
if cfg.weigh_only_present_joints:
if self.cfg.weigh_only_present_joints:
weights = np.zeros(scmap_shape)
for person_joint_id in joint_id:
for j_id in person_joint_id:
Expand All @@ -326,18 +321,15 @@ def compute_scmap_weights(self, scmap_shape, joint_id, data_item):
return weights

def compute_target_part_scoremap_numpy(self, joint_id, coords, data_item, size, scale):
stride = self.cfg.stride
dist_thresh = float(self.cfg.pos_dist_thresh * scale)
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg.num_joints
half_stride = stride / 2

scmap = np.zeros(cat([size, arr([num_joints])]))
locref_size = cat([size, arr([num_joints * 2])])
locref_mask = np.zeros(locref_size)
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg.locref_stdev
dist_thresh_sq = dist_thresh ** 2

width = size[1]
height = size[0]
grid = np.mgrid[:height, :width].transpose((1,2,0))
Expand All @@ -346,17 +338,17 @@ def compute_target_part_scoremap_numpy(self, joint_id, coords, data_item, size,
for k, j_id in enumerate(joint_id[person_id]):
joint_pt = coords[person_id][k, :]
j_x = np.asscalar(joint_pt[0])
j_x_sm = round((j_x - half_stride) / stride)
j_x_sm = round((j_x - self.half_stride) / self.stride)
j_y = np.asscalar(joint_pt[1])
j_y_sm = round((j_y - half_stride) / stride)
j_y_sm = round((j_y - self.half_stride) / self.stride)
min_x = round(max(j_x_sm - dist_thresh - 1, 0))
max_x = round(min(j_x_sm + dist_thresh + 1, width - 1))
min_y = round(max(j_y_sm - dist_thresh - 1, 0))
max_y = round(min(j_y_sm + dist_thresh + 1, height - 1))
x = grid.copy()[:, :, 1]
y = grid.copy()[:, :, 0]
dx = j_x - x*stride - half_stride
dy = j_y - y*stride - half_stride
dx = j_x - x*self.stride - self.half_stride
dy = j_y - y*self.stride - self.half_stride
dist = dx**2 + dy**2
mask1 = (dist <= dist_thresh_sq)
mask2 = ((x >= min_x) & (x <= max_x))
Expand All @@ -365,8 +357,8 @@ def compute_target_part_scoremap_numpy(self, joint_id, coords, data_item, size,
scmap[mask, j_id] = 1
locref_mask[mask, j_id*2+0] = 1
locref_mask[mask, j_id*2+1] = 1
locref_map[mask, j_id * 2 + 0] = (dx * locref_scale)[mask]
locref_map[mask, j_id * 2 + 1] = (dy * locref_scale)[mask]
locref_map[mask, j_id * 2 + 0] = (dx * self.locref_scale)[mask]
locref_map[mask, j_id * 2 + 1] = (dy * self.locref_scale)[mask]

weights = self.compute_scmap_weights(scmap.shape, joint_id, data_item)
return scmap, weights, locref_map, locref_mask
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ class PoseDataset:
def __init__(self, cfg):
self.cfg = cfg
self.data = self.load_dataset()

self.num_images = len(self.data)
self.max_input_sizesquare=cfg.get('max_input_size', 1500)**2
self.min_input_sizesquare=cfg.get('min_input_size', 64)**2
self.locref_scale = 1.0 / cfg.locref_stdev
self.stride = cfg.stride
self.half_stride = cfg.stride / 2
self.scale = cfg.global_scale
self.scale_jitter_lo=cfg.get('scale_jitter_lo',.75)
self.scale_jitter_up=cfg.get('scale_jitter_up',1.25)

if self.cfg.mirror:
self.symmetric_joints = mirror_joints_map(cfg.all_joints, cfg.num_joints)
self.curr_img = 0
Expand Down Expand Up @@ -120,11 +130,7 @@ def get_training_sample(self, imidx):
return self.data[imidx]

def get_scale(self):
cfg = self.cfg
scale = cfg.global_scale
if hasattr(cfg, 'scale_jitter_lo') and hasattr(cfg, 'scale_jitter_up'):
scale_jitter = rand.uniform(cfg.scale_jitter_lo, cfg.scale_jitter_up)
scale *= scale_jitter
scale = rand.uniform(self.scale_jitter_lo, self.scale_jitter_up)*self.scale
return scale

def next_batch(self):
Expand All @@ -141,18 +147,12 @@ def next_batch(self):
def is_valid_size(self, image_size, scale):
im_width = image_size[2]
im_height = image_size[1]

max_input_size = 100
if im_height < max_input_size or im_width < max_input_size:
input_width = im_width * scale
input_height = im_height * scale
if input_height * input_width > self.max_input_sizesquare:
return False
if input_height * input_width < self.min_input_sizesquare:
return False

if hasattr(self.cfg, 'max_input_size'):
max_input_size = self.cfg.max_input_size
input_width = im_width * scale
input_height = im_height * scale
if input_height * input_width > max_input_size * max_input_size:
return False

return True

def make_batch(self, data_item, scale, mirror):
Expand Down Expand Up @@ -218,18 +218,14 @@ def make_batch(self, data_item, scale, mirror):
return batch

def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale):
stride = self.cfg.stride
dist_thresh = self.cfg.pos_dist_thresh * scale
dist_thresh_sq = dist_thresh ** 2
num_joints = self.cfg.num_joints
half_stride = stride / 2

scmap = np.zeros(cat([size, arr([num_joints])]))
locref_size = cat([size, arr([num_joints * 2])])
locref_mask = np.zeros(locref_size)
locref_map = np.zeros(locref_size)

locref_scale = 1.0 / self.cfg.locref_stdev
dist_thresh_sq = dist_thresh ** 2

width = size[1]
height = size[0]

Expand All @@ -240,20 +236,20 @@ def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale)
j_y = np.asscalar(joint_pt[1])

# don't loop over entire heatmap, but just relevant locations
j_x_sm = round((j_x - half_stride) / stride)
j_y_sm = round((j_y - half_stride) / stride)
j_x_sm = round((j_x - self.half_stride) / self.stride)
j_y_sm = round((j_y - self.half_stride) / self.stride)
min_x = round(max(j_x_sm - dist_thresh - 1, 0))
max_x = round(min(j_x_sm + dist_thresh + 1, width - 1))
min_y = round(max(j_y_sm - dist_thresh - 1, 0))
max_y = round(min(j_y_sm + dist_thresh + 1, height - 1))

for j in range(min_y, max_y + 1): # range(height):
pt_y = j * stride + half_stride
pt_y = j * self.stride + self.half_stride
for i in range(min_x, max_x + 1): # range(width):
# pt = arr([i*stride+half_stride, j*stride+half_stride])
# diff = joint_pt - pt
# The code above is too slow in python
pt_x = i * stride + half_stride
pt_x = i * self.stride + self.half_stride
dx = j_x - pt_x
dy = j_y - pt_y
dist = dx ** 2 + dy ** 2
Expand All @@ -262,16 +258,15 @@ def compute_target_part_scoremap(self, joint_id, coords, data_item, size, scale)
scmap[j, i, j_id] = 1
locref_mask[j, i, j_id * 2 + 0] = 1
locref_mask[j, i, j_id * 2 + 1] = 1
locref_map[j, i, j_id * 2 + 0] = dx * locref_scale
locref_map[j, i, j_id * 2 + 1] = dy * locref_scale
locref_map[j, i, j_id * 2 + 0] = dx * self.locref_scale
locref_map[j, i, j_id * 2 + 1] = dy * self.locref_scale

weights = self.compute_scmap_weights(scmap.shape, joint_id, data_item)

return scmap, weights, locref_map, locref_mask

def compute_scmap_weights(self, scmap_shape, joint_id, data_item):
cfg = self.cfg
if cfg.weigh_only_present_joints:
if self.cfg.weigh_only_present_joints:
weights = np.zeros(scmap_shape)
for person_joint_id in joint_id:
for j_id in person_joint_id:
Expand Down
6 changes: 4 additions & 2 deletions examples/testscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
dfolder=None
net_type='resnet_50' #'mobilenet_v2_0.35' #'resnet_50'
augmenter_type='default' #'tensorpack'
augmenter_type2='imgaug'
augmenter_type3='tensorpack'
numiter=5

print("CREATING PROJECT")
Expand Down Expand Up @@ -137,7 +139,7 @@ def make_frame(t):
deeplabcut.merge_datasets(path_config_file)

print("CREATING TRAININGSET")
deeplabcut.create_training_dataset(path_config_file,net_type=net_type)
deeplabcut.create_training_dataset(path_config_file,net_type=net_type,augmenter_type=augmenter_type2)

cfg=deeplabcut.auxiliaryfunctions.read_config(path_config_file)
posefile=os.path.join(cfg['project_path'],'dlc-models/iteration-'+str(cfg['iteration'])+'/'+ cfg['Task'] + cfg['date'] + '-trainset' + str(int(cfg['TrainingFraction'][0] * 100)) + 'shuffle' + str(1),'train/pose_cfg.yaml')
Expand Down Expand Up @@ -183,7 +185,7 @@ def make_frame(t):

print("CREATING TRAININGSET for shuffle 2")
print("will be used for 3D testscript...")
deeplabcut.create_training_dataset(path_config_file,Shuffles=[2],net_type=net_type)
deeplabcut.create_training_dataset(path_config_file,Shuffles=[2],net_type=net_type,augmenter_type=augmenter_type3)

posefile=os.path.join(cfg['project_path'],'dlc-models/iteration-'+str(cfg['iteration'])+'/'+ cfg['Task'] + cfg['date'] + '-trainset' + str(int(cfg['TrainingFraction'][0] * 100)) + 'shuffle' + str(2),'train/pose_cfg.yaml')

Expand Down

0 comments on commit 4cc161e

Please sign in to comment.