-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
gpt_layer_modelopt_spec.py
82 lines (73 loc) · 3.47 KB
/
gpt_layer_modelopt_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError) as e:
TransformerLayer = TransformerLayerSubmodules = ApexGuardDefaults
MLP = MLPSubmodules = ModuleSpec = IdentityOp = ApexGuardDefaults
AttnMaskType = DotProductAttention = TENorm = ApexGuardDefaults
ColumnParallelLinear = RowParallelLinear = SelfAttention = SelfAttentionSubmodules = ApexGuardDefaults
HAVE_MEGATRON_CORE = False
IMPORT_ERROR = e
# Use this spec for Model Optimizer PTQ and TensorRT-LLM export
def get_gpt_layer_modelopt_spec() -> ModuleSpec:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. This TENorm supports both FusedLayerNorm and RMSNorm and
prevents the apex dependency.
"""
if not HAVE_MEGATRON_CORE:
raise Exception(IMPORT_ERROR)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=TENorm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)