diff --git a/examples/multimodal/layer_specs.py b/examples/multimodal/layer_specs.py index ad24850b631..621ab01d297 100644 --- a/examples/multimodal/layer_specs.py +++ b/examples/multimodal/layer_specs.py @@ -1,4 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from functools import partial + import torch from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add @@ -170,8 +172,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec: mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), @@ -184,10 +186,10 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec: ) -def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: +def get_mlp_module_spec(use_te: bool = True): # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TEColumnParallelLinear) if use_te else ColumnParallelLinear, linear_fc2=not_none(TERowParallelLinear) if use_te else RowParallelLinear, @@ -195,9 +197,9 @@ def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: ) -def get_norm_mlp_module_spec_te() -> ModuleSpec: - return ModuleSpec( - module=MLP, +def get_norm_mlp_module_spec_te(): + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), diff --git a/examples/multimodal/nvlm/internvit.py b/examples/multimodal/nvlm/internvit.py index 0018bb5ccb9..d38ac64c16b 100644 --- a/examples/multimodal/nvlm/internvit.py +++ b/examples/multimodal/nvlm/internvit.py @@ -160,10 +160,10 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata={}): return super().sharded_state_dict(prefix, sharded_offsets, metadata) -def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: +def get_mlp_module_spec(use_te: bool = True): # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, diff --git a/examples/multimodal/radio/radio_g.py b/examples/multimodal/radio/radio_g.py index 9883d58db61..30d2be7df2e 100644 --- a/examples/multimodal/radio/radio_g.py +++ b/examples/multimodal/radio/radio_g.py @@ -1,8 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. from functools import partial -import torch - from examples.multimodal.layer_scaling import ( LayerScalingTransformerLayer, get_bias_dropout_add_layer_scaling, @@ -14,7 +12,7 @@ from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules from megatron.core.typed_torch import not_none from megatron.core.extensions.transformer_engine import HAVE_TE @@ -51,10 +49,10 @@ LNImpl = WrappedTorchNorm -def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: +def get_mlp_module_spec(use_te: bool = True): # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TEColumnParallelLinear) if use_te else ColumnParallelLinear, linear_fc2=not_none(TERowParallelLinear) if use_te else RowParallelLinear, @@ -62,9 +60,9 @@ def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: ) -def get_norm_mlp_module_spec_te() -> ModuleSpec: - return ModuleSpec( - module=MLP, +def get_norm_mlp_module_spec_te(): + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 17358f8a921..b9dd9a323a3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -43,7 +43,7 @@ ) from megatron.core.tensor_parallel.utils import divide from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.torch_norm import LayerNormInterface from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import ( @@ -2428,6 +2428,31 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option return out, bias + @classmethod + def as_mlp_submodule( + cls, + submodules: MLPSubmodules, + config: TransformerConfig, + pg_collection: ProcessGroupCollection, + is_mtp_layer: bool, + is_expert: bool = False, + input_size: int | None = None, + ffn_hidden_size: int | None = None, + ) -> MLP: + """Helper function to build an MLP as a TransformerLayer's mlp submodule.""" + del is_mtp_layer + assert hasattr( + pg_collection, 'tp' + ), 'TP process group is required for TEFusedMLP in TransformerLayer' + return cls( + config=config, + submodules=submodules, + tp_group=pg_collection.tp, + is_expert=is_expert, + input_size=input_size, + ffn_hidden_size=ffn_hidden_size, + ) + else: TEFusedMLP = None # type: ignore[assignment, misc] diff --git a/megatron/core/models/T5/t5_spec.py b/megatron/core/models/T5/t5_spec.py index 9f465df5c21..5bbe49c337b 100644 --- a/megatron/core/models/T5/t5_spec.py +++ b/megatron/core/models/T5/t5_spec.py @@ -1,4 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from functools import partial + from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -69,8 +71,8 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec: ), ), self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), @@ -111,8 +113,8 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec: ), ), cross_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), @@ -143,8 +145,8 @@ def encoder_model_with_local_spec() -> ModuleSpec: ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=LNImpl, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), @@ -190,8 +192,8 @@ def decoder_model_with_local_spec() -> ModuleSpec: ), cross_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=LNImpl, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), diff --git a/megatron/core/models/bert/bert_layer_specs.py b/megatron/core/models/bert/bert_layer_specs.py index 53cc0f4280d..dc0099fa66e 100644 --- a/megatron/core/models/bert/bert_layer_specs.py +++ b/megatron/core/models/bert/bert_layer_specs.py @@ -1,5 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import warnings +from functools import partial from megatron.core.extensions.transformer_engine import HAVE_TE from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add @@ -66,8 +67,8 @@ def get_bert_layer_with_transformer_engine_submodules() -> TransformerLayerSubmo ), ), self_attn_bda=get_bias_dropout_add, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=not_none(TELayerNormColumnParallelLinear), linear_fc2=not_none(TERowParallelLinear), @@ -117,8 +118,8 @@ def __getattr__(name): ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=LNImpl, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), ), mlp_bda=get_bias_dropout_add, diff --git a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py index 6608073136c..c94ad022303 100644 --- a/megatron/core/models/gpt/experimental_attention_variant_module_specs.py +++ b/megatron/core/models/gpt/experimental_attention_variant_module_specs.py @@ -24,10 +24,12 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + MlpBuilder, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, ) +from megatron.core.typed_torch import not_none try: import transformer_engine as te # type: ignore[import-untyped] # pylint: disable=unused-import @@ -213,14 +215,18 @@ def get_transformer_block_with_experimental_attention_variant_spec( moe_layer_pattern = [0] * config.num_layers if 1 in moe_layer_pattern: - moe_layer_spec = _get_moe_module_spec(config=config, backend=backend) + moe_layer_spec, fuse_layernorm_pre_moe = _get_moe_module_spec( + config=config, backend=backend + ) else: - moe_layer_spec = None + moe_layer_spec, fuse_layernorm_pre_moe = None, False if 0 in moe_layer_pattern: - dense_mlp_layer_spec = _get_dense_mlp_module_spec(config=config, backend=backend) + dense_mlp_layer_spec, fuse_layernorm_pre_dense = _get_dense_mlp_module_spec( + config=config, backend=backend + ) else: - dense_mlp_layer_spec = None + dense_mlp_layer_spec, fuse_layernorm_pre_dense = None, False # Get GPT decoder block layer specs rms_norm = config.normalization == "RMSNorm" @@ -232,6 +238,11 @@ def get_transformer_block_with_experimental_attention_variant_spec( else standard_attention_spec ) mlp = moe_layer_spec if moe_layer_pattern[layer_number] == 1 else dense_mlp_layer_spec + fuse_pre_mlp_layernorm = ( + fuse_layernorm_pre_moe + if moe_layer_pattern[layer_number] == 1 + else fuse_layernorm_pre_dense + ) input_layernorm = ( IdentityOp if attention.metainfo["fuse_input_layernorm"] @@ -239,7 +250,7 @@ def get_transformer_block_with_experimental_attention_variant_spec( ) pre_mlp_layernorm = ( IdentityOp - if mlp.metainfo["fuse_pre_mlp_layernorm"] + if fuse_pre_mlp_layernorm else backend.layer_norm(rms_norm=rms_norm, for_qk=False) ) @@ -251,7 +262,7 @@ def get_transformer_block_with_experimental_attention_variant_spec( self_attention=attention, self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=pre_mlp_layernorm, - mlp=mlp, + mlp=not_none(mlp), mlp_bda=get_bias_dropout_add, ), ) @@ -410,41 +421,50 @@ def _get_self_attention_module_spec( def _get_dense_mlp_module_spec( config: TransformerConfig, backend: BackendSpecProvider = None -) -> ModuleSpec: +) -> tuple[MlpBuilder, bool]: """Get dense MLP module spec. For hybrid models that mix dense MLP and experimental attention architectures. - Warning: This function may be deprecated in the future.""" + Warning: This function may be deprecated in the future. + + Returns: + A tuple of (MLP module spec, whether to fuse pre-MLP layernorm) + """ if backend is None: backend = _get_backend_spec_provider(config=config) from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend - mlp_spec = get_mlp_module_spec_for_backend(backend=backend, num_experts=None) - mlp_spec.metainfo["fuse_pre_mlp_layernorm"] = backend.fuse_layernorm_and_linear() - - return mlp_spec + return ( + get_mlp_module_spec_for_backend(backend=backend, num_experts=None), + backend.fuse_layernorm_and_linear(), + ) def _get_moe_module_spec( config: TransformerConfig, backend: BackendSpecProvider = None -) -> ModuleSpec: +) -> tuple[MlpBuilder, bool]: """Get MoE module spec. For hybrid models that mix MoE and experimental attention architectures. - Warning: This function may be deprecated in the future.""" + Warning: This function may be deprecated in the future. + + Returns: + A tuple of (MoE module spec, whether to fuse pre-MoE layernorm) + """ if backend is None: backend = _get_backend_spec_provider(config=config) from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend - moe_spec = get_moe_module_spec_for_backend( - backend=backend, - num_experts=config.num_moe_experts, - moe_grouped_gemm=config.moe_grouped_gemm, - use_te_activation_func=config.use_te_activation_func, + return ( + get_moe_module_spec_for_backend( + backend=backend, + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + use_te_activation_func=config.use_te_activation_func, + ), + False, ) - moe_spec.metainfo["fuse_pre_mlp_layernorm"] = False - return moe_spec diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 5e90f0b36be..3fba1687aac 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import warnings +from functools import partial from typing import Optional, Union from megatron.core.extensions.transformer_engine import HAVE_TE @@ -34,11 +35,12 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import ( + MlpBuilder, TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset, ) -from megatron.core.typed_torch import copy_signature +from megatron.core.typed_torch import copy_signature, not_none from megatron.core.utils import is_te_min_version if HAVE_TE: @@ -485,7 +487,7 @@ def get_mlp_module_spec( moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, # pylint: disable=unused-argument use_te_op_fuser: Optional[bool] = False, -) -> ModuleSpec: +) -> MlpBuilder: """Helper function to get module spec for MLP/MoE""" if fp8 is not None: warnings.warn( @@ -516,7 +518,7 @@ def get_mlp_module_spec_for_backend( moe_grouped_gemm: Optional[bool] = False, use_te_op_fuser: Optional[bool] = False, use_te_activation_func: bool = False, -) -> ModuleSpec: +) -> MlpBuilder: """Helper function to get module spec for MLP/MoE""" linear_fc2 = backend.row_parallel_linear() @@ -524,14 +526,14 @@ def get_mlp_module_spec_for_backend( if num_experts is None: # Dense MLP w/ or w/o TE modules. - module = TEFusedMLP if use_te_op_fuser else MLP + module = not_none(TEFusedMLP).as_mlp_submodule if use_te_op_fuser else MLP.as_mlp_submodule if backend.fuse_layernorm_and_linear(): linear_fc1 = backend.column_parallel_layer_norm_linear() assert linear_fc1 is not None else: linear_fc1 = backend.column_parallel_linear() - return ModuleSpec( - module=module, + return partial( + module, submodules=MLPSubmodules( linear_fc1=linear_fc1, linear_fc2=linear_fc2, activation_func=activation_func ), diff --git a/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py b/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py index f4385429422..a7ecd9e9658 100644 --- a/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py +++ b/megatron/core/models/gpt/heterogeneous/heterogeneous_layer_specs.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import warnings +from functools import partial from typing import Optional from megatron.core.extensions.transformer_engine import HAVE_TE @@ -128,17 +129,19 @@ def _get_heterogenous_attention_spec( def _get_heterogenous_mlp_spec(mlp_config: MLPConfig, use_te: bool): if mlp_config.no_op: - mlp = ModuleSpec(module=IdentityOp) + return IdentityOp elif mlp_config.replace_with_linear: - mlp = ModuleSpec( - module=( - TELayerNormColumnParallelLinearGathered if use_te else ColumnParallelLinearGathered + return partial( + ( + not_none(TELayerNormColumnParallelLinearGathered) + if use_te + else ColumnParallelLinearGathered ), - params={"tp_comm_buffer_name": "linear_mlp"}, + tp_comm_buffer_name="linear_mlp", ) else: - mlp = ModuleSpec( - module=MLP, + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=( not_none(TELayerNormColumnParallelLinear) if use_te else ColumnParallelLinear @@ -146,7 +149,6 @@ def _get_heterogenous_mlp_spec(mlp_config: MLPConfig, use_te: bool): linear_fc2=not_none(TERowParallelLinear) if use_te else RowParallelLinear, ), ) - return mlp def _get_sharded_state_dict_keys_map(block_config: TransformerBlockConfig, use_te: bool): diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 53bca85f502..59dc0a59471 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -1,5 +1,4 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - from functools import partial from typing import Optional @@ -13,14 +12,14 @@ from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.router import InferenceTopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import MlpBuilder def get_moe_module_spec( use_te: Optional[bool] = True, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, -) -> ModuleSpec: +) -> MlpBuilder: """Helper function to get module spec for MoE. Called by mamba_layer_specs.py for standard (non-inference) MoE specs. @@ -46,7 +45,7 @@ def get_moe_module_spec_for_backend( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, use_te_activation_func: bool = False, -) -> ModuleSpec: +) -> MlpBuilder: """Helper function to get module spec for MoE""" assert num_experts is not None @@ -63,15 +62,12 @@ def get_moe_module_spec_for_backend( shared_experts = partial(SharedExpertMLP, submodules=mlp) # MoE module spec - moe_module_spec = ModuleSpec( - module=MoELayer, - submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), - metainfo={"fuse_pre_mlp_layernorm": False}, + return partial( + MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) ) - return moe_module_spec -def get_inference_optimized_moe_spec() -> ModuleSpec: +def get_inference_optimized_moe_spec() -> MlpBuilder: """MoE module spec for inference-optimized transformer impl. Uses InferenceSpecProvider to select inference-optimized modules: @@ -93,10 +89,9 @@ def get_inference_optimized_moe_spec() -> ModuleSpec: ), ) - return ModuleSpec( - module=MoELayer, + return partial( + MoELayer, submodules=MoESubmodules( router=InferenceTopKRouter, experts=experts, shared_experts=shared_experts ), - metainfo={"fuse_pre_mlp_layernorm": False}, ) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index d2a85d004ef..cb6a762a37f 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -1,4 +1,5 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +from functools import partial from megatron.core.extensions.transformer_engine import ( TEColumnParallelLinear, @@ -122,8 +123,8 @@ mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear ), @@ -182,8 +183,8 @@ mlp_layer=ModuleSpec( module=MLPLayer, submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=InferenceLayerNormColumnParallelLinear, linear_fc2=InferenceRowParallelLinear, diff --git a/megatron/core/models/vision/vit_layer_specs.py b/megatron/core/models/vision/vit_layer_specs.py index 51074b3836e..71c78ad3df4 100644 --- a/megatron/core/models/vision/vit_layer_specs.py +++ b/megatron/core/models/vision/vit_layer_specs.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from functools import partial from megatron.core.extensions.transformer_engine import ( TEDotProductAttention, @@ -85,10 +86,10 @@ def get_vit_layer_with_local_spec() -> ModuleSpec: # Helper function to get module spec for MLP/MoE -def _get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: +def _get_mlp_module_spec(use_te: bool = True): # Dense MLP w/ or w/o TE modules. - return ModuleSpec( - module=MLP, + return partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, diff --git a/megatron/core/post_training/modelopt/mamba/model_specs.py b/megatron/core/post_training/modelopt/mamba/model_specs.py index e9f83b49f71..72ced9932b3 100755 --- a/megatron/core/post_training/modelopt/mamba/model_specs.py +++ b/megatron/core/post_training/modelopt/mamba/model_specs.py @@ -1,4 +1,5 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from functools import partial from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add @@ -108,8 +109,8 @@ def _get_mamba_stack_local_spec( module=TransformerLayer, submodules=TransformerLayerSubmodules( pre_mlp_layernorm=Norm, - mlp=ModuleSpec( - module=MLP, + mlp=partial( + MLP.as_mlp_submodule, submodules=MLPSubmodules( linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear ), diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 8a19fef87ec..fd85dd1131f 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -25,6 +25,7 @@ ) from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl, weighted_bias_swiglu_impl +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module, not_none @@ -366,6 +367,31 @@ def backward_dw(self): self.linear_fc2.backward_dw() self.linear_fc1.backward_dw() + @classmethod + def as_mlp_submodule( + cls, + submodules: MLPSubmodules, + config: TransformerConfig, + pg_collection: ProcessGroupCollection, + is_mtp_layer: bool, + is_expert: bool = False, + input_size: int | None = None, + ffn_hidden_size: int | None = None, + ) -> MLP: + """Helper function to build an MLP as a TransformerLayer's mlp submodule.""" + del is_mtp_layer + assert hasattr( + pg_collection, 'tp' + ), 'TP process group is required for MLP in TransformerLayer' + return cls( + config=config, + submodules=submodules, + tp_group=pg_collection.tp, + is_expert=is_expert, + input_size=input_size, + ffn_hidden_size=ffn_hidden_size, + ) + # pylint: disable=missing-function-docstring def apply_swiglu_sharded_factory( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 34e9fb17a02..9a8de547e3c 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -733,7 +733,7 @@ class SequentialMLP(MegatronModule): # TODO(M4): breaking api, switched from pass in tp_group to pass in pg_collection. def __init__( self, - num_local_experts, + num_local_experts: int, config: TransformerConfig, submodules: MLPSubmodules, pg_collection: Optional[ProcessGroupCollection] = None, diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index ab71108c7fc..e7965a7779d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -526,7 +526,7 @@ def forward( hidden_states: torch.Tensor, intermediate_tensors=None, padding_mask: Optional[torch.Tensor] = None, - ): + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward pass for the MoE layer. The forward pass comprises four main steps: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index cf63199347c..12b8c2c50f9 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -6,7 +6,7 @@ import warnings from abc import ABC from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Protocol, Union import torch import torch.distributed @@ -199,6 +199,33 @@ def get_transformer_layer_offset( return offset +class MlpInterface(Protocol): + """Interface for MLP implementations in the transformer layer.""" + + def forward( + self, + hidden_states: Tensor, + /, + *, + intermediate_tensors: tuple[Tensor, ...] | None = None, + padding_mask: Tensor | None = None, + ) -> tuple[Tensor, Tensor | None]: + """Forward method for the MLP interface.""" + ... + + +class MlpBuilder(Protocol): + """MLP builder protocol for building MLPs in the transformer layer.""" + + def __call__( + self, + *, + config: TransformerConfig, + pg_collection: ProcessGroupCollection, + is_mtp_layer: bool, + ) -> MlpInterface: ... + + @dataclass class TransformerLayerSubmodules: """ @@ -236,7 +263,7 @@ class TransformerLayerSubmodules: cross_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp pre_mlp_layernorm: LayerNormBuilder = IdentityOp - mlp: Union[ModuleSpec, type] = IdentityOp + mlp: MlpBuilder | type[IdentityOp] = IdentityOp mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method @@ -360,39 +387,30 @@ def __init__( eps=self.config.layernorm_epsilon, ) # [Module 8: MLP block] - additional_mlp_kwargs = {} # import here to avoid circular import from megatron.core.extensions.transformer_engine import TEFusedMLP - from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.moe_layer import MoELayer # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs - if isinstance(submodules.mlp, ModuleSpec): - if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): - additional_mlp_kwargs["pg_collection"] = pg_collection - # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer: - additional_mlp_kwargs["is_mtp_layer"] = self.is_mtp_layer - elif submodules.mlp.module == MLP: - assert hasattr( - pg_collection, 'tp' - ), 'TP process group is required for MLP in TransformerLayer' - additional_mlp_kwargs["tp_group"] = pg_collection.tp - elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: - assert hasattr( - pg_collection, 'tp' - ), 'TP process group is required for TEFusedMLP in TransformerLayer' - additional_mlp_kwargs["tp_group"] = pg_collection.tp - else: - log_single_rank( - logger, - logging.WARNING, - f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.", - ) - self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + if isinstance(submodules.mlp, ModuleSpec) and submodules.mlp.module in (MLP, TEFusedMLP): + submodules.mlp = functools.partial( + submodules.mlp.module.as_mlp_submodule, + submodules=submodules.mlp.submodules, + **submodules.mlp.params, + ) + log_single_rank( + logger, + logging.WARNING, + f"Rewrapping ModuleSpec with module {type(submodules.mlp)} to forward kwargs. " + "Consider migrating the `mlp` submodule spec to a direct call of the " + "`as_mlp_submodule` classmethod instead.", + ) + self.mlp = submodules.mlp( + config=self.config, pg_collection=pg_collection, is_mtp_layer=self.is_mtp_layer + ) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) @@ -784,7 +802,7 @@ def _forward_mlp( from megatron.core.extensions.transformer_engine import te_checkpoint mlp_output_with_bias = te_checkpoint( - self.mlp, + apply_module(self.mlp), False, tensor_parallel.random.get_cuda_rng_tracker, self.pg_collection.tp, @@ -793,7 +811,7 @@ def _forward_mlp( ) else: mlp_output_with_bias = tensor_parallel.checkpoint( - functools.partial(self.mlp, padding_mask=padding_mask), + functools.partial(apply_module(self.mlp), padding_mask=padding_mask), False, pre_mlp_layernorm_output, ) @@ -803,7 +821,7 @@ def _forward_mlp( chunks = pre_mlp_layernorm_output.chunk(num_chunks, dim=0) # Compute outputs for each chunk - outputs = [self.mlp(chunk) for chunk in chunks] + outputs = [apply_module(self.mlp)(chunk) for chunk in chunks] # Aggregate chunk outputs mlp_output = torch.cat([out for out, _ in outputs], dim=0) @@ -815,7 +833,9 @@ def _forward_mlp( # Set the residual for fused reduce-scatter + add + layer-norm + all-gather # operation in MLP's fc2. self._set_fc2_residual(residual) - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + mlp_output_with_bias = apply_module(self.mlp)( + pre_mlp_layernorm_output, padding_mask=padding_mask + ) nvtx_range_pop(suffix="mlp") @@ -1157,7 +1177,7 @@ def _te_cuda_graph_replay(self, *args, **kwargs): hidden_states, probs = self.mlp.preprocess(hidden_states, probs, routing_map) nvtx_range_pop(suffix="mlp") return residual, hidden_states, probs, shared_expert_output - mlp_output_with_bias = self.mlp(hidden_states) + mlp_output_with_bias = apply_module(self.mlp)(hidden_states) self.mlp.cudagraph_tensor_store.clear() nvtx_range_pop(suffix="mlp") @@ -1407,7 +1427,7 @@ def _forward_mlp_router(self, hidden_states, padding_mask=None): if self.config.fp32_residual_connection: residual = residual.float() - router_outputs = self.mlp( + router_outputs = apply_module(self.mlp)( pre_mlp_layernorm_output, intermediate_tensors=(), padding_mask=padding_mask ) @@ -1441,7 +1461,7 @@ def _forward_mlp_expert_compute(self, hidden_states, probs): setattr(obj, hier_attr_name[-1], attr) self.mlp.fwd_execution_map = "expert_compute" - return self.mlp(None, intermediate_tensors=(hidden_states, probs)) + return apply_module(self.mlp)(None, intermediate_tensors=(hidden_states, probs)) def _forward_mlp_postprocess(self, residual, output, shared_expert_output, mlp_bias): """ @@ -1460,7 +1480,7 @@ def _forward_mlp_postprocess(self, residual, output, shared_expert_output, mlp_b setattr(self.mlp.token_dispatcher, name, attr) self.mlp.fwd_execution_map = "postprocess" - output = self.mlp(None, intermediate_tensors=(output, shared_expert_output)) + output = apply_module(self.mlp)(None, intermediate_tensors=(output, shared_expert_output)) return self._forward_post_mlp((output, mlp_bias), residual) def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): diff --git a/pretrain_vlm.py b/pretrain_vlm.py index 9da1afa669f..dcbdbd05e25 100644 --- a/pretrain_vlm.py +++ b/pretrain_vlm.py @@ -6,20 +6,12 @@ import torch -from megatron.core import mpu, parallel_state, tensor_parallel -from megatron.core.datasets.blended_megatron_dataset_builder import ( - BlendedMegatronDatasetBuilder, -) -from megatron.core.datasets.multimodal_dataset import ( - MockMultimodalDataset, - MultimodalDatasetConfig, -) +from megatron.core import parallel_state, tensor_parallel +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig from megatron.core.enums import ModelType from megatron.core.models.multimodal import context_parallel -from megatron.core.models.multimodal.llava_model import ( - DEFAULT_IMAGE_TOKEN_INDEX, - LLaVAModel, -) +from megatron.core.models.multimodal.llava_model import DEFAULT_IMAGE_TOKEN_INDEX, LLaVAModel from megatron.core.models.multimodal.llava_spec import ( decoder_model_with_local_default_spec, decoder_model_with_transformer_engine_default_spec, @@ -30,14 +22,8 @@ get_vit_layer_with_transformer_engine_spec, ) from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.spec_utils import import_module -from megatron.training import ( - get_args, - get_timers, - get_tokenizer, - pretrain, - print_rank_0, -) +from megatron.core.transformer.spec_utils import get_submodules, import_module +from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from pretrain_gpt import loss_func @@ -70,13 +56,23 @@ def model_provider( args = get_args() vision_model_type = "clip" - assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently." - assert not (args.context_parallel_size > 1 and args.pipeline_model_parallel_size > 1), "PP+CP is not yet supported by this script. \ + assert ( + args.ckpt_format == 'torch' + ), "Only ckpt-format torch is supported for VLM training currently." + assert not ( + args.context_parallel_size > 1 and args.pipeline_model_parallel_size > 1 + ), "PP+CP is not yet supported by this script. \ Current mock dataset does not support natively packed sequence dataset required for correct PP comm shapes." num_image_embeddings = get_num_image_embeddings( - args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, - class_token_len=1, pixel_shuffle=False, use_tile_tags=False + args.img_h, + args.img_w, + args.patch_dim, + vision_model_type, + args.disable_vision_class_token, + class_token_len=1, + pixel_shuffle=False, + use_tile_tags=False, ) old_seq_length = args.seq_length @@ -99,7 +95,7 @@ def model_provider( args.tensor_model_parallel_size, args.sequence_parallel, args.decoder_tp_comm_overlap, - args.decoder_seq_length + args.decoder_seq_length, ) args.decoder_seq_length = decoder_seq_len + mp_padding_needed @@ -115,8 +111,9 @@ def model_provider( else: language_transformer_config.num_layers = args.num_layers if args.decoder_tp_comm_overlap: - assert args.transformer_impl == "transformer_engine", \ - "TransformerEngine is needed to support Decoder TP Comm overlap" + assert ( + args.transformer_impl == "transformer_engine" + ), "TransformerEngine is needed to support Decoder TP Comm overlap" language_transformer_config.tp_comm_overlap = args.decoder_tp_comm_overlap if args.spec is not None: @@ -132,10 +129,24 @@ def model_provider( # Prepare mask type for any required padding to support CP/SP sequence sharding. if mp_padding_needed > 0: - if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal: - language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding_causal - elif language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.no_mask: - language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding + if ( + language_transformer_layer_spec.submodules.self_attention.params.get( + 'attn_mask_type', '' + ) + == AttnMaskType.causal + ): + language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = ( + AttnMaskType.padding_causal + ) + elif ( + language_transformer_layer_spec.submodules.self_attention.params.get( + 'attn_mask_type', '' + ) + == AttnMaskType.no_mask + ): + language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = ( + AttnMaskType.padding + ) if args.transformer_impl == "transformer_engine": vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec() @@ -148,7 +159,7 @@ def model_provider( vision_transformer_config.first_pipeline_num_layers = None vision_transformer_config.last_pipeline_num_layers = None vision_transformer_config.vision_model_type = vision_model_type - vision_transformer_config.context_parallel_size = 1 # Force CP=1 for Vision Transformer + vision_transformer_config.context_parallel_size = 1 # Force CP=1 for Vision Transformer if vision_transformer_config.sequence_parallel: print_rank_0("> Disabling Sequence parallelism in Vision Transformer. Not yet supported") vision_transformer_config.sequence_parallel = False @@ -158,7 +169,7 @@ def model_provider( vision_projection_type = "mlp" vision_projection_config = deepcopy(language_transformer_config) - vision_projection_config.context_parallel_size = 1 # Force CP=1 for Vision Projection + vision_projection_config.context_parallel_size = 1 # Force CP=1 for Vision Projection if vision_projection_config.sequence_parallel: print_rank_0("> Disabling Sequence parallelism in Vision Projection. Not yet supported") vision_projection_config.sequence_parallel = False @@ -170,7 +181,9 @@ def model_provider( vision_transformer_config.pipeline_model_parallel_size = 1 vision_projection_config.pipeline_model_parallel_size = 1 - vision_projection_modules = deepcopy(language_transformer_layer_spec.submodules.mlp.submodules) + vision_projection_modules = deepcopy( + get_submodules(language_transformer_layer_spec.submodules.mlp) + ) language_max_sequence_length = args.decoder_seq_length if args.context_parallel_size > 1: @@ -320,41 +333,47 @@ def get_batch(data_iterator): vision_model_type = "clip" # Calculate the number of image embedding tokens will be added to text tokens num_image_embeddings_per_tile = get_num_image_embeddings( - args.img_h, args.img_w, args.patch_dim, vision_model_type, - args.disable_vision_class_token, 1, False + args.img_h, + args.img_w, + args.patch_dim, + vision_model_type, + args.disable_vision_class_token, + 1, + False, ) # Pad to make sure the text sequence can be sharded equally by CP chunks. image_token_mask = tokens == DEFAULT_IMAGE_TOKEN_INDEX num_images_per_sample = torch.sum(image_token_mask, dim=-1) - img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max() + img_seq_len = ( + num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample + ).max() mp_padding_needed_for_text = context_parallel.get_padding( tokens.shape[1] + img_seq_len, args.context_parallel_size, args.tensor_model_parallel_size, args.sequence_parallel, args.decoder_tp_comm_overlap, - args.decoder_seq_length + args.decoder_seq_length, ) if mp_padding_needed_for_text > 0: - tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)] - packed_seq_params = context_parallel.get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text, cp_size, args.use_packed_sequence) + tokens, position_ids, labels, loss_mask = [ + torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) + for item in (tokens, position_ids, labels, loss_mask) + ] + packed_seq_params = context_parallel.get_packed_seq_params( + tokens, img_seq_len, mp_padding_needed_for_text, cp_size, args.use_packed_sequence + ) if packed_seq_params.qkv_format == 'thd': # Reshape from [B,S] to [T,1] - tokens = ( - tokens.contiguous() - .view(tokens.shape[0] * tokens.shape[1]) - .unsqueeze(0) - ) + tokens = tokens.contiguous().view(tokens.shape[0] * tokens.shape[1]).unsqueeze(0) position_ids = ( position_ids.contiguous() .view(position_ids.shape[0] * position_ids.shape[1]) .unsqueeze(0) ) labels = labels.view(labels.shape[0] * labels.shape[1]).unsqueeze(0) - loss_mask = loss_mask.view( - loss_mask.shape[0] * loss_mask.shape[1] - ).unsqueeze(0) + loss_mask = loss_mask.view(loss_mask.shape[0] * loss_mask.shape[1]).unsqueeze(0) attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model. @@ -376,11 +395,19 @@ def forward_step(data_iterator, model: LLaVAModel): # Get the batch. timers('batch-generator', log_level=2).start() - tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params = get_batch(data_iterator) + tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params = get_batch( + data_iterator + ) timers('batch-generator').stop() output_tensor, loss_mask = model( - images, tokens, position_ids, attention_mask, labels, loss_mask, packed_seq_params=packed_seq_params + images, + tokens, + position_ids, + attention_mask, + labels, + loss_mask, + packed_seq_params=packed_seq_params, ) return output_tensor, partial(loss_func, loss_mask) @@ -401,15 +428,21 @@ def add_vlm_extra_args(parser): default=False, help="Drop vision model class token", ) - group.add_argument("--dataloader-seq-length", type=int, help="Make dataloader to produce sequences of specific length.") - group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of " - "Tensor parallel communication and GEMM kernels in Decoder only. " - "Please provide decoder-seq-length when using this feature.") group.add_argument( - "--use-packed-sequence", + "--dataloader-seq-length", + type=int, + help="Make dataloader to produce sequences of specific length.", + ) + group.add_argument( + "--decoder-tp-comm-overlap", action="store_true", default=False, - help="Use packed sequence", + help="Enables the overlap of " + "Tensor parallel communication and GEMM kernels in Decoder only. " + "Please provide decoder-seq-length when using this feature.", + ) + group.add_argument( + "--use-packed-sequence", action="store_true", default=False, help="Use packed sequence" ) return parser diff --git a/tests/functional_tests/test_cases/common/moe_perf/__main__.py b/tests/functional_tests/test_cases/common/moe_perf/__main__.py index f1dea5f93c1..64671b36714 100644 --- a/tests/functional_tests/test_cases/common/moe_perf/__main__.py +++ b/tests/functional_tests/test_cases/common/moe_perf/__main__.py @@ -20,10 +20,10 @@ get_gpt_layer_with_transformer_engine_submodules, ) from megatron.core.transformer.moe.fused_a2a import HAVE_DEEP_EP, HAVE_HYBRIDEP -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.moe_utils import RandomSTE +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import is_te_min_version from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils @@ -86,10 +86,12 @@ def _build_transformer_config(case: MoEPerformanceCase) -> TransformerConfig: # NOTE: Only TE backend is covered in this test. -def _resolve_moe_submodules(case: MoEPerformanceCase): - return get_gpt_layer_with_transformer_engine_submodules( - num_experts=case.model.num_experts, moe_grouped_gemm=True - ).mlp.submodules +def _resolve_moe_submodules(case: MoEPerformanceCase) -> MoESubmodules: + return get_submodules( + get_gpt_layer_with_transformer_engine_submodules( + num_experts=case.model.num_experts, moe_grouped_gemm=True + ).mlp + ) def _load_baselines() -> Dict[str, Dict[str, float]]: diff --git a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py index 037e368ea2f..7d0f3c1b3a3 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py +++ b/tests/unit_tests/dist_checkpointing/models/test_mlp_glu.py @@ -1,23 +1,18 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import inspect import logging import pytest import torch -from torch.optim import Adam from megatron.core import parallel_state from megatron.core.dist_checkpointing import ShardedTensor, load, load_plain_tensors, save -from megatron.core.dist_checkpointing.dict_utils import diff, nested_values -from megatron.core.dist_checkpointing.optimizer import ( - get_param_id_to_sharded_param_map, - optim_state_to_sharding_state, -) +from megatron.core.dist_checkpointing.dict_utils import diff from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_with_transformer_engine_submodules, ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLP, apply_swiglu_sharded_factory +from megatron.core.transformer.mlp import MLP, MLPSubmodules, apply_swiglu_sharded_factory +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.dist_checkpointing import TempNamedDir from tests.unit_tests.test_utilities import Utils @@ -33,9 +28,9 @@ def initialize_mlp(glu=True): use_cpu_initialization=True, gated_linear_unit=glu, ) - return MLP( - transformer_config, get_gpt_layer_with_transformer_engine_submodules().mlp.submodules - ) + mlp_submodules = get_submodules(get_gpt_layer_with_transformer_engine_submodules().mlp) + assert isinstance(mlp_submodules, MLPSubmodules) + return MLP(transformer_config, mlp_submodules) class TestParallelMLPWithGLU: diff --git a/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py b/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py index 15c479774c6..6461dbee751 100644 --- a/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py +++ b/tests/unit_tests/distributed/test_grad_sync_with_expert_parallel.py @@ -12,7 +12,8 @@ get_gpt_layer_with_transformer_engine_submodules, ) from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from tests.unit_tests.test_utilities import Utils @@ -42,15 +43,15 @@ def __init__( params_dtype=torch.bfloat16, add_bias_linear=False, ) - submodules = get_gpt_layer_with_transformer_engine_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm + submodules = get_submodules( + get_gpt_layer_with_transformer_engine_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm + ).mlp ) + assert isinstance(submodules, MoESubmodules) super().__init__() self.layers = torch.nn.ModuleList( - [ - MoELayer(transformer_config, submodules.mlp.submodules).cuda() - for _ in range(num_layers) - ] + [MoELayer(transformer_config, submodules).cuda() for _ in range(num_layers)] ) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py index 5344ae9f8eb..8b3962dd7b7 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_vlm_text_generation_controller.py @@ -23,8 +23,9 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.module import Float16Module -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.spec_utils import ModuleSpec, get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from tests.unit_tests.test_utilities import Utils @@ -71,7 +72,8 @@ def setup_method(self, method): vision_layer_spec = ModuleSpec( module=TransformerLayer, submodules=copy.deepcopy(language_layer_submodules) ) - vision_projection_spec = copy.deepcopy(language_layer_submodules.mlp.submodules) + vision_projection_spec = copy.deepcopy(get_submodules(language_layer_submodules.mlp)) + assert isinstance(vision_projection_spec, MLPSubmodules) language_config.language_model_type = "dummy" vision_config.vision_model_type = "clip" diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py index 970b0e54635..2a8fd553f85 100644 --- a/tests/unit_tests/models/test_llava_model.py +++ b/tests/unit_tests/models/test_llava_model.py @@ -15,7 +15,8 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec, get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import is_te_min_version @@ -53,7 +54,8 @@ def setup_method(self, method): vision_layer_spec = ModuleSpec( module=TransformerLayer, submodules=deepcopy(language_layer_submodules) ) - vision_projection_spec = deepcopy(language_layer_submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) + assert isinstance(vision_projection_spec, MLPSubmodules) language_config.language_model_type = "dummy" vision_config.vision_model_type = "clip" @@ -491,7 +493,7 @@ def setup_and_teardown_llava_model(request): vision_layer_spec = ModuleSpec( module=TransformerLayer, submodules=deepcopy(language_layer_submodules) ) - vision_projection_spec = deepcopy(language_layer_submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) language_config.language_model_type = "dummy" vision_model_type = request.param @@ -601,7 +603,7 @@ def _init_llava_model(self, cp_size, tp_size, sequence_parallel): vision_layer_spec = ModuleSpec( module=TransformerLayer, submodules=deepcopy(language_layer_submodules) ) - vision_projection_spec = deepcopy(language_layer_submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) language_config.language_model_type = "dummy" vision_config.vision_model_type = "clip" diff --git a/tests/unit_tests/models/test_multimodal_projector.py b/tests/unit_tests/models/test_multimodal_projector.py index 52fda330c2e..5df34954b6a 100644 --- a/tests/unit_tests/models/test_multimodal_projector.py +++ b/tests/unit_tests/models/test_multimodal_projector.py @@ -20,7 +20,7 @@ def setup_method(self, method): transformer_config = TransformerConfig( num_layers=1, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True ) - mlp_layer_spec = get_mlp_module_spec().submodules + mlp_layer_spec = get_mlp_module_spec().keywords['submodules'] affine_layer_spec = MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=None) self.mlp = MultimodalProjector( diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index dc554612811..84d6b3ae4eb 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -20,7 +20,8 @@ ) from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from tests.unit_tests.test_utilities import Utils success_string = "hello,world" @@ -315,12 +316,11 @@ def test_param_norm_moe(use_distributed_optimizer: bool): add_bias_linear=False, bf16=True, ) - model = MoELayer( - transformer_config, - get_gpt_layer_with_transformer_engine_submodules( - num_experts=2, moe_grouped_gemm=True - ).mlp.submodules, - ).to(device='cuda') + submodules = get_submodules( + get_gpt_layer_with_transformer_engine_submodules(num_experts=2, moe_grouped_gemm=True).mlp + ) + assert isinstance(submodules, MoESubmodules) + model = MoELayer(transformer_config, submodules).to(device='cuda') model.requires_grad_(True) # Initialize the model with all 1.0 for weights. for param in model.parameters(): diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 58ca0cf1e3f..b1f15a9eef1 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -10,7 +10,8 @@ ) from megatron.core.transformer.module import Float16Module from megatron.core.transformer.moe.experts import TEGroupedMLP -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version from megatron.training.arguments import parse_args @@ -58,8 +59,11 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): ## Vanilla sequential GEMM # Set random seed for reproducability _set_random_seed(seed_=123, data_parallel_random_init=False) - submodules = get_gpt_layer_local_submodules(self.num_experts, moe_grouped_gemm=False) - self.sequential_mlp = MoELayer(tf_config, submodules.mlp.submodules) + sequential_submodules = get_submodules( + get_gpt_layer_local_submodules(self.num_experts, moe_grouped_gemm=False).mlp + ) + assert isinstance(sequential_submodules, MoESubmodules) + self.sequential_mlp = MoELayer(tf_config, sequential_submodules) self.args = parse_args(ignore_unknown_args=True) self.args.bf16 = True @@ -71,12 +75,13 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): ## Grouped GEMM _set_random_seed(seed_=123, data_parallel_random_init=False) tf_config.moe_grouped_gemm = True - self.grouped_mlp = MoELayer( - tf_config, + grouped_submodules = get_submodules( get_gpt_layer_with_transformer_engine_submodules( self.num_experts, moe_grouped_gemm=True - ).mlp.submodules, + ).mlp ) + assert isinstance(grouped_submodules, MoESubmodules) + self.grouped_mlp = MoELayer(tf_config, grouped_submodules) assert isinstance(self.grouped_mlp.experts, TEGroupedMLP) self.grouped_mlp = Float16Module(self.grouped_mlp.config, self.grouped_mlp).module diff --git a/tests/unit_tests/transformer/moe/test_latent_moe_layer.py b/tests/unit_tests/transformer/moe/test_latent_moe_layer.py index f62de67860a..bb5ced291fc 100644 --- a/tests/unit_tests/transformer/moe/test_latent_moe_layer.py +++ b/tests/unit_tests/transformer/moe/test_latent_moe_layer.py @@ -4,14 +4,11 @@ import torch from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_decoder_block_spec, - get_gpt_layer_local_spec, - get_gpt_layer_with_transformer_engine_spec, + get_gpt_layer_local_submodules, + get_gpt_layer_with_transformer_engine_submodules, ) -from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.router import Router -from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version from megatron.training.initialize import _set_random_seed @@ -53,16 +50,16 @@ def test_latent_moe_layer( moe_latent_size=moe_latent_size, ) if use_te: - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_submodules = get_gpt_layer_with_transformer_engine_submodules( num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm ) else: - transformer_layer_spec = get_gpt_layer_local_spec( + transformer_layer_submodules = get_gpt_layer_local_submodules( num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm ) - moe_layer = MoELayer( - self.transformer_config, transformer_layer_spec.submodules.mlp.submodules - ) + submodules = get_submodules(transformer_layer_submodules.mlp) + assert isinstance(submodules, MoESubmodules) + moe_layer = MoELayer(self.transformer_config, submodules) moe_layer.cuda() config = moe_layer.config diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py index 0004b7fef98..c4a3283ac5b 100644 --- a/tests/unit_tests/transformer/moe/test_moe_layer.py +++ b/tests/unit_tests/transformer/moe/test_moe_layer.py @@ -9,7 +9,8 @@ get_gpt_layer_with_transformer_engine_submodules, ) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version @@ -44,10 +45,13 @@ def test_te_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_ moe_ffn_hidden_size=128, add_bias_linear=False, ) - submodules = get_gpt_layer_with_transformer_engine_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + submodules = get_submodules( + get_gpt_layer_with_transformer_engine_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + ).mlp ) - moe_layer = MoELayer(self.transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + moe_layer = MoELayer(self.transformer_config, submodules) Utils.destroy_model_parallel() @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) @@ -70,10 +74,13 @@ def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grou moe_grouped_gemm=grouped_gemm, add_bias_linear=False, ) - transformer_layer_submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + submodules = get_submodules( + get_gpt_layer_local_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + ).mlp ) - moe_layer = MoELayer(self.transformer_config, transformer_layer_submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + moe_layer = MoELayer(self.transformer_config, submodules) Utils.destroy_model_parallel() @pytest.mark.skip( @@ -105,15 +112,18 @@ def test_moe_with_late_initialize( bf16=True, params_dtype=torch.bfloat16, ) - submodules = get_gpt_layer_with_transformer_engine_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + submodules = get_submodules( + get_gpt_layer_with_transformer_engine_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm + ).mlp ) + assert isinstance(submodules, MoESubmodules) # Fake initialization as NeMo does Utils.fake_initialize_model_parallel( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size ) - moe_layer = MoELayer(transformer_config, submodules.mlp.submodules).cuda() + moe_layer = MoELayer(transformer_config, submodules).cuda() Utils.initialize_model_parallel( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size @@ -229,11 +239,12 @@ def test_moe_layer_fp16_forward_backward( params_dtype=torch.float16, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) + assert isinstance(submodules, MoESubmodules) - moe_layer = MoELayer(transformer_config, submodules.mlp.submodules).cuda() + moe_layer = MoELayer(transformer_config, submodules).cuda() hidden_states = torch.randn( sequence_length, @@ -332,15 +343,20 @@ def test_moe_layer_recompute_forward_backward( # Use TE spec for fp8, local spec otherwise if fp8: - transformer_layer_submodules = get_gpt_layer_with_transformer_engine_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_with_transformer_engine_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=False + ).mlp ) else: - transformer_layer_submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules( + num_experts=num_moe_experts, moe_grouped_gemm=False + ).mlp ) + assert isinstance(submodules, MoESubmodules) - moe_layer = MoELayer(transformer_config, transformer_layer_submodules.mlp.submodules).cuda() + moe_layer = MoELayer(transformer_config, submodules).cuda() hidden_states = torch.randn( sequence_length, diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index 8f3dbbe96e0..a03766d668f 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -7,9 +7,10 @@ import torch from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.moe_utils import get_updated_expert_bias, router_gating_linear from megatron.core.transformer.moe.router import Router +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils @@ -44,10 +45,11 @@ def setup_method(self, method): params_dtype=torch.bfloat16, add_bias_linear=False, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.sequential_mlp = MoELayer(self.transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + self.sequential_mlp = MoELayer(self.transformer_config, submodules) self.router = cast(Router, self.sequential_mlp.router) def teardown_method(self, method): @@ -313,10 +315,11 @@ def setup_method(self, method): ) # init MoE layer - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.moe_layer = MoELayer(self.transformer_config, submodules.mlp.submodules).cuda() + assert isinstance(submodules, MoESubmodules) + self.moe_layer = MoELayer(self.transformer_config, submodules).cuda() self.router = cast(Router, self.moe_layer.router) def teardown_method(self, method): @@ -418,10 +421,11 @@ def setup_method(self, method): params_dtype=torch.bfloat16, add_bias_linear=False, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.moe_layer = MoELayer(self.transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + self.moe_layer = MoELayer(self.transformer_config, submodules) self.router = cast(Router, self.moe_layer.router) assert self.router.expert_bias is not None assert self.router.local_tokens_per_expert is not None @@ -464,11 +468,14 @@ def test_router_forward_aux_free(self): def test_router_forward_fusion_equivalence(self, score_function): with torch.no_grad(): # Build two fresh routers to avoid bias update interference - submodules = get_gpt_layer_local_submodules( - num_experts=self.transformer_config.num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules( + num_experts=self.transformer_config.num_moe_experts, moe_grouped_gemm=False + ).mlp ) - moe_layer_ref = MoELayer(self.transformer_config, submodules.mlp.submodules) - moe_layer_fused = MoELayer(self.transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + moe_layer_ref = MoELayer(self.transformer_config, submodules) + moe_layer_fused = MoELayer(self.transformer_config, submodules) router_ref = moe_layer_ref.router.cuda() router_fused = moe_layer_fused.router.cuda() diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py index e618f6a8318..6a0637a4e65 100644 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py @@ -1,6 +1,4 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -from importlib.metadata import version - import pytest import torch @@ -10,8 +8,9 @@ from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.experts import SequentialMLP -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.moe_utils import get_default_pg_collection +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version from tests.unit_tests.test_utilities import Utils @@ -37,10 +36,11 @@ def setup_method(self, method): moe_router_topk=1, add_bias_linear=False, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.sequential_mlp = MoELayer(transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + self.sequential_mlp = MoELayer(transformer_config, submodules) def teardown_method(self, method): Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/moe/test_shared_experts.py b/tests/unit_tests/transformer/moe/test_shared_experts.py index d99dc1a0b05..245caf355cc 100644 --- a/tests/unit_tests/transformer/moe/test_shared_experts.py +++ b/tests/unit_tests/transformer/moe/test_shared_experts.py @@ -5,7 +5,8 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -41,10 +42,11 @@ def test_gpu_forward(self, shared_expert_gate): add_bias_linear=False, moe_shared_expert_gate=shared_expert_gate, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.moe_layer = MoELayer(transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + self.moe_layer = MoELayer(transformer_config, submodules) assert isinstance(self.moe_layer, MoELayer) @@ -101,10 +103,11 @@ def test_gpu_forward(self): moe_router_topk=1, add_bias_linear=False, ) - submodules = get_gpt_layer_local_submodules( - num_experts=num_moe_experts, moe_grouped_gemm=False + submodules = get_submodules( + get_gpt_layer_local_submodules(num_experts=num_moe_experts, moe_grouped_gemm=False).mlp ) - self.moe_layer = MoELayer(transformer_config, submodules.mlp.submodules) + assert isinstance(submodules, MoESubmodules) + self.moe_layer = MoELayer(transformer_config, submodules) assert isinstance(self.moe_layer, MoELayer) diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index 6ff8fcdc6e5..bdf28359c8b 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -1,6 +1,5 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -import copy import dataclasses import pytest @@ -8,8 +7,9 @@ from megatron.core import config, parallel_state from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.moe.moe_utils import get_capacity +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module from megatron.core.utils import is_te_min_version @@ -99,11 +99,15 @@ def __init__( self.moe_layer = self.new_moe_layer() def new_moe_layer(self, **kargs): - submodules = get_gpt_layer_local_submodules( - num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm + submodules = get_submodules( + get_gpt_layer_local_submodules( + num_experts=self.config.num_moe_experts, + moe_grouped_gemm=self.config.moe_grouped_gemm, + ).mlp ) + assert isinstance(submodules, MoESubmodules) new_config = dataclasses.replace(self.config, **kargs) - moe_layer = MoELayer(new_config, submodules.mlp.submodules).cuda().to(dtype=self.test_dtype) + moe_layer = MoELayer(new_config, submodules).cuda().to(dtype=self.test_dtype) moe_layer.set_layer_number(0) return moe_layer diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index bfde9ff9cf1..3c840105d4a 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -12,6 +12,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_layer_with_transformer_engine_submodules, get_gpt_mtp_block_spec, ) from megatron.core.models.gpt.gpt_model import GPTModel @@ -35,11 +36,13 @@ _CudagraphGlobalRecord, ) from megatron.core.transformer.enums import CudaGraphScope +from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.fused_a2a import reset_hybrid_ep_buffer +from megatron.core.transformer.spec_utils import ModuleSpec, get_submodules from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules -from megatron.core.utils import is_fa_min_version, is_te_min_version +from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.utils import is_te_min_version from megatron.training.arguments import core_transformer_config_from_args, parse_args, validate_args from megatron.training.global_vars import ( destroy_global_vars, @@ -327,10 +330,10 @@ def setup_method(self, method): ) # Get layer specs - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + language_layer_submodules = get_gpt_layer_with_transformer_engine_submodules() vision_layer_spec = get_vit_layer_with_transformer_engine_spec() - assert isinstance(language_layer_spec.submodules, TransformerLayerSubmodules) - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) + assert isinstance(vision_projection_spec, MLPSubmodules) # Set vision model type vision_config.vision_model_type = "clip" @@ -339,7 +342,9 @@ def setup_method(self, method): # Create LLaVA model with both encoder and decoder self.llava_model = LLaVAModel( language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, + language_transformer_layer_spec=ModuleSpec( + module=TransformerLayer, submodules=language_layer_submodules + ), language_vocab_size=8192, language_max_sequence_length=4096, vision_transformer_config=vision_config, diff --git a/tests/unit_tests/transformer/test_mlp.py b/tests/unit_tests/transformer/test_mlp.py index 45e0df9ad0a..46b22a1bad8 100644 --- a/tests/unit_tests/transformer/test_mlp.py +++ b/tests/unit_tests/transformer/test_mlp.py @@ -6,7 +6,8 @@ from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_submodules from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import get_submodules from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils @@ -19,7 +20,9 @@ def setup_method(self, method): transformer_config = TransformerConfig( num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True ) - self.mlp = MLP(transformer_config, get_gpt_layer_local_submodules().mlp.submodules) + mlp_submodules = get_submodules(get_gpt_layer_local_submodules().mlp) + assert isinstance(mlp_submodules, MLPSubmodules) + self.mlp = MLP(transformer_config, mlp_submodules) def teardown_method(self, method): Utils.destroy_model_parallel() diff --git a/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py b/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py index c55babe35ca..c415b6dda26 100644 --- a/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py +++ b/tests/unit_tests/transformer/test_transformer_block_custom_pgs.py @@ -1,5 +1,5 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - +from __future__ import annotations import copy import os @@ -61,20 +61,18 @@ def __init__( config: TransformerConfig, submodules: TransformerLayerSubmodules, layer_number: int = 1, - hidden_dropout: Optional[float] = None, - pg_collection: ProcessGroupCollection = None, - vp_stage: Optional[int] = None, + hidden_dropout: float | None = None, + pg_collection: ProcessGroupCollection | None = None, + vp_stage: int | None = None, ): - # Temporarily replace attention and MLP with IdentityOp, + # Temporarily replace attention with IdentityOp, # This is a temporary workaround for the test until we have a better interface # will rebuild them with custom process groups after super init def _modify_submodules(submodules: TransformerLayerSubmodules): submodules.self_attention = IdentityOp - submodules.mlp = IdentityOp return submodules original_attention = submodules.self_attention - original_mlp = submodules.mlp new_submodules = _modify_submodules(copy.copy(submodules)) super().__init__( @@ -92,10 +90,6 @@ def _modify_submodules(submodules: TransformerLayerSubmodules): self.self_attention = build_module( original_attention, config=self.config, layer_number=layer_number ) - assert ( - 'tp_group' in submodules.mlp.params - ), "tp_group should be in the params of the submodules" - self.mlp = build_module(original_mlp, config=self.config) def create_reference_mlp(hidden_size, ffn_hidden_size, seed=12345): @@ -168,7 +162,22 @@ def copy_weights_to_tp_mlp(ref_mlp, tp_mlp, tp_group): tp_mlp.linear_fc2.bias.copy_(ref_fc2.bias.to(tp_mlp.linear_fc2.bias.device)) -def _gpt_te_layer_spec_with_hetro_pgs(attn_pg_collection, mlp_pg_collection): +def _gpt_te_layer_spec_with_hetro_pgs( + attn_pg_collection, mlp_pg_collection: ProcessGroupCollection +): + + def build_mlp( + config: TransformerConfig, pg_collection: ProcessGroupCollection, is_mtp_layer: bool + ): + del pg_collection, is_mtp_layer + return MLP( + config, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + tp_group=mlp_pg_collection.tp, + ) + return ModuleSpec( module=HeterogenousTransformerLayer, submodules=TransformerLayerSubmodules( @@ -183,13 +192,7 @@ def _gpt_te_layer_spec_with_hetro_pgs(attn_pg_collection, mlp_pg_collection): ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=IdentityOp, - mlp=ModuleSpec( - module=MLP, - params={'tp_group': mlp_pg_collection.tp}, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ), + mlp=build_mlp, mlp_bda=get_bias_dropout_add, ), ) diff --git a/tests/unit_tests/transformer/test_vision_cuda_graphs.py b/tests/unit_tests/transformer/test_vision_cuda_graphs.py index bfd431e67a3..30aea20be4b 100644 --- a/tests/unit_tests/transformer/test_vision_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_vision_cuda_graphs.py @@ -1,7 +1,6 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import gc -import os from copy import deepcopy from types import SimpleNamespace from unittest.mock import MagicMock @@ -10,7 +9,10 @@ import torch from megatron.core import parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, + get_gpt_layer_with_transformer_engine_submodules, +) from megatron.core.models.vision.vit_layer_specs import get_vit_layer_with_transformer_engine_spec from megatron.core.tensor_parallel.random import ( HAVE_TE, @@ -23,9 +25,11 @@ _layer_is_graphable, _wrap_graph_for_vision, get_vision_cuda_graph_seq_length, - set_current_microbatch, ) +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec, get_submodules from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import is_te_min_version from tests.unit_tests.test_utilities import Utils @@ -236,16 +240,19 @@ def setup_method(self, method): pipeline_dtype=torch.bfloat16, ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + language_layer_submodules = get_gpt_layer_with_transformer_engine_submodules() vision_layer_spec = get_vit_layer_with_transformer_engine_spec() - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) + assert isinstance(vision_projection_spec, MLPSubmodules) self.vision_config.vision_model_type = "clip" language_config.language_model_type = "dummy" self.llava_model = LLaVAModel( language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, + language_transformer_layer_spec=ModuleSpec( + module=TransformerLayer, submodules=language_layer_submodules + ), language_vocab_size=8192, language_max_sequence_length=4096, vision_transformer_config=self.vision_config, @@ -494,9 +501,10 @@ def setup_method(self, method): pipeline_dtype=torch.bfloat16, ) - language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + language_layer_submodules = get_gpt_layer_with_transformer_engine_submodules() vision_layer_spec = get_vit_layer_with_transformer_engine_spec() - vision_projection_spec = deepcopy(language_layer_spec.submodules.mlp.submodules) + vision_projection_spec = deepcopy(get_submodules(language_layer_submodules.mlp)) + assert isinstance(vision_projection_spec, MLPSubmodules) self.vision_config.vision_model_type = "clip" language_config.language_model_type = "dummy" @@ -504,7 +512,9 @@ def setup_method(self, method): self.is_first_stage = is_first_stage self.llava_model = LLaVAModel( language_transformer_config=language_config, - language_transformer_layer_spec=language_layer_spec, + language_transformer_layer_spec=ModuleSpec( + module=TransformerLayer, submodules=language_layer_submodules + ), language_vocab_size=8192, language_max_sequence_length=4096, vision_transformer_config=self.vision_config,