Skip to content

Commit

Permalink
[TIR] Handle DeclBuffer in LowerCustomDatatypes (#15041)
Browse files Browse the repository at this point in the history
* [TIR] Handle DeclBuffer in LowerCustomDatatypes

Preserve DeclBuffer node when transforming with `LowerCustomDatatypes`
This is a subset of changes, being split out from
#14778 into independent portions.

* Fix lint error

* Fix parsing error introduced by lint fix
  • Loading branch information
Lunderberg committed Jun 10, 2023
1 parent eea6268 commit dee3c2a
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 85 deletions.
5 changes: 5 additions & 0 deletions src/tir/transforms/lower_custom_datatypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ class CustomDatatypesLowerer : public StmtExprMutator {
}
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto modified = VisitBufferAccess(node);
Expand Down
260 changes: 175 additions & 85 deletions tests/python/unittest/test_custom_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
register_op,
)
from tvm.tir.op import call_pure_extern
from tvm.script import tir as T


# note: we can't use relay.testing models because params are randomly initialized,
Expand Down Expand Up @@ -116,88 +117,105 @@ def setup_myfloat():
Own Datatypes framework.
"""

# To use datatype operations in an external library, you should first load
# the library containing the datatype implementation:
# CDLL("libposit.so", RTLD_GLOBAL)
# In this case, the datatype library we are using is built right into TVM,
# so we do not need to explicitly load any library.
def _setup_myfloat_inner():
# To use datatype operations in an external library, you should first load
# the library containing the datatype implementation:
# CDLL("libposit.so", RTLD_GLOBAL)
# In this case, the datatype library we are using is built right into TVM,
# so we do not need to explicitly load any library.

# You can pick a code for your datatype arbitrarily, as long as it is
# greater than 128 and has not already been chosen.
register("myfloat", 131)
# You can pick a code for your datatype arbitrarily, as long as it is
# greater than 128 and has not already been chosen.
register("myfloat", 131)

register_op(
create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat"
)
register_op(
create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float"
)
register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat")
register_op(
create_lower_func(
{
32: "Custom32Sub",
}
),
"Sub",
"llvm",
"myfloat",
)
register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat")
register_op(
create_lower_func(
{
32: "FloatToCustom32",
}
),
"FloatImm",
"llvm",
"myfloat",
)
register_op(
create_lower_func(
{
32: "Custom32Div",
}
),
"Div",
"llvm",
"myfloat",
)
register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat")
register_op(
create_lower_func({32: "Custom32Sqrt"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.sqrt",
)
register_op(
create_lower_func({32: "Custom32Exp"}), "Call", "llvm", "myfloat", intrinsic_name="tir.exp"
)
register_op(
create_lower_func({32: "Custom32Log"}), "Call", "llvm", "myfloat", intrinsic_name="tir.log"
)
register_op(
create_lower_func({32: "Custom32Sigmoid"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.sigmoid",
)
register_op(
create_lower_func({32: "Custom32Tanh"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.tanh",
)
register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else")
register_op(
lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern"
)
register_op(
create_lower_func({(32, 32): "FloatToCustom32"}), "Cast", "llvm", "float", "myfloat"
)
register_op(
create_lower_func({(32, 32): "Custom32ToFloat"}), "Cast", "llvm", "myfloat", "float"
)
register_op(create_lower_func({32: "Custom32Add"}), "Add", "llvm", "myfloat")
register_op(
create_lower_func(
{
32: "Custom32Sub",
}
),
"Sub",
"llvm",
"myfloat",
)
register_op(create_lower_func({32: "Custom32Mul"}), "Mul", "llvm", "myfloat")
register_op(
create_lower_func(
{
32: "FloatToCustom32",
}
),
"FloatImm",
"llvm",
"myfloat",
)
register_op(
create_lower_func(
{
32: "Custom32Div",
}
),
"Div",
"llvm",
"myfloat",
)
register_op(create_lower_func({32: "Custom32Max"}), "Max", "llvm", "myfloat")
register_op(
create_lower_func({32: "Custom32Sqrt"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.sqrt",
)
register_op(
create_lower_func({32: "Custom32Exp"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.exp",
)
register_op(
create_lower_func({32: "Custom32Log"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.log",
)
register_op(
create_lower_func({32: "Custom32Sigmoid"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.sigmoid",
)
register_op(
create_lower_func({32: "Custom32Tanh"}),
"Call",
"llvm",
"myfloat",
intrinsic_name="tir.tanh",
)
register_op(lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else")
register_op(
lower_call_pure_extern, "Call", "llvm", "myfloat", intrinsic_name="tir.call_pure_extern"
)

register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat")

register_min_func(create_min_lower_func({32: "MinCustom32"}, "myfloat"), "myfloat")
try:
_setup_myfloat_inner()
except tvm._ffi.base.TVMError as e:
# Ignore this specific error which can happen if another test
# that uses "myfloat" has already run.
if "float is already registered" not in str(e):
raise e


def setup_posites2():
Expand Down Expand Up @@ -513,12 +531,8 @@ def run_batchnorm(src_dtype, dst_dtype, rtol=1e-6, atol=1e-6):


def test_myfloat():
try:
setup_myfloat()
except tvm._ffi.base.TVMError as e:
if "float is already registered" not in str(e):
# Ignore this specific error which can happen if this test runs twice within the same process
raise e
setup_myfloat()

run_ops("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
run_conv2d("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
run_batchnorm("float32", "custom[myfloat]32", rtol=1e-6, atol=1e-6)
Expand All @@ -529,6 +543,82 @@ def test_myfloat():
# 'custom[myfloat]32')


class TestMyfloatLowering(tvm.testing.CompareBeforeAfter):
setup_myfloat()

transform = tvm.tir.transform.LowerCustomDatatypes()

def before(self):
dtype = "custom[myfloat]32"

@T.prim_func
def func(A_data: T.handle(dtype)):
T.func_attr({"target": T.target("llvm")})
A = T.Buffer(16, dtype=dtype, data=A_data)
B_data = T.allocate([16], dtype=dtype)
B = T.Buffer(16, dtype=dtype, data=B_data)
for i in range(16):
B[i] = A[i] + 1.0

return func

def expected(self):
dtype = "custom[myfloat]32"

@T.prim_func
def func(A_data: T.handle(dtype)):
T.func_attr({"target": T.target("llvm")})
A_uint32 = T.Buffer(16, "uint32", data=A_data)
B_data = T.allocate([16], dtype="uint32")
B_uint32 = T.Buffer(16, "uint32", data=B_data)
for i in range(16):
B_uint32[i] = T.call_pure_extern(
"uint32",
"FloatToCustom32",
T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1),
)

return func


class TestMyfloatLoweringDeclBuffer(tvm.testing.CompareBeforeAfter):
"""Like TestMyfloatLoweringDeclBuffer, but using DeclBuffer"""

setup_myfloat()

transform = tvm.tir.transform.LowerCustomDatatypes()

def before(self):
dtype = "custom[myfloat]32"

@T.prim_func
def func(A_data: T.handle(dtype)):
T.func_attr({"target": T.target("llvm")})
A = T.decl_buffer(16, dtype=dtype, data=A_data)
B = T.decl_buffer(16, dtype=dtype)
for i in range(16):
B[i] = A[i] + 1.0

return func

def expected(self):
dtype = "custom[myfloat]32"

@T.prim_func
def func(A_data: T.handle(dtype)):
T.func_attr({"target": T.target("llvm")})
A_uint32 = T.decl_buffer(16, "uint32", data=A_data)
B_uint32 = T.decl_buffer(16, dtype="uint32")
for i in range(16):
B_uint32[i] = T.call_pure_extern(
"uint32",
"FloatToCustom32",
T.call_pure_extern("float32", "Custom32ToFloat", A_uint32[i]) + T.float32(1),
)

return func


def _has_posit():
return tvm.support.libinfo()["USE_BYODT_POSIT"] == "ON"

Expand Down

0 comments on commit dee3c2a

Please sign in to comment.