Skip to content

Commit

Permalink
[Fix]: fix type change of labels in albumentations (open-mmlab#9074)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored and MambaWong committed Oct 21, 2022
1 parent bf92abe commit 560e109
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
10 changes: 10 additions & 0 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,16 @@ def _postprocess_results(
ori_masks: Optional[Union[BitmapMasks,
PolygonMasks]] = None) -> dict:
"""Post-processing Albu output."""
# albumentations may return np.array or list on different versions
if 'gt_bboxes_labels' in results and isinstance(
results['gt_bboxes_labels'], list):
results['gt_bboxes_labels'] = np.array(
results['gt_bboxes_labels'], dtype=np.int64)
if 'gt_ignore_flags' in results and isinstance(
results['gt_ignore_flags'], list):
results['gt_ignore_flags'] = np.array(
results['gt_ignore_flags'], dtype=np.bool)

if 'bboxes' in results:
if isinstance(results['bboxes'], list):
results['bboxes'] = np.array(
Expand Down
31 changes: 31 additions & 0 deletions tests/test_datasets/test_transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,37 @@ def test_transform(self):

self.assertEqual(results['img'].dtype, np.uint8)

# test bbox
albu_transform = dict(
type='Albu',
transforms=[dict(type='ChannelShuffle', p=1)],
bbox_params=dict(
type='BboxParams',
format='pascal_voc',
label_fields=['gt_bboxes_labels', 'gt_ignore_flags']),
keymap={
'img': 'image',
'gt_bboxes': 'bboxes'
})
albu_transform = TRANSFORMS.build(albu_transform)
results = {
'img':
np.random.random((224, 224, 3)),
'img_shape': (224, 224),
'gt_bboxes_labels':
np.array([1, 2, 3], dtype=np.int64),
'gt_bboxes':
np.array([[10, 10, 20, 20], [20, 20, 40, 40], [40, 40, 80, 80]],
dtype=np.float32),
'gt_ignore_flags':
np.array([0, 0, 1], dtype=bool),
}
results = albu_transform(results)
self.assertEqual(results['img'].dtype, np.float64)
self.assertEqual(results['gt_bboxes'].dtype, np.float32)
self.assertEqual(results['gt_ignore_flags'].dtype, np.bool)
self.assertEqual(results['gt_bboxes_labels'].dtype, np.int64)

@unittest.skipIf(albumentations is None, 'albumentations is not installed')
def test_repr(self):
albu_transform = dict(
Expand Down

0 comments on commit 560e109

Please sign in to comment.