# Getting Started

## Overview

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.

This is specifically a guide for getting started to use Transformer Engine with JAX. We recommend you to try understanding the basics of JAX first, using these resources:
- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html
- JAX 101: https://docs.jax.dev/en/latest/jax-101.html
- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array
- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html

If you do not wish to learn/use JAX, there is another guide in this same director, called quickstart.ipynb, that is to get started with PyTorch.

## Let's build a Transformer layer (*)!
<small>(*) _This was based upon the GPT decoder layer, but for the sake of simplicity and mirroring the PyTorch tutorial whose defaults are without any attention mask, we are setting attention mask here also to 0, making the attention basically an encoder, which does not exist in the GPT architecture. However, since the code support any attention mask here in the TransformerLayer (later in this guide), we will leave it to the audience to try experimenting with different attention masks._</small>

<div class="alert alert-info">

<b>Summary</b>
    
We build a basic Transformer layer using regular JAX modules. This will be our baseline for later comparisons with Transformer Engine.

</div>

Let's start with creating the transformer layer using plain JAX/Flax. Figure 1 shows the overall structure.

<figure align="center">
<img src="transformer_layer.png" width="20%">
<figcaption> Figure 1: Structure of a GPT encoder layer.</figcaption>
</figure>

We construct the components as follows:

- `LayerNorm`: `nn.LayerNorm` (JAX/Flax)
- `QKV Projection`: `nn.Dense` (conceptually three `Dense` layers for Q, K, and V separately, but we fuse into a single `Dense` layer that is three times larger)
- `DotProductAttention`: `DotProductAttention` from [quickstart_jax_utils.py](quickstart_jax_utils.py)
- `Projection`: `nn.Dense` (JAX/Flax)
- `Dropout`: `nn.Dropout` (JAX/Flax)
- `MLP`: `BasicMLP` from [quickstart_jax_utils.py](quickstart_jax_utils.py)

Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together:  


In [25]:
try:
    from datasets import load_dataset
except ModuleNotFoundError:
    %pip install --quiet datasets
    from datasets import load_dataset

In [26]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import quickstart_jax_utils as utils
from typing import Optional

In [27]:
class BasicTransformerLayer(nn.Module):
    """Basic Transformer layer using plain JAX/Flax modules
    
    This is the JAX/Flax equivalent of the PyTorch BasicTransformerLayer
    from the quickstart.ipynb notebook.
    """
    
    hidden_size: int
    ffn_hidden_size: int
    num_attention_heads: int
    layernorm_eps: float = 1e-5
    attention_dropout: float = 0.1
    hidden_dropout: float = 0.1
    
    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads
        
    @nn.compact
    def __call__(
        self, 
        x: jnp.ndarray, 
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = False
    ) -> jnp.ndarray:
        res = x
        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
        
        # Fused QKV projection
        qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)
        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, 3, axis=3)  # qkv.shape[3] = 3?
        
        # Attention self-implemented. Comment out if not used
        attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout,
        )
        x = attention(q, k, v, attention_mask, deterministic=deterministic)
        
        # Attention built-in. Comment out if not used
        # attention = flax.nnx.MultiheadAttention(
        #     num_heads=self.num_attention_heads,
        #     in_features=self.hidden_size,
        #     qkv_features=self.kv_channels,
        #     dropout_rate=self.attention_dropout,
        #     deterministic=True
        # )
        # x = attention(q, k, v, attention_mask, deterministic=deterministic)

        # Projection and dropout
        x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
        x = nn.Dropout(rate=self.hidden_dropout)(x, deterministic=deterministic)
        x = res + x
        
        # Second residual connection
        res = x
        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
        
        # MLP
        mlp = utils.BasicMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size,
        )
        x = mlp(x)
        
        return x + res


That's it! We now have a simple Transformer layer in JAX/Flax. Let's test it:


## Testing Performance

Now let's test the performance of our BasicTransformerLayer:


In [28]:
# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.bfloat16

# Synthetic data
key, dropout_key = jax.random.split(jax.random.PRNGKey(42))
x = jax.random.normal(key, (sequence_length, batch_size, hidden_size)).astype(dtype)
dy = jax.random.normal(key, (sequence_length, batch_size, hidden_size)).astype(dtype)


In [29]:
# Initialize the BasicTransformerLayer
basic_transformer = BasicTransformerLayer(
    hidden_size=hidden_size,
    ffn_hidden_size=ffn_hidden_size,
    num_attention_heads=num_attention_heads,
)

# Initialize parameters
params = basic_transformer.init(key, x, attention_mask=None, deterministic=False)

print("Pure Flax BasicTransformerLayer initialized successfully!")
print(f"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}")


Pure Flax BasicTransformerLayer initialized successfully!
Parameter shapes: {'params': {'BasicMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}


In [30]:
# Test forward pass
y = basic_transformer.apply(params, x, attention_mask=None, deterministic=True)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output dtype: {y.dtype}")
print("Forward pass completed successfully!")


Input shape: (2048, 4, 4096)
Output shape: (2048, 4, 4096)
Output dtype: float32
Forward pass completed successfully!


In [31]:
import importlib
import quickstart_jax_utils
importlib.reload(quickstart_jax_utils)

utils.speedometer(
    model_apply_fn=basic_transformer.apply,
    variables=params,  # Ensure the correct `params` is passed
    input=x,
    output_grad=dy,
    dropout_key=dropout_key,
    forward_kwargs={"attention_mask": None, "deterministic": False},
)

Mean time: 28.229827880859375 ms


## Meet Transformer Engine

<div class="alert alert-info">

<b>Summary</b>
    
Now that we have a basic Transformer layer in JAX/Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.

</div>

The JAX/Flax BasicTransformerLayer above is equivalent to the PyTorch version in the main quickstart.ipynb notebook. It uses:

- `nn.LayerNorm`: JAX/Flax LayerNorm
- `nn.Dense`: JAX/Flax Dense layer for QKV projection  
- `DotProductAttention`: Custom attention from [quickstart_jax_utils.py] (**)(quickstart_jax_utils.py)
- `nn.Dense`: JAX/Flax Dense layer for projection
- `nn.Dropout`: JAX/Flax Dropout
- `BasicMLP`: Custom MLP from [quickstart_jax_utils.py](quickstart_jax_utils.py)

<small> (**) _The code below also shows how to use the built-in attention sub-layer from either pure Flax or TE Flax in commented code if you wish to use those instead of the custom attention in [quickstart_jax_utils.py]. The implementation is there for your reference of how attention is roughly implemented in our source_</small>

Below we show how to use Transformer Engine JAX/Flax modules for better performance:


In [32]:
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

TE provides a set of JAX modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear ` and `LayerNorm` layers, which we can use instead of `flax.linen.Linear` and ` flax.linen.LayerNorm`. Let's modify our `BasicTransformLayer`:

In [33]:

class BasicTEMLP(nn.Module):
    hidden_size : int
    ffn_hidden_size: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)
        x = nn.gelu(x, approximate=True)
        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)
        return x

class BasicTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int 
    num_attention_heads: int  
    layernorm_eps: float = 1e-5
    attention_dropout: float = 0.1 
    hidden_dropout: float = 0.1

    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads

    @nn.compact
    def __call__(
        self, 
        x: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = False
    ) -> jnp.ndarray:
        res = x
        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)

        # Fused QKV projection
        qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)
        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, 3, axis=3)

        attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout,
        )
        x = attention(q, k, v, attention_mask, deterministic=deterministic)
        
        # If you'd like to use the built-in Attention layer from JAX/Flax, uncomment the below
        # If used, please remove the subsequent Dense layer that is projecting
        # the concatenated QKVoutput to hidden size.
        # attention = flax.nnx.MultiheadAttention(
        #     num_heads=self.num_attention_heads,
        #     qkv_features=self.kv_channels,
        #     dropout_rate=self.attention_dropout,
        #     attention_mask_type='no_mask'
        # )
        # x = attention(q, k, v, attention_mask, deterministic=deterministic)

        # If you'd like to use the built-in Attention layer from TE JAX, uncomment the below
        # If used, please remove the subsequent Dense layer that is projecting 
        # the concatenated QKVoutput to hidden size.
        # attention = te_flax.MultiHeadAttention(
        #     num_attention_heads = self.num_attention_heads,
        #     head_dim=self.kv_channels,
        #     attention_dropout=self.attention_dropout,
        #     attention_mask_type='no_mask',
        # )

         # x = attention(q, k, v, attention_mask, deterministic=deterministic)

        # Projection concatenated QKVoutput back to hidden size.
        # Delete this if use the buiil-in MultiheadAttention module from either flax or te_flax
        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
        x = nn.Dropout(rate=self.hidden_dropout)(x, deterministic=deterministic)
        x = res + x

        # Second residual connection
        res = x
        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)

        # MLP
        mlp = BasicTEMLP(
            hidden_size=self.hidden_size,
            ffn_hidden_size=self.ffn_hidden_size
        )

        x = mlp(x)

        return x + res

In [34]:
import quickstart_jax_utils
importlib.reload(quickstart_jax_utils)

basic_te_transformer = BasicTETransformerLayer(
    hidden_size, 
    ffn_hidden_size, 
    num_attention_heads,
)

te_params = basic_te_transformer.init(key, x, attention_mask=None, deterministic=False)

# Test forward pass
y = basic_te_transformer.apply(te_params, x, attention_mask=None, deterministic=True)

utils.speedometer(
    model_apply_fn=basic_te_transformer.apply,
    variables=te_params,  # Ensure the correct `params` is passed
    input=x,
    output_grad=dy,
    dropout_key=dropout_key,
    forward_kwargs={"attention_mask": None, "deterministic": False},
)

Mean time: 17.390952110290527 ms



## Fused TE Modules

<div class="alert alert-info">

<b>Summary</b>
    
We optimize the example Transformer layer with TE modules for fused operations.

</div>

The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.

Transformer Engine therefore provides coarser modules that span multiple layers:

* `LayerNormDenseGeneral`
* `LayerNormMLP`
* `TransformerLayer`

To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules

Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:

In [35]:
class FusedTETransformerLayer(nn.Module):
    hidden_size: int
    ffn_hidden_size: int 
    num_attention_heads: int  
    layernorm_eps: float = 1e-5
    attention_dropout: float = 0.1 
    hidden_dropout: float = 0.1

    def setup(self):
        self.kv_channels = self.hidden_size // self.num_attention_heads

    @nn.compact
    def __call__(
        self, 
        x: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        deterministic: bool = False
    ) -> jnp.ndarray:

        res = x

        # Fused QKV projection
        qkv, x_norm = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, epsilon=self.layernorm_eps, use_bias=True)(x)

        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = jnp.split(qkv, 3, axis=3)

        attention = utils.DotProductAttention(
            num_attention_heads=self.num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=self.attention_dropout,
        )
        x = attention(q, k, v, attention_mask, deterministic=deterministic)

        # If you'd like to use the built-in Attention layer from TE JAX, uncomment the below
        # If used, please remove the subsequent Dense layer that is projecting 
        # the concatenated QKVoutput to hidden size.
        # attention = te_flax.MultiHeadAttention(
        #     num_attention_heads = self.num_attention_heads,
        #     head_dim=self.kv_channels,
        #     attention_dropout=self.attention_dropout,
        #     attention_mask_type='no_mask',
        # )
        # x = attention(q, k, v, attention_mask, deterministic=deterministic)

        # Projection concatenated QKVoutput back to hidden size.
        # Delete this if use the buiil-in MultiheadAttention module from either flax or te_flax
        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
        x = nn.Dropout(rate=self.hidden_dropout)(x, deterministic=deterministic)
        x = res + x

        # Second residual connection
        res = x
        x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, 
                                 epsilon=self.layernorm_eps,
                                 use_bias=True
                                 )(x, deterministic=deterministic)

        return x + res

In [36]:
fused_te_transformer = FusedTETransformerLayer(
    hidden_size, 
    ffn_hidden_size, 
    num_attention_heads
)

fused_te_params = fused_te_transformer.init(key, x, attention_mask=None, deterministic=False)

In [37]:
# Test forward pass
y = fused_te_transformer.apply(fused_te_params, x, attention_mask=None, deterministic=True)

utils.speedometer(
    model_apply_fn=fused_te_transformer.apply,
    variables=fused_te_params,  # Ensure the correct `params` is passed
    input=x,
    output_grad=dy,
    dropout_key=dropout_key,
    forward_kwargs={"attention_mask": None, "deterministic": False},
)

Mean time: 18.087706565856934 ms


Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:

In [38]:
te_transformer = te_flax.TransformerLayer(
    hidden_size=hidden_size,
    mlp_hidden_size=ffn_hidden_size, 
    num_attention_heads=num_attention_heads,
    mlp_activations=("gelu",),
    self_attn_mask_type='no_mask',
    layernorm_epsilon=1e-5,
    use_bias=True
    )

te_transformer_params = te_transformer.init(key, x, deterministic=False)

In [39]:
# Test forward pass
y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)

utils.speedometer(
    model_apply_fn=te_transformer.apply,
    variables=te_transformer_params,  # Ensure the correct `params` is passed
    input=x,
    output_grad=dy,
    dropout_key=dropout_key,
    forward_kwargs={"attention_mask": None, "deterministic": False},
)

Mean time: 12.37576961517334 ms


## Enabling FP8

<div class="alert alert-info">

<b>Summary</b>
    
We configure a TE module to perform compute in FP8.

</div>

Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](.../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) (currently only available in PyTorch) for a detailed explanation of FP8 recipes and the supported options.

<div class="alert alert-warning">

<b>Important: FP8 Metadata Initialization</b>

When using FP8, the model **must be initialized within the `fp8_autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `fp8_autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.

</div>

In [40]:
from transformer_engine.common.recipe import Format, DelayedScaling

te_transformer = te_flax.TransformerLayer(
    hidden_size=hidden_size,
    mlp_hidden_size=ffn_hidden_size, 
    num_attention_heads=num_attention_heads,
    mlp_activations=("gelu",),
    self_attn_mask_type='no_mask',
    layernorm_epsilon=1e-5,
    use_bias=True
)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    te_transformer_params = te_transformer.init(key, x, deterministic=False)
    y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)

In [41]:
utils.speedometer(
    model_apply_fn=te_transformer.apply,
    model_init_fn=te_transformer.init,
    variables=te_transformer_params,  # Includes both params and fp8_metas
    input=x,
    output_grad=dy,
    dropout_key=dropout_key,
    forward_kwargs={"attention_mask": None, "deterministic": False},
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe }
)

Mean time: 7.956786155700684 ms
