diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index 869467c496..9030fa3ced 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -111,6 +111,11 @@ SmartCache handler .. autoclass:: SmartCacheHandler :members: +Parameter Scheduler handler +--------------------------- +.. autoclass:: ParamSchedulerHandler + :members: + EarlyStop handler ----------------- .. autoclass:: EarlyStopHandler diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 2112b074a0..f88531ea8e 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -21,6 +21,7 @@ from .mean_dice import MeanDice from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver +from .parameter_scheduler import ParamSchedulerHandler from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver from .smartcache_handler import SmartCacheHandler diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py new file mode 100644 index 0000000000..2aa0224a5a --- /dev/null +++ b/monai/handlers/parameter_scheduler.py @@ -0,0 +1,163 @@ +import logging +from bisect import bisect_right +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import Engine, Events +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") + + +class ParamSchedulerHandler: + """ + General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or + multistep function. One can also pass Callables to have customized scheduling logic. + + Args: + parameter_setter (Callable): Function that sets the required parameter + value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep') + or Callable for custom logic. + vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator. + epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False. + name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``. + event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED. + """ + + def __init__( + self, + parameter_setter: Callable, + value_calculator: Union[str, Callable], + vc_kwargs: Dict, + epoch_level: bool = False, + name: Optional[str] = None, + event=Events.ITERATION_COMPLETED, + ): + self.epoch_level = epoch_level + self.event = event + + self._calculators = { + "linear": self._linear, + "exponential": self._exponential, + "step": self._step, + "multistep": self._multistep, + } + + self._parameter_setter = parameter_setter + self._vc_kwargs = vc_kwargs + self._value_calculator = self._get_value_calculator(value_calculator=value_calculator) + + self.logger = logging.getLogger(name) + self._name = name + + def _get_value_calculator(self, value_calculator): + if isinstance(value_calculator, str): + return self._calculators[value_calculator] + if callable(value_calculator): + return value_calculator + raise ValueError( + f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." + ) + + def __call__(self, engine: Engine): + if self.epoch_level: + self._vc_kwargs["current_step"] = engine.state.epoch + else: + self._vc_kwargs["current_step"] = engine.state.iteration + + new_value = self._value_calculator(**self._vc_kwargs) + self._parameter_setter(new_value) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine that is used for training. + """ + if self._name is None: + self.logger = engine.logger + engine.add_event_handler(self.event, self) + + @staticmethod + def _linear( + initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int + ) -> float: + """ + Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an + additional step_one steps passed. Continues the trend until it reaches max_value. + + Args: + initial_value (float): Starting value of the parameter. + step_constant (int): Step index until parameter's value is kept constant. + step_max_value (int): Step index at which parameter's value becomes max_value. + max_value (float): Max parameter value. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + if current_step <= step_constant: + delta = 0.0 + elif current_step > step_max_value: + delta = max_value - initial_value + else: + delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant) + + return initial_value + delta + + @staticmethod + def _exponential(initial_value: float, gamma: float, current_step: int) -> float: + """ + Decays the parameter value by gamma every step. + + Based on the closed form of ExponentialLR from Pytorch + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + return initial_value * gamma ** current_step + + @staticmethod + def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: + """ + Decays the parameter value by gamma every step_size. + + Based on StepLR from Pytorch. + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + step_size (int): Period of parameter value decay. + current_step (int): Current step index. + + Returns + float: new parameter value + """ + return initial_value * gamma ** (current_step // step_size) + + @staticmethod + def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float: + """ + Decays the parameter value by gamma once the number of steps reaches one of the milestones. + + Based on MultiStepLR from Pytorch. + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + milestones (List[int]): List of step indices. Must be increasing. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + return initial_value * gamma ** bisect_right(milestones, current_step) diff --git a/tests/min_tests.py b/tests/min_tests.py index 9b96e7eaab..98f6d822a7 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -47,6 +47,7 @@ def run_testsuit(): "test_handler_prob_map_producer", "test_handler_rocauc", "test_handler_rocauc_dist", + "test_handler_parameter_scheduler", "test_handler_segmentation_saver", "test_handler_smartcache", "test_handler_stats", diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py new file mode 100644 index 0000000000..5b3e845ace --- /dev/null +++ b/tests/test_handler_parameter_scheduler.py @@ -0,0 +1,123 @@ +import unittest + +import torch +from ignite.engine import Engine, Events +from torch.nn import Module + +from monai.handlers.parameter_scheduler import ParamSchedulerHandler + + +class ToyNet(Module): + def __init__(self, value): + super(ToyNet, self).__init__() + self.value = value + + def forward(self, input): + return input + + def get_value(self): + return self.value + + def set_value(self, value): + self.value = value + + +class TestHandlerParameterScheduler(unittest.TestCase): + def test_linear_scheduler(self): + # Testing step_constant + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 0) + + # Testing linear increase + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=3) + torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0) + + # Testing max_value + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10) + + def test_exponential_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="exponential", + vc_kwargs={"initial_value": 10, "gamma": 0.99}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_step_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="step", + vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_multistep_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="multistep", + vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_custom_scheduler(self): + def custom_logic(initial_value, gamma, current_step): + return initial_value * gamma ** (current_step % 9) + + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator=custom_logic, + vc_kwargs={"initial_value": 10, "gamma": 0.99}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + +if __name__ == "__main__": + unittest.main()