diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c69c0e0547..b058935f67 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader @@ -84,23 +86,23 @@ class Evaluator(Workflow): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], - epoch_length: Optional[int] = None, + val_data_loader: Iterable | DataLoader, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -144,7 +146,7 @@ def run(self, global_epoch: int = 1) -> None: self.state.iteration = 0 super().run() - def get_validation_stats(self) -> Dict[str, float]: + def get_validation_stats(self) -> dict[str, float]: return {"best_validation_metric": self.state.best_metric, "best_validation_epoch": self.state.best_metric_epoch} @@ -199,25 +201,25 @@ class SupervisedEvaluator(Evaluator): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], + val_data_loader: Iterable | DataLoader, network: torch.nn.Module, - epoch_length: Optional[int] = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -243,7 +245,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -252,7 +254,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - PRED: prediction result of model. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `SupervisedEvaluator` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -261,26 +263,25 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation - with self.mode(self.network): - if self.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore + with engine.mode(engine.network): + + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -342,26 +343,26 @@ class EnsembleEvaluator(Evaluator): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], + val_data_loader: Iterable | DataLoader, networks: Sequence[torch.nn.Module], - pred_keys: Optional[KeysCollection] = None, - epoch_length: Optional[int] = None, + pred_keys: KeysCollection | None = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -392,7 +393,7 @@ def __init__( raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -404,7 +405,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - pred_keys[N]: prediction result of network N. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `EnsembleEvaluator` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -413,31 +414,29 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - for idx, network in enumerate(self.networks): - with self.mode(network): - if self.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore + for idx, network in enumerate(engine.networks): + with engine.mode(network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): if isinstance(engine.state.output, dict): engine.state.output.update( - {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)} ) else: if isinstance(engine.state.output, dict): engine.state.output.update( - {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)} ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 12753765ef..fbdb309bb5 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.optim.optimizer import Optimizer @@ -55,7 +57,7 @@ def run(self) -> None: self.scaler = torch.cuda.amp.GradScaler() if self.amp else None super().run() - def get_train_stats(self) -> Dict[str, float]: + def get_train_stats(self) -> dict[str, float]: return {"total_epochs": self.state.max_epochs, "total_iterations": self.state.epoch_length} @@ -116,27 +118,27 @@ def __init__( self, device: torch.device, max_epochs: int, - train_data_loader: Union[Iterable, DataLoader], + train_data_loader: Iterable | DataLoader, network: torch.nn.Module, optimizer: Optimizer, loss_function: Callable, - epoch_length: Optional[int] = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_train_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - train_handlers: Optional[Sequence] = None, + train_handlers: Sequence | None = None, amp: bool = False, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, optim_set_to_none: bool = False, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -165,7 +167,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.optim_set_to_none = optim_set_to_none - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -175,7 +177,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - LOSS: loss value computed by loss function. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `SupervisedTrainer` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -184,39 +186,37 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} def _compute_pred_loss(): - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) - engine.state.output[Keys.LOSS] = self.loss_function(engine.state.output[Keys.PRED], targets).mean() + engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) - self.network.train() - self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) + engine.network.train() + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) - if self.amp and self.scaler is not None: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore + if engine.amp and engine.scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): _compute_pred_loss() - self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore + engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - self.scaler.step(self.optimizer) - self.scaler.update() + engine.scaler.step(engine.optimizer) + engine.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() # type: ignore + engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - self.optimizer.step() + engine.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output @@ -295,25 +295,25 @@ def __init__( d_network: torch.nn.Module, d_optimizer: Optimizer, d_loss_function: Callable, - epoch_length: Optional[int] = None, - g_inferer: Optional[Inferer] = None, - d_inferer: Optional[Inferer] = None, + epoch_length: int | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, d_train_steps: int = 1, latent_shape: int = 64, non_blocking: bool = False, d_prepare_batch: Callable = default_prepare_batch, g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - postprocessing: Optional[Transform] = None, - key_train_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - train_handlers: Optional[Sequence] = None, + train_handlers: Sequence | None = None, decollate: bool = True, optim_set_to_none: bool = False, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -351,13 +351,13 @@ def __init__( self.optim_set_to_none = optim_set_to_none def _iteration( - self, engine: Engine, batchdata: Union[Dict, Sequence] - ) -> Dict[str, Union[torch.Tensor, int, float, bool]]: + self, engine: GanTrainer, batchdata: dict | Sequence + ) -> dict[str, torch.Tensor | int | float | bool]: """ Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `GanTrainer` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -367,42 +367,40 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore - ) - batch_size = self.data_loader.batch_size # type: ignore - g_input = self.g_prepare_batch( + d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + batch_size = engine.data_loader.batch_size # type: ignore + g_input = engine.g_prepare_batch( num_latents=batch_size, - latent_size=self.latent_shape, - device=engine.state.device, # type: ignore - non_blocking=engine.non_blocking, # type: ignore - **engine.to_kwargs, # type: ignore + latent_size=engine.latent_shape, + device=engine.state.device, + non_blocking=engine.non_blocking, + **engine.to_kwargs, ) - g_output = self.g_inferer(g_input, self.g_network) + g_output = engine.g_inferer(g_input, engine.g_network) # Train Discriminator d_total_loss = torch.zeros(1) - for _ in range(self.d_train_steps): - self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) - dloss = self.d_loss_function(g_output, d_input) + for _ in range(engine.d_train_steps): + engine.d_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + dloss = engine.d_loss_function(g_output, d_input) dloss.backward() - self.d_optimizer.step() + engine.d_optimizer.step() d_total_loss += dloss.item() # Train Generator - if self.g_update_latents: - g_input = self.g_prepare_batch( + if engine.g_update_latents: + g_input = engine.g_prepare_batch( num_latents=batch_size, - latent_size=self.latent_shape, - device=engine.state.device, # type: ignore - non_blocking=engine.non_blocking, # type: ignore - **engine.to_kwargs, # type: ignore + latent_size=engine.latent_shape, + device=engine.state.device, + non_blocking=engine.non_blocking, + **engine.to_kwargs, ) - g_output = self.g_inferer(g_input, self.g_network) - self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) - g_loss = self.g_loss_function(g_output) + g_output = engine.g_inferer(g_input, engine.g_network) + engine.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + g_loss = engine.g_loss_function(g_output) g_loss.backward() - self.g_optimizer.step() + engine.g_optimizer.step() return { GanKeys.REALS: d_input, diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 75123da153..5b3cb556a5 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -250,8 +250,8 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): metric.attach(self, name) @self.on(Events.EPOCH_COMPLETED) - def _compare_metrics(engine: Engine) -> None: - key_metric_name = engine.state.key_metric_name # type: ignore + def _compare_metrics(engine: Workflow) -> None: + key_metric_name = engine.state.key_metric_name if key_metric_name is not None: current_val_metric = engine.state.metrics[key_metric_name] if not is_scalar(current_val_metric): @@ -261,10 +261,10 @@ def _compare_metrics(engine: Engine) -> None: ) return - if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric # type: ignore - engine.state.best_metric_epoch = engine.state.epoch # type: ignore + engine.state.best_metric = current_val_metric + engine.state.best_metric_epoch = engine.state.epoch def _register_handlers(self, handlers: Sequence): """ @@ -289,7 +289,7 @@ def run(self) -> None: return super().run(data=self.data_loader, max_epochs=self.state.max_epochs) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine, batchdata: Dict[str, torch.Tensor]): """ Abstract callback function for the processing logic of 1 iteration in Ignite Engine. Need subclass to implement different logics, like SupervisedTrainer/Evaluator, GANTrainer, etc.