Skip to content

Commit

Permalink
[BugFix] Use x*x*x instead of pow(x,3) (#16518)
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieFRuan committed Feb 4, 2024
1 parent a6e8fee commit 5ebdd49
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 37 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def te_gelu_tanh(x: te.Tensor):
tir.const(1.0, dtype)
+ topi.tanh(
tir.const(math.sqrt(2.0 / math.pi), dtype)
* (x + tir.const(0.044715, dtype) * topi.power(x, 3))
* x
* (1 + tir.const(0.044715, dtype) * x * x)
)
)
)
Expand Down
26 changes: 26 additions & 0 deletions src/tir/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,32 @@ TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a, Span span)
PrimExpr pow(PrimExpr x, PrimExpr y, Span span) {
BinaryOpMatchTypes(x, y, span);
ICHECK(x.dtype().is_float()) << "power only applies to float";

// If we detect pow(x, 3), suggest using x * x * x
if (y.dtype().is_int()) {
using tir::IntImmNode;
const IntImmNode* px = y.as<IntImmNode>();
if (px) {
if (px->value >= 3) {
LOG(WARNING)
<< "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to "
"uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or "
"`pow(x, 2) * pow(x, 2) ...`.";
}
}
} else if (y.dtype().is_float()) {
using tir::FloatImmNode;
const FloatImmNode* fx = y.as<FloatImmNode>();
if (fx) {
if (fx->value >= 3.0) {
LOG(WARNING)
<< "Detected pow(x, y) where y >= 3, it is recommended to avoid this as it may lead to "
"uninteded behaviors when x < 0. Perhaps with `x * x * x ...` or "
"`pow(x, 2) * pow(x, 2) ...`.";
}
}
}

static auto op = Op::Get("tir.pow");
return tir::Call(x.dtype(), op, {x, y}, span);
}
Expand Down
87 changes: 51 additions & 36 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,10 +1259,11 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3
def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3)))
compute = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
Expand All @@ -1272,43 +1273,49 @@ def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Bu
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_power"):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3))
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_1"):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1]
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_2"):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_3[v_i0, v_i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_multiply_3"):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
Expand Down Expand Up @@ -1344,11 +1351,13 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n))
T_multiply = T.match_buffer(var_T_multiply, (m, n))
# with T.block("root"):
T_multiply_1 = T.alloc_buffer((m, n))
T_power = T.alloc_buffer((m, n))
T_multiply_2 = T.alloc_buffer((m, n))
T_add = T.alloc_buffer((m, n))
T_multiply_3 = T.alloc_buffer((m, n))
T_multiply_4 = T.alloc_buffer((m, n))
T_add = T.alloc_buffer((m, n))
T_multiply_5 = T.alloc_buffer((m, n))
compute = T.alloc_buffer((m, n))
T_add_1 = T.alloc_buffer((m, n))
for ax0, ax1 in T.grid(m, n):
Expand All @@ -1358,43 +1367,49 @@ def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
T.writes(T_multiply_1[v_ax0, v_ax1])
T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_power"):
with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1])
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], T.float32(3))
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.79788456080286541) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_1"):
with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
T.writes(T_multiply_2[v_ax0, v_ax1])
T_multiply_2[v_ax0, v_ax1] = T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
T.reads(A[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.044714999999999998) * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_3[v_ax0, v_ax1], A[v_ax0, v_ax1])
T.writes(T_multiply_4[v_ax0, v_ax1])
T_multiply_4[v_ax0, v_ax1] = T_multiply_3[v_ax0, v_ax1] * A[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.reads(T_multiply_4[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T_multiply_2[v_ax0, v_ax1]
T_add[v_ax0, v_ax1] = T.float32(1) + T_multiply_4[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_2"):
with T.block("T_multiply_4"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
T.writes(T_multiply_3[v_ax0, v_ax1])
T_multiply_3[v_ax0, v_ax1] = T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
T.reads(T_multiply_2[v_ax0, v_ax1], T_add[v_ax0, v_ax1])
T.writes(T_multiply_5[v_ax0, v_ax1])
T_multiply_5[v_ax0, v_ax1] = T_multiply_2[v_ax0, v_ax1] * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(m, n):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(T_multiply_3[v_i0, v_i1])
T.reads(T_multiply_5[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
compute[v_i0, v_i1] = T.tanh(T_multiply_5[v_i0, v_i1])
for ax0, ax1 in T.grid(m, n):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(compute[v_ax0, v_ax1])
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_multiply_3"):
with T.block("T_multiply_5"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
Expand Down

0 comments on commit 5ebdd49

Please sign in to comment.