Skip to content

Commit

Permalink
Restored previous version of ToTensor and added ToTensorV2
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Sep 3, 2019
1 parent 1983417 commit a502680
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 8 deletions.
76 changes: 73 additions & 3 deletions albumentations/pytorch/transforms.py
@@ -1,19 +1,89 @@
from __future__ import absolute_import

import warnings

import numpy as np
import torch
from torchvision.transforms import functional as F

from ..core.transforms_interface import BasicTransform


__all__ = ['ToTensor']
__all__ = ['ToTensor', 'ToTensorV2']


def img_to_tensor(im, normalize=None):
tensor = torch.from_numpy(np.moveaxis(im / (255. if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
if normalize is not None:
return F.normalize(tensor, **normalize)
return tensor


def mask_to_tensor(mask, num_classes, sigmoid):
# todo
if num_classes > 1:
if not sigmoid:
# softmax
long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
if len(mask.shape) == 3:
for c in range(mask.shape[2]):
long_mask[mask[..., c] > 0] = c
else:
long_mask[mask > 127] = 1
long_mask[mask == 0] = 0
mask = long_mask
else:
mask = np.moveaxis(mask / (255. if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
else:
mask = np.expand_dims(mask / (255. if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
return torch.from_numpy(mask)


class ToTensor(BasicTransform):
"""Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type.
WARNING! Please use this with care and look into sources before usage.
Args:
num_classes (int): only for segmentation
sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not.
normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize
"""

def __init__(self, num_classes=1, sigmoid=True, normalize=None):
super(ToTensor, self).__init__(always_apply=True, p=1.)
self.num_classes = num_classes
self.sigmoid = sigmoid
self.normalize = normalize
warnings.warn("ToTensor is deprecated and will be replaced by ToTensorV2 "
"in albumentations 0.4.0", DeprecationWarning)

def __call__(self, force_apply=True, **kwargs):
kwargs.update({'image': img_to_tensor(kwargs['image'], self.normalize)})
if 'mask' in kwargs.keys():
kwargs.update({'mask': mask_to_tensor(kwargs['mask'], self.num_classes, sigmoid=self.sigmoid)})

for k, v in kwargs.items():
if self._additional_targets.get(k) == 'image':
kwargs.update({k: img_to_tensor(kwargs[k], self.normalize)})
if self._additional_targets.get(k) == 'mask':
kwargs.update({k: mask_to_tensor(kwargs[k], self.num_classes, sigmoid=self.sigmoid)})
return kwargs

@property
def targets(self):
raise NotImplementedError

def get_transform_init_args_names(self):
return 'num_classes', 'sigmoid', 'normalize'


class ToTensorV2(BasicTransform):
"""Convert image and mask to `torch.Tensor`.
"""

def __init__(self):
super(ToTensor, self).__init__(always_apply=True)
super(ToTensorV2, self).__init__(always_apply=True)

@property
def targets(self):
Expand All @@ -29,7 +99,7 @@ def apply_to_mask(self, mask, **params):
return torch.from_numpy(mask)

def get_transform_init_args_names(self):
return {}
return []

def get_params_dependent_on_targets(self, params):
return {}
33 changes: 28 additions & 5 deletions tests/test_pytorch.py
@@ -1,22 +1,23 @@
import pytest
import numpy as np
import torch

import albumentations as A
from albumentations.pytorch.transforms import ToTensor
from albumentations.pytorch.transforms import ToTensor, ToTensorV2


def test_torch_to_tensor_augmentations(image, mask):
aug = ToTensor()
def test_torch_to_tensor_v2_augmentations(image, mask):
aug = ToTensorV2()
data = aug(image=image, mask=mask, force_apply=True)
assert isinstance(data['image'], torch.Tensor) and data['image'].shape == image.shape[::-1]
assert isinstance(data['mask'], torch.Tensor) and data['mask'].shape == mask.shape
assert data['image'].dtype == torch.uint8
assert data['mask'].dtype == torch.uint8


def test_additional_targets_for_totensor():
def test_additional_targets_for_totensorv2():
aug = A.Compose(
[ToTensor()], additional_targets={'image2': 'image', 'mask2': 'mask'})
[ToTensorV2()], additional_targets={'image2': 'image', 'mask2': 'mask'})
for i in range(10):
image1 = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
image2 = image1.copy()
Expand All @@ -29,3 +30,25 @@ def test_additional_targets_for_totensor():
assert isinstance(res['mask2'], torch.Tensor) and res['mask2'].shape == mask2.shape
assert np.array_equal(res['image'], res['image2'])
assert np.array_equal(res['mask'], res['mask2'])


def test_torch_to_tensor_augmentations(image, mask):
with pytest.warns(DeprecationWarning):
aug = ToTensor()
data = aug(image=image, mask=mask, force_apply=True)
assert data['image'].dtype == torch.float32
assert data['mask'].dtype == torch.float32


def test_additional_targets_for_totensor():
with pytest.warns(DeprecationWarning):
aug = A.Compose(
[ToTensor(num_classes=4)], additional_targets={'image2': 'image', 'mask2': 'mask'})
for i in range(10):
image1 = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
image2 = image1.copy()
mask1 = np.random.randint(low=0, high=256, size=(100, 100, 4), dtype=np.uint8)
mask2 = mask1.copy()
res = aug(image=image1, image2=image2, mask=mask1, mask2=mask2)
assert np.array_equal(res['image'], res['image2'])
assert np.array_equal(res['mask'], res['mask2'])

0 comments on commit a502680

Please sign in to comment.