Skip to content

Commit

Permalink
Merge branch 'jxin/new_tp_dp_mapping' into 'main'
Browse files Browse the repository at this point in the history
Support alternative mapping TP->PP->DP

See merge request ADLR/megatron-lm!1183
  • Loading branch information
jaredcasper committed Apr 12, 2024
2 parents b5aba3a + 6513cde commit fbb375d
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 176 deletions.
417 changes: 276 additions & 141 deletions megatron/core/parallel_state.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ def validate_args(args, defaults={}):
if args.use_dist_ckpt and not args.use_mcore_models:
raise RuntimeError('--use-dist-ckpt only support Megatron Core, please add --use-mcore-models.')

if args.use_tp_pp_dp_mapping:
assert args.context_parallel_size * args.expert_model_parallel_size <= 1, \
"context_parallel and expert_model_parallel can't be used with tp-pp-dp mapping."

# Print arguments.
_print_args("arguments", args)

Expand Down Expand Up @@ -1330,6 +1334,10 @@ def _add_distributed_args(parser):
'configurations. The number of min/max thread groups and thread '
'group cluster size of each communicator can be configured by '
'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.')
group.add_argument('--use-tp-pp-dp-mapping', action='store_true', default=False,
help='If set, distributed ranks initialize order is changed '
'from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren\'t used '
'with this option enabled')
return parser


Expand Down
1 change: 1 addition & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _initialize_distributed():
expert_model_parallel_size=args.expert_model_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
)
if args.rank == 0:
print(
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/dist_checkpointing/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,20 @@ def common_test_simple_sharded_state_dict_save_load(initialize_model_fn, tmp_pat


def common_test_parallel_reconfiguration_e2e(initialize_model_fn, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp,
src_layer_spec_fn, dst_layer_spec_fn):
src_layer_spec_fn, dst_layer_spec_fn,
load_order="tp-dp-pp", store_order="tp-dp-pp"):
""" Test model saving and loading with different TP/PP """
with TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_A') as ckpt_dir_A, \
TempNamedDir(tmp_path_dist_ckpt / 'test_gpt_model_reconfiguration_model_B') as ckpt_dir_B:
# Save checkpoint A
Utils.initialize_model_parallel(*src_tp_pp)
Utils.initialize_model_parallel(*src_tp_pp, order=load_order)
gpt_model_A = initialize_model_fn(1, src_layer_spec_fn)
save(gpt_model_A.sharded_state_dict(), ckpt_dir_A)
regular_state_dict_A = gpt_model_A.state_dict()
Utils.destroy_model_parallel()

# Load checkpoint A with different TP/PP and save as checkpoint B
Utils.initialize_model_parallel(*dest_tp_pp)
Utils.initialize_model_parallel(*dest_tp_pp, order=store_order)
gpt_model_B = initialize_model_fn(2, dst_layer_spec_fn)
state_dict = load(gpt_model_B.sharded_state_dict(), ckpt_dir_A)
gpt_model_B.load_state_dict(state_dict)
Expand Down
9 changes: 7 additions & 2 deletions tests/unit_tests/dist_checkpointing/models/test_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def test_sharded_state_dict_save_load(self, tmp_path_dist_ckpt,


class TestGPTModelReconfiguration:
@pytest.mark.parametrize("load_order,store_order", [
('tp-dp-pp', 'tp-dp-pp'),
('tp-pp-dp', 'tp-pp-dp'),
('tp-dp-pp', 'tp-pp-dp'),
])
@pytest.mark.parametrize("src_tp_pp,dest_tp_pp,src_layer_spec_fn,dst_layer_spec_fn", [
((2, 4), (4, 2), gpt_te_spec, gpt_te_spec),
((1, 8), (8, 1), gpt_te_spec, gpt_te_spec),
Expand All @@ -53,10 +58,10 @@ class TestGPTModelReconfiguration:
((1, 8), (2, 1), gpt_local_spec, gpt_te_spec),
])
def test_parallel_reconfiguration_e2e(self, tmp_path_dist_ckpt, src_tp_pp, dest_tp_pp,
src_layer_spec_fn, dst_layer_spec_fn):
src_layer_spec_fn, dst_layer_spec_fn, load_order, store_order):
""" Test model saving and loading with different TP/PP """
common_test_parallel_reconfiguration_e2e(initialize_gpt_model, tmp_path_dist_ckpt, src_tp_pp,
dest_tp_pp, src_layer_spec_fn, dst_layer_spec_fn)
dest_tp_pp, src_layer_spec_fn, dst_layer_spec_fn, load_order, store_order)


def test_state_dict_comparison(self, tmp_path_dist_ckpt):
Expand Down
Loading

0 comments on commit fbb375d

Please sign in to comment.