Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions examples/multimodal/layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from functools import partial

import torch

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
Expand Down Expand Up @@ -170,8 +172,8 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
mlp_layer=ModuleSpec(
module=MLPLayer,
submodules=TransformerLayerSubmodules(
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand All @@ -184,20 +186,20 @@ def get_mamba_layer_spec_te(padding=False) -> ModuleSpec:
)


def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
def get_mlp_module_spec(use_te: bool = True):
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
return partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TEColumnParallelLinear) if use_te else ColumnParallelLinear,
linear_fc2=not_none(TERowParallelLinear) if use_te else RowParallelLinear,
),
)


def get_norm_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
def get_norm_mlp_module_spec_te():
return partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand Down
6 changes: 3 additions & 3 deletions examples/multimodal/nvlm/internvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata={}):
return super().sharded_state_dict(prefix, sharded_offsets, metadata)


def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
def get_mlp_module_spec(use_te: bool = True):
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
return partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
Expand Down
16 changes: 7 additions & 9 deletions examples/multimodal/radio/radio_g.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from functools import partial

import torch

from examples.multimodal.layer_scaling import (
LayerScalingTransformerLayer,
get_bias_dropout_add_layer_scaling,
Expand All @@ -14,7 +12,7 @@
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules
from megatron.core.typed_torch import not_none
from megatron.core.extensions.transformer_engine import HAVE_TE

Expand Down Expand Up @@ -51,20 +49,20 @@
LNImpl = WrappedTorchNorm


def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
def get_mlp_module_spec(use_te: bool = True):
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
return partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TEColumnParallelLinear) if use_te else ColumnParallelLinear,
linear_fc2=not_none(TERowParallelLinear) if use_te else RowParallelLinear,
),
)


def get_norm_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
def get_norm_mlp_module_spec_te():
return partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand Down
27 changes: 26 additions & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.torch_norm import LayerNormInterface
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
Expand Down Expand Up @@ -2428,6 +2428,31 @@ def forward(self, hidden_states: torch.Tensor, **kwargs) -> Tuple[Tensor, Option

return out, bias

@classmethod
def as_mlp_submodule(
cls,
submodules: MLPSubmodules,
config: TransformerConfig,
pg_collection: ProcessGroupCollection,
is_mtp_layer: bool,
is_expert: bool = False,
input_size: int | None = None,
ffn_hidden_size: int | None = None,
) -> MLP:
"""Helper function to build an MLP as a TransformerLayer's mlp submodule."""
del is_mtp_layer
assert hasattr(
pg_collection, 'tp'
), 'TP process group is required for TEFusedMLP in TransformerLayer'
return cls(
config=config,
submodules=submodules,
tp_group=pg_collection.tp,
is_expert=is_expert,
input_size=input_size,
ffn_hidden_size=ffn_hidden_size,
)

else:
TEFusedMLP = None # type: ignore[assignment, misc]

Expand Down
18 changes: 10 additions & 8 deletions megatron/core/models/T5/t5_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from functools import partial

from megatron.core.extensions.transformer_engine import HAVE_TE
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
Expand Down Expand Up @@ -69,8 +71,8 @@ def encoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand Down Expand Up @@ -111,8 +113,8 @@ def decoder_model_with_transformer_engine_default_spec() -> ModuleSpec:
),
),
cross_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand Down Expand Up @@ -143,8 +145,8 @@ def encoder_model_with_local_spec() -> ModuleSpec:
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
Expand Down Expand Up @@ -190,8 +192,8 @@ def decoder_model_with_local_spec() -> ModuleSpec:
),
cross_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
),
Expand Down
9 changes: 5 additions & 4 deletions megatron/core/models/bert/bert_layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from functools import partial

from megatron.core.extensions.transformer_engine import HAVE_TE
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
Expand Down Expand Up @@ -66,8 +67,8 @@ def get_bert_layer_with_transformer_engine_submodules() -> TransformerLayerSubmo
),
),
self_attn_bda=get_bias_dropout_add,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(
linear_fc1=not_none(TELayerNormColumnParallelLinear),
linear_fc2=not_none(TERowParallelLinear),
Expand Down Expand Up @@ -117,8 +118,8 @@ def __getattr__(name):
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=ModuleSpec(
module=MLP,
mlp=partial(
MLP.as_mlp_submodule,
submodules=MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear),
),
mlp_bda=get_bias_dropout_add,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
)
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import (
MlpBuilder,
TransformerLayer,
TransformerLayerSubmodules,
get_transformer_layer_offset,
)
from megatron.core.typed_torch import not_none

try:
import transformer_engine as te # type: ignore[import-untyped] # pylint: disable=unused-import
Expand Down Expand Up @@ -213,14 +215,18 @@ def get_transformer_block_with_experimental_attention_variant_spec(
moe_layer_pattern = [0] * config.num_layers

if 1 in moe_layer_pattern:
moe_layer_spec = _get_moe_module_spec(config=config, backend=backend)
moe_layer_spec, fuse_layernorm_pre_moe = _get_moe_module_spec(
config=config, backend=backend
)
else:
moe_layer_spec = None
moe_layer_spec, fuse_layernorm_pre_moe = None, False

if 0 in moe_layer_pattern:
dense_mlp_layer_spec = _get_dense_mlp_module_spec(config=config, backend=backend)
dense_mlp_layer_spec, fuse_layernorm_pre_dense = _get_dense_mlp_module_spec(
config=config, backend=backend
)
else:
dense_mlp_layer_spec = None
dense_mlp_layer_spec, fuse_layernorm_pre_dense = None, False

# Get GPT decoder block layer specs
rms_norm = config.normalization == "RMSNorm"
Expand All @@ -232,14 +238,19 @@ def get_transformer_block_with_experimental_attention_variant_spec(
else standard_attention_spec
)
mlp = moe_layer_spec if moe_layer_pattern[layer_number] == 1 else dense_mlp_layer_spec
fuse_pre_mlp_layernorm = (
fuse_layernorm_pre_moe
if moe_layer_pattern[layer_number] == 1
else fuse_layernorm_pre_dense
)
input_layernorm = (
IdentityOp
if attention.metainfo["fuse_input_layernorm"]
else backend.layer_norm(rms_norm=rms_norm, for_qk=False)
)
pre_mlp_layernorm = (
IdentityOp
if mlp.metainfo["fuse_pre_mlp_layernorm"]
if fuse_pre_mlp_layernorm
else backend.layer_norm(rms_norm=rms_norm, for_qk=False)
)

Expand All @@ -251,7 +262,7 @@ def get_transformer_block_with_experimental_attention_variant_spec(
self_attention=attention,
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=pre_mlp_layernorm,
mlp=mlp,
mlp=not_none(mlp),
mlp_bda=get_bias_dropout_add,
),
)
Expand Down Expand Up @@ -410,41 +421,50 @@ def _get_self_attention_module_spec(

def _get_dense_mlp_module_spec(
config: TransformerConfig, backend: BackendSpecProvider = None
) -> ModuleSpec:
) -> tuple[MlpBuilder, bool]:
"""Get dense MLP module spec.
For hybrid models that mix dense MLP and experimental attention architectures.

Warning: This function may be deprecated in the future."""
Warning: This function may be deprecated in the future.

Returns:
A tuple of (MLP module spec, whether to fuse pre-MLP layernorm)
"""

if backend is None:
backend = _get_backend_spec_provider(config=config)

from megatron.core.models.gpt.gpt_layer_specs import get_mlp_module_spec_for_backend

mlp_spec = get_mlp_module_spec_for_backend(backend=backend, num_experts=None)
mlp_spec.metainfo["fuse_pre_mlp_layernorm"] = backend.fuse_layernorm_and_linear()

return mlp_spec
return (
get_mlp_module_spec_for_backend(backend=backend, num_experts=None),
backend.fuse_layernorm_and_linear(),
)


def _get_moe_module_spec(
config: TransformerConfig, backend: BackendSpecProvider = None
) -> ModuleSpec:
) -> tuple[MlpBuilder, bool]:
"""Get MoE module spec.
For hybrid models that mix MoE and experimental attention architectures.

Warning: This function may be deprecated in the future."""
Warning: This function may be deprecated in the future.

Returns:
A tuple of (MoE module spec, whether to fuse pre-MoE layernorm)
"""

if backend is None:
backend = _get_backend_spec_provider(config=config)

from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend

moe_spec = get_moe_module_spec_for_backend(
backend=backend,
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
use_te_activation_func=config.use_te_activation_func,
return (
get_moe_module_spec_for_backend(
backend=backend,
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
use_te_activation_func=config.use_te_activation_func,
),
False,
)
moe_spec.metainfo["fuse_pre_mlp_layernorm"] = False
return moe_spec
Loading
Loading