New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sockeye 2 Interleaved Multi-head Attention Operators #814
Conversation
45ef7b5
to
e2f54a2
Compare
Thanks! This should be mergable once MXNet 1.7 is out, right? Would you expect perplexity scores for training to be exactly the same with and without this change? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @blchu!
I tried this change together with latest nightly build of mxnet (mxnet ==1.7.0b20200527) and I am still getting slightly different perplexity values than current sockeye_2
master.
I see that you merged the input projections of keys and values in the MultiHeadAttention
class, but even when undoing this the small difference remains. Do you know why?
Hm, depending on what you're testing, my best guess is that the weights may be initialized slightly differently by the xavier initializer since the key and value weights are now a single array. Could you clarify what you mean by undoing the merged input projections (i.e. what changes did you make to do that)? In my own tests I didn't observe any more negative impact in training after my fix, but I can check it again to make sure. |
Here's a commit that still gives me different perplexity values when compared to sockeye_2 branch: 1676d6a |
Looking at the changes in your commit, it looks like the discrepancy is likely still there because of the way that you've combined the weights together. The You can reference the code changes in sockeye/inference.py from 6b25ad5 to see how the weights would need to be combined. |
Right, thanks for the clarification. I will try interleaving them instead of concatenating. |
I tried testing the perplexity values myself, as it turns out since the model now transposes the batch and time dimensions, not only do the attention input arrays need to be interleaved together if the keys and values are initialized separately, but also that they need to be transposed for the dropout and layer norms to produce the same results. Accounting for all of this, I'm observing that the array values largely align with the main sockeye_2 branch. I've pushed the changes I've made to run the tests onto a branch from my fork here https://github.com/blchu/sockeye/commits/sockeye_2_test_interleaved_perplex |
e2f54a2
to
975c041
Compare
Now that we have a MXNet 1.7 branch, we can update/review this PR and merge it into that branch so it will be immediately usable for anyone who wants to manually install a MXNet pre-release. When MXNet 1.7 wheels are available, |
It looks like there are some merge conflicts between this PR and the SSRU updates. @barbaradarques and @blchu, as the experts on these parts of the code, can you help us find the best way to resolve the conflicts? |
Looking over the SSRU PR, there is significant overlap since both modify the decoder autoregressive layer. Most of the conflicts should be relatively straightforward to resolve, but one of my changes transposes the batch and time dimensions for the duration of the model to improve attention performance. @barbaradarques is that compatible with the SSRU layer, or can be made compatible? Otherwise I'll look for some workarounds. |
Actually I think there's some positive synergy between the changes as the SSRU code currently temporarily transposes to time major in order to use the foreach operator: https://github.com/awslabs/sockeye/blob/master/sockeye/layers.py#L893 |
Hi @blchu , as @tdomhan mentioned, the transposition shouldn't be a problem for SSRU layers. On top of removing the internal |
We've officially updated Sockeye to MXNet 1.7, so we can merge this into master as soon as the conflicts are resolved. |
I'm currently working on resolving the merge conflicts. Right now I have all tests passing; however, I'm seeing some strange performance and perplexity numbers in training, so I'm trying to get that resolved. |
@blchu, feel free to commit your WIP, maybe we can help figuring out the inconsistencies. |
With further testing, I did notice that while perplexity was noticeable worse during training than what I had measured before, this is also the case with the mx17 branch without any of the changes from the PR, so it may be due to something entirely unrelated. The BLEU scores at each training checkpoint were no different. There's still an issue with the inference during a training checkpoint using an SSRU decoder, though I think I've narrowed the cause to a different state shape than in pure inference ([batch, dim] vs the expected [time, batch, dim]). |
thanks @blchu though I think something went wrong with the merge/conflict resolution (the PR shows 49 changed files). The base branch of this PR has been changed back to master (which now includes the mx17 branch). |
Ah hm, that might have been because the way I did the merge and the fact that I was originally building off of the sockeye_2 branch. I think the files changed is using a diff off of that branch, and not master perhaps? I'm not entirely sure how to fix that. |
Hi @blchu , since a lot has changed since the original fork, it might help if you create a new branch from master and copy+paste on top of it the changes of those 6 files you actually changed. |
I've just created a new PR by cherry-picking my commits to a new branch off of master, since I can't modify the source branch of this PR. The new PR can be seen here #884. |
Replaced batched dot product in multi-head attention with interleaved_matmul attention operators to improve performance. Also changes the batch-major data to time-major format while in the model to comply with the new operator requirements
Pull Request Checklist
until you can check this box.
pytest
)pytest test/system
)./style-check.sh
)sockeye/__init__.py
. Major version bump if this is a backwards incompatible change.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.