-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multihead scaled dot product attention. #7791
Conversation
7442fe0
to
4a24f76
Compare
9c550a9
to
c0ac68b
Compare
c0ac68b
to
90f334e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
90f334e
to
d163592
Compare
dc11205
to
738cc13
Compare
f4e5bd0
to
d6f2d79
Compare
d6f2d79
to
d00eb53
Compare
python/paddle/v2/fluid/layers/nn.py
Outdated
if len(x_shape) == 1: | ||
x_shape = [1] + x_shape | ||
if len(y_shape) == 1: | ||
y_shape = [1] + y_shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the rank of y
is 1, it is treated as [D, 1]
in nontransposed form.
python/paddle/v2/fluid/nets.py
Outdated
[bs, max_sequence_length, num_heads * hidden_dim]. | ||
""" | ||
|
||
if len(x.shape) == 3: return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might it be return x
here.
doc/api/v2/fluid/nets.rst
Outdated
dot_product_attention | ||
--------------------- | ||
scaled_dot_product_attention | ||
---------------------------- | ||
.. autofunction:: paddle.v2.fluid.nets.dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might it be paddle.v2.fluid.nets.scaled_dot_product_attention
here.
No description provided.