Skip to content

Commit

Permalink
New-style checkpointing (again) (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Oct 3, 2023
1 parent 973090f commit 602968a
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 86 deletions.
141 changes: 141 additions & 0 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Custom distributed checkpointing.
"""

import io
import logging
import pickle
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Optional, cast

import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.filesystem import WriteResult, _StorageInfo
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
from torch.distributed.checkpoint.planner import LoadItemType
from torch.futures import Future

from .aliases import PathOrStr
from .util import get_bytes_range, resource_path, upload

__all__ = ["RemoteFileSystemWriter", "RemoteFileSystemReader"]


log = logging.getLogger(__name__)


class RemoteFileSystemWriter(dist_cp.FileSystemWriter):
"""
A subclass of :class:`~torch.distributed.checkpoint.FileSystemWriter` that can upload files
directly to a cloud bucket when ``upload_to`` is specified.
"""

def __init__(
self,
path: PathOrStr,
single_file_per_rank: bool = True,
sync_files: bool = True,
thread_count: int = 1,
per_thread_copy_ahead: int = 10_000_000,
upload_to: Optional[str] = None,
save_overwrite: bool = False,
) -> None:
super().__init__(
path,
single_file_per_rank=single_file_per_rank,
sync_files=sync_files,
thread_count=thread_count,
per_thread_copy_ahead=per_thread_copy_ahead,
)
self.upload_to = None if upload_to is None else upload_to.rstrip("/")
self.save_overwrite = save_overwrite

def write_data(
self,
plan: dist_cp.SavePlan,
planner: dist_cp.SavePlanner,
) -> Future[List[WriteResult]]:
fut = super().write_data(plan, planner)
if self.upload_to is not None:
files_to_upload = set()
for write_result in fut.wait():
files_to_upload.add(write_result.storage_data.relative_path)

with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for fname in files_to_upload:
source = self.path / fname
target = f"{self.upload_to}/{fname}"
log.info(f"Uploading {source} to {target}...")
futures.append(executor.submit(upload, source, target, save_overwrite=self.save_overwrite))
for f in as_completed(futures):
f.result()
return fut

def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
super().finish(metadata, results)
if self.upload_to is not None:
source = self.path / ".metadata"
target = f"{self.upload_to}/.metadata"
log.info(f"Uploading {source} to {target}...")
upload(source, target, save_overwrite=self.save_overwrite)


class RemoteFileSystemReader(dist_cp.StorageReader):
"""
A :class:`~torch.distributed.checkpoint.StorageReader` based on :class:`~torch.distributed.checkpoint.FileSystemReader`
that can read data directly from cloud storage as well as a local directory.
"""

def __init__(self, path: PathOrStr, local_cache: Optional[PathOrStr] = None):
super().__init__()
self.path = str(path).rstrip("/")
self.cache = None if local_cache is None else Path(local_cache)
self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict()

def _get_bytes(self, relative_path: str, offset: int, length: int) -> bytes:
if self.cache is not None and (path := self.cache / relative_path).is_file():
return get_bytes_range(path, offset, length)
else:
return get_bytes_range(f"{self.path}/{relative_path}", offset, length)

def read_data(self, plan: dist_cp.LoadPlan, planner: dist_cp.LoadPlanner) -> Future[None]:
# Modified from `FileSystemReader.read_data()`
for read_item in plan.items:
sinfo = self.storage_data[read_item.storage_index]
content = self._get_bytes(sinfo.relative_path, sinfo.offset, sinfo.length)
bytes = io.BytesIO(content)
bytes.seek(0)
if read_item.type == LoadItemType.BYTE_IO:
planner.load_bytes(read_item, bytes)
else:
tensor = cast(torch.Tensor, torch.load(bytes, map_location="cpu"))
tensor = narrow_tensor_by_index(tensor, read_item.storage_offsets, read_item.lengths)
target_tensor = planner.resolve_tensor(read_item).detach()

assert (
target_tensor.size() == tensor.size()
), f"req {read_item.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
target_tensor.copy_(tensor)
planner.commit_tensor(read_item, target_tensor)

fut: Future = Future()
fut.set_result(None)
return fut

def read_metadata(self) -> Metadata:
with resource_path(self.path, ".metadata", local_cache=self.cache).open("rb") as metadata_file:
return pickle.load(metadata_file)

def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
del is_coordinator
self.storage_data = metadata.storage_data
assert self.storage_data is not None

def prepare_local_plan(self, plan: dist_cp.LoadPlan) -> dist_cp.LoadPlan:
return plan

def prepare_global_plan(self, global_plan: List[dist_cp.LoadPlan]) -> List[dist_cp.LoadPlan]:
return global_plan
2 changes: 1 addition & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ class TrainConfig(BaseConfig):

load_path: Optional[str] = None
"""
The path to a (sharded) training checkpoint to restore/resume from.
The path to a training checkpoint to restore/resume from.
"""

max_duration: int = 10000
Expand Down
64 changes: 50 additions & 14 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import logging
import math
import os
from abc import abstractmethod
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, cast

Expand All @@ -20,7 +19,13 @@

from .aliases import PathOrStr
from .beam_search import BeamSearch, Constraint, FinalSequenceScorer, Sampler
from .config import ActivationType, BlockType, LayerNormType, ModelConfig
from .config import (
ActivationType,
BlockType,
CheckpointType,
LayerNormType,
ModelConfig,
)
from .exceptions import OlmoConfigurationError
from .initialization import init_weights

Expand Down Expand Up @@ -1060,27 +1065,58 @@ def step(
)

@classmethod
def from_checkpoint(cls, checkpoint_dir: PathOrStr, device: str = "cpu") -> Olmo:
def from_checkpoint(
cls, checkpoint_dir: PathOrStr, device: str = "cpu", checkpoint_type: Optional[CheckpointType] = None
) -> Olmo:
"""
Load an OLMo model from a checkpoint.
"""
from cached_path import cached_path
from .util import resource_path

# Guess checkpoint type.
if checkpoint_type is None:
try:
if resource_path(checkpoint_dir, "model.pt").is_file():
checkpoint_type = CheckpointType.unsharded
else:
checkpoint_type = CheckpointType.sharded
except FileNotFoundError:
checkpoint_type = CheckpointType.sharded

# Load config.
config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml"))
config_path = resource_path(checkpoint_dir, "config.yaml")
model_config = ModelConfig.load(config_path, key="model", validate_paths=False)

# Initialize model (always on CPU to start with so we don't run out of GPU memory).
model_config.init_device = "cpu"
model = Olmo(model_config)
model.config.init_device = device
if checkpoint_type == CheckpointType.unsharded:
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
model_config.init_device = "cpu"
model = Olmo(model_config)

# Load state dict directly to target device.
state_dict_path = resource_path(checkpoint_dir, "model.pt")
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(model._make_state_dict_compatible(state_dict))
model = model.to(torch.device(device))
else:
from torch.distributed.checkpoint import load_state_dict

# Load state dict directly to target device.
state_dict_path = cached_path(os.path.join(checkpoint_dir, "model.pt"))
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(model._make_state_dict_compatible(state_dict))
from .checkpoint import RemoteFileSystemReader

# Initialize model on target device. In this case the state dict is loaded in-place
# so it's not necessary to start on CPU if the target device is a GPU.
model_config.init_device = device
model = Olmo(model_config)

# Load state dict in place.
state_dict = {"model": model.state_dict()}
load_state_dict(
state_dict,
RemoteFileSystemReader(f"{str(checkpoint_dir).rstrip('/')}/model_and_optim"),
no_dist=True,
)
model.load_state_dict(state_dict["model"])

return model.to(torch.device(device)).eval()
return model.eval()

def _make_state_dict_compatible(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# For backwards compatibility prior to fixing https://github.com/allenai/LLM/issues/222
Expand Down

0 comments on commit 602968a

Please sign in to comment.