From 0d0d7a46b0ce5ddcbf9e3479638b4255694d00a5 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 11 Dec 2021 00:53:14 +0800 Subject: [PATCH 1/3] [DLMED] add Iteration base class Signed-off-by: Nic Ma --- monai/apps/deepgrow/interaction.py | 8 +++++--- monai/engines/__init__.py | 1 + monai/engines/utils.py | 16 +++++++++++++++- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index 81e82c958d..c25886a188 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -13,13 +13,13 @@ import torch from monai.data import decollate_batch, list_data_collate -from monai.engines import SupervisedEvaluator, SupervisedTrainer +from monai.engines import Iteration, SupervisedEvaluator, SupervisedTrainer from monai.engines.utils import IterationEvents from monai.transforms import Compose from monai.utils.enums import CommonKeys -class Interaction: +class Interaction(Iteration): """ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. This implementation is based on: @@ -51,7 +51,9 @@ def __init__( self.train = train self.key_probability = key_probability - def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): + def __call__( # type: ignore + self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor] + ): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index f24bc0fc37..439e48df70 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -14,6 +14,7 @@ from .trainer import GanTrainer, SupervisedTrainer, Trainer from .utils import ( GanKeys, + Iteration, IterationEvents, PrepareBatch, PrepareBatchDefault, diff --git a/monai/engines/utils.py b/monai/engines/utils.py index f06a3bf255..482058974b 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -20,8 +20,9 @@ from monai.utils.enums import CommonKeys if TYPE_CHECKING: - from ignite.engine import EventEnum + from ignite.engine import Engine, EventEnum else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") __all__ = [ @@ -32,6 +33,7 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", + "Iteration", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -200,6 +202,18 @@ def _get_data(key: str): return image, label, tuple(args), kwargs +class Iteration(ABC): + """ + Base class of customized iteration in the trainer or evaluator workflows. + It takes ignite Engine and the data of current batch as input. + + """ + + @abstractmethod + def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + def default_make_latent( num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> torch.Tensor: From abb656fc242e66a4ba154a32fbf1a9ed4c0b3b12 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 13 Dec 2021 06:19:23 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/deepgrow/interaction.py | 9 ++++----- monai/engines/__init__.py | 1 - monai/engines/evaluator.py | 32 ++++++++++++++++++++---------- monai/engines/trainer.py | 24 ++++++++++++++-------- monai/engines/utils.py | 16 +-------------- monai/engines/workflow.py | 12 +++++++---- 6 files changed, 51 insertions(+), 43 deletions(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index c25886a188..bb349f02fb 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -13,15 +13,16 @@ import torch from monai.data import decollate_batch, list_data_collate -from monai.engines import Iteration, SupervisedEvaluator, SupervisedTrainer +from monai.engines import SupervisedEvaluator, SupervisedTrainer from monai.engines.utils import IterationEvents from monai.transforms import Compose from monai.utils.enums import CommonKeys -class Interaction(Iteration): +class Interaction: """ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. + For more details please refer to: https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. This implementation is based on: Sakinis et al., Interactive segmentation of medical images through @@ -51,9 +52,7 @@ def __init__( self.train = train self.key_probability = key_probability - def __call__( # type: ignore - self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor] - ): + def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchdata: Dict[str, torch.Tensor]): if batchdata is None: raise ValueError("Must provide batch data for current iteration.") diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index 439e48df70..f24bc0fc37 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -14,7 +14,6 @@ from .trainer import GanTrainer, SupervisedTrainer, Trainer from .utils import ( GanKeys, - Iteration, IterationEvents, PrepareBatch, PrepareBatchDefault, diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 3ddbca45bf..325dc4b37d 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.utils.data import DataLoader @@ -45,9 +45,13 @@ class Evaluator(Workflow): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to @@ -80,7 +84,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, @@ -147,9 +151,13 @@ class SupervisedEvaluator(Evaluator): epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -184,7 +192,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, @@ -275,9 +283,13 @@ class EnsembleEvaluator(Evaluator): the length must exactly match the number of networks. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -313,7 +325,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index c7a8b49e30..ab6c77642e 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from torch.optim.optimizer import Optimizer @@ -75,9 +75,13 @@ class SupervisedTrainer(Trainer): epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for current iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -115,7 +119,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, inferer: Optional[Inferer] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, @@ -241,12 +245,16 @@ class GanTrainer(Trainer): non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. d_prepare_batch: callback function to prepare batchdata for D inferer. - Defaults to return ``GanKeys.REALS`` in batchdata dict. + Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. g_prepare_batch: callback function to create batch of latent input for G inferer. - Defaults to return random latents. + Defaults to return random latents. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to @@ -286,7 +294,7 @@ def __init__( d_prepare_batch: Callable = default_prepare_batch, g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 482058974b..f06a3bf255 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -20,9 +20,8 @@ from monai.utils.enums import CommonKeys if TYPE_CHECKING: - from ignite.engine import Engine, EventEnum + from ignite.engine import EventEnum else: - Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") __all__ = [ @@ -33,7 +32,6 @@ "PrepareBatch", "PrepareBatchDefault", "PrepareBatchExtraInput", - "Iteration", "default_make_latent", "engine_apply_transform", "default_metric_cmp_fn", @@ -202,18 +200,6 @@ def _get_data(key: str): return image, label, tuple(args), kwargs -class Iteration(ABC): - """ - Base class of customized iteration in the trainer or evaluator workflows. - It takes ignite Engine and the data of current batch as input. - - """ - - @abstractmethod - def __call__(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") - - def default_make_latent( num_latents: int, latent_size: int, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False ) -> torch.Tensor: diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index f6f0a6a059..79fcf2dadf 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -11,7 +11,7 @@ import warnings from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch import torch.distributed as dist @@ -67,9 +67,13 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona epoch_length: number of iterations for one epoch, default to `len(data_loader)`. non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to the host. For other cases, this argument has no effect. - prepare_batch: function to parse image and label for every iteration. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. iteration_update: the callable function for every iteration, expect to accept `engine` - and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_metric: compute metric when every iteration completed, and save average value to @@ -107,7 +111,7 @@ def __init__( epoch_length: Optional[int] = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable] = None, + iteration_update: Optional[Callable[[Engine, Any], Any]] = None, postprocessing: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, From 28bb5f1e97ae434f1a24beb092f2ed765e49ddd1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 13 Dec 2021 22:53:09 +0800 Subject: [PATCH 3/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/deepgrow/interaction.py | 2 +- monai/engines/evaluator.py | 12 ++++++------ monai/engines/trainer.py | 10 +++++----- monai/engines/workflow.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/monai/apps/deepgrow/interaction.py b/monai/apps/deepgrow/interaction.py index bb349f02fb..692de633aa 100644 --- a/monai/apps/deepgrow/interaction.py +++ b/monai/apps/deepgrow/interaction.py @@ -22,7 +22,7 @@ class Interaction: """ Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation. - For more details please refer to: https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. This implementation is based on: Sakinis et al., Interactive segmentation of medical images through diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 325dc4b37d..27aa805404 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -47,11 +47,11 @@ class Evaluator(Workflow): with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) from `engine.state.batch` for every iteration, for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to @@ -153,11 +153,11 @@ class SupervisedEvaluator(Evaluator): with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) from `engine.state.batch` for every iteration, for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -285,11 +285,11 @@ class EnsembleEvaluator(Evaluator): with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) from `engine.state.batch` for every iteration, for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index ab6c77642e..c6521db854 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -77,11 +77,11 @@ class SupervisedTrainer(Trainer): with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) from `engine.state.batch` for every iteration, for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. @@ -246,15 +246,15 @@ class GanTrainer(Trainer): with respect to the host. For other cases, this argument has no effect. d_prepare_batch: callback function to prepare batchdata for D inferer. Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_prepare_batch: callback function to create batch of latent input for G inferer. Defaults to return random latents. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 79fcf2dadf..dbe1c36b4f 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -69,11 +69,11 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona with respect to the host. For other cases, this argument has no effect. prepare_batch: function to parse expected data (usually `image`, `label` and other network args) from `engine.state.batch` for every iteration, for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/__init__.py#L33. + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. iteration_update: the callable function for every iteration, expect to accept `engine` and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. if not provided, use `self._iteration()` instead. for more details please refer to: - https://github.com/pytorch/ignite/blob/v0.4.7/ignite/engine/engine.py#L831. + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_metric: compute metric when every iteration completed, and save average value to