Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix topi.rms_norm with float32 upscale #16091

Merged
merged 3 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 12 additions & 16 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,31 @@ using namespace tvm::te;
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";

auto square = multiply(data, data);
const auto& data_fp32 = cast(data, DataType::Float(32));
const auto& weight_fp32 = cast(weight, DataType::Float(32));

auto square = multiply(data_fp32, data_fp32);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
auto ndim = data_fp32->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
auto reduce_extent = make_const(data_fp32->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
reduce_extent *= data_fp32->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
Expand All @@ -78,15 +77,12 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& b
}
}
auto output =
data(indices) * weight(reduce_indices) *
data_fp32(indices) * weight_fp32(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
if (bias.defined()) {
output += bias(reduce_indices);
}
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return rms_norm;
auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
return cast(rms_norm, data_type);
}

} // namespace nn
Expand Down
7 changes: 2 additions & 5 deletions python/tvm/topi/nn/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .. import cpp


def rms_norm(data, weight, bias, axis, epsilon=1e-5):
def rms_norm(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator. The output will have the same data type as input.

Parameters
Expand All @@ -29,9 +29,6 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
weight: tvm.te.Tensor
K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

bias: tvm.te.Tensor
Optional, K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and d_{axis_k} == r_k

axis : list of int
Axis over the normalization applied

Expand All @@ -43,4 +40,4 @@ def rms_norm(data, weight, bias, axis, epsilon=1e-5):
result : tvm.te.Tensor
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
return cpp.nn.rms_norm(data, weight, bias, axis, epsilon)
return cpp.nn.rms_norm(data, weight, axis, epsilon)
9 changes: 5 additions & 4 deletions python/tvm/topi/testing/rms_norm_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
def rms_norm_python(data, weight, axis, epsilon=1e-5):
"""Root mean square normalization operator in Python.

Parameters
Expand All @@ -44,8 +44,9 @@ def rms_norm_python(data, weight, bias, axis, epsilon=1e-5):
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
dtype = data.dtype
data = data.astype("float32")
weight = weight.astype("float32")
square_mean = np.mean(np.square(data), axis, keepdims=True)
result = data * weight / np.sqrt(square_mean + epsilon)
if bias is not None:
result += bias
return result
return result.astype(dtype)
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal

/* Ops from nn/rms_norm.h */
TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = nn::rms_norm(args[0], args[1], args[2], args[3], static_cast<double>(args[4]));
*rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
});

} // namespace topi
Expand Down
14 changes: 6 additions & 8 deletions tests/python/topi/python/test_topi_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,31 @@
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize(
"shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,))]
"shape,axis",
[([4, 16], (1,)), ([4, 16, 16], (1, 2)), ([("a", 4), ("b", 16)], (1,)), ([2, 8192], (1,))],
)
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-3, atol=1e-4):
shape_te = [te.var(v[0]) if isinstance(v, tuple) else v for v in shape]
scale_shape_te = [shape_te[dim] for dim in axis]
data = te.placeholder(shape_te, dtype=dtype, name="data")
weight = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
bias = te.placeholder(scale_shape_te, dtype=dtype, name="weight")
B = topi.nn.rms_norm(data, weight, bias, axis, episilon)
B = topi.nn.rms_norm(data, weight, axis, episilon)

shape_np = [v[1] if isinstance(v, tuple) else v for v in shape]
scale_shape_np = [shape_np[dim] for dim in axis]
data_np = np.random.uniform(size=shape_np).astype(dtype)
weight_np = np.random.uniform(size=scale_shape_np).astype(dtype)
bias_np = np.random.uniform(size=scale_shape_np).astype(dtype)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, bias_np, axis, episilon)
b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)

with tvm.target.Target(target):
s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
s = s_func([B])
data_tvm = tvm.nd.array(data_np, dev)
weight_tvm = tvm.nd.array(weight_np, dev)
bias_tvm = tvm.nd.array(bias_np, dev)
b_tvm = tvm.nd.array(np.zeros(shape_np, dtype=dtype), dev)
f = tvm.build(s, [data, weight, bias, B], target)
f(data_tvm, weight_tvm, bias_tvm, b_tvm)
f = tvm.build(s, [data, weight, B], target)
f(data_tvm, weight_tvm, b_tvm)
tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)


Expand Down