Skip to content

Commit

Permalink
[air] pyarrow.fs persistence (5/n): ray.train.Checkpoint save dir…
Browse files Browse the repository at this point in the history
…ection (ray-project#37888)

This PR:
1. Uses the storage context to upload the new `ray.train.Checkpoint` (from ray-project#37925)
directly from the Train worker.
2. Gets checkpoint reporting to work in the save direction, simplifying the checkpoint handling logic to avoid the Train `CheckpointManager` and use as single, simplified checkpoint manager (from ray-project#37962).
3. Updates the e2e test to check for worker-uploaded checkpoints.

### Follow-ups needed

1. `Trial` path resolution is still messed up (using the legacy path), causing some issues with the custom fs test case. That test case skips some assertions at the moment. This fix is up next.
2. Trial restoration is explicitly disabled at the moment. This is up next as well.
3. Artifacts are currently being synced by the driver due to the train worker living on the same node, which is why it passes in the test case. This upload should be done from the worker, and the test case should be updated to check that.
4. The `on_checkpoint` hook for `tune.Callback` takes in a `_TrackedCheckpoint`. Currently, I skip invoking the callbacks -- TBD what to expose to the user callbacks here.
5. Checkpoints cannot be ordered based on auto-filled metrics at the moment, only user specified metrics. Ex: `CheckpointConfig(checkpoint_score_attribute="training_iteration", mode="min")`

Signed-off-by: NripeshN <nn2012@hw.ac.uk>
  • Loading branch information
justinvyu authored and NripeshN committed Aug 15, 2023
1 parent 5b5f393 commit 332c343
Show file tree
Hide file tree
Showing 11 changed files with 476 additions and 100 deletions.
2 changes: 1 addition & 1 deletion python/ray/train/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ py_test(

py_test(
name = "test_new_persistence",
size = "small",
size = "medium",
srcs = ["tests/test_new_persistence.py"],
tags = ["team:ml", "exclusive"],
deps = [":train_lib", ":conftest"]
Expand Down
43 changes: 43 additions & 0 deletions python/ray/train/_internal/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,46 @@ def _set_legacy_checkpoint_uri(self, uri: str):
"""
self.legacy_checkpoint_uri = uri

def new_checkpoint(self, checkpoint):
from ray.train._checkpoint import Checkpoint as NewCheckpoint

if not isinstance(checkpoint, NewCheckpoint):
raise ValueError(
"You must pass a `ray.train.checkpoint.Checkpoint` "
"object to `train.report`. `ray.air.Checkpoint` is deprecated."
)

# Persist the reported checkpoint files to storage.
persisted_checkpoint = self.storage.persist_current_checkpoint(checkpoint)

self.loaded_checkpoint = persisted_checkpoint

metadata = self._auto_fill_checkpoint_metrics({})

# Save the rank of the worker that created this checkpoint.
metadata.update({CHECKPOINT_RANK_KEY: self.world_rank})

result = TrainingResult(
type=TrainingResultType.CHECKPOINT,
data=persisted_checkpoint,
metadata=metadata,
)

# Add result to a thread-safe queue.
self.result_queue.put(result, block=True)

# Acquire lock to stop the training thread until
# checkpoint has been processed.
self.continue_lock.acquire()

def new_report(self, metrics: Dict, checkpoint=None) -> None:
if checkpoint:
self.new_checkpoint(checkpoint)

# TODO(justinvyu): Unify checkpoint / report logic to just report a single
# (metrics, Checkpoint) result for the consumer to handle.
self._report_legacy(**metrics)

def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None:
# TODO(xwjiang): tons of optimizations.

Expand All @@ -457,6 +497,9 @@ def report(self, metrics: Dict, checkpoint: Optional[Checkpoint] = None) -> None
"store your Torch objects."
)

if _use_storage_context():
return self.new_report(metrics, checkpoint=checkpoint)

if checkpoint:
self.checkpoint(checkpoint)
self._report_legacy(**metrics)
Expand Down
54 changes: 53 additions & 1 deletion python/ray/train/_internal/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pathlib import Path
import shutil
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING

try:
import fsspec
Expand All @@ -30,6 +30,9 @@
from ray.tune.syncer import Syncer, SyncConfig, _BackgroundSyncer
from ray.tune.result import _get_defaults_results_dir

if TYPE_CHECKING:
from ray.train._checkpoint import Checkpoint


logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -472,6 +475,55 @@ def _check_validation_file(self):
"to the configured storage path."
)

def persist_current_checkpoint(self, checkpoint: "Checkpoint") -> "Checkpoint":
"""Persists a given checkpoint to the current checkpoint path on the filesystem.
"Current" is defined by the `current_checkpoint_index` attribute of the
storage context.
This method copies the checkpoint files to the storage location,
drops a marker at the storage path to indicate that the checkpoint
is completely uploaded, then deletes the original checkpoint directory.
For example, the original directory is typically a local temp directory.
Args:
checkpoint: The checkpoint to persist to (fs, checkpoint_fs_path).
Returns:
Checkpoint: A Checkpoint pointing to the persisted checkpoint location.
"""
# TODO(justinvyu): Fix this cyclical import.
from ray.train._checkpoint import Checkpoint

logger.debug(
"Copying checkpoint files to storage path:\n"
"({source_fs}, {source}) -> ({dest_fs}, {destination})".format(
source=checkpoint.path,
destination=self.checkpoint_fs_path,
source_fs=checkpoint.filesystem,
dest_fs=self.storage_filesystem,
)
)
self.storage_filesystem.create_dir(self.checkpoint_fs_path)
_pyarrow_fs_copy_files(
source=checkpoint.path,
destination=self.checkpoint_fs_path,
source_filesystem=checkpoint.filesystem,
destination_filesystem=self.storage_filesystem,
)

# Delete local checkpoint files.
# TODO(justinvyu): What if checkpoint.path == self.checkpoint_fs_path?
# TODO(justinvyu): What if users don't want to delete the local checkpoint?
checkpoint.filesystem.delete_dir(checkpoint.path)

uploaded_checkpoint = Checkpoint(
filesystem=self.storage_filesystem,
path=self.checkpoint_fs_path,
)
logger.debug(f"Checkpoint successfully created at: {uploaded_checkpoint}")
return uploaded_checkpoint

@property
def experiment_path(self) -> str:
"""The path the experiment directory, where the format matches the
Expand Down
17 changes: 15 additions & 2 deletions python/ray/train/data_parallel_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ray._private.thirdparty.tabulate.tabulate import tabulate

import ray
from ray import tune
from ray import train, tune
from ray.air.checkpoint import Checkpoint
from ray.air._internal.checkpointing import add_preprocessor_to_checkpoint
from ray.air.config import DatasetConfig, RunConfig, ScalingConfig, CheckpointConfig
Expand All @@ -17,6 +17,7 @@
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
from ray.train._internal.checkpoint import TuneCheckpointManager
from ray.train._internal.data_config import DataConfig, _LegacyDataConfigWrapper
from ray.train._internal.storage import _use_storage_context
from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset
Expand Down Expand Up @@ -429,7 +430,19 @@ def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
# TODO(ml-team): add ability to report results from multiple workers.
first_worker_results = results[0]
tune.report(**first_worker_results)
if _use_storage_context():
assert (
isinstance(first_worker_results, tuple)
and len(first_worker_results) == 2
)
metrics, checkpoint = first_worker_results
logger.debug(
"Report (metrics, checkpoint) to the Tune session:\n"
f" metrics={metrics}\n checkpoint={checkpoint}"
)
train.report(metrics, checkpoint=checkpoint)
else:
tune.report(**first_worker_results)

def training_loop(self) -> None:
scaling_config = self._validate_scaling_config(self.scaling_config)
Expand Down
Loading

0 comments on commit 332c343

Please sign in to comment.