Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync with original repo #2

Merged
merged 10 commits into from
Apr 29, 2024
3 changes: 3 additions & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ ARG FROM_IMAGE_NAME
FROM ${FROM_IMAGE_NAME}

COPY . megatron-lm

RUN cp -r /workspace/megatron-lm /opt && \
pip install /opt/megatron-lm
2 changes: 1 addition & 1 deletion megatron/core/package_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


MAJOR = 0
MINOR = 6
MINOR = 7
PATCH = 0
PRE_RELEASE = 'rc0'

Expand Down
82 changes: 69 additions & 13 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Utility functions used throughout Megatron core"""
import array
import hashlib
import logging
import math
import operator
Expand All @@ -21,6 +23,8 @@
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedTensor

logger = logging.getLogger(__name__)


def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
Expand Down Expand Up @@ -194,6 +198,60 @@ def init_(tensor):
return init_


def check_param_hashes_across_dp_replicas(model: List[torch.nn.Module]) -> bool:
"""Computes hashes of all parameters in model, all-gathers hashes across DP replicas,
and then checks for equality between the locally-computed hashes and the hashes
from DP replica 0.

NOTE: This function computes SHA-1 hashes on the CPU and thus needs to move all param
tensors from GPU to CPU first; as a result, this function is not intended to be called
very frequently in the main training loop.

Args:
model (List[torch.nn.Module]): List of model chunks whose parameter hashes need to
be checked.

Returns:
True if all param hashes match with corresponding hash on DP replica 0, False
otherwise.
"""

# Compute per-parameter hashes on this rank.
params = []
local_param_hashes = []
for model_chunk_id, model_chunk in enumerate(model):
for (param_name, param) in model_chunk.named_parameters():
param_hash = torch.frombuffer(
array.array(
'B', hashlib.sha1(param.data.to("cpu").float().numpy(force=True)).digest()
),
dtype=torch.uint8,
)
params.append((model_chunk_id, param_name, param))
local_param_hashes.append(param_hash)
local_param_hashes = torch.stack(local_param_hashes)

# Collect per-parameter hashes across all ranks in DP group.
all_param_hashes = [
torch.zeros_like(local_param_hashes)
for _ in range(parallel_state.get_data_parallel_world_size())
]
torch.distributed.all_gather(
all_param_hashes, local_param_hashes, group=parallel_state.get_data_parallel_group_gloo()
)

# Make sure local per-parameter hash matches DP rank 0.
param_hashes_match = torch.equal(local_param_hashes, all_param_hashes[0])
if not param_hashes_match:
for i, (model_chunk_id, param_name, param) in enumerate(params):
if not torch.equal(local_param_hashes[i], all_param_hashes[0][i]):
rank = torch.distributed.get_rank()
logger.info(
f"[Rank {rank}] Hash not matching for {param_name} in model chunk {model_chunk_id}"
)
return param_hashes_match


def make_tp_sharded_tensor_for_checkpoint(
tensor, key, tp_axis=0, replica_id=None, prepend_offsets=(), **kwargs
):
Expand Down Expand Up @@ -490,7 +548,6 @@ class StragglerDetector:
stop_batch (list[int]): stop time for get_batch
sock (socket): the controller socket
ctrlr (Thread): the controller thread
logger (Logger): the logger instance for this instance
"""

_configured = False
Expand Down Expand Up @@ -541,7 +598,6 @@ def __init__(self) -> None:
self.stop_batch = None
self.sock = None
self.ctrlr = None
self.logger = logging.getLogger(__name__)

def configure(
self,
Expand Down Expand Up @@ -714,9 +770,9 @@ def elapsed(self) -> Tuple[float, float, int, int, int, int]:
power = 0
clock = 0
if ls_ev != le_ev:
self.logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
logger.warning(f"Event Start/Stop out of sync {ls_ev}/{le_ev}")
elif ls_bs != ls_be:
self.logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
logger.warning(f"get_batch Start/Stop out of sync {ls_bs}/{ls_be}")
else:
temp = torch.cuda.temperature()
power = torch.cuda.power_draw()
Expand Down Expand Up @@ -770,7 +826,7 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
now = f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
min_flops, min_frank, _ = o_dt.aflops[0]()
max_flops, max_frank, _ = o_dt.aflops[-1]()
self.logger.info(
logger.info(
f"{now} | "
f"MnRtt/Rnk: {o_dt.min_elapsed} | "
f"MxRtt/Rnk: {o_dt.max_elapsed} | "
Expand All @@ -791,12 +847,12 @@ def report(self, total_flops: float = 0.0, log_interval: int = 0) -> bool:
line = f"^^^^ Bottom {self.mmcnt} Ranks with lowest Etpt(TF):"
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i]},"
self.logger.info(line)
logger.info(line)
line = f"^^^^ Top {self.mmcnt} Ranks with highest Etpt(TF):"
shift = self.world - self.mmcnt
for i in range(self.mmcnt):
line += f" {o_dt.aflops[i+shift]},"
self.logger.info(line)
logger.info(line)
ret = True

# Check/Communicate if tracking is turned off or on
Expand Down Expand Up @@ -828,7 +884,7 @@ def _check_toggle(self) -> None:
self.stop = self.null_method
state = "OFF"
if self.rank == 0 and off is not self._off:
self.logger.info(f"Toggling StragglerDetector State {state}")
logger.info(f"Toggling StragglerDetector State {state}")

def _handler(self) -> None:
"""Thread function for the controller.
Expand All @@ -842,7 +898,7 @@ def _handler(self) -> None:

if self.rank == 0:
state = "OFF" if self._off else "ON"
self.logger.info(
logger.info(
f"Controller ready to recv " f"commands on port {self.port}. Current state {state}"
)
while True:
Expand All @@ -856,9 +912,9 @@ def _handler(self) -> None:
final_resp = f"{resp}{msg_len}\r\n\r\n{msg}"
conn.send(final_resp.encode())
conn.close()
self.logger.info(msg)
logger.info(msg)
except Exception as err:
self.logger.error(f"Error in stragler handler.. {str(err)}")
logger.error(f"Error in stragler handler.. {str(err)}")
return

def _controller(self):
Expand All @@ -879,7 +935,7 @@ def _controller(self):
)
self.ctrlr.start()
except Exception as err:
self.logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")
logger.warning(f"StragglerDetector cannot be controlled.. {str(err)}")

def _min_max(
self,
Expand Down Expand Up @@ -1086,7 +1142,7 @@ def __exit__(
ret = False
if ex_type is not None:
err = traceback.format_exception(ex_tb)
self.logger.warning(f"{str(ex_val)}\n{err}")
logger.warning(f"{str(ex_val)}\n{err}")
ret = True
self.stop()
return ret
Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,8 @@ def _add_training_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--check-weight-hash-across-dp-replicas-interval', type=int, default=None,
help='Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.')

# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
Expand Down
17 changes: 14 additions & 3 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

"""Pretrain utilities."""

import gc
import dataclasses
from datetime import datetime
import math
import gc
import logging
import math
import os
import sys
from .log_handler import CustomHandler
Expand All @@ -19,7 +19,7 @@
import torch

from megatron.core import mpu, tensor_parallel
from megatron.core.utils import get_model_config, StragglerDetector
from megatron.core.utils import check_param_hashes_across_dp_replicas, get_model_config, StragglerDetector
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.legacy.model import Float16Module
Expand Down Expand Up @@ -1057,6 +1057,17 @@ def track_e2e_metrics():
stimer.report(total_flops, args.log_interval)
total_flops = 0.0

if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.disable_pre_hook()
assert check_param_hashes_across_dp_replicas(model), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather:
optimizer.enable_pre_hook()

# Autoresume
if args.adlr_autoresume and \
(iteration % args.adlr_autoresume_interval == 0):
Expand Down
19 changes: 9 additions & 10 deletions tests/functional_tests/jet_recipes/MR-bert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ spec:
name: "{model}_{variant}_{scope}_{platforms}_{nodes}N{gpus}G_\
{'mcore_' if use_mcore else ''}{'te_' if use_te else ''}\
tp{tp_size}_pp{pp_size}{'_vp'+str(vp_size) if vp_size else ''}\
{'_resume_'+str(ckpt_format) if ckpt_resume else ''}\
{'_'+args_meta if args_meta else ''}"
model: bert
variant: 345m
Expand All @@ -14,7 +15,6 @@ spec:
nodes: 1
gpus: 8
platforms: dgx_a100
steps: 50
use_te: False
use_mcore: True
vp_size: null
Expand All @@ -25,7 +25,8 @@ spec:
precision: bf16
time_limit: 1200
artifacts: {/workspace/data/bert_data: text/the_pile/bert_shard00}
checkpoint_resume_test: 0
ckpt_format: torch_dist
ckpt_resume: 0
script: |-
ls
cd /workspace/megatron-lm
Expand All @@ -39,20 +40,18 @@ spec:
TP_SIZE={tp_size} \
PP_SIZE={pp_size} \
NUM_NODES={nodes} \
MAX_STEPS={steps} \
MAX_STEPS={100 if ckpt_resume else 50} \
USE_CORE={"1" if use_mcore else "0"} \
VP_SIZE={vp_size if vp_size is not None else '""'} \
MBS={micro_batch_size} \
GBS={batch_size} \
CHECKPOINT_RESUME_TEST={checkpoint_resume_test} \
CHECKPOINT_RESUME_TEST={ckpt_resume} \
JOB_NAME={key.split("/")[1]} \
ADDITIONAL_PARAMS={extra_args if extra_args is not None else '""'}
products:
# MCore
- {tp_size: [2], pp_size: [2]}
- {tp_size: [2], pp_size: [2], extra_args: ['"--spec local"'], args_meta: ["local_spec"]}
- {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1]}
- {tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], extra_args: ['"--spec local"'], args_meta: ["local_spec"]}
# Non-MCore
- {use_mcore: [False], tp_size: [2], pp_size: [2], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [2], extra_args: ['"--transformer-impl local"']}
# Checkpoint resume
- {checkpoint_resume_test: [1], scope: [merge-request-resume], steps: [100], use_mcore: [False], tp_size: [1], pp_size: [2], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [2], pp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
- {use_mcore: [False], tp_size: [1], pp_size: [4], vp_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--transformer-impl local"']}
45 changes: 45 additions & 0 deletions tests/functional_tests/jet_recipes/MR-gpt-nemo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
type: basic
format_version: 1
maintainers: [maanug]
loggers: [stdout]
launchers:
type:slurm:
ntasks_per_node: '{gpus}'
no_container_mount_home: 'true'
spec:
name: "{model}_{variant}_{scope}_{platforms}_{nodes}N{gpus}G_\
mbs{mbs}_gbs{gbs}_ \
{'mcore_' if use_mcore else ''}{'te_' if use_te else ''}\
tp{tp_size}_pp{pp_size}{'_vp'+str(vp_size) if vp_size else ''}\
{'_'+args_meta if args_meta else ''}"
model: gpt3-nemo
variant: 126m
build: mcore-nemo
scope: merge-request
nodes: 1
gpus: 8
platforms: dgx_a100
steps: 50
extra_args: null
args_meta: null
precision: bf16
time_limit: 1200
use_mcore: True
use_te: True
vp_size: null
script: |-
cd /opt/NeMo

/opt/megatron-lm/tests/functional_tests/test_scripts/gpt3/pretrain_gpt3_nemo_test.sh \
TP_SIZE={tp_size} \
PP_SIZE={pp_size} \
NUM_NODES={nodes} \
MAX_STEPS={steps} \
VP_SIZE={vp_size if vp_size is not None else '""'} \
MBS={mbs} \
GBS={gbs} \
JOB_NAME={key.split("/")[1]} \
ADDITIONAL_PARAMS={extra_args if extra_args is not None else '""'}
products:
- {tp_size: [1], pp_size: [1], mbs: [4], gbs: [64], vp_size: [null]}
- {tp_size: [2], pp_size: [4], mbs: [1], gbs: [8], vp_size: [3], extra_args: ['"model.sequence_parallel=True model.overlap_p2p_comm=True model.batch_p2p_comm=False"'], args_meta: ["seq_par_overlap_p2p"]}
Loading