Skip to content

Commit

Permalink
Use better function name for te_lowering and annotate current target …
Browse files Browse the repository at this point in the history
…at TE functions
  • Loading branch information
cgerum authored and MichaelJKlaiber committed Aug 9, 2022
1 parent 0418ad8 commit 0f0b1bf
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/tvm/relay/backend/contrib/uma/api/lower.py
Expand Up @@ -82,15 +82,20 @@ def _get_tensors(te_cached_func):

return args + outputs

f = tvm._ffi.get_global_func("relay.backend.LowerToTE")
te_cached_func = f(relay_prim_func)
lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE")
te_cached_func = lower_to_te(relay_prim_func)
x = _get_tensors(te_cached_func)
tir_prim_func = te.create_prim_func(x)
tir_prim_func = tir_prim_func.with_attr(
"global_symbol", relay_prim_func.attrs["global_symbol"]
)
# TODO: The target should probably come from somewhere else instead of being created here.
tir_prim_func = tir_prim_func.with_attr("target", tvm.target.Target(self.target_name))

compiler_attr = relay_prim_func.attrs["Compiler"]
target = tvm.target.Target.current()
if target.kind.name != compiler_attr:
target = tvm.target.Target(compiler_attr)

tir_prim_func = tir_prim_func.with_attr("target", target)
tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs)
return tir_prim_func

Expand Down

0 comments on commit 0f0b1bf

Please sign in to comment.