Skip to content

[Bug] TVM ONNX LayerNormalization lacks numerical stability for large float32 inputs #19592

@beanduan22

Description

@beanduan22
import sys

import numpy as np
import onnx
import tvm
from onnx import TensorProto, helper
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

x = np.array([[80000.0, 80001.0, 80002.0, 80003.0]], dtype=np.float32)
s = np.ones(4, dtype=np.float32)
b = np.zeros(4, dtype=np.float32)

node = helper.make_node("LayerNormalization", ["x", "s", "b"], ["y"], axis=-1, epsilon=1e-5)
graph = helper.make_graph(
    [node],
    "g",
    [
        helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]),
        helper.make_tensor_value_info("s", TensorProto.FLOAT, [4]),
        helper.make_tensor_value_info("b", TensorProto.FLOAT, [4]),
    ],
    [helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 4])],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)])
model.ir_version = 9
onnx.checker.check_model(model)

mod = from_onnx(model, keep_params_in_input=False)
with tvm.transform.PassContext(opt_level=3):
    ex = tvm.compile(mod, target=tvm.target.Target("llvm"))
vm = relax.VirtualMachine(ex, tvm.cpu())
y = vm["main"](
    tvm.runtime.tensor(x, tvm.cpu()),
    tvm.runtime.tensor(s, tvm.cpu()),
    tvm.runtime.tensor(b, tvm.cpu()),
).numpy()

x64 = x.astype(np.float64)
expected = ((x64 - x64.mean(axis=-1, keepdims=True)) / np.sqrt(((x64 - x64.mean(axis=-1, keepdims=True)) ** 2).mean(axis=-1, keepdims=True) + 1e-5)).astype(np.float32)
wrong = bool(np.isfinite(expected).all() and not np.isfinite(y).all())

print("TVM:", y.tolist())
print("Expected:", expected.tolist())
print("Wrong:", wrong)
sys.exit(0 if wrong else 1)

output:

  TVM: [[nan, nan, nan, nan]]
  Expected: [[-1.3416354656219482, -0.4472118020057678, 0.4472118020057678, 1.3416354656219482]]
  Wrong: True

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions