Skip to content

Conversation

@singleheart
Copy link

@singleheart singleheart commented Feb 6, 2026

Description

Add the original GLU (Gated Linear Unit) activation function as described in
Dauphin et al. (2017) and referenced in
Shazeer (2020), "GLU Variants Improve Transformer".

GLU is defined as:

$$\text{GLU}(a, b) = \sigma(a) \odot b$$

where $\sigma$ is the sigmoid function and the input is split into two halves $a$ and $b$ along the last dimension.

Transformer Engine already supports several GLU variants (GEGLU, ReGLU, SReGLU, SwiGLU, etc.)
but was missing the original sigmoid-gated GLU. This PR fills that gap so that users can
simply pass activation="glu" to LayerNormMLP or TransformerLayer.

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

  • transformer_engine/common/activation/glu.cu (new file): CUDA kernels nvte_glu and nvte_dglu using existing sigmoid/dsigmoid primitives from math.h and the gated_act_fn/dgated_act_fn templates.
  • transformer_engine/common/include/transformer_engine/activation.h: Added GLU to NVTE_Activation_Type enum; declared nvte_glu and nvte_dglu with doxygen documentation.
  • transformer_engine/common/CMakeLists.txt: Registered activation/glu.cu in both arch_specific_sources and fast_math build lists.
  • transformer_engine/pytorch/csrc/extensions/activation.cpp: Added glu() and dglu() C++ wrapper functions.
  • transformer_engine/pytorch/csrc/extensions.h: Declared glu and dglu.
  • transformer_engine/pytorch/csrc/extensions/pybind.cpp: Exposed tex.glu and tex.dglu to Python.
  • transformer_engine/pytorch/module/layernorm_mlp.py: Added "glu" to _get_act_func_supported_list (all 3 recipe branches), FC1 output-doubling condition, ONNX export activation_map, and docstring.
  • transformer_engine/pytorch/ops/basic/activation.py: Added GLU operation class with forward (tex.glu) and backward (tex.dglu).
  • transformer_engine/pytorch/ops/basic/__init__.py: Exported GLU.
  • transformer_engine/pytorch/transformer.py: Updated TransformerLayer docstring to list 'glu' as a supported activation.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds the original sigmoid-gated GLU activation across the Transformer Engine stack:

  • Common/CUDA: introduces nvte_glu / nvte_dglu entrypoints implemented via existing gated-activation templates with sigmoid/dsigmoid and wires the new activation/glu.cu into the CMake source lists.
  • C API: extends NVTE_Activation_Type with GLU and exposes the corresponding function declarations.
  • PyTorch: adds C++ extension wrappers, pybind exports (tex.glu, tex.dglu), a new te_ops.GLU fusible op, and integrates "glu" into LayerNormMLP (supported activation list, FC1 output sizing for gated activations, and ONNX export reference).
  • JAX: maps the tuple ("sigmoid", "linear") to NVTE_Activation_Type.GLU and adds dispatch for forward/backward to the new common entrypoints.
  • Tests: extends activation coverage to include "glu" in numerics and sanity suites and adds a fusible-op test case.

No additional must-fix issues were found beyond items already discussed in existing PR threads; the changes appear internally consistent (shape doubling for gated activations, TE vs PyTorch GLU convention warnings, and binding wiring).

Confidence Score: 4/5

  • This PR looks safe to merge once existing threaded test-reference issues are resolved.
  • Core implementation is a straightforward addition of GLU using existing gated-activation infrastructure, with consistent C/CUDA/JAX/PyTorch wiring and test coverage updates. I did not find additional definite runtime or correctness bugs introduced in this commit beyond the previously noted test-reference issue in the PR threads.
  • tests/pytorch/test_fusible_ops.py

Important Files Changed

Filename Overview
tests/pytorch/test_fusible_ops.py Added GLU activation to parametrization and reference computation; current reference uses reshape/flip/reshape likely incorrect for in_shape already doubled (see existing thread) and adds unreachable sigmoid branch.
tests/pytorch/test_numerics.py Added 'glu' to all_activations and mapped it to nn.Sigmoid() in _supported_act; this is incorrect for GLU (requires gating of half input and multiplication), so numerics reference will be wrong.
transformer_engine/common/activation/glu.cu Introduced nvte_glu and nvte_dglu using gated_act_fn with sigmoid/dsigmoid primitives; needs verification of template expectations and correct fp32 instantiation.
transformer_engine/common/include/transformer_engine/activation.h Added GLU enum value and declared nvte_glu/nvte_dglu with docs; needs ensuring enum ordering matches downstream bindings.
transformer_engine/jax/cpp_extensions/activation.py Mapped ('sigmoid','linear') activation tuple to NVTE_Activation_Type.GLU; should confirm this tuple matches JAX-side API for GLU activation selection.
transformer_engine/pytorch/csrc/extensions/activation.cpp Implemented glu/dglu wrappers using activation_helper/dactivation_helper; glu uses gate factor 2 like other gated activations.
transformer_engine/pytorch/module/layernorm_mlp.py Added 'glu' to supported activations, doubled FC1 output for GLU, ONNX export mapping uses sigmoid(a)*b; looks coherent.
transformer_engine/pytorch/ops/basic/activation.py Added GLU activation op class delegating to tex.glu/tex.dglu; doc warns about TE vs PyTorch convention.

Sequence Diagram

sequenceDiagram
  participant User as User Code
  participant LNMLP as LayerNormMLP/TransformerLayer
  participant Ops as te_ops.GLU (fuser)
  participant Tex as transformer_engine_torch (pybind)
  participant Cpp as pytorch::glu/dglu wrappers
  participant Common as nvte_glu/nvte_dglu (common)
  participant CUDA as gated_act_fn/dgated_act_fn kernels

  User->>LNMLP: activation="glu"
  LNMLP->>Ops: build activation op (GLU)
  Ops->>Tex: tex.glu(input, quantizer)
  Tex->>Cpp: transformer_engine::pytorch::glu
  Cpp->>Common: nvte_glu(input_tensor, output_tensor)
  Common->>CUDA: gated_act_fn(sigmoid gate)
  CUDA-->>Common: output
  Common-->>Cpp: output
  Cpp-->>Tex: output
  Tex-->>Ops: output
  Ops-->>LNMLP: activated output

  Note over Ops,Tex: Backward uses tex.dglu
  Ops->>Tex: tex.dglu(grad, fwd_input, quantizer)
  Tex->>Cpp: transformer_engine::pytorch::dglu
  Cpp->>Common: nvte_dglu(grad, input, grad_input)
  Common->>CUDA: dgated_act_fn(sigmoid/dsigmoid)
  CUDA-->>Common: grad_input
  Common-->>Cpp: grad_input
  Cpp-->>Tex: grad_input
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

tests/pytorch/test_sanity.py
"glu" not added to test list - new activation won't be tested

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
    "clamped_swiglu",
]

tests/pytorch/test_numerics.py
"glu" missing from test list

all_activations = [
    "gelu",
    "geglu",
    "glu",
    "qgelu",
    "qgeglu",
    "relu",
    "reglu",
    "srelu",
    "sreglu",
    "silu",
    "swiglu",
]

tests/pytorch/test_fusible_ops.py
"glu" missing from test parameters - add glu to tuple and handle in test logic below (around line 1631)

singleheart and others added 2 commits February 6, 2026 20:15
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
for more information, see https://pre-commit.ci

Signed-off-by: Kim, Jin <jinn.kim@sk.com>
@singleheart singleheart force-pushed the feature/add-sigmoid-glu branch from 8adb19a to f7c6385 Compare February 6, 2026 11:15
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10 and others added 2 commits February 6, 2026 18:40
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10

This comment was marked as outdated.

timmoon10
timmoon10 previously approved these changes Feb 6, 2026
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

If we're adding GLU, it would also be natural to have sigmoid. Not critical though.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (3)

transformer_engine/jax/csrc/extensions/activation.cpp
Missing GLU dispatch

NVTE_Activation_Type::GLU was added to the shared activation enum, but the JAX FFI dispatch here doesn’t handle it. Calling act_lu/dact_lu with GLU will hit the default branch and raise NVTE_ERROR("Unsupported ActivationEnum").

Fix: add explicit case NVTE_Activation_Type::GLU: branches that call nvte_glu(...) and nvte_dglu(...) (the backward path is in the DActLuDBiasQuantizeFFI switch later in this file as well).


transformer_engine/jax/csrc/extensions/pybind.cpp
GLU not exported to Python

The pybind enum export for NVTE_Activation_Type doesn’t include GLU, so JAX Python can’t reference NVTE_Activation_Type.GLU even though it exists in the C++ header. This will block using GLU from Python-side activation selection.

Fix: add .value("GLU", NVTE_Activation_Type::GLU) to this enum binding.


transformer_engine/jax/cpp_extensions/activation.py
ActivationEnum missing GLU

ActivationEnum doesn’t have an entry mapping any activation tuple to NVTE_Activation_Type.GLU, so GLU can’t be selected through the existing string/tuple activation routing in JAX.

Fix: add the appropriate key (consistent with the existing conventions, e.g. how GEGLU/REGLU/SWIGLU are represented) mapping to NVTE_Activation_Type.GLU.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10

This comment was marked as outdated.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
glu is missing from the make_op dictionary, causing a KeyError when the test runs

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Comment on lines +1632 to +1633
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigmoid is not an option in the test, is this a leftover code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that important, but I find it awkward to have GeLU/GeGLU, ReLU/ReGLU, SiLU/SwiGLU, and then just GLU.

@ptrendx
Copy link
Member

ptrendx commented Feb 7, 2026

Could we also add the test for regular LayerNormMLP module in the test_numerics.py?
Also @jberchtold-nvidia FYI - we will also want to expose it in JAX.

Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (2)

transformer_engine/jax/csrc/extensions/pybind.cpp
Missing GLU enum export

NVTE_Activation_Type gained a new GLU value in transformer_engine/common/include/transformer_engine/activation.h, but the JAX pybind enum here doesn’t export it. This will break JAX-side code that expects the enum values to match the core header (e.g., passing NVTE_Activation_Type.GLU from Python will fail / be unavailable).

Also update the JAX activation dispatch switch to handle NVTE_Activation_Type::GLU (see transformer_engine/jax/csrc/extensions/activation.cpp:105-143).


transformer_engine/jax/csrc/extensions/activation.cpp
GLU not handled in switch

The activation dispatch switch (act_type) doesn’t include a case NVTE_Activation_Type::GLU, even though GLU was added to the shared NVTE_Activation_Type enum. If JAX passes the new enum value, this currently falls into default: and raises NVTE_ERROR("Unsupported ActivationEnum").

Add a GLU case that calls nvte_glu(...), and ensure the backward/quantized paths (other switches later in this file) are also updated consistently.

@timmoon10
Copy link
Collaborator

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 7, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
missing glu in test's make_op dictionary - test will fail with KeyError when running the glu case

        make_op = dict(
            gelu=te_ops.GELU,
            geglu=te_ops.GEGLU,
            glu=te_ops.GLU,
            qgelu=te_ops.QGELU,
            qgeglu=te_ops.QGEGLU,
            relu=te_ops.ReLU,
            reglu=te_ops.ReGLU,
            srelu=te_ops.SReLU,
            sreglu=te_ops.SReGLU,
            silu=te_ops.SiLU,
            swiglu=te_ops.SwiGLU,
        )[activation]

Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Signed-off-by: Kim, Jin <jinn.kim@sk.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@singleheart
Copy link
Author

Could we also add the test for regular LayerNormMLP module in the test_numerics.py? Also @jberchtold-nvidia FYI - we will also want to expose it in JAX.

Done. The latest commits address both points:

Tests:

  • Added GLU to all_activations in test_numerics.py (covers LayerNormMLP tests)
  • Added GLU to all_activations in test_sanity.py
  • Fixed missing glu=te_ops.GLU entry in make_op dict in test_fusible_ops.py

JAX support:

  • jax/csrc/extensions/activation.cpp: Added NVTE_Activation_Type::GLU cases for both forward (nvte_glu) and backward (nvte_dglu) dispatch
  • jax/csrc/extensions/pybind.cpp: Exported GLU enum value to Python
  • jax/cpp_extensions/activation.py: Added ("sigmoid", "linear"): NVTE_Activation_Type.GLU to ActivationEnum

@ptrendx
Copy link
Member

ptrendx commented Feb 9, 2026

/te-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +1632 to +1633
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unreachable sigmoid branch

test_activation never parametrizes activation="sigmoid" (see the @pytest.mark.parametrize("activation", ...) list above), but this new elif activation == "sigmoid": branch was added anyway, so it can’t execute. This makes the test misleading/fragile (it looks like sigmoid is covered when it isn’t). Either add "sigmoid" to the parametrization or remove this branch.

@ptrendx
Copy link
Member

ptrendx commented Feb 10, 2026

@singleheart The CI failed because you did not add the GLU implementation reference computation to the test_numerics.py:

self = TorchGLU(), activation = 'glu'
    def __init__(self, activation: str):
        super().__init__()
>       self.act = _supported_act[activation]
E       KeyError: 'glu'
../../tests/pytorch/test_numerics.py:497: KeyError

Could you fix that?

@singleheart
Copy link
Author

@singleheart The CI failed because you did not add the GLU implementation reference computation to the test_numerics.py:

self = TorchGLU(), activation = 'glu'
    def __init__(self, activation: str):
        super().__init__()
>       self.act = _supported_act[activation]
E       KeyError: 'glu'
../../tests/pytorch/test_numerics.py:497: KeyError

Could you fix that?

@ptrendx Thanks for catching this! I've added "glu": nn.Sigmoid() to the _supported_act dict in test_numerics.py. This should fix the KeyError.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

16 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx
Copy link
Member

ptrendx commented Feb 11, 2026

/te-ci

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants