Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support S3 checkpointing for the torch strategy in distributed checkpointing #748

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions megatron/core/dist_checkpointing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
""" Module for managing distributed checkpoints metadata. """

import json
import os
from cloudpathlib import AnyPath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume it requires extra dependencies.
Can we make it optional?

from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional

CONFIG_FNAME = 'metadata.json'
Expand Down Expand Up @@ -33,7 +34,7 @@ class CheckpointingConfig:
common_backend_version: int = 1


def check_is_distributed_checkpoint(checkpoint_dir):
def check_is_distributed_checkpoint(checkpoint_dir: str):
""" Checks if `metadata.json` exists in the checkpoint and is a valid config.

Args:
Expand All @@ -54,7 +55,7 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
config_path = AnyPath(os.path.join(checkpoint_dir, CONFIG_FNAME))
if not config_path.exists():
return None
with config_path.open() as f:
Expand All @@ -72,6 +73,6 @@ def save_config(config: CheckpointingConfig, checkpoint_dir: str):
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
config_path = AnyPath(os.path.join(checkpoint_dir, CONFIG_FNAME))
with config_path.open('w') as f:
json.dump(asdict(config), f)
60 changes: 39 additions & 21 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
"""

import logging
import io
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from cloudpathlib import AnyPath, S3Path
from torch.serialization import MAP_LOCATION
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -57,6 +59,22 @@
logger = logging.getLogger(__name__)


def _save(obj: object, path: AnyPath):
fileobj = io.BytesIO()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried 2-stage writing might affect performance in the baseline scenario (no cloud).
Can we make it optional as well?

torch.save(obj, fileobj)
fileobj.seek(0)
with path.open('wb') as f:
f.write(fileobj.read())


def _load(path: AnyPath, map_location: MAP_LOCATION = None):
fileobj = io.BytesIO()
with path.open('rb') as f:
fileobj.write(f.read())
fileobj.seek(0)
return torch.load(fileobj, map_location)


def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
Expand Down Expand Up @@ -92,7 +110,7 @@ def load(

sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)

checkpoint_dir = Path(checkpoint_dir)
checkpoint_dir = AnyPath(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
Expand Down Expand Up @@ -121,7 +139,7 @@ def load(
if validate_access_integrity:
validate_sharding_integrity(nested_values(sharded_state_dict))

loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
loaded_state_dict = sharded_strategy.load(sharded_state_dict, str(checkpoint_dir))

loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)

Expand All @@ -140,7 +158,7 @@ def _verify_checkpoint_and_load_strategy(
if compatible with the checkpoint content. If None, the default load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
if not AnyPath(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')

saved_config = maybe_load_config(checkpoint_dir)
Expand All @@ -161,31 +179,31 @@ def _verify_checkpoint_and_load_strategy(


# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
def load_common_state_dict(checkpoint_dir: AnyPath) -> StateDict:
""" Load common (non-sharded) objects state dict from the checkpoint.

Args:
checkpoint_dir (Path): checkpoint directory
checkpoint_dir (AnyPath): checkpoint directory

Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
load_path = AnyPath(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
return _load(load_path, map_location='cpu')
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e


def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: AnyPath):
""" Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.

Args:
sharded_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
checkpoint_dir (AnyPath): checkpoint directory

Returns:
None: state dict is modified in place
Expand All @@ -198,7 +216,7 @@ def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(load_path)
loaded_obj = _load(load_path)
except FileNotFoundError as e:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
Expand Down Expand Up @@ -232,7 +250,7 @@ def load_tensors_metadata(
given, a default for a given backend is used.
"""
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
return sharded_strategy.load_tensors_metadata(str(checkpoint_dir))


def load_plain_tensors(checkpoint_dir: str):
Expand Down Expand Up @@ -277,10 +295,10 @@ def save(
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
"""
checkpoint_dir = Path(checkpoint_dir)
checkpoint_dir = AnyPath(checkpoint_dir)

if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
if (not isinstance(checkpoint_dir, S3Path)) and (not checkpoint_dir.exists()):
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
Expand Down Expand Up @@ -313,20 +331,20 @@ def save(
sharded_state_dict, checkpoint_dir, validate_access_integrity
)

sharded_strategy.save(sharded_state_dict, checkpoint_dir)
sharded_strategy.save(sharded_state_dict, str(checkpoint_dir))
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), str(checkpoint_dir)
)
torch.distributed.barrier()


# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
state_dict: StateDict, checkpoint_dir: AnyPath, validate_consistency: bool = False
):
if torch.distributed.get_rank() == 0:
torch.save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
_save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
Expand All @@ -337,7 +355,7 @@ def _save_common_dict(


def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
state_dict: StateDict, checkpoint_dir: AnyPath, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
Expand All @@ -346,8 +364,8 @@ def _extract_and_save_sharded_objects(
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
save_path.parent.mkdir(exist_ok=True)
_save(sh_obj.data, save_path)
return state_dict


Expand Down
11 changes: 5 additions & 6 deletions megatron/core/dist_checkpointing/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional

from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict
Expand Down Expand Up @@ -77,19 +76,19 @@ class LoadCommonStrategy(LoadStrategyBase):
""" Load strategy for common (non-sharded) objects """

@abstractmethod
def load(self, checkpoint_dir: Path):
def load(self, checkpoint_dir: str):
raise NotImplementedError


class LoadShardedStrategy(LoadStrategyBase):
""" Load strategy for sharded tensors """

@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: str):
raise NotImplementedError

@abstractmethod
def load_tensors_metadata(self, checkpoint_dir: Path):
def load_tensors_metadata(self, checkpoint_dir: str):
"""Load tensors metadata from the checkpoint.

Returns a dictionary similar to a sharded state dict, but note that
Expand All @@ -108,13 +107,13 @@ class SaveCommonStrategy(SaveStrategyBase):
""" Save strategy for common (non-sharded) objects """

@abstractmethod
def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
def save(self, common_state_dict: StateDict, checkpoint_dir: str):
raise NotImplementedError


class SaveShardedStrategy(SaveStrategyBase):
""" Save strategy for sharded tensors """

@abstractmethod
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: str):
raise NotImplementedError
18 changes: 9 additions & 9 deletions megatron/core/dist_checkpointing/strategies/tensorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from itertools import starmap
from logging import getLogger
from pathlib import Path

import os
import tensorstore as ts
import torch

Expand All @@ -30,7 +30,7 @@ def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device

def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: str):
if torch.distributed.get_rank() == 0:
print(f'Loading distributed checkpoint with {self.__class__.__name__}')
if self.load_directly_on_device:
Expand All @@ -43,7 +43,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict

def load_tensors_metadata(self, checkpoint_dir: Path):
def load_tensors_metadata(self, checkpoint_dir: str):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
Expand Down Expand Up @@ -73,7 +73,7 @@ def _merge_slice(dim_slice, dim_size):

def _load_from_array(
sharded_tensor: ShardedTensor,
checkpoint_dir: Path,
checkpoint_dir: str,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
):
Expand All @@ -86,9 +86,9 @@ def _load_from_array(
return ten


def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: str):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
arr = open_ts_array(os.path.join(checkpoint_dir, sharded_tensor.key))
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
Expand All @@ -108,16 +108,16 @@ def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
return x


def open_ts_array(arr_path: Path):
def open_ts_array(arr_path: str):
"""Opens a Zarr file array with Tensorstore with basic setting.

Arguments:
arr_path (Path): path to a Zarr (Tensorstore) array
arr_path (str): path to a Zarr (Tensorstore) array
"""
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'driver': 'file',
'path': str(arr_path),
'path': arr_path,
}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
Expand Down
Loading