Skip to content

Commit

Permalink
Refactor checkpointing, bring back legacy sharded checkpointing as th…
Browse files Browse the repository at this point in the history
…e default (#316)

Co-authored-by: Iz Beltagy <iz@beltagy.net>
  • Loading branch information
epwalsh and ibeltagy committed Oct 8, 2023
1 parent fed4cf3 commit d2e84fe
Show file tree
Hide file tree
Showing 8 changed files with 483 additions and 154 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pip-wheel-metadata/
.vscode/
/*.iml
pyrightconfig.json
.ruff.toml


# jupyter notebooks
Expand Down
274 changes: 258 additions & 16 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,251 @@
"""
Custom distributed checkpointing.
"""

import io
import logging
import pickle
import shutil
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Tuple, cast

import torch
import torch.distributed.checkpoint as dist_cp
from packaging import version
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
from torch.distributed.checkpoint.planner import LoadItemType
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.checkpoint.planner import LoadItemType, ReadItem
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import StateDictType
from torch.distributed.fsdp.api import (
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
)
from torch.futures import Future

from .aliases import PathOrStr
from .util import get_bytes_range, resource_path, upload
from .optim import Optimizer, fix_optim_state_dict
from .util import (
barrier,
default_thread_count,
dir_is_empty,
get_bytes_range,
get_fs_local_rank,
resource_path,
upload,
)

__all__ = ["RemoteFileSystemWriter", "RemoteFileSystemReader"]
__all__ = [
"save_fsdp_model_and_optim_state",
"load_fsdp_model_and_optim_state",
"load_fsdp_optim_state",
"save_state_dict",
"load_state_dict",
"load_model_state",
"RemoteFileSystemWriter",
"RemoteFileSystemReader",
]


log = logging.getLogger(__name__)

MODEL_AND_OPTIM_FOLDER = "model_and_optim"


def save_fsdp_model_and_optim_state(
checkpoint_dir: PathOrStr,
fsdp_model: FSDP,
optim: Optimizer,
*,
upload_to: Optional[str] = None,
save_overwrite: bool = False,
):
"""
Use this to save a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
functions. This should be used during distributed training and should be called by all ranks.
:param checkpoint_dir: The directory to save to.
:param fsdp_model: The FSDP model.
:param optim: The FSDP model's optimizer.
:param upload_to: Optional, a remote "directory" to upload the checkpoint files to.
:param save_overwrite: Overwrite existing files.
:raises FileExistsError: If a model and optim checkpoint already exists in ``checkpoint_dir`` and ``save_overwrite=False``.
"""
checkpoint_dir = Path(checkpoint_dir)
target_dir = checkpoint_dir / MODEL_AND_OPTIM_FOLDER
if save_overwrite:
if get_fs_local_rank() == 0:
shutil.rmtree(target_dir, ignore_errors=True)
elif not dir_is_empty(target_dir):
raise FileExistsError(target_dir)
barrier()
if get_fs_local_rank() == 0:
target_dir.mkdir(exist_ok=True, parents=True)
barrier()
with FSDP.state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
):
model_and_optim_state = {
"model": fsdp_model.state_dict(),
"optim": FSDP.optim_state_dict(fsdp_model, optim),
}
dist_cp.save_state_dict(
model_and_optim_state,
RemoteFileSystemWriter(
target_dir,
upload_to=None if upload_to is None else f"{upload_to.rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}",
save_overwrite=save_overwrite,
),
)


def load_fsdp_model_and_optim_state(
checkpoint_dir: PathOrStr,
fsdp_model: FSDP,
optim: Optimizer,
*,
local_cache: Optional[PathOrStr] = None,
load_optimizer_state: bool = True,
):
"""
Use this to load a state dict for an FSDP model and its optimizer via :module:`torch.distributed.checkpoint`
functions. This should be used during distributed training and should be called by all ranks.
:param checkpoint_dir: The checkpoint directory to load from. This can be a local or remote directory.
:param fsdp_model: The FSDP model.
:param optim: The FSDP model's optimizer.
:param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
remote "directory" but there might be a cached version of the same artifacts.
:param load_optimizer_state: Set to ``False`` to skip loading the optimizer state.
:raises FileNotFoundError: If the ``checkpoint_dir`` doesn't contain a model and optimizer checkpoint.
"""
load_path = str(checkpoint_dir).rstrip("/")
local_cache = None if local_cache is None else Path(local_cache)
with FSDP.state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(offload_to_cpu=True),
optim_state_dict_config=ShardedOptimStateDictConfig(offload_to_cpu=True),
):
# Load the model state dict in place.
log.info("Loading model state...")
model_state = {"model": fsdp_model.state_dict()}
dist_cp.load_state_dict(
model_state,
RemoteFileSystemReader(
f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
),
)
fsdp_model.load_state_dict(model_state["model"])

if not load_optimizer_state:
return

# Load optim state dict in place.
log.info("Loading sharded optimizer state...")
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=model_state["model"],
optimizer_key="optim",
storage_reader=RemoteFileSystemReader(
f"{load_path}/{MODEL_AND_OPTIM_FOLDER}",
local_cache=None if local_cache is None else local_cache / MODEL_AND_OPTIM_FOLDER,
),
)
del model_state
torch.cuda.empty_cache()
load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])


def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[str, Any]):
log.info("Flattening sharded optimizer state...")
# NOTE: Careful! The order of the these arguments has changed from 2.0 to 2.1... ¯\_(ツ)_/¯
if version.parse(torch.__version__) < version.parse("2.1.0"):
flattened_osd = FSDP.optim_state_dict_to_load(optim_state, fsdp_model, optim) # type: ignore
else:
flattened_osd = FSDP.optim_state_dict_to_load(fsdp_model, optim, optim_state) # type: ignore
del optim_state
log.info("Loading flattened optimizer state...")
# Put optim state on CPU since `Optimizer.load_state_dict()` will create a deepcopy of the whole state dict,
# which takes up unnecessary GPU memory.
for state in flattened_osd["state"].values():
for k in state.keys():
v = state[k]
if isinstance(v, torch.Tensor):
state[k] = v.to(device="cpu")
torch.cuda.empty_cache()
optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))


def save_state_dict(
checkpoint_dir: PathOrStr,
fname: str,
state_dict: Dict[str, Any],
*,
upload_to: Optional[str] = None,
save_overwrite: bool = False,
):
"""
Save a regular state dict to the file ``fname`` within ``checkpoint_dir`` using :func:`torch.save()`.
This can be used during distributed training or not. If during distributed training the ``fname`` should be unique
for each rank.
:param checkpoint_dir: The directory to save to.
:param fname: The target file within ``checkpoint_dir`` to save to. This should be a path relative to the ``checkpoint_dir``.
:param state_dict: The state dict to save.
:param upload_to: Optional, a remote "directory" to upload the file to.
:param save_overwrite: Overwrite existing files.
:raises FileExistsError: If the ``fname`` already exists within ``checkpoint_dir`` and ``save_overwrite=False``.
"""
checkpoint_dir = Path(checkpoint_dir)
target_path = checkpoint_dir / fname
if save_overwrite:
target_path.unlink(missing_ok=True)
elif target_path.is_file():
raise FileExistsError(target_path)
barrier()
target_path.parent.mkdir(exist_ok=True, parents=True)
barrier()
torch.save(state_dict, target_path)
if upload_to is not None:
upload_target = f"{upload_to.rstrip('/')}/{fname}"
log.info(f"Uploading {target_path} to {upload_target}...")
upload(target_path, upload_target, save_overwrite=save_overwrite)


def load_state_dict(checkpoint_dir: PathOrStr, fname: str, *, local_cache: Optional[PathOrStr] = None):
"""
Load a regular state dict from the file ``fname`` within ``checkpoint_dir`` using :func:`torch.load()`.
This can be used during distributed training or not.
:param checkpoint_dir: A local or remote checkpoint directory.
:param fname: The target file within the ``checkpoint_dir``. This should be a path relative to the ``checkpoint_dir``.
:param local_cache: A local cache of the checkpoint directory. Use this when the ``checkpoint_dir`` is a
remote "directory" but there might be a cached version of the same artifacts.
:raises FileNotFoundError: If ``fname`` doesn't exist in the ``checkpoint_dir`` or the local cache.
"""
return torch.load(resource_path(str(checkpoint_dir).rstrip("/"), fname, local_cache=local_cache))


def load_model_state(checkpoint_dir: PathOrStr, model: torch.nn.Module):
"""
Load model state from a distributed FSDP model checkpoint created from :func:`save_fsdp_model_and_optim_state()`.
Note that ``model`` should not be wrapped with FSDP.
"""
state_dict = {"model": model.state_dict()}
dist_cp.load_state_dict(
state_dict,
RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/{MODEL_AND_OPTIM_FOLDER}"),
no_dist=True,
)
model.load_state_dict(state_dict["model"])


class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
"""
Expand All @@ -37,16 +258,18 @@ def __init__(
path: PathOrStr,
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
thread_count: Optional[int] = None,
per_thread_copy_ahead: int = 10_000_000,
upload_to: Optional[str] = None,
save_overwrite: bool = False,
) -> None:
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
super().__init__(
path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
thread_count=thread_count or default_thread_count(),
per_thread_copy_ahead=per_thread_copy_ahead,
)
self.upload_to = None if upload_to is None else upload_to.rstrip("/")
Expand Down Expand Up @@ -89,23 +312,40 @@ class RemoteFileSystemReader(dist_cp.StorageReader):
that can read data directly from cloud storage as well as a local directory.
"""

def __init__(self, path: PathOrStr, local_cache: Optional[PathOrStr] = None):
def __init__(
self, path: PathOrStr, *, local_cache: Optional[PathOrStr] = None, thread_count: Optional[int] = None
):
super().__init__()
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
self.path = str(path).rstrip("/")
self.cache = None if local_cache is None else Path(local_cache)
self.thread_count = thread_count or default_thread_count()
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()
self._metadata: Optional[Metadata] = None

def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
if self.cache is not None and (path := self.cache / relative_path).is_file():
return get_bytes_range(path, offset, length)
else:
return get_bytes_range(f"{self.path}/{relative_path}", offset, length)

def _get_content_for_read(self, read_item: ReadItem) -> Tuple[ReadItem, bytes]:
sinfo = self.storage_data[read_item.storage_index]
content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
return (read_item, content)

def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
read_item_content_futures = []
for read_item in plan.items:
read_item_content_futures.append(executor.submit(self._get_content_for_read, read_item))
read_item_content_results = []
for f in as_completed(read_item_content_futures):
read_item_content_results.append(f.result())

# Modified from `FileSystemReader.read_data()`
for read_item in plan.items:
sinfo = self.storage_data[read_item.storage_index]
content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
for read_item, content in read_item_content_results:
bytes = io.BytesIO(content)
bytes.seek(0)
if read_item.type == LoadItemType.BYTE_IO:
Expand All @@ -126,8 +366,10 @@ def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Fut
return fut

def read_metadata(self) -> Metadata:
with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
return pickle.load(metadata_file)
if self._metadata is None:
with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
self._metadata = pickle.load(metadata_file)
return self._metadata

def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
del is_coordinator
Expand Down
16 changes: 13 additions & 3 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,14 @@ class ModelConfig(BaseConfig):
layer_norm_with_affine: bool = True
"""
Whether to include bias and weight parameters for the layer norms.
This only affects layer norms that are immediately followed by a linear layer in the forward pass.
Other layer norms, such as those applied to attention keys and queries, will always include an elementwise
affine transform.
This only affects layer norms that are immediately followed by a linear layer in the forward pass,
so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
to ``False``.
"""

attention_layer_norm_with_affine: bool = True
"""
Toggle affine transform for the QK norms.
"""

max_sequence_length: int = 1024
Expand Down Expand Up @@ -804,6 +809,11 @@ class TrainConfig(BaseConfig):
curve (according to the current learning rate schedule settings), and continues from there.
"""

new_style_checkpoints: bool = False
"""
Whether to use new-style sharded checkpointing or not.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
Loading

0 comments on commit d2e84fe

Please sign in to comment.