From 7e253e92effaa8748d07e8ed844a38e19cb55d81 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 20 Apr 2022 20:34:44 +0800 Subject: [PATCH 1/5] [DLMED] update Workflow.py Signed-off-by: Nic Ma --- monai/engines/workflow.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 75123da153..bc42509d35 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -250,8 +250,8 @@ def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): metric.attach(self, name) @self.on(Events.EPOCH_COMPLETED) - def _compare_metrics(engine: Engine) -> None: - key_metric_name = engine.state.key_metric_name # type: ignore + def _compare_metrics(engine: Workflow) -> None: + key_metric_name = engine.state.key_metric_name if key_metric_name is not None: current_val_metric = engine.state.metrics[key_metric_name] if not is_scalar(current_val_metric): @@ -261,10 +261,10 @@ def _compare_metrics(engine: Engine) -> None: ) return - if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): # type: ignore + if self.metric_cmp_fn(current_val_metric, engine.state.best_metric): self.logger.info(f"Got new best metric of {key_metric_name}: {current_val_metric}") - engine.state.best_metric = current_val_metric # type: ignore - engine.state.best_metric_epoch = engine.state.epoch # type: ignore + engine.state.best_metric = current_val_metric + engine.state.best_metric_epoch = engine.state.epoch def _register_handlers(self, handlers: Sequence): """ From 33bde2967ab2ed9492d4577412de9c19589e7be6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 20 Apr 2022 23:37:27 +0800 Subject: [PATCH 2/5] [DLMED] update all the engines Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 44 ++++++++++---------- monai/engines/trainer.py | 83 +++++++++++++++++++------------------- monai/engines/workflow.py | 2 +- 3 files changed, 66 insertions(+), 63 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c69c0e0547..17cd4bbd34 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -243,7 +244,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -252,7 +253,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - PRED: prediction result of model. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `SupervisedEvaluator` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -261,8 +262,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + batch = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs ) if len(batch) == 2: inputs, targets = batch @@ -272,15 +273,16 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # execute forward computation - with self.mode(self.network): - if self.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore + with engine.mode(engine.network): + + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) else: - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) # type: ignore + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) @@ -392,7 +394,7 @@ def __init__( raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: EnsembleEvaluator, batchdata: Dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -404,7 +406,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - pred_keys[N]: prediction result of network N. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `EnsembleEvaluator` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -413,8 +415,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + batch = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs ) if len(batch) == 2: inputs, targets = batch @@ -424,20 +426,20 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} - for idx, network in enumerate(self.networks): - with self.mode(network): - if self.amp: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore + for idx, network in enumerate(engine.networks): + with engine.mode(network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): if isinstance(engine.state.output, dict): engine.state.output.update( - {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)} ) else: if isinstance(engine.state.output, dict): engine.state.output.update( - {self.pred_keys[idx]: self.inferer(inputs, network, *args, **kwargs)} + {engine.pred_keys[idx]: engine.inferer(inputs, network, *args, **kwargs)} ) engine.fire_event(IterationEvents.FORWARD_COMPLETED) engine.fire_event(IterationEvents.MODEL_COMPLETED) diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 12753765ef..54bf8966c2 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -165,7 +166,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.optim_set_to_none = optim_set_to_none - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedTrainer, batchdata: Dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -175,7 +176,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): - LOSS: loss value computed by loss function. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `SupervisedTrainer` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -184,8 +185,8 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + batch = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs ) if len(batch) == 2: inputs, targets = batch @@ -194,29 +195,29 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state - engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} # type: ignore + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} def _compute_pred_loss(): - engine.state.output[Keys.PRED] = self.inferer(inputs, self.network, *args, **kwargs) + engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) engine.fire_event(IterationEvents.FORWARD_COMPLETED) - engine.state.output[Keys.LOSS] = self.loss_function(engine.state.output[Keys.PRED], targets).mean() + engine.state.output[Keys.LOSS] = engine.loss_function(engine.state.output[Keys.PRED], targets).mean() engine.fire_event(IterationEvents.LOSS_COMPLETED) - self.network.train() - self.optimizer.zero_grad(set_to_none=self.optim_set_to_none) + engine.network.train() + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) - if self.amp and self.scaler is not None: - with torch.cuda.amp.autocast(**engine.amp_kwargs): # type: ignore + if engine.amp and engine.scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): _compute_pred_loss() - self.scaler.scale(engine.state.output[Keys.LOSS]).backward() # type: ignore + engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - self.scaler.step(self.optimizer) - self.scaler.update() + engine.scaler.step(engine.optimizer) + engine.scaler.update() else: _compute_pred_loss() - engine.state.output[Keys.LOSS].backward() # type: ignore + engine.state.output[Keys.LOSS].backward() engine.fire_event(IterationEvents.BACKWARD_COMPLETED) - self.optimizer.step() + engine.optimizer.step() engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output @@ -351,13 +352,13 @@ def __init__( self.optim_set_to_none = optim_set_to_none def _iteration( - self, engine: Engine, batchdata: Union[Dict, Sequence] + self, engine: GanTrainer, batchdata: Union[Dict, Sequence] ) -> Dict[str, Union[torch.Tensor, int, float, bool]]: """ Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine. Args: - engine: Ignite Engine, it can be a trainer, validator or evaluator. + engine: `GanTrainer` to execute operation for an iteration. batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. Raises: @@ -367,42 +368,42 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = self.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs # type: ignore + d_input = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs ) - batch_size = self.data_loader.batch_size # type: ignore - g_input = self.g_prepare_batch( + batch_size = engine.data_loader.batch_size # type: ignore + g_input = engine.g_prepare_batch( num_latents=batch_size, - latent_size=self.latent_shape, - device=engine.state.device, # type: ignore - non_blocking=engine.non_blocking, # type: ignore - **engine.to_kwargs, # type: ignore + latent_size=engine.latent_shape, + device=engine.state.device, + non_blocking=engine.non_blocking, + **engine.to_kwargs, ) - g_output = self.g_inferer(g_input, self.g_network) + g_output = engine.g_inferer(g_input, engine.g_network) # Train Discriminator d_total_loss = torch.zeros(1) - for _ in range(self.d_train_steps): - self.d_optimizer.zero_grad(set_to_none=self.optim_set_to_none) - dloss = self.d_loss_function(g_output, d_input) + for _ in range(engine.d_train_steps): + engine.d_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + dloss = engine.d_loss_function(g_output, d_input) dloss.backward() - self.d_optimizer.step() + engine.d_optimizer.step() d_total_loss += dloss.item() # Train Generator - if self.g_update_latents: - g_input = self.g_prepare_batch( + if engine.g_update_latents: + g_input = engine.g_prepare_batch( num_latents=batch_size, - latent_size=self.latent_shape, - device=engine.state.device, # type: ignore - non_blocking=engine.non_blocking, # type: ignore - **engine.to_kwargs, # type: ignore + latent_size=engine.latent_shape, + device=engine.state.device, + non_blocking=engine.non_blocking, + **engine.to_kwargs, ) - g_output = self.g_inferer(g_input, self.g_network) - self.g_optimizer.zero_grad(set_to_none=self.optim_set_to_none) - g_loss = self.g_loss_function(g_output) + g_output = engine.g_inferer(g_input, engine.g_network) + engine.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + g_loss = engine.g_loss_function(g_output) g_loss.backward() - self.g_optimizer.step() + engine.g_optimizer.step() return { GanKeys.REALS: d_input, diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index bc42509d35..5b3cb556a5 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -289,7 +289,7 @@ def run(self) -> None: return super().run(data=self.data_loader, max_epochs=self.state.max_epochs) - def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine, batchdata: Dict[str, torch.Tensor]): """ Abstract callback function for the processing logic of 1 iteration in Ignite Engine. Need subclass to implement different logics, like SupervisedTrainer/Evaluator, GANTrainer, etc. From a13117c9cd6db596689c3bc3256838af0393437d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Apr 2022 15:38:20 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/engines/evaluator.py | 94 +++++++++++++++++++------------------- monai/engines/trainer.py | 56 +++++++++++------------ 2 files changed, 75 insertions(+), 75 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 17cd4bbd34..b10e60aa3a 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -85,23 +85,23 @@ class Evaluator(Workflow): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], - epoch_length: Optional[int] = None, + val_data_loader: Iterable | DataLoader, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -145,7 +145,7 @@ def run(self, global_epoch: int = 1) -> None: self.state.iteration = 0 super().run() - def get_validation_stats(self) -> Dict[str, float]: + def get_validation_stats(self) -> dict[str, float]: return {"best_validation_metric": self.state.best_metric, "best_validation_epoch": self.state.best_metric_epoch} @@ -200,25 +200,25 @@ class SupervisedEvaluator(Evaluator): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], + val_data_loader: Iterable | DataLoader, network: torch.nn.Module, - epoch_length: Optional[int] = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -244,7 +244,7 @@ def __init__( self.network = network self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -267,8 +267,8 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Ten ) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch @@ -277,7 +277,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Ten # execute forward computation with engine.mode(engine.network): - + if engine.amp: with torch.cuda.amp.autocast(**engine.amp_kwargs): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) @@ -344,26 +344,26 @@ class EnsembleEvaluator(Evaluator): def __init__( self, device: torch.device, - val_data_loader: Union[Iterable, DataLoader], + val_data_loader: Iterable | DataLoader, networks: Sequence[torch.nn.Module], - pred_keys: Optional[KeysCollection] = None, - epoch_length: Optional[int] = None, + pred_keys: KeysCollection | None = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_val_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - val_handlers: Optional[Sequence] = None, + val_handlers: Sequence | None = None, amp: bool = False, - mode: Union[ForwardMode, str] = ForwardMode.EVAL, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -394,7 +394,7 @@ def __init__( raise ValueError("length of `pred_keys` must be same as the length of `networks`.") self.inferer = SimpleInferer() if inferer is None else inferer - def _iteration(self, engine: EnsembleEvaluator, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: EnsembleEvaluator, batchdata: dict[str, torch.Tensor]): """ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -420,8 +420,8 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: Dict[str, torch.Tenso ) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 54bf8966c2..7697edd51f 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -56,7 +56,7 @@ def run(self) -> None: self.scaler = torch.cuda.amp.GradScaler() if self.amp else None super().run() - def get_train_stats(self) -> Dict[str, float]: + def get_train_stats(self) -> dict[str, float]: return {"total_epochs": self.state.max_epochs, "total_iterations": self.state.epoch_length} @@ -117,27 +117,27 @@ def __init__( self, device: torch.device, max_epochs: int, - train_data_loader: Union[Iterable, DataLoader], + train_data_loader: Iterable | DataLoader, network: torch.nn.Module, optimizer: Optimizer, loss_function: Callable, - epoch_length: Optional[int] = None, + epoch_length: int | None = None, non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - inferer: Optional[Inferer] = None, - postprocessing: Optional[Transform] = None, - key_train_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - train_handlers: Optional[Sequence] = None, + train_handlers: Sequence | None = None, amp: bool = False, - event_names: Optional[List[Union[str, EventEnum]]] = None, - event_to_attr: Optional[dict] = None, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, decollate: bool = True, optim_set_to_none: bool = False, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ) -> None: super().__init__( device=device, @@ -166,7 +166,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.optim_set_to_none = optim_set_to_none - def _iteration(self, engine: SupervisedTrainer, batchdata: Dict[str, torch.Tensor]): + def _iteration(self, engine: SupervisedTrainer, batchdata: dict[str, torch.Tensor]): """ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. Return below items in a dictionary: @@ -190,8 +190,8 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: Dict[str, torch.Tenso ) if len(batch) == 2: inputs, targets = batch - args: Tuple = () - kwargs: Dict = {} + args: tuple = () + kwargs: dict = {} else: inputs, targets, args, kwargs = batch # put iteration outputs into engine.state @@ -296,25 +296,25 @@ def __init__( d_network: torch.nn.Module, d_optimizer: Optimizer, d_loss_function: Callable, - epoch_length: Optional[int] = None, - g_inferer: Optional[Inferer] = None, - d_inferer: Optional[Inferer] = None, + epoch_length: int | None = None, + g_inferer: Inferer | None = None, + d_inferer: Inferer | None = None, d_train_steps: int = 1, latent_shape: int = 64, non_blocking: bool = False, d_prepare_batch: Callable = default_prepare_batch, g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, - iteration_update: Optional[Callable[[Engine, Any], Any]] = None, - postprocessing: Optional[Transform] = None, - key_train_metric: Optional[Dict[str, Metric]] = None, - additional_metrics: Optional[Dict[str, Metric]] = None, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, - train_handlers: Optional[Sequence] = None, + train_handlers: Sequence | None = None, decollate: bool = True, optim_set_to_none: bool = False, - to_kwargs: Optional[Dict] = None, - amp_kwargs: Optional[Dict] = None, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -352,8 +352,8 @@ def __init__( self.optim_set_to_none = optim_set_to_none def _iteration( - self, engine: GanTrainer, batchdata: Union[Dict, Sequence] - ) -> Dict[str, Union[torch.Tensor, int, float, bool]]: + self, engine: GanTrainer, batchdata: dict | Sequence + ) -> dict[str, torch.Tensor | int | float | bool]: """ Callback function for Adversarial Training processing logic of 1 iteration in Ignite Engine. From 5a95621afdc60a30894254f35c3797819e81ea29 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 20 Apr 2022 23:44:47 +0800 Subject: [PATCH 4/5] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 11 ++++------- monai/engines/trainer.py | 9 +++------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 17cd4bbd34..67d0a117c9 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -10,6 +10,7 @@ # limitations under the License. from __future__ import annotations + from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -262,9 +263,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Ten """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = engine.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -277,7 +276,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: Dict[str, torch.Ten # execute forward computation with engine.mode(engine.network): - + if engine.amp: with torch.cuda.amp.autocast(**engine.amp_kwargs): engine.state.output[Keys.PRED] = engine.inferer(inputs, engine.network, *args, **kwargs) @@ -415,9 +414,7 @@ def _iteration(self, engine: EnsembleEvaluator, batchdata: Dict[str, torch.Tenso """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = engine.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch args: Tuple = () diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 54bf8966c2..113a16d7cb 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -10,6 +10,7 @@ # limitations under the License. from __future__ import annotations + from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch @@ -185,9 +186,7 @@ def _iteration(self, engine: SupervisedTrainer, batchdata: Dict[str, torch.Tenso """ if batchdata is None: raise ValueError("Must provide batch data for current iteration.") - batch = engine.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) + batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) if len(batch) == 2: inputs, targets = batch args: Tuple = () @@ -368,9 +367,7 @@ def _iteration( if batchdata is None: raise ValueError("must provide batch data for current iteration.") - d_input = engine.prepare_batch( - batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs - ) + d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) batch_size = engine.data_loader.batch_size # type: ignore g_input = engine.g_prepare_batch( num_latents=batch_size, From 3721648ae790258136b577c545a434ea21ccfd1a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 21 Apr 2022 00:01:01 +0800 Subject: [PATCH 5/5] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/engines/evaluator.py | 2 +- monai/engines/trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index c5b1124b81..b058935f67 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.utils.data import DataLoader diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index 7ce8ee105a..fbdb309bb5 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence import torch from torch.optim.optimizer import Optimizer