Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Storage Cleaner] Unsharding improvements #483

Merged
merged 6 commits into from
Mar 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import boto3.session
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
import omegaconf
import torch
import wandb
from boto3.s3.transfer import TransferConfig
Expand Down Expand Up @@ -622,6 +623,8 @@ class DeleteBadRunsConfig(StorageCleanerConfig):
@dataclass
class UnshardCheckpointsConfig(StorageCleanerConfig):
latest_checkpoint_only: bool
delete_sharded_checkpoints: bool
checkpoint_num: Optional[int]


@dataclass
Expand Down Expand Up @@ -765,9 +768,13 @@ def delete_bad_runs(run_paths: List[str], config: DeleteBadRunsConfig):
shutil.rmtree(config.temp_dir)


def _is_sharded_checkpoint_dir(directory: str) -> bool:
def _is_checkpoint_dir(directory: str) -> bool:
storage = _get_storage_adapter_for_path(directory)
return storage.is_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None
return storage.is_dir(directory) and re.match(r"step\d+(-unsharded)?$", Path(directory).name) is not None


def _is_sharded_checkpoint_dir(directory: str) -> bool:
return _is_checkpoint_dir(directory) and re.match(r"step\d+$", Path(directory).name) is not None


def _get_checkpoint_number(checkpoint_dir: str) -> int:
Expand All @@ -781,17 +788,31 @@ def _get_checkpoint_number(checkpoint_dir: str) -> int:


def _get_sharded_checkpoint_dirs(
run_dir_storage: StorageAdapter, run_dir: str, run_dir_or_archive: str, latest_checkpoint_only: bool
run_dir_storage: StorageAdapter,
run_dir: str,
run_dir_or_archive: str,
latest_checkpoint_only: bool,
checkpoint_num: Optional[int] = None,
) -> List[str]:
run_subdir_names = run_dir_storage.list_dirs(run_dir)
run_subdirectories = list(map(lambda dir_name: os.path.join(run_dir, dir_name), run_subdir_names))
sharded_checkpoint_directories = list(filter(_is_sharded_checkpoint_dir, run_subdirectories))

if latest_checkpoint_only and checkpoint_num is not None:
raise ValueError("Cannot set both 'latest_checkpoint_only' and 'checkpoint_num'")

if latest_checkpoint_only:
latest_checkpoint_directory = max(sharded_checkpoint_directories, default=None, key=_get_checkpoint_number)
sharded_checkpoint_directories = (
[latest_checkpoint_directory] if latest_checkpoint_directory is not None else []
)
elif checkpoint_num is not None:
sharded_checkpoint_directories = [
sharded_checkpoint_dir
for sharded_checkpoint_dir in sharded_checkpoint_directories
if _get_checkpoint_number(sharded_checkpoint_dir) == checkpoint_num
]
assert len(sharded_checkpoint_directories) <= 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this script warn anywhere if there are no matching checkpoints here?

Copy link
Contributor Author

@2015aroras 2015aroras Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will say (as an info log, line 817) there are no matching directories and then will exit gracefully. There will be ~10 lines of logs from the whole program in this scenario, so it won't be a needle in a haystack.


