Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:

def on_train_start(self) -> None:
# clear cache before training
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags(local_rank: int) -> None:
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,13 @@
rank_zero_deprecation,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed

Expand Down Expand Up @@ -335,7 +341,7 @@ def post_dispatch(self) -> None:
self.cluster_environment.teardown()

def barrier(self, *args, **kwargs) -> None:
if not torch.distributed.is_initialized():
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,13 @@
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
distributed_available,
rank_zero_info,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.seed import reset_seed

if _TORCH_GREATER_EQUAL_1_8:
Expand Down Expand Up @@ -312,7 +318,7 @@ def __recover_child_process_weights(self, best_path, last_path):
self.lightning_module.load_state_dict(ckpt)

def barrier(self, *args, **kwargs) -> None:
if not torch.distributed.is_initialized():
if not distributed_available():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed
from torch.optim.lr_scheduler import _LRScheduler, Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import group, rank_zero_only, ReduceOp
from pytorch_lightning.utilities.distributed import distributed_available, group, rank_zero_only, ReduceOp

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand Down Expand Up @@ -125,7 +124,7 @@ def start_predicting(self, trainer):
self.join()

def barrier(self, *args, **kwargs):
if torch.distributed.is_initialized():
if distributed_available():
self.join()

def broadcast(self, obj: object, src: int = 0) -> object:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,4 @@ def teardown(self) -> None:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,4 @@ def teardown(self) -> None:
# GPU teardown
self.lightning_module.cpu()
# clean up memory
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()
torch.cuda.empty_cache()
14 changes: 3 additions & 11 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
DeviceType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
Expand Down Expand Up @@ -68,8 +62,7 @@ def resume_start(self) -> None:
return

# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()
torch.cuda.empty_cache()

# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
Expand All @@ -87,8 +80,7 @@ def resume_end(self) -> None:
self._loaded_checkpoint = dict()

# clear cache after restore
if self.trainer._device_type == DeviceType.GPU:
torch.cuda.empty_cache()
torch.cuda.empty_cache()

# wait for all to catch up
self.trainer.training_type_plugin.barrier("CheckpointConnector.resume_end")
Expand Down
15 changes: 7 additions & 8 deletions pytorch_lightning/utilities/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,10 @@ def is_out_of_cpu_memory(exception):
def garbage_collection_cuda():
"""Garbage collection Torch (CUDA) memory."""
gc.collect()
if torch.cuda.is_available():
try:
# This is the last thing that should cause an OOM error, but seemingly it can.
torch.cuda.empty_cache()
except RuntimeError as exception:
if not is_oom_error(exception):
# Only handle OOM errors
raise
try:
# This is the last thing that should cause an OOM error, but seemingly it can.
torch.cuda.empty_cache()
except RuntimeError as exception:
if not is_oom_error(exception):
# Only handle OOM errors
raise