Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added matmul_v2+transpose+reshape fuse pass #36481

Merged

Conversation

jakpiase
Copy link
Contributor

PR types

New features

PR changes

OPs

Describe

Added matmul_v2+transpose+reshape fuse pass(same behavior as matmul+transpose+reshape fuse pass). It is used almost only in BERT-like models but was requested in #36461 as highest priority for 2.2 release

Copy link
Contributor

@lidanqing-intel lidanqing-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I haven't finish all reviews. Thank you very much for your work !

Copy link
Contributor

@arlesniak arlesniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

Maybe you could make some additional fuse tests like it's done in python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py:429 ?

Copy link
Contributor

@lidanqing-intel lidanqing-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for your contribution Jakub !

Copy link
Contributor

@arlesniak arlesniak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@lidanqing-intel lidanqing-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jakpiase
Copy link
Contributor Author

@jczaja please merge this PR

@jczaja jczaja merged commit 856cb9c into PaddlePaddle:develop Oct 21, 2021
@jczaja
Copy link
Contributor

jczaja commented Oct 21, 2021

@jakpiase It is my pleasure to merge your another impactful contribution:)

def set_op_type(self):
self.op_type = "matmul_v2"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This UT seems inheritate tests of matmul_v1 and did not test matmul_v2 broadcasting with transpose, reshape fuses. So, should matmul_v2 broadcasting case be detected at graph pattern detecting stage and excluded in fuses ? Cause it is not tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have recently tested it locally with broadcasting(and it works), from what I know this fuse pass is for Ernie model, and there were no broadcasting and I wanted to finish that PR as soon as possible, that's why I have not included that here, if you want I can add these additional broadcasting tests now

@baoachun
Copy link
Contributor

baoachun commented Dec 2, 2021

Hi @jakpiase @lidanqing-intel , the matmul_v2_transpose_reshape_fuse_pass will get MKLDNNDeviceContext error when turning on GLOG_v, chould you please check it out? You can refer this pr #37416 to reproduce the problem.
图片

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants