Skip to content

Commit

Permalink
Merge branch 'sbak/pyt-dist-async-comm-opt' into 'main'
Browse files Browse the repository at this point in the history
Optimize metadata communication for dist-checkpointing in saving through reuse of cached metadata

See merge request ADLR/megatron-lm!1470
  • Loading branch information
ericharper committed Jul 1, 2024
2 parents 5161b16 + 6421b1c commit 0bc3547
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 48 deletions.
51 changes: 34 additions & 17 deletions megatron/core/dist_checkpointing/strategies/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class AsyncRequest(NamedTuple):
""" Represents an async request that needs to be scheduled for execution.
"""Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
Expand All @@ -32,7 +32,7 @@ class AsyncRequest(NamedTuple):
is_frozen: bool = False

def add_finalize_fn(self, fn: Callable) -> None:
""" Adds a new finalize function to the request.
"""Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
Expand All @@ -46,7 +46,7 @@ def add_finalize_fn(self, fn: Callable) -> None:
self.finalize_fns.append(fn)

def execute_sync(self) -> None:
""" Helper to synchronously execute the request.
"""Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
Expand All @@ -57,7 +57,7 @@ def execute_sync(self) -> None:
finalize_fn()

def freeze(self) -> 'AsyncRequest':
""" Freezes the async request, disallowing adding new finalization functions.
"""Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
Expand All @@ -67,7 +67,7 @@ def freeze(self) -> 'AsyncRequest':


class DistributedAsyncCaller:
""" Wrapper around mp.Process that ensures correct semantic of distributed finalization.
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
Expand All @@ -76,9 +76,13 @@ def __init__(self):
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None

def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple,) -> None:
""" Spawn a process with `async_fn` as the target.
def schedule_async_call(
self,
async_fn: Optional[Callable],
save_args: Tuple,
) -> None:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
Expand All @@ -88,14 +92,27 @@ def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple,) -
"""
if async_fn is None:
return # nothing to do
start_sync = time()
torch.cuda.synchronize()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {end_sync - start_sync} to finish D2H "
)

ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(target=async_fn, args=save_args,)
self.process = ctx.Process(
target=async_fn,
args=save_args,
)
self.process.start()
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} to schedule async ckpt "
)

def is_current_async_call_done(self, blocking=False) -> bool:
""" Check if async save is finished on all ranks.
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Expand Down Expand Up @@ -132,7 +149,7 @@ def is_current_async_call_done(self, blocking=False) -> bool:


class _ActiveAsyncRequest(NamedTuple):
""" Helper to represent an active async call.
"""Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
Expand All @@ -147,7 +164,7 @@ class _ActiveAsyncRequest(NamedTuple):


class AsyncCallsQueue:
""" Manages a queue of async calls.
"""Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
Expand All @@ -158,8 +175,8 @@ def __init__(self):
self.call_idx: int = -1

def schedule_async_request(self, async_request: AsyncRequest) -> int:
""" Start a new async call and add it to a queue of active async calls.
"""Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
Expand All @@ -177,7 +194,7 @@ def schedule_async_request(self, async_request: AsyncRequest) -> int:
return self.call_idx

def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
""" Finalizes all available calls.
"""Finalizes all available calls.
This method must be called on all ranks.
Expand Down Expand Up @@ -206,9 +223,9 @@ def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
return call_idx_finalized

def get_num_unfinalized_calls(self):
""" Get the number of active async calls. """
"""Get the number of active async calls."""
return len(self.async_calls)

def close(self):
""" Finalize all calls upon closing. """
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def gen_file():
self.write_results = ctx.Manager().dict()
else:
self.write_results = {}
logger.debug(f"D2H and push, time: {time() - start}")
end = time()
logger.debug(f"D2H and push, time: {end - start}")

def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]:
"""
Expand Down Expand Up @@ -197,7 +198,11 @@ def write_preloaded_data(
f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}"
)

def write_data(self, plan: SavePlan, planner: SavePlanner,) -> Future[List[WriteResult]]:
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
) -> Future[List[WriteResult]]:
raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')

def retrieve_write_results(self) -> List[WriteResult]:
Expand Down
15 changes: 13 additions & 2 deletions megatron/core/dist_checkpointing/strategies/fully_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,23 @@ def __init__(

self.cached_distribution: Optional[SaveLoadDistribution] = None

def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def async_save(
self,
sharded_state_dict: ShardedStateDict,
checkpoint_dir: Path,
):
if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
)
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)

def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
def save(
self,
sharded_state_dict: ShardedStateDict,
checkpoint_dir: Path,
):
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.save(sharded_state_dict, checkpoint_dir)

Expand All @@ -120,6 +128,7 @@ def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) ->
Returns: None
"""
start = time()
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* save parallelization')
precomputed_distribution = self.cached_distribution
Expand All @@ -137,6 +146,8 @@ def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) ->
validate_sharding_integrity(nested_values(sharded_state_dict))
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
end = time()
logger.debug(f"parallel save sharding, time: {end - start}")

@property
def can_handle_sharded_objects(self):
Expand Down
52 changes: 40 additions & 12 deletions megatron/core/dist_checkpointing/strategies/state_dict_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.distributed.checkpoint import CheckpointException
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict

if TYPE_CHECKING:
Expand All @@ -27,7 +27,8 @@ def save_state_dict_async_plan(
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
) -> Tuple['FileSystemWriterAsync', Metadata, _DistWrapper]:
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
Expand All @@ -50,55 +51,82 @@ def save_state_dict_async_plan(
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional):
Each object of this tuple will be used in the order as following
cached_central_plan (SavePlan): a globally coordinated save plan
cached in the previous iteration
cached_local_plan (SavePlan): a local plan
cached in the previous iteration
validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict
is consistent over iterations
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize`.
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
"""
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
if cached_ckpt_structure:
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure

rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
dist_wrapper = _DistWrapper(process_group, True, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None

global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan

def local_step():
nonlocal local_plan
assert planner is not None
planner.set_up_planner(state_dict, dist_wrapper.is_coordinator)
storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator)
local_plan = planner.create_local_plan()
if not validated_cache_reuse and local_plan is None:
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan

def global_step(all_local_plans):
nonlocal global_metadata

assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans

# Execute local and global planning
start_plan = time()
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
logger.debug(f"rank: {rank}, plan time: {time() - start_plan}")

if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
end_plan = time()
logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}")
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start = time()
final_local_plan = planner.finish_plan(central_plan)
storage_writer.prepare_write_data(final_local_plan, planner)
storage_writer.prepare_write_data(central_plan, planner)
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return storage_writer, cast(Metadata, global_metadata), dist_wrapper
return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
)


def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper,
storage_writer: 'FileSystemWriterAsync',
global_metadata: Metadata,
dist_wrapper: _DistWrapper,
) -> None:
"""
Finalization of save_state_dict_async_plan.
Expand Down
Loading

0 comments on commit 0bc3547

Please sign in to comment.