-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add dot-product attention #4674
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost LGTM.
@@ -1396,6 +1398,85 @@ def simple_attention(encoded_sequence, | |||
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name) | |||
|
|||
|
|||
@wrap_name_default() | |||
def dot_product_attention(encoded_sequence, | |||
attending_sequence, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attending_sequence --> attended_sequence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
name=None): | ||
""" | ||
Calculate and return a context vector with dot-product attention mechanism. | ||
Size of the context vector equals to size of the attending_sequence. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dimension of context vector equals to the dimension of the attended sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
c_{i} & = \\sum_{j=1}^{T_{x}}a_{i,j}z_{j} | ||
|
||
where :math:`h_{j}` is the jth element of encoded_sequence, | ||
:math:`z_{j}` is the jth element of attending_sequence, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attended sequence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
.. code-block:: python | ||
|
||
context = dot_product_attention(encoded_sequence=enc_seq, | ||
attending_sequence=att_seq, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attending_sequence --> attended_sequence
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
attending_sequence=att_seq, | ||
transformed_state=state,) | ||
|
||
:param name: name of the dot-product attention model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A prefix attached to the name of each layer that defined inside the dot_product_attention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
:type softmax_param_attr: ParameterAttribute | ||
:param encoded_sequence: output of the encoder | ||
:type encoded_sequence: LayerOutput | ||
:param attending_sequence: attention weight is computed by a feed forward neural |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention weight ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
hidden state of previous time step and encoder's output. | ||
attending_sequence is the sequence to be attended. | ||
:type attending_sequence: LayerOutput | ||
:param transformed_state: transformed hidden state of decoder in previous time step, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The transformed ...
- Are words "transformed hidden state" the commonly accepted name used in the original paper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i use this just for flexibility consideration
attending_sequence is the sequence to be attended. | ||
:type attending_sequence: LayerOutput | ||
:param transformed_state: transformed hidden state of decoder in previous time step, | ||
its size should equal to encoded_sequence's. Here we do the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whose dimension should be equal to encoded_sequence's dimension. Or use a period at the end of last sentence and changes "its" into "Its".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
transformation outside dot_product_attention for flexibility | ||
consideration. | ||
:type transformed_state: LayerOutput | ||
:return: a context vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The context vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
:return: a context vector | ||
:rtype: LayerOutput | ||
""" | ||
assert transformed_state.size == encoded_sequence.size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please leaves a message to explain the check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
@lcy-seso