Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 67 additions & 68 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -84,23 +86,23 @@ class Evaluator(Workflow):
def __init__(
self,
device: torch.device,
val_data_loader: Union[Iterable, DataLoader],
epoch_length: Optional[int] = None,
val_data_loader: Iterable | DataLoader,
epoch_length: int | None = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
iteration_update: Callable[[Engine, Any], Any] | None = None,
postprocessing: Transform | None = None,
key_val_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
val_handlers: Sequence | None = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
Expand Down Expand Up @@ -144,7 +146,7 @@ def run(self, global_epoch: int = 1) -> None:
self.state.iteration = 0
super().run()

def get_validation_stats(self) -> Dict[str, float]:
def get_validation_stats(self) -> dict[str, float]:
return {"best_validation_metric": self.state.best_metric, "best_validation_epoch": self.state.best_metric_epoch}


Expand Down Expand Up @@ -199,25 +201,25 @@ class SupervisedEvaluator(Evaluator):
def __init__(
self,
device: torch.device,
val_data_loader: Union[Iterable, DataLoader],
val_data_loader: Iterable | DataLoader,
network: torch.nn.Module,
epoch_length: Optional[int] = None,
epoch_length: int | None = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
iteration_update: Callable[[Engine, Any], Any] | None = None,
inferer: Inferer | None = None,
postprocessing: Transform | None = None,
key_val_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
val_handlers: Sequence | None = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
Expand All @@ -243,7 +245,7 @@ def __init__(
self.network = network
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]):
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand All @@ -252,7 +254,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
- PRED: prediction result of model.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
engine: `SupervisedEvaluator` to execute operation for an iteration.
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

Raises:
Expand All @@ -261,26 +263,25 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
kwargs: Dict = {}
args: tuple = ()
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

# execute forward computation
with self.mode(self.network):
if self.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
with engine.mode(engine.network):

if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
else:
engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore
engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)

Expand Down Expand Up @@ -342,26 +343,26 @@ class EnsembleEvaluator(Evaluator):
def __init__(
self,
device: torch.device,
val_data_loader: Union[Iterable, DataLoader],
val_data_loader: Iterable | DataLoader,
networks: Sequence[torch.nn.Module],
pred_keys: Optional[KeysCollection] = None,
epoch_length: Optional[int] = None,
pred_keys: KeysCollection | None = None,
epoch_length: int | None = None,
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable[[Engine, Any], Any]] = None,
inferer: Optional[Inferer] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
iteration_update: Callable[[Engine, Any], Any] | None = None,
inferer: Inferer | None = None,
postprocessing: Transform | None = None,
key_val_metric: dict[str, Metric] | None = None,
additional_metrics: dict[str, Metric] | None = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
val_handlers: Optional[Sequence] = None,
val_handlers: Sequence | None = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
mode: ForwardMode | str = ForwardMode.EVAL,
event_names: list[str | EventEnum] | None = None,
event_to_attr: dict | None = None,
decollate: bool = True,
to_kwargs: Optional[Dict] = None,
amp_kwargs: Optional[Dict] = None,
to_kwargs: dict | None = None,
amp_kwargs: dict | None = None,
) -> None:
super().__init__(
device=device,
Expand Down Expand Up @@ -392,7 +393,7 @@ def __init__(
raise ValueError("length of `pred_keys` must be same as the length of `networks`.")
self.inferer = SimpleInferer() if inferer is None else inferer

def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]):
"""
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
Return below items in a dictionary:
Expand All @@ -404,7 +405,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
- pred_keys[N]: prediction result of network N.

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
engine: `EnsembleEvaluator` to execute operation for an iteration.
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.

Raises:
Expand All @@ -413,31 +414,29 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]):
"""
if batchdata is None:
raise ValueError("Must provide batch data for current iteration.")
batch = self.prepare_batch(
batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore
)
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
if len(batch) == 2:
inputs, targets = batch
args: Tuple = ()
kwargs: Dict = {}
args: tuple = ()
kwargs: dict = {}
else:
inputs, targets, args, kwargs = batch

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

for idx, network in enumerate(self.networks):
with self.mode(network):
if self.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore
for idx, network in enumerate(engine.networks):
with engine.mode(network):
if engine.amp:
with torch.cuda.amp.autocast(**engine.amp_kwargs):
if isinstance(engine.state.output, dict):
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
)
else:
if isinstance(engine.state.output, dict):
engine.state.output.update(
{self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)}
{engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)}
)
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
engine.fire_event(IterationEvents.MODEL_COMPLETED)
Expand Down
Loading