log.info(
"Found %d sharded checkpoint directories for %s", len(sharded_checkpoint_directories), run_dir_or_archive
Expand Down Expand Up @@ -844,13 +865,29 @@ def _unshard_checkpoint(
sharding_output_dir = local_storage.create_temp_dir(directory=unsharding_config.temp_dir)

try:
config = TrainConfig.load(Path(sharding_input_dir) / "config.yaml", validate_paths=False)
sharded_checkpoint_type = config.sharded_checkpointer
# `TrainConfig` is not backwards-compatible with all older checkpoints, so
# we need to load the yaml directly.
raw_config = om.load(str(Path(sharding_input_dir) / "config.yaml"))
assert isinstance(raw_config, omegaconf.DictConfig)

sharded_checkpoint_type_str = raw_config.get("sharded_checkpointer", "torch_legacy")
if sharded_checkpoint_type_str == "legacy":
# At some point, the enum string for ShardedCheckpointerType.torch_legacy was "legacy"
sharded_checkpoint_type_str = "torch_legacy"

sharded_checkpoint_type = ShardedCheckpointerType[sharded_checkpoint_type_str]

# The ShardedCheckpointers require a `TrainConfig` to be passed in, but
# legacy configs are not all compatible with this class. None of the config
# settings are needed for unsharding, so we pass in a dummy config instead.
# This is a hack, but decoupling unsharding for checkpoint saving/loading
# seems like overkill.
dummy_config = TrainConfig.new()
checkpointer: Checkpointer
if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy:
checkpointer = TorchLegacyShardedCheckpointer(config)
checkpointer = TorchLegacyShardedCheckpointer(dummy_config)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
checkpointer = LocalShardedCheckpointer(dummy_config)
else:
raise NotImplementedError(sharded_checkpoint_type)

Expand Down Expand Up @@ -911,11 +948,14 @@ def _unshard_checkpoints(
):
log.info("Starting unsharding checkpoints of run directory or archive %s", run_dir_or_archive)

if config.delete_sharded_checkpoints and _is_archive(run_dir_or_archive, run_storage):
raise ValueError("Cannot delete sharded checkpoints of run archive files")

run_dir = _unarchive_if_archive(run_dir_or_archive, run_storage)
run_dir_storage = _get_storage_adapter_for_path(run_dir)

sharded_checkpoint_directories = _get_sharded_checkpoint_dirs(
run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only
run_dir_storage, run_dir, run_dir_or_archive, config.latest_checkpoint_only, config.checkpoint_num
)
for sharded_checkpoint_directory in sharded_checkpoint_directories:
sharded_checkpoint_dir_name = Path(sharded_checkpoint_directory).name
Expand Down Expand Up @@ -947,6 +987,14 @@ def _unshard_checkpoints(
log.info("Unsharding sharded checkpoint %s to %s", sharded_checkpoint_directory, dest_directory)
_unshard_checkpoint(sharded_checkpoint_directory, dest_directory, run_dir, config)

if config.delete_sharded_checkpoints:
assert run_dir == run_dir_or_archive
if config.dry_run:
log.info("Would delete sharded checkpoint %s", sharded_checkpoint_directory)
else:
log.info("Deleting sharded checkpoint %s", sharded_checkpoint_directory)
run_dir_storage.delete_path(sharded_checkpoint_directory)


def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: UnshardCheckpointsConfig):
storage = _get_storage_adapter_for_path(run_path)
Expand Down Expand Up @@ -1252,6 +1300,8 @@ def perform_operation(args: argparse.Namespace):
dry_run=args.dry_run,
temp_dir=temp_dir,
latest_checkpoint_only=args.latest_checkpoint_only,
delete_sharded_checkpoints=args.delete_sharded_checkpoints,
checkpoint_num=args.checkpoint_num,
)
if args.run_path is not None:
unshard_run_checkpoints(args.run_path, args.dest_dir, unshard_checkpoints_config)
Expand Down Expand Up @@ -1327,6 +1377,18 @@ def _add_unsharding_subparser(subparsers: _SubParsersAction):
action="store_true",
help="If set, only the latest checkpoint of each run (if sharded) is unsharded.",
)
unsharding_runs_parser.add_argument(
"--delete_sharded",
dest="delete_sharded_checkpoints",
action="store_true",
help="If set, deletes sharded checkpoints after they have been successfully unsharded.",
)
unsharding_runs_parser.add_argument(
"--checkpoint_num",
type=int,
default=None,
help="If provided, unsharding is restricted to this checkpoint of the run.",
)


def _add_move_subparser(subparsers: _SubParsersAction):
Expand Down
Loading