Skip to content

Commit

Permalink
fix checkpointing to remote file paths (#2925)
Browse files Browse the repository at this point in the history
  • Loading branch information
f4hy committed Aug 12, 2020
1 parent d13e5c9 commit 56396ab
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 15 deletions.
15 changes: 10 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import gfile, makedirs, is_remote_path


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -122,10 +122,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
if not is_remote_path(filepath): # dont normalize remote paths
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
if not gfile.exists(self.dirpath):
makedirs(self.dirpath)
makedirs(self.dirpath) # calls with exist_ok
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -174,7 +174,12 @@ def _del_model(self, filepath):
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)
if is_remote_path(filepath):
log.warning("Unable to remove stale checkpoints due to running gfile in compatibility mode."
" Please install tensorflow to run gfile in full mode"
" if writing checkpoints to remote locations")
else:
os.remove(filepath)

def _save_model(self, filepath, trainer, pl_module):

Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def train_fx(trial_hparams, cluster_manager, _):
"""

import io
import os
import re
from abc import ABC, abstractmethod
Expand All @@ -146,6 +147,7 @@ def train_fx(trial_hparams, cluster_manager, _):
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.cloud_io import cloud_open


try:
Expand Down Expand Up @@ -435,10 +437,13 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
bytesbuffer = io.BytesIO()
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(model.state_dict(), last_path, _use_new_zipfile_serialization=False)
torch.save(model.state_dict(), bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(model.state_dict(), last_path)
torch.save(model.state_dict(), bytesbuffer)
with cloud_open(last_path, 'wb') as f:
f.write(bytesbuffer.getvalue())
mp_queue.put(last_path)

def save_spawn_weights(self, model):
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -880,7 +881,7 @@ def default_root_dir(self) -> str:
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
if "://" in str(self._default_root_dir):
if is_remote_path(self._default_root_dir):
# it is a remote uri, use as is
return self._default_root_dir
return os.path.normpath(self._default_root_dir)
Expand All @@ -891,7 +892,7 @@ def weights_save_path(self) -> str:
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
if "://" in str(self._weights_save_path):
if is_remote_path(self._weights_save_path):
# it is a remote uri, use as is
return self._weights_save_path
return os.path.normpath(self._weights_save_path)
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"""

import io
import os
import re
import signal
Expand All @@ -104,7 +105,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import cloud_open, gfile, makedirs

try:
import torch_xla
Expand Down Expand Up @@ -269,15 +270,16 @@ def _atomic_save(self, checkpoint, filepath: str):
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)
torch.save(checkpoint, bytesbuffer)
with cloud_open(filepath, 'wb') as f:
f.write(bytesbuffer.getvalue())

def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ def load(path_or_url: str, map_location=None):
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)


def is_remote_path(path: pathlike):
"""Determine if a path is a local path or a remote path like s3://bucket/path
This should catch paths like s3:// hdfs:// and gcs://
"""
return "://" in str(path)


def modern_gfile():
"""Check the version number of tensorboard.
Expand Down Expand Up @@ -61,6 +69,7 @@ def cloud_open(path: pathlike, mode: str, newline: str = None):

def makedirs(path: pathlike):
if hasattr(gfile, "makedirs") and modern_gfile():
return gfile.makedirs(str(path))
if not gfile.exists(str(path)):
return gfile.makedirs(str(path))
# otherwise minimal dependencies are installed and only local files will work
return os.makedirs(path, exist_ok=True)

0 comments on commit 56396ab

Please sign in to comment.