-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Closed
Description
I have an onnx model that I want to run with my own data type.
So I got the idea from bring_your_own_datatypes.py and developed the following code:
import numpy as np
import tvm
from tvm.contrib import graph_runtime
import os
import onnx
from tvm import relay
import ctypes
from tvm.relay.frontend.change_datatype import ChangeDatatype
#register the custom type with TVM
ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
tvm.target.datatype.register("float_st", 150)
ctx = tvm.cpu()
# load example onnx model
onnx_model = onnx.load('./model.onnx')
# convert to relay, needs the onnx model and input layer name and shape
module, params = relay.frontend.from_onnx( onnx_model, {"input_1": (100,128,3)} )
ex = tvm.relay.create_executor("graph", mod=module)
def convert_ndarray(dst_dtype, array):
"""Converts an NDArray into the specified datatype"""
x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
cast = relay.Function([x], x.astype(dst_dtype))
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
a = relay.create_executor("graph").evaluate(cast)(array)
return a
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func(
{
(32, 128): "FloatToFloatst",
}
),
"Cast",
"llvm",
"float",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
"Cast",
"llvm",
"float_st",
"float",
)
src_dtype = "float32"
dst_dtype = "custom[float_st]128"
module = relay.transform.InferType()(module)
# Currently, custom datatypes only work if you run simplify_inference beforehand
module = tvm.relay.transform.SimplifyInference()(module)
# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)
# Change datatype from float to float_st and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)
# We need to convert our input:
data_shape = [100,128,3]
input = np.random.uniform(size=data_shape).astype('float32')
input_st = convert_ndarray(dst_dtype, input)
# We also convert the parameters:
params_st = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}
#register all the needed functions:
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
"FloatImm",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_ite, "Call", "llvm", "float_st", intrinsic_name="tir.if_then_else"
)
tvm.target.datatype.register_op(
tvm.target.datatype.lower_call_pure_extern,
"Call",
"llvm",
"float_st",
intrinsic_name="tir.call_pure_extern",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
"Mul",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
"Div",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.sqrt",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
"Sub",
"llvm",
"float_st",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
"Call",
"llvm",
"float_st",
intrinsic_name="tir.exp",
)
tvm.target.datatype.register_op(
tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
"Max",
"llvm",
"float_st",
)
tvm.target.datatype.register_min_func(
tvm.target.datatype.create_min_lower_func({128: "MinFloatst"}, "float_st"),
"float_st",
)
# Vectorization is not implemented with custom datatypes.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
result_myfloat = ex.evaluate(expr)(input_st, **params_st)
result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
print(result_myfloat)
but when I try to run it I get the following error:
data types custom[float_st]128 and float32do not match in BroadcastRel
data types custom[float_st]128 and float32do not match in BroadcastRel
After changing the model's data type form float to custom data type using line expr = cdtype.visit(module["main"]) with printing expr I noticed that there are still two instructions in the model's tree which have float32 data type:
%387 = zeros(shape=[100, 128], dtype="float32");
%402 = zeros(shape=[100, 128], dtype="float32");
I guess the error is related to these zeros but I don't know how can change their data type to my custom data type. Any idea?
I uploaded the onnx model as well as float_st.so here for reproducing the error
Thanka in advance
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels