Skip to content

Commit

Permalink
[Feature] Support crop sequence (open-mmlab#648)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckkelvinchan committed Dec 19, 2021
1 parent 1b2848b commit 3798730
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 4 deletions.
5 changes: 3 additions & 2 deletions mmedit/datasets/pipelines/__init__.py
Expand Up @@ -8,7 +8,8 @@
TemporalReverse, UnsharpMasking)
from .compose import Compose
from .crop import (Crop, CropAroundCenter, CropAroundFg, CropAroundUnknown,
CropLike, FixedCrop, ModCrop, PairedRandomCrop)
CropLike, CropSequence, FixedCrop, ModCrop,
PairedRandomCrop)
from .formating import (Collect, FormatTrimap, GetMaskedImage, ImageToTensor,
ToTensor)
from .generate_assistant import GenerateCoordinateAndCell, GenerateHeatmap
Expand Down Expand Up @@ -41,5 +42,5 @@
'CropLike', 'GenerateHeatmap', 'MATLABLikeResize', 'CopyValues',
'Quantize', 'RandomBlur', 'RandomJPEGCompression', 'RandomNoise',
'DegradationsWithShuffle', 'RandomResize', 'UnsharpMasking',
'RandomVideoCompression'
'RandomVideoCompression', 'CropSequence'
]
41 changes: 41 additions & 0 deletions mmedit/datasets/pipelines/crop.py
Expand Up @@ -85,6 +85,47 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class CropSequence(Crop):
"""Crop a sequence to specific size for training.
The main difference to 'Crop' is that the region to be cropped is the same
for every images in the sequence.
Args:
keys (Sequence[str]): The images to be cropped.
crop_size (Tuple[int]): Target spatial size (h, w).
random_crop (bool): If set to True, it will random crop
image. Otherwise, it will work as center crop.
"""

def _crop(self, data):
if not isinstance(data, list):
raise TypeError(f'Input must be a list, but got {type(data)}.')

# determine crop location. Must be the same for all images
data_h, data_w = data[0].shape[:2]
crop_h, crop_w = self.crop_size
crop_h = min(data_h, crop_h)
crop_w = min(data_w, crop_w)

if self.random_crop:
x_offset = np.random.randint(0, data_w - crop_w + 1)
y_offset = np.random.randint(0, data_h - crop_h + 1)
else:
x_offset = max(0, (data_w - crop_w)) // 2
y_offset = max(0, (data_h - crop_h)) // 2
crop_bbox = [x_offset, y_offset, crop_w, crop_h]

data_list = []
for item in data:
item = item[y_offset:y_offset + crop_h, x_offset:x_offset + crop_w,
...]
data_list.append(item)

return data_list, crop_bbox


@PIPELINES.register_module()
class FixedCrop:
"""Crop paired data (at a specific position) to specific size for training.
Expand Down
39 changes: 37 additions & 2 deletions tests/test_data/test_pipelines/test_crop.py
Expand Up @@ -5,8 +5,9 @@
import pytest

from mmedit.datasets.pipelines import (Crop, CropAroundCenter, CropAroundFg,
CropAroundUnknown, CropLike, FixedCrop,
ModCrop, PairedRandomCrop)
CropAroundUnknown, CropLike,
CropSequence, FixedCrop, ModCrop,
PairedRandomCrop)


class TestAugmentations:
Expand Down Expand Up @@ -76,6 +77,40 @@ def test_crop(self):
random_crop.__class__.__name__ +
"keys=['img'], crop_size=(512, 512), random_crop=True")

def test_crop_sequence(self):
# input must be a list
crop = CropSequence(['gt'], (4, 8), False)
results = {'gt': np.random.rand(16, 16, 1)}
with pytest.raises(TypeError):
crop(results)

# test center crop
results = {'gt': [np.random.rand(16, 16, 1)] * 5}
inputs = copy.deepcopy(results)
results = crop(results)
for i, output in enumerate(results['gt']):
assert np.array_equal(inputs['gt'][i][6:10, 4:12, :], output)

# test random crop
crop = CropSequence(['gt'], (4, 8), True)
results = {'gt': [np.random.rand(16, 16, 1)] * 5}
inputs = copy.deepcopy(results)
results = crop(results)
assert 0 <= results['gt_crop_bbox'][0] <= 9
assert 0 <= results['gt_crop_bbox'][1] <= 13
assert results['gt_crop_bbox'][2] == 8
assert results['gt_crop_bbox'][3] == 4

# test random crop for lager size than the original shape
crop = CropSequence(['gt'], (19, 31), True)
results = {'gt': [np.random.rand(16, 16, 1)] * 5}
inputs = copy.deepcopy(results)
results = crop(results)
assert np.array_equal(inputs['gt'], results['gt'])
assert str(crop) == (
crop.__class__.__name__ +
"keys=['gt'], crop_size=(19, 31), random_crop=True")

def test_fixed_crop(self):
with pytest.raises(TypeError):
FixedCrop(['img_a', 'img_b'], (0.23, 0.1))
Expand Down

0 comments on commit 3798730

Please sign in to comment.