Skip to content

[Unity][Frontend] NNModule tensor_ir_op support#16278

Merged
junrushao merged 1 commit intoapache:unityfrom
Hzfengsy:tensor_ir_op
Dec 26, 2023
Merged

[Unity][Frontend] NNModule tensor_ir_op support#16278
junrushao merged 1 commit intoapache:unityfrom
Hzfengsy:tensor_ir_op

Conversation

@Hzfengsy
Copy link
Copy Markdown
Member

@Hzfengsy Hzfengsy commented Dec 26, 2023

This PR adds support for tensor_ir_op in NNModule, which enables us to call TensorIR function in NNModule.

Also this PR adds a test case for extern op.

cc @junrushao @MasterJH5574

This PR adds support for `tensor_ir_op` in NNModule, which enables us to
call TensorIR function in NNModule.

Also this PR adds a test case for extern op.
Copy link
Copy Markdown
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @Hzfengsy!


def tensor_ir_op(
func: _tir.PrimFunc,
name_hint: str,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s a bit of complication here: if the PrimFunc provided is a public function (has “global_symbol” field in its attrs), Relax is not allowed to rename it, and in this case, it’s not a name hint but a name instead. Therefore, we will have to check symbol duplication and potentially throw an error if it happens.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably leave this logic to future work, but let’s rename name_hint to name to better reflect this point

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree and thanks for pointing it out. However, the current Python interface AddFunction also treats it as name_hint, which may be renamed if conflicts exist.

It would be an independent problem out of the scope of this PR.



def test_tensor_ir_op():
num_q_heads, num_kv_heads, head_dim = 8, 8, 16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unittest is a bit more complicated than I expected :)) in the simplest case, we could probably just supply a “B = A + 1”-style TIR

Copy link
Copy Markdown
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this in for now, and please follow up with my comments in subsequent PRs

@junrushao junrushao merged commit 889d2f6 into apache:unity Dec 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants