Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inference mode for evaluation #12715

Merged
merged 8 commits into from Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,8 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Enabled `torch.inference_mode` for evaluation and prediction ([#12715](https://github.com/PyTorchLightning/pytorch-lightning/pull/12715))


- Added support for setting `val_check_interval` to a value higher than the amount of training batches when `check_val_every_n_epoch=None` ([#11993](https://github.com/PyTorchLightning/pytorch-lightning/pull/11993))


- Include the `pytorch_lightning` version as a header in the CLI config files ([#12532](https://github.com/PyTorchLightning/pytorch-lightning/pull/12532))


Expand Down
22 changes: 18 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -19,13 +19,15 @@
import traceback
import warnings
from argparse import ArgumentParser, Namespace
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Callable, cast, Dict, Generator, Iterable, List, Optional, Type, Union
from weakref import proxy

import torch
import torch.distributed as dist
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -97,7 +99,7 @@
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_9
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -1316,7 +1318,7 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT:
# reset trainer on this loop and all child loops in case user connected a custom loop
self._evaluation_loop.trainer = self

with self.profiler.profile(f"run_{self.state.stage}_evaluation"), torch.no_grad():
with self.profiler.profile(f"run_{self.state.stage}_evaluation"), _evaluation_context():
eval_loop_results = self._evaluation_loop.run()

# remove the tensors from the eval results
Expand All @@ -1332,7 +1334,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
self.reset_predict_dataloader(self.lightning_module)
# reset trainer on this loop and all child loops in case user connected a custom loop
self.predict_loop.trainer = self
with torch.no_grad():
with _evaluation_context():
return self.predict_loop.run()

def _run_sanity_check(self) -> None:
Expand Down Expand Up @@ -2748,6 +2750,18 @@ def configure_optimizers(self):
return max_estimated_steps


@contextmanager
def _evaluation_context() -> Generator:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
# inference mode is not supported with gloo backend (#9431)
context_manager_class = (
torch.inference_mode
if _TORCH_GREATER_EQUAL_1_9 and not (dist.is_initialized() and dist.get_backend() == "gloo")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't find any information in docs or anywhere about inference_mode not being compatible with the gloo backend. A comment here in the code would probably be appropriate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding the comment. i'm definitely not satisfied still, and will investigate. this is very sus, especially since I cannot find any open or closed issue ticket on the pytorch github.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find it either. Let me open an issue on PT GitHub.

else torch.no_grad
)
with context_manager_class():
yield


def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
if batches is None:
# batches is optional to know if the user passed a value so that we can show the above info messages only to the
Expand Down