Skip to content
1 change: 1 addition & 0 deletions monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class Interaction:
"""
Ignite process_function used to introduce interactions (simulation of clicks) for Deepgrow Training/Evaluation.
For more details please refer to: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
This implementation is based on:

Sakinis et al., Interactive segmentation of medical images through
Expand Down
32 changes: 22 additions & 10 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -45,9 +45,13 @@ class Evaluator(Workflow):
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -80,7 +84,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -147,9 +151,13 @@ class SupervisedEvaluator(Evaluator):
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -184,7 +192,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -275,9 +283,13 @@ class EnsembleEvaluator(Evaluator):
the length must exactly match the number of networks.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -313,7 +325,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
Expand Down
24 changes: 16 additions & 8 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import torch
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -75,9 +75,13 @@ class SupervisedTrainer(Trainer):
epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for current iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
Expand Down Expand Up @@ -115,7 +119,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
Expand Down Expand Up @@ -241,12 +245,16 @@ class GanTrainer(Trainer):
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
d_prepare_batch: callback function to prepare batchdata for D inferer.
Defaults to return ``GanKeys.REALS`` in batchdata dict.
Defaults to return ``GanKeys.REALS`` in batchdata dict. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
g_prepare_batch: callback function to create batch of latent input for G inferer.
Defaults to return random latents.
Defaults to return random latents. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_train_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -286,7 +294,7 @@ def __init__(
d_prepare_batch: Callable = default_prepare_batch,
g_prepare_batch: Callable = default_make_latent,
g_update_latents: bool = True,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down
12 changes: 8 additions & 4 deletions monai/engines/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -67,9 +67,13 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
epoch_length: number of iterations for one epoch, default to `len(data_loader)`.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function to parse image and label for every iteration.
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
from `engine.state.batch` for every iteration, for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
if not provided, use `self._iteration()` instead. for more details please refer to:
https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_metric: compute metric when every iteration completed, and save average value to
Expand Down Expand Up @@ -107,7 +111,7 @@ def __init__(
epoch_length: Optional[int] = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Callable] = None,
key_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
Expand Down