Skip to content

Commit

Permalink
Merge branch 'dnarayanan/check_param_hashes' into 'main'
Browse files Browse the repository at this point in the history
Compute hashes on each rank, and compare across DP replicas

See merge request ADLR/megatron-lm!1368
  • Loading branch information
ericharper committed Apr 26, 2024
2 parents 57ba5a8 + 2afccb6 commit 0d983e6
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 62 deletions.
82 changes: 69 additions & 13 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Utility functions used throughout Megatron core"""
import array
import hashlib
import logging
import math
import operator
Expand All @@ -21,6 +23,8 @@
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor

logger = logging.getLogger(__name__)


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down Expand Up @@ -194,6 +198,60 @@ def init_(tensor):
return init_


def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and the hashes
from DP replica 0.
NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.
Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.
Returns:
True if all param hashes match with corresponding hash on DP replica 0, False
otherwise.
"""

# Compute per-parameter hashes on this rank.
params = []
local_param_hashes = []
for model_chunk_id, model_chunk in enumerate(model):
for (param_name, param) in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
params.append((model_chunk_id, param_name, param))
local_param_hashes.append(param_hash)
local_param_hashes = torch.stack(local_param_hashes)

# Collect per-parameter hashes across all ranks in DP group.
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(parallel_state.get_data_parallel_world_size())
]
torch.distributed.all_gather(
all_param_hashes, local_param_hashes, group=parallel_state.get_data_parallel_group_gloo()
)

# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk {model_chunk_id}"
)
return param_hashes_match


def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
Expand Down Expand Up @@ -490,7 +548,6 @@ class StragglerDetector:
stop_batch (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
logger (Logger): the logger instance for this instance
"""

_configured = False
Expand Down Expand Up @@ -541,7 +598,6 @@ def __init__(self) -> None:
self.stop_batch = None
self.sock = None
self.ctrlr = None
self.logger = logging.getLogger(__name__)

def configure(
self,
Expand Down Expand Up @@ -714,9 +770,9 @@ def elapsed(self) -> Tuple[float, float, int, int, int, int]:
power = 0
clock = 0
if ls_ev != le_ev:
self.logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
self.logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
Expand Down Expand Up @@ -770,7 +826,7 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
self.logger.info(
logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
Expand All @@ -791,12 +847,12 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
self.logger.info(line)
logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
self.logger.info(line)
logger.info(line)
ret = True

# Check/Communicate if tracking is turned off or on
Expand Down Expand Up @@ -828,7 +884,7 @@ def _check_toggle(self) -> None:
self.stop = self.null_method
state = "OFF"
if self.rank == 0 and off is not self._off:
self.logger.info(f"Toggling StragglerDetector State {state}")
logger.info(f"Toggling StragglerDetector State {state}")

def _handler(self) -> None:
"""Thread function for the controller.
Expand All @@ -842,7 +898,7 @@ def _handler(self) -> None:

if self.rank == 0:
state = "OFF" if self._off else "ON"
self.logger.info(
logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True:
Expand All @@ -856,9 +912,9 @@ def _handler(self) -> None:
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
self.logger.info(msg)
logger.info(msg)
except Exception as err:
self.logger.error(f"Error in stragler handler.. {str(err)}")
logger.error(f"Error in stragler handler.. {str(err)}")
return

def _controller(self):
Expand All @@ -879,7 +935,7 @@ def _controller(self):
)
self.ctrlr.start()
except Exception as err:
self.logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")

def _min_max(
self,
Expand Down Expand Up @@ -1086,7 +1142,7 @@ def __exit__(
ret = False
if ex_type is not None:
err = traceback.format_exception(ex_tb)
self.logger.warning(f"{str(ex_val)}\n{err}")
logger.warning(f"{str(ex_val)}\n{err}")
ret = True
self.stop()
return ret
Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,8 @@ def _add_training_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None,
help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.')

# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
Expand Down
17 changes: 14 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

"""Pretrain utilities."""

import gc
import dataclasses
from datetime import datetime
import math
import gc
import logging
import math
import os
import sys
from .log_handler import CustomHandler
Expand All @@ -19,7 +19,7 @@
import torch

from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config, StragglerDetector
from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config, StragglerDetector
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.legacy.model import Float16Module
Expand Down Expand Up @@ -1057,6 +1057,17 @@ def track_e2e_metrics():
stimer.report(total_flops, args.log_interval)
total_flops = 0.0

if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.disable_pre_hook()
assert check_param_hashes_across_dp_replicas(model), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.enable_pre_hook()

# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ products:
- {tp_size: [1], pp_size: [4], vp_size: [1], extra_args: ['"--decoupled-lr 0.0002"'], args_meta: ["decoupled_lr"]}
- {tp_size: [1], pp_size: [4], vp_size: [1], extra_args: ['"--use-distributed-optimizer --overlap-grad-reduce"'], args_meta: ["dist_optimizer_overlap_grad_reduce"]}
- {tp_size: [1], pp_size: [4], vp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--use-distributed-optimizer --overlap-grad-reduce --untie-embeddings-and-output-weights"'], args_meta: ["dist_optimizer_overlap_grad_reduce_untied"]}
- {tp_size: [1], pp_size: [4], vp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--use-distributed-optimizer --overlap-grad-reduce --overlap-param-gather"'], args_meta: ["dist_optimizer_overlap_grad_reduce_param_gather"]}
- {tp_size: [1], pp_size: [4], vp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--use-distributed-optimizer --overlap-grad-reduce --overlap-param-gather --check-weight-hash-across-dp-replicas-interval 10"'], args_meta: ["dist_optimizer_overlap_grad_reduce_param_gather"]}
# Non-MCore, only legacy checkpoints supported
- {use_mcore: [False], use_te: [False, True], tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch]}
- {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [1], ckpt_resume: [0, 1], ckpt_format: [torch]}
Loading

0 comments on commit 0d983e6

Please sign in to comment.