[Unity][Frontend] NNModule tensor_ir_op support#16278
Conversation
This PR adds support for `tensor_ir_op` in NNModule, which enables us to call TensorIR function in NNModule. Also this PR adds a test case for extern op.
MasterJH5574
left a comment
There was a problem hiding this comment.
LGTM. Thank you @Hzfengsy!
|
|
||
| def tensor_ir_op( | ||
| func: _tir.PrimFunc, | ||
| name_hint: str, |
There was a problem hiding this comment.
There’s a bit of complication here: if the PrimFunc provided is a public function (has “global_symbol” field in its attrs), Relax is not allowed to rename it, and in this case, it’s not a name hint but a name instead. Therefore, we will have to check symbol duplication and potentially throw an error if it happens.
There was a problem hiding this comment.
We could probably leave this logic to future work, but let’s rename name_hint to name to better reflect this point
There was a problem hiding this comment.
I agree and thanks for pointing it out. However, the current Python interface AddFunction also treats it as name_hint, which may be renamed if conflicts exist.
It would be an independent problem out of the scope of this PR.
|
|
||
|
|
||
| def test_tensor_ir_op(): | ||
| num_q_heads, num_kv_heads, head_dim = 8, 8, 16 |
There was a problem hiding this comment.
This unittest is a bit more complicated than I expected :)) in the simplest case, we could probably just supply a “B = A + 1”-style TIR
junrushao
left a comment
There was a problem hiding this comment.
Merging this in for now, and please follow up with my comments in subsequent PRs
This PR adds support for
tensor_ir_opin NNModule, which enables us to call TensorIR function in NNModule.Also this PR adds a test case for extern op.
cc @junrushao @MasterJH5574