Skip to content

[Bug]trtexec build engine succeeded on the H100 GPU but failed on the 5090 GPU #2891

@Hukongtao

Description

@Hukongtao

Describe the bug

Referenced Official Documentation:
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/onnx/onnx_export.html
Strangely, during the process of generating the TensorRT engine, it succeeds on an H100 GPU but fails on a 5090 GPU.

Steps/Code to reproduce bug

First Step: export to onnx

import time

import torch
import torch.nn as nn
import transformer_engine.pytorch as te
from transformer_engine.pytorch.export import te_translation_table

from transformers import AutoModelForImageClassification


def _measure_time(f):

    time_taken = []
    num_iterations = 10
    f()  # warm-up
    f()

    for _ in range(num_iterations):
        start_time = time.time()
        f()
        torch.cuda.synchronize()
        end_time = time.time()
        time_taken.append(end_time - start_time)
    return round(sum(time_taken) / num_iterations, 3)


def convert_model(
    model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True
):
    """
    Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            params_to_gather = [module.weight]
            if has_bias:
                params_to_gather.append(module.bias)

            if any(p % 16 != 0 for p in module.weight.shape):
                return
            te_module = te.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                params_dtype=module.weight.dtype,
            )
            te_module.weight.copy_(module.weight)
            if has_bias:
                te_module.bias.copy_(module.bias)

            setattr(model, name, te_module)
        # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
        elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
            has_bias = module.bias is not None
            te_module = te.LayerNorm(
                module.normalized_shape[0],
                eps=module.eps,
                params_dtype=module.weight.dtype,
            )
            te_module.weight.copy_(module.weight)
            if has_bias:
                te_module.bias.copy_(module.bias)

            setattr(model, name, te_module)
        elif (
            isinstance(module, te.Linear)
            and not to_transformer_engine
            and _convert_linear
        ):
            has_bias = module.bias is not None
            new_module = nn.Linear(
                module.in_features,
                module.out_features,
                bias=has_bias,
                params_dtype=module.weight.dtype,
            )
            new_module.weight.copy_(module.weight)
            if has_bias:
                new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        elif (
            isinstance(module, te.LayerNorm)
            and not to_transformer_engine
            and _convert_ln
        ):
            new_module = nn.LayerNorm(
                module.normalized_shape[0],
                eps=module.eps,
                params_dtype=module.weight.dtype,
            )
            new_module.weight.copy_(module.weight)
            new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        else:
            convert_model(
                module,
                to_transformer_engine=to_transformer_engine,
                _convert_linear=_convert_linear,
                _convert_ln=_convert_ln,
            )


# model_path = "/data/hukongtao/models/huggingface/dinov2-large/"
model_path = "facebook/dinov2-large"
model = AutoModelForImageClassification.from_pretrained(model_path)
print(model)
with torch.no_grad():
    convert_model(model)
print(type(model.dinov2.encoder.layer[0].mlp.fc1))


model = model.eval().cuda()
inps = (torch.randn([16, 3, 256, 256], device="cuda"),)


def _inference(fp8_enabled):
    with torch.no_grad(), te.autocast(enabled=fp8_enabled):
        model(*inps)


te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))
te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))

print(f"Average inference time FP32: {te_fp32_time} ms")
print(f"Average inference time FP8: {te_fp8_time} ms")


def export(model, fname, inputs, fp8=True):
    with torch.no_grad(), te.autocast(enabled=fp8):
        # ! IMPORTANT !
        # Transformer Engine models must have warm-up run
        # before export. FP8 recipe during warm-up should
        # match the recipe used during export.
        model(*inputs)

        # Only dynamo=True mode is supported;
        # dynamo=False is deprecated and unsupported.
        #
        # te_translation_table contains necessary ONNX translations
        # for FP8 quantize/dequantize operators.
        print(f"Exporting {fname}")
        with te.onnx_export(enabled=True):
            torch.onnx.export(
                model,
                inputs,
                fname,
                output_names=["output"],
                dynamo=True,
                custom_translation_table=te_translation_table,
            )


# Example usage:
export(model, "model_fp8.onnx", inps, fp8=True)
export(model, "model_fp32.onnx", inps, fp8=False)

Second step: build tensorrt engine

trtexec --onnx=model_fp8.onnx  --saveEngine=model_fp8.engine  > output_fp8.log 2>&1

During the second step, execution failed on the 5090 GPU but succeeded on the H100 GPU.

The failure log is as follows:
output_fp8.log

Expected behavior

Build TensorRT Engine Successfully

Environment overview (please complete the following information)

  • Method of Transformer Engine install: pip3 install --no-build-isolation . # Build and install

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version: Ubuntu 22.04
  • PyTorch version: 2.10.0+cu130
  • Python version: 3.12
  • Transformer Engine version: 2.13.0+5aa4823d
  • CUDA version: 13.0
  • CUDNN version

Device details

  • 5090 GPU

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions