-
Notifications
You must be signed in to change notification settings - Fork 633
Add sigmoid GLU #2656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add sigmoid GLU #2656
Conversation
Greptile OverviewGreptile SummaryThis PR adds the original sigmoid-gated GLU activation across the Transformer Engine stack:
No additional must-fix issues were found beyond items already discussed in existing PR threads; the changes appear internally consistent (shape doubling for gated activations, TE vs PyTorch GLU convention warnings, and binding wiring). Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant LNMLP as LayerNormMLP/TransformerLayer
participant Ops as te_ops.GLU (fuser)
participant Tex as transformer_engine_torch (pybind)
participant Cpp as pytorch::glu/dglu wrappers
participant Common as nvte_glu/nvte_dglu (common)
participant CUDA as gated_act_fn/dgated_act_fn kernels
User->>LNMLP: activation="glu"
LNMLP->>Ops: build activation op (GLU)
Ops->>Tex: tex.glu(input, quantizer)
Tex->>Cpp: transformer_engine::pytorch::glu
Cpp->>Common: nvte_glu(input_tensor, output_tensor)
Common->>CUDA: gated_act_fn(sigmoid gate)
CUDA-->>Common: output
Common-->>Cpp: output
Cpp-->>Tex: output
Tex-->>Ops: output
Ops-->>LNMLP: activated output
Note over Ops,Tex: Backward uses tex.dglu
Ops->>Tex: tex.dglu(grad, fwd_input, quantizer)
Tex->>Cpp: transformer_engine::pytorch::dglu
Cpp->>Common: nvte_dglu(grad, input, grad_input)
Common->>CUDA: dgated_act_fn(sigmoid/dsigmoid)
CUDA-->>Common: grad_input
Common-->>Cpp: grad_input
Cpp-->>Tex: grad_input
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 3 comments
Additional Comments (3)
|
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
for more information, see https://pre-commit.ci Signed-off-by: Kim, Jin <jinn.kim@sk.com>
8adb19a to
f7c6385
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, pending CI
If we're adding GLU, it would also be natural to have sigmoid. Not critical though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 4 comments
Additional Comments (3)
Fix: add explicit
The pybind enum export for Fix: add
Fix: add the appropriate key (consistent with the existing conventions, e.g. how GEGLU/REGLU/SWIGLU are represented) mapping to |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
Additional Comments (1)
|
| elif activation == "sigmoid": | ||
| y_ref = torch.nn.functional.sigmoid(x_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sigmoid is not an option in the test, is this a leftover code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.
|
Could we also add the test for regular LayerNormMLP module in the test_numerics.py? |
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
Additional Comments (2)
Also update the JAX activation dispatch switch to handle
The activation dispatch Add a |
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
Additional Comments (1)
|
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Done. The latest commits address both points: Tests:
JAX support:
|
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 1 comment
| elif activation == "sigmoid": | ||
| y_ref = torch.nn.functional.sigmoid(x_ref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unreachable sigmoid branch
test_activation never parametrizes activation="sigmoid" (see the @pytest.mark.parametrize("activation", ...) list above), but this new elif activation == "sigmoid": branch was added anyway, so it can’t execute. This makes the test misleading/fragile (it looks like sigmoid is covered when it isn’t). Either add "sigmoid" to the parametrization or remove this branch.
|
@singleheart The CI failed because you did not add the GLU implementation reference computation to the test_numerics.py: Could you fix that? |
…IDIA#2656) Signed-off-by: Kim, Jin <jinn.kim@sk.com>
@ptrendx Thanks for catching this! I've added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, no comments
|
/te-ci |
ptrendx
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Add the original GLU (Gated Linear Unit) activation function as described in
Dauphin et al. (2017) and referenced in
Shazeer (2020), "GLU Variants Improve Transformer".
GLU is defined as:
where$\sigma$ is the sigmoid function and the input is split into two halves $a$ and $b$ along the last dimension.
Transformer Engine already supports several GLU variants (GEGLU, ReGLU, SReGLU, SwiGLU, etc.)
but was missing the original sigmoid-gated GLU. This PR fills that gap so that users can
simply pass
activation="glu"toLayerNormMLPorTransformerLayer.Type of change
Changes
transformer_engine/common/activation/glu.cu(new file): CUDA kernelsnvte_gluandnvte_dgluusing existingsigmoid/dsigmoidprimitives frommath.hand thegated_act_fn/dgated_act_fntemplates.transformer_engine/common/include/transformer_engine/activation.h: AddedGLUtoNVTE_Activation_Typeenum; declarednvte_gluandnvte_dgluwith doxygen documentation.transformer_engine/common/CMakeLists.txt: Registeredactivation/glu.cuin botharch_specific_sourcesandfast_mathbuild lists.transformer_engine/pytorch/csrc/extensions/activation.cpp: Addedglu()anddglu()C++ wrapper functions.transformer_engine/pytorch/csrc/extensions.h: Declaredgluanddglu.transformer_engine/pytorch/csrc/extensions/pybind.cpp: Exposedtex.gluandtex.dgluto Python.transformer_engine/pytorch/module/layernorm_mlp.py: Added"glu"to_get_act_func_supported_list(all 3 recipe branches), FC1 output-doubling condition, ONNX exportactivation_map, and docstring.transformer_engine/pytorch/ops/basic/activation.py: AddedGLUoperation class with forward (tex.glu) and backward (tex.dglu).transformer_engine/pytorch/ops/basic/__init__.py: ExportedGLU.transformer_engine/pytorch/transformer.py: UpdatedTransformerLayerdocstring to list'glu'as a supported activation.Checklist: