Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
- Updated code and tests to support batch of tensors, e.g. input of shape (B, C, H, W).
  • Loading branch information
vfdev-5 committed Dec 4, 2020
1 parent 77aff2a commit 7b90935
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 64 deletions.
61 changes: 0 additions & 61 deletions test/test_transforms.py
Expand Up @@ -1624,67 +1624,6 @@ def test_random_grayscale(self):
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()

def test_random_erasing(self):
"""Unit tests for random erasing transform"""
for is_scripted in [False, True]:
torch.manual_seed(12)
img = torch.rand(3, 60, 60)

# Test Set 0: invalid value
random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
img_re = random_erasing(img)

# Test Set 1: Erasing with int value
random_erasing = transforms.RandomErasing(value=0.2)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

i, j, h, w, v = transforms.RandomErasing.get_params(
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)

# Test Set 2: Check if the unerased region is preserved
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = random_erasing.value
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 3: Erasing with random value
random_erasing = transforms.RandomErasing(value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)

self.assertEqual(img_re.size(0), 3)

# Test Set 4: Erasing with tuple value
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 5: Testing the inplace behaviour
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))

# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))


if __name__ == '__main__':
unittest.main()
24 changes: 24 additions & 0 deletions test/test_transforms_tensor.py
Expand Up @@ -464,6 +464,30 @@ def test_compose(self):
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)

def test_random_erasing(self):
img = torch.rand(3, 60, 60)

# Test Set 0: invalid value
random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img)

tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

test_configs = [
{"value": 0.2},
{"value": "random"},
{"value": (0.2, 0.2, 0.2)},
{"value": "random", "ratio": (0.1, 0.2)},
]

for config in test_configs:
fn = T.RandomErasing(**config)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Expand Up @@ -1024,5 +1024,5 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
if not inplace:
img = img.clone()

img[:, i:i + h, j:j + w] = v
img[..., i:i + h, j:j + w] = v
return img
4 changes: 2 additions & 2 deletions torchvision/transforms/transforms.py
Expand Up @@ -1375,7 +1375,7 @@ def __repr__(self):

class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
Args:
p: probability that the random erasing operation will be performed.
Expand Down Expand Up @@ -1439,7 +1439,7 @@ def get_params(
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_c, img_h, img_w = img.shape
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
area = img_h * img_w

for _ in range(10):
Expand Down

0 comments on commit 7b90935

Please sign in to comment.