diff --git a/docs/source/engines.rst b/docs/source/engines.rst index cc0ec3c659..90c7be2a1e 100644 --- a/docs/source/engines.rst +++ b/docs/source/engines.rst @@ -15,16 +15,18 @@ Multi-GPU data parallel Workflows --------- -.. automodule:: monai.engines.workflow -.. currentmodule:: monai.engines.workflow +.. currentmodule:: monai.engines + +`BaseWorkflow` +~~~~~~~~~~~~~~ +.. autoclass:: BaseWorkflow + :members: `Workflow` ~~~~~~~~~~ .. autoclass:: Workflow :members: -.. currentmodule:: monai.engines - `Trainer` ~~~~~~~~~ .. autoclass:: Trainer diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py index d04401829f..f24bc0fc37 100644 --- a/monai/engines/__init__.py +++ b/monai/engines/__init__.py @@ -24,3 +24,4 @@ engine_apply_transform, get_devices_spec, ) +from .workflow import BaseWorkflow, Workflow diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 48e2dc1774..f6f0a6a059 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -10,6 +10,7 @@ # limitations under the License. import warnings +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch @@ -37,6 +38,18 @@ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") +class BaseWorkflow(ABC): + """ + Base class for any MONAI style workflow. + `run()` is designed to execute the train, evaluation or inference logic. + + """ + + @abstractmethod + def run(self, *args, **kwargs): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optional_import """ Workflow defines the core work process inheriting from Ignite engine.