diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 63b6399b1fe..fbda5932735 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -14,6 +14,26 @@ All transformations accept PIL Image, Tensor Image or batch of Tensor Images as Tensor Images is a tensor of ``(B, C, H, W)`` shape, where ``B`` is a number of images in the batch. Deterministic or random transformations applied on the batch of Tensor Images identically transform all the images of the batch. + +Scriptable transforms +^^^^^^^^^^^^^^^^^^^^^ + +In order to script the transformations, please use ``torch.nn.Sequential`` instead of :class:`Compose`. + +.. code:: python + + transforms = torch.nn.Sequential( + transforms.CenterCrop(10), + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ) + scripted_transforms = torch.jit.script(transforms) + +Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor`` and does not require +`lambda` functions or ``PIL.Image``. + +For any custom transformations to be used with ``torch.jit.script``, they should be derived from ``torch.nn.Module``. + + .. autoclass:: Compose Transforms on PIL Image diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 182c70712fe..74e128d9b1c 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -376,6 +376,63 @@ def test_to_grayscale(self): "RandomGrayscale", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max" ) + def test_normalize(self): + tensor, _ = self._create_data(26, 34, device=self.device) + batch_tensors = torch.rand(4, 3, 44, 56, device=self.device) + + tensor = tensor.to(dtype=torch.float32) / 255.0 + # test for class interface + fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + scripted_fn = torch.jit.script(fn) + + self._test_transform_vs_scripted(fn, scripted_fn, tensor) + self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors) + + def test_linear_transformation(self): + c, h, w = 3, 24, 32 + + tensor, _ = self._create_data(h, w, channels=c, device=self.device) + + matrix = torch.rand(c * h * w, c * h * w, device=self.device) + mean_vector = torch.rand(c * h * w, device=self.device) + + fn = T.LinearTransformation(matrix, mean_vector) + scripted_fn = torch.jit.script(fn) + + self._test_transform_vs_scripted(fn, scripted_fn, tensor) + + batch_tensors = torch.rand(4, c, h, w, device=self.device) + # We skip some tests from _test_transform_vs_scripted_on_batch as + # results for scripted and non-scripted transformations are not exactly the same + torch.manual_seed(12) + transformed_batch = fn(batch_tensors) + torch.manual_seed(12) + s_transformed_batch = scripted_fn(batch_tensors) + self.assertTrue(transformed_batch.equal(s_transformed_batch)) + + def test_compose(self): + tensor, _ = self._create_data(26, 34, device=self.device) + tensor = tensor.to(dtype=torch.float32) / 255.0 + + transforms = T.Compose([ + T.CenterCrop(10), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ]) + s_transforms = torch.nn.Sequential(*transforms.transforms) + + scripted_fn = torch.jit.script(s_transforms) + torch.manual_seed(12) + transformed_tensor = transforms(tensor) + torch.manual_seed(12) + transformed_tensor_script = scripted_fn(tensor) + self.assertTrue(transformed_tensor.equal(transformed_tensor_script), msg="{}".format(transforms)) + + t = T.Compose([ + lambda x: x, + ]) + with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"): + torch.jit.script(t) + @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") class CUDATester(Tester): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index a29b7399cde..b10bad7103f 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -283,7 +283,7 @@ def to_pil_image(pic, mode=None): return Image.fromarray(npimg, mode=mode) -def normalize(tensor, mean, std, inplace=False): +def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: """Normalize a tensor image with mean and standard deviation. .. note:: @@ -292,7 +292,7 @@ def normalize(tensor, mean, std, inplace=False): See :class:`~torchvision.transforms.Normalize` for more details. Args: - tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized. mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. inplace(bool,optional): Bool to make this operation inplace. @@ -300,11 +300,11 @@ def normalize(tensor, mean, std, inplace=False): Returns: Tensor: Normalized Tensor image. """ - if not torch.is_tensor(tensor): - raise TypeError('tensor should be a torch tensor. Got {}.'.format(type(tensor))) + if not isinstance(tensor, torch.Tensor): + raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) - if tensor.ndimension() != 3: - raise ValueError('Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = ' + if tensor.ndim < 3: + raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' '{}.'.format(tensor.size())) if not inplace: @@ -316,9 +316,9 @@ def normalize(tensor, mean, std, inplace=False): if (std == 0).any(): raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) if mean.ndim == 1: - mean = mean[:, None, None] + mean = mean.view(-1, 1, 1) if std.ndim == 1: - std = std[:, None, None] + std = std.view(-1, 1, 1) tensor.sub_(mean).div_(std) return tensor diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 1c182826887..ee892bb2856 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,7 +3,7 @@ import random import warnings from collections.abc import Sequence -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, Any import torch from PIL import Image @@ -33,7 +33,7 @@ } -class Compose(object): +class Compose: """Composes several transforms together. Args: @@ -44,6 +44,19 @@ class Compose(object): >>> transforms.CenterCrop(10), >>> transforms.ToTensor(), >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + """ def __init__(self, transforms): @@ -63,7 +76,7 @@ def __repr__(self): return format_string -class ToTensor(object): +class ToTensor: """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. Converts a PIL Image or numpy.ndarray (H x W x C) in the range @@ -94,7 +107,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class PILToTensor(object): +class PILToTensor: """Convert a ``PIL Image`` to a tensor of the same type. Converts a PIL Image (H x W x C) to a Tensor of shape (C x H x W). @@ -114,7 +127,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class ConvertImageDtype(object): +class ConvertImageDtype: """Convert a tensor image to the given ``dtype`` and scale the values accordingly Args: @@ -139,7 +152,7 @@ def __call__(self, image: torch.Tensor) -> torch.Tensor: return F.convert_image_dtype(image, self.dtype) -class ToPILImage(object): +class ToPILImage: """Convert a tensor or an ndarray to PIL Image. Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape @@ -178,7 +191,7 @@ def __repr__(self): return format_string -class Normalize(object): +class Normalize(torch.nn.Module): """Normalize a tensor image with mean and standard deviation. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will normalize each channel of the input @@ -196,11 +209,12 @@ class Normalize(object): """ def __init__(self, mean, std, inplace=False): + super().__init__() self.mean = mean self.std = std self.inplace = inplace - def __call__(self, tensor): + def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be normalized. @@ -358,7 +372,7 @@ def __repr__(self): format(self.padding, self.fill, self.padding_mode) -class Lambda(object): +class Lambda: """Apply a user-defined lambda as a transform. Args: @@ -366,7 +380,8 @@ class Lambda(object): """ def __init__(self, lambd): - assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" + if not callable(lambd): + raise TypeError("Argument lambd should be callable, got {}".format(repr(type(lambd).__name__))) self.lambd = lambd def __call__(self, img): @@ -376,7 +391,7 @@ def __repr__(self): return self.__class__.__name__ + '()' -class RandomTransforms(object): +class RandomTransforms: """Base class for a list of transformations with randomness Args: @@ -408,7 +423,7 @@ class RandomApply(RandomTransforms): """ def __init__(self, transforms, p=0.5): - super(RandomApply, self).__init__(transforms) + super().__init__(transforms) self.p = p def __call__(self, img): @@ -897,7 +912,7 @@ def __repr__(self): return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) -class LinearTransformation(object): +class LinearTransformation(torch.nn.Module): """Transform a tensor image with a square transformation matrix and a mean_vector computed offline. Given transformation_matrix and mean_vector, will flatten the torch.*Tensor and @@ -916,6 +931,7 @@ class LinearTransformation(object): """ def __init__(self, transformation_matrix, mean_vector): + super().__init__() if transformation_matrix.size(0) != transformation_matrix.size(1): raise ValueError("transformation_matrix should be square. Got " + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) @@ -925,10 +941,14 @@ def __init__(self, transformation_matrix, mean_vector): " as any one of the dimensions of the transformation_matrix [{}]" .format(tuple(transformation_matrix.size()))) + if transformation_matrix.device != mean_vector.device: + raise ValueError("Input tensors should be on the same device. Got {} and {}" + .format(transformation_matrix.device, mean_vector.device)) + self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector - def __call__(self, tensor): + def forward(self, tensor: Tensor) -> Tensor: """ Args: tensor (Tensor): Tensor image of size (C, H, W) to be whitened. @@ -936,13 +956,20 @@ def __call__(self, tensor): Returns: Tensor: Transformed image. """ - if tensor.size(0) * tensor.size(1) * tensor.size(2) != self.transformation_matrix.size(0): - raise ValueError("tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(*tensor.size()) + - "{}".format(self.transformation_matrix.size(0))) - flat_tensor = tensor.view(1, -1) - self.mean_vector + shape = tensor.shape + n = shape[-3] * shape[-2] * shape[-1] + if n != self.transformation_matrix.shape[0]: + raise ValueError("Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0])) + + if tensor.device.type != self.mean_vector.device.type: + raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + + flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) - tensor = transformed_tensor.view(tensor.size()) + tensor = transformed_tensor.view(shape) return tensor def __repr__(self):