Skip to content

Commit

Permalink
Revert "Enable inference mode for evaluation (#12715)"
Browse files Browse the repository at this point in the history
This reverts commit 4df546a.
  • Loading branch information
akihironitta committed Apr 22, 2022
1 parent 11c5348 commit a3da197
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 22 deletions.
4 changes: 0 additions & 4 deletions CHANGELOG.md
Expand Up @@ -9,12 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715))


- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))


- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))


Expand Down
22 changes: 4 additions & 18 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -19,15 +19,13 @@
import traceback
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Type, Union
from weakref import proxy

import torch
import torch.distributed as dist
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -99,7 +97,7 @@
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -1318,7 +1316,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context():
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
eval_loop_results = self._evaluation_loop.run()

# remove the tensors from the eval results
Expand All @@ -1334,7 +1332,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with _evaluation_context():
with torch.no_grad():
return self.predict_loop.run()

def _run_sanity_check(self) -> None:
Expand Down Expand Up @@ -2750,18 +2748,6 @@ def configure_optimizers(self):
return max_estimated_steps


@contextmanager
def _evaluation_context() -> Generator:
# inference mode is not supported with gloo backend (#9431)
context_manager_class = (
torch.inference_mode
if _TORCH_GREATER_EQUAL_1_9 and not (dist.is_initialized() and dist.get_backend() == "gloo")
else torch.no_grad
)
with context_manager_class():
yield


def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
if batches is None:
# batches is optional to know if the user passed a value so that we can show the above info messages only to the
Expand Down

0 comments on commit a3da197

Please sign in to comment.