Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #23 from PhoenixDL/progressive_resizing_multiproc
Browse files Browse the repository at this point in the history
progressive resizing with multiprocessing
  • Loading branch information
mibaumgartner committed Dec 24, 2019
2 parents 1411e19 + 90dfa93 commit 6b90197
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
52 changes: 46 additions & 6 deletions rising/transforms/spatial.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import torch
import random
from torch.multiprocessing import Value
from .abstract import RandomDimsTransform, AbstractTransform, BaseTransform, RandomProcess
from typing import Union, Sequence, Callable
from itertools import permutations
Expand Down Expand Up @@ -199,8 +202,6 @@ def forward(self, **data) -> dict:


class ProgressiveResize(Resize):
step = 0

def __init__(self, scheduler: schduler_type, mode: str = 'nearest',
align_corners: bool = None, preserve_range: bool = False,
keys: Sequence = ('data',), grad: bool = False, **kwargs):
Expand All @@ -226,18 +227,57 @@ def __init__(self, scheduler: schduler_type, mode: str = 'nearest',
enable gradient computation inside transformation
kwargs:
keyword arguments passed to augment_fn
Warnings
--------
When this transformations is used in combination with multiprocessing
the step counter is not perfectly synchronized between multiple
processes. As a result the step count my jump between values
in a range of the number of processes used.
"""
super().__init__(size=0, mode=mode, align_corners=align_corners,
preserve_range=preserve_range,
keys=keys, grad=grad, **kwargs)
self.scheduler = scheduler
self._step = Value('i', 0)

def reset_step(self):
def reset_step(self) -> ProgressiveResize:
"""
Reset step to 0
Returns
-------
ProgressiveResize
returns self to allow chaining
"""
with self._step.get_lock():
self._step.value = 0
return self

def increment(self) -> ProgressiveResize:
"""
Increment step by 1
Returns
-------
ProgressiveResize
returns self to allow chaining
"""
with self._step.get_lock():
self._step.value += 1
return self

@property
def step(self) -> int:
"""
Current step
Returns
-------
int
number of steps
"""
self.step = 0
type(self).step = 0
return self._step.value

def forward(self, **data) -> dict:
"""
Expand All @@ -254,7 +294,7 @@ def forward(self, **data) -> dict:
augmented batch
"""
self.kwargs["size"] = self.scheduler(self.step)
self.step += 1
self.increment()
return super().forward(**data)


Expand Down
15 changes: 15 additions & 0 deletions tests/transforms/test_spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tests.transforms import chech_data_preservation
from rising.transforms.spatial import *
from rising.transforms.functional.spatial import resize
from rising.loading import DataLoader


class TestSpatialTransforms(unittest.TestCase):
Expand Down Expand Up @@ -83,6 +84,20 @@ def test_size_step_scheduler_error(self):
with self.assertRaises(TypeError):
scheduler = SizeStepScheduler([10, 20], [32, 64])

def test_progressive_resize_integration(self):
sizes = [1, 3, 6]
scheduler = SizeStepScheduler([1, 2], [1, 3, 6])
trafo = ProgressiveResize(scheduler)

dset = [self.batch_dict] * 10
loader = DataLoader(dset, num_workers=4, batch_transforms=trafo)

data_shape = [tuple(i["data"].shape) for i in loader]

self.assertIn((1, 1, 1, 1, 1), data_shape)
# self.assertIn((1, 1, 3, 3, 3), data_shape)
self.assertIn((1, 1, 6, 6, 6), data_shape)


if __name__ == '__main__':
unittest.main()

0 comments on commit 6b90197

Please sign in to comment.