Skip to content

BYODT- Error: my custom data type and float32 do not match in BroadcastRel #7811

@sahooora

Description

@sahooora

I‌ have an onnx model that I want to run with my own data type.

So I got the idea from bring_your_own_datatypes.py and developed the following code:

import numpy as np
import tvm
from tvm.contrib import graph_runtime
import os
import onnx
from tvm import relay
import ctypes
from tvm.relay.frontend.change_datatype import ChangeDatatype


#register the custom type with TVM
ctypes.CDLL('./float_st.so', ctypes.RTLD_GLOBAL)
tvm.target.datatype.register("float_st", 150)

ctx = tvm.cpu()

# load example onnx model 
onnx_model = onnx.load('./model.onnx')

# convert to relay, needs the onnx model and input layer name and shape
module, params = relay.frontend.from_onnx( onnx_model, {"input_1": (100,128,3)} )


ex = tvm.relay.create_executor("graph", mod=module)


def convert_ndarray(dst_dtype, array):
    """Converts an NDArray into the specified datatype"""
    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
    cast = relay.Function([x], x.astype(dst_dtype))
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        a = relay.create_executor("graph").evaluate(cast)(array)
        return a


tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func(
        {
            (32, 128): "FloatToFloatst", 
        }
    ),
    "Cast",
    "llvm",
    "float",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({(128, 32): "FloatstToFloat"}),
    "Cast",
    "llvm",
    "float_st",
    "float",
)



src_dtype = "float32"
dst_dtype = "custom[float_st]128"


module = relay.transform.InferType()(module)

# Currently, custom datatypes only work if you run simplify_inference beforehand
module = tvm.relay.transform.SimplifyInference()(module)

# Run type inference before changing datatype
module = tvm.relay.transform.InferType()(module)

# Change datatype from float to float_st and re-infer types
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)

# We need to convert our input:
data_shape = [100,128,3]
input = np.random.uniform(size=data_shape).astype('float32')
input_st = convert_ndarray(dst_dtype, input)

# We also convert the parameters:
params_st = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}


#register all the needed functions:
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatToFloatst"}),
    "FloatImm",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_ite, "Call", "llvm", "float_st", intrinsic_name="tir.if_then_else"
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_call_pure_extern,
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.call_pure_extern",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstMul"}),
    "Mul",
    "llvm",
    "float_st",
)
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstDiv"}),
    "Div",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstSqrt"}),
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.sqrt",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstSub"}),
    "Sub",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstExp"}),
    "Call",
    "llvm",
    "float_st",
    intrinsic_name="tir.exp",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({128: "FloatstMax"}),
    "Max",
    "llvm",
    "float_st",
)

tvm.target.datatype.register_min_func(
    tvm.target.datatype.create_min_lower_func({128: "MinFloatst"}, "float_st"),
    "float_st",
)



# Vectorization is not implemented with custom datatypes.
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
    result_myfloat = ex.evaluate(expr)(input_st, **params_st)
    result_myfloat = convert_ndarray(src_dtype, result_myfloat).asnumpy()
    print(result_myfloat)

but when I try to run it I get the following error:

data types custom[float_st]128 and float32do not match in BroadcastRel
data types custom[float_st]128 and float32do not match in BroadcastRel

After changing the model's data type form float to custom data type using line expr = cdtype.visit(module["main"]) with printing expr I noticed that there are still two instructions in the model's tree which have float32 data type:

%387 = zeros(shape=[100, 128], dtype="float32");
%402 = zeros(shape=[100, 128], dtype="float32");

I guess the error is related to these zeros but I don't know how can change their data type to my custom data type. Any idea?

I uploaded the onnx model as well as float_st.so here for reproducing the error

Thanka in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions