Skip to content

Commit

Permalink
Merge pull request #2 from NVIDIA/main
Browse files Browse the repository at this point in the history
sync with original repo
  • Loading branch information
vlad-karpuhin committed Apr 29, 2024
2 parents 03c670c + 0d983e6 commit 8ff696d
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 125 deletions.
3 changes: 3 additions & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ ARG FROM_IMAGE_NAME
FROM ${FROM_IMAGE_NAME}

COPY . megatron-lm

RUN cp -r /workspace/megatron-lm /opt && \
pip install /opt/megatron-lm
2 changes: 1 addition & 1 deletion megatron/core/package_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


MAJOR = 0
MINOR = 6
MINOR = 7
PATCH = 0
PRE_RELEASE = 'rc0'

Expand Down
82 changes: 69 additions & 13 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Utility functions used throughout Megatron core"""
import array
import hashlib
import logging
import math
import operator
Expand All @@ -21,6 +23,8 @@
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor

logger = logging.getLogger(__name__)


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down Expand Up @@ -194,6 +198,60 @@ def init_(tensor):
return init_


def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and the hashes
from DP replica 0.
NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.
Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.
Returns:
True if all param hashes match with corresponding hash on DP replica 0, False
otherwise.
"""

# Compute per-parameter hashes on this rank.
params = []
local_param_hashes = []
for model_chunk_id, model_chunk in enumerate(model):
for (param_name, param) in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
params.append((model_chunk_id, param_name, param))
local_param_hashes.append(param_hash)
local_param_hashes = torch.stack(local_param_hashes)

# Collect per-parameter hashes across all ranks in DP group.
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(parallel_state.get_data_parallel_world_size())
]
torch.distributed.all_gather(
all_param_hashes, local_param_hashes, group=parallel_state.get_data_parallel_group_gloo()
)

# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk {model_chunk_id}"
)
return param_hashes_match


def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
Expand Down Expand Up @@ -490,7 +548,6 @@ class StragglerDetector:
stop_batch (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
logger (Logger): the logger instance for this instance
"""

_configured = False
Expand Down Expand Up @@ -541,7 +598,6 @@ def __init__(self) -> None:
self.stop_batch = None
self.sock = None
self.ctrlr = None
self.logger = logging.getLogger(__name__)

def configure(
self,
Expand Down Expand Up @@ -714,9 +770,9 @@ def elapsed(self) -> Tuple[float, float, int, int, int, int]:
power = 0
clock = 0
if ls_ev != le_ev:
self.logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
self.logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
Expand Down Expand Up @@ -770,7 +826,7 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
self.logger.info(
logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
Expand All @@ -791,12 +847,12 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
self.logger.info(line)
logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
self.logger.info(line)
logger.info(line)
ret = True

# Check/Communicate if tracking is turned off or on
Expand Down Expand Up @@ -828,7 +884,7 @@ def _check_toggle(self) -> None:
self.stop = self.null_method
state = "OFF"
if self.rank == 0 and off is not self._off:
self.logger.info(f"Toggling StragglerDetector State {state}")
logger.info(f"Toggling StragglerDetector State {state}")

def _handler(self) -> None:
"""Thread function for the controller.
Expand All @@ -842,7 +898,7 @@ def _handler(self) -> None:

if self.rank == 0:
state = "OFF" if self._off else "ON"
self.logger.info(
logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True:
Expand All @@ -856,9 +912,9 @@ def _handler(self) -> None:
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
self.logger.info(msg)
logger.info(msg)
except Exception as err:
self.logger.error(f"Error in stragler handler.. {str(err)}")
logger.error(f"Error in stragler handler.. {str(err)}")
return

def _controller(self):
Expand All @@ -879,7 +935,7 @@ def _controller(self):
)
self.ctrlr.start()
except Exception as err:
self.logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")

def _min_max(
self,
Expand Down Expand Up @@ -1086,7 +1142,7 @@ def __exit__(
ret = False
if ex_type is not None:
err = traceback.format_exception(ex_tb)
self.logger.warning(f"{str(ex_val)}\n{err}")
logger.warning(f"{str(ex_val)}\n{err}")
ret = True
self.stop()
return ret
Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,8 @@ def _add_training_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None,
help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.')

# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
Expand Down
17 changes: 14 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

"""Pretrain utilities."""

import gc
import dataclasses
from datetime import datetime
import math
import gc
import logging
import math
import os
import sys
from .log_handler import CustomHandler
Expand All @@ -19,7 +19,7 @@
import torch

from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config, StragglerDetector
from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config, StragglerDetector
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.legacy.model import Float16Module
Expand Down Expand Up @@ -1057,6 +1057,17 @@ def track_e2e_metrics():
stimer.report(total_flops, args.log_interval)
total_flops = 0.0

if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.disable_pre_hook()
assert check_param_hashes_across_dp_replicas(model), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.enable_pre_hook()

# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
Expand Down
19 changes: 9 additions & 10 deletions tests/functional_tests/jet_recipes/MR-bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ spec:
name: "{model}_{variant}_{scope}_{platforms}_{nodes}N{gpus}G_\
{'mcore_' if use_mcore else ''}{'te_' if use_te else ''}\
tp{tp_size}_pp{pp_size}{'_vp'+str(vp_size) if vp_size else ''}\
{'_resume_'+str(ckpt_format) if ckpt_resume else ''}\
{'_'+args_meta if args_meta else ''}"
model: bert
variant: 345m
Expand All @@ -14,7 +15,6 @@ spec:
nodes: 1
gpus: 8
platforms: dgx_a100
steps: 50
use_te: False
use_mcore: True
vp_size: null
Expand All @@ -25,7 +25,8 @@ spec:
precision: bf16
time_limit: 1200
artifacts: {/workspace/data/bert_data: text/the_pile/bert_shard00}
checkpoint_resume_test: 0
ckpt_format: torch_dist
ckpt_resume: 0
script: |-
ls
cd /workspace/megatron-lm
Expand All @@ -39,20 +40,18 @@ spec:
TP_SIZE={tp_size} \
PP_SIZE={pp_size} \
NUM_NODES={nodes} \
MAX_STEPS={steps} \
MAX_STEPS={100 if ckpt_resume else 50} \
USE_CORE={"1" if use_mcore else "0"} \
VP_SIZE={vp_size if vp_size is not None else '""'} \
MBS={micro_batch_size} \
GBS={batch_size} \
CHECKPOINT_RESUME_TEST={checkpoint_resume_test} \
CHECKPOINT_RESUME_TEST={ckpt_resume} \
JOB_NAME={key.split("/")[1]} \
ADDITIONAL_PARAMS={extra_args if extra_args is not None else '""'}
products:
# MCore
- {tp_size: [2], pp_size: [2]}
- {tp_size: [2], pp_size: [2], extra_args: ['"--spec local"'], args_meta: ["local_spec"]}
- {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1]}
- {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], extra_args: ['"--spec local"'], args_meta: ["local_spec"]}
# Non-MCore
- {use_mcore: [False], tp_size: [2], pp_size: [2], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [2], extra_args: ['"--transformer-impl local"']}
# Checkpoint resume
- {checkpoint_resume_test: [1], scope: [merge-request-resume], steps: [100], use_mcore: [False], tp_size: [1], pp_size: [2], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
45 changes: 45 additions & 0 deletions tests/functional_tests/jet_recipes/MR-gpt-nemo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
type: basic
format_version: 1
maintainers: [maanug]
loggers: [stdout]
launchers:
type:slurm:
ntasks_per_node: '{gpus}'
no_container_mount_home: 'true'
spec:
name: "{model}_{variant}_{scope}_{platforms}_{nodes}N{gpus}G_\
mbs{mbs}_gbs{gbs}_ \
{'mcore_' if use_mcore else ''}{'te_' if use_te else ''}\
tp{tp_size}_pp{pp_size}{'_vp'+str(vp_size) if vp_size else ''}\
{'_'+args_meta if args_meta else ''}"
model: gpt3-nemo
variant: 126m
build: mcore-nemo
scope: merge-request
nodes: 1
gpus: 8
platforms: dgx_a100
steps: 50
extra_args: null
args_meta: null
precision: bf16
time_limit: 1200
use_mcore: True
use_te: True
vp_size: null
script: |-
cd /opt/NeMo
/opt/megatron-lm/tests/functional_tests/test_scripts/gpt3/pretrain_gpt3_nemo_test.sh \
TP_SIZE={tp_size} \
PP_SIZE={pp_size} \
NUM_NODES={nodes} \
MAX_STEPS={steps} \
VP_SIZE={vp_size if vp_size is not None else '""'} \
MBS={mbs} \
GBS={gbs} \
JOB_NAME={key.split("/")[1]} \
ADDITIONAL_PARAMS={extra_args if extra_args is not None else '""'}
products:
- {tp_size: [1], pp_size: [1], mbs: [4], gbs: [64], vp_size: [null]}
- {tp_size: [2], pp_size: [4], mbs: [1], gbs: [8], vp_size: [3], extra_args: ['"model.sequence_parallel=True model.overlap_p2p_comm=True model.batch_p2p_comm=False"'], args_meta: ["seq_par_overlap_p2p"]}
Loading

0 comments on commit 8ff696d

Please sign in to comment.