Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sockeye 2 Interleaved Multi-head Attention Operators #814

Closed
wants to merge 12 commits into from

Conversation

blchu
Copy link
Contributor

@blchu blchu commented May 22, 2020

Replaced batched dot product in multi-head attention with interleaved_matmul attention operators to improve performance. Also changes the batch-major data to time-major format while in the model to comply with the new operator requirements

Pull Request Checklist

  • Changes are complete (if posting work-in-progress code, prefix your pull request title with '[WIP]'
    until you can check this box.
  • Unit tests pass (pytest)
  • Were system tests modified? If so did you run these at least 5 times to account for the variation across runs?
  • System tests pass (pytest test/system)
  • Passed code style checking (./style-check.sh)
  • You have considered writing a test
  • Updated major/minor version in sockeye/__init__.py. Major version bump if this is a backwards incompatible change.
  • Updated CHANGELOG.md

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

@fhieber
Copy link
Contributor

fhieber commented May 27, 2020

Thanks! This should be mergable once MXNet 1.7 is out, right? Would you expect perplexity scores for training to be exactly the same with and without this change?

Copy link
Contributor

@fhieber fhieber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @blchu!
I tried this change together with latest nightly build of mxnet (mxnet ==1.7.0b20200527) and I am still getting slightly different perplexity values than current sockeye_2 master.
I see that you merged the input projections of keys and values in the MultiHeadAttention class, but even when undoing this the small difference remains. Do you know why?

sockeye/layers.py Outdated Show resolved Hide resolved
sockeye/layers.py Show resolved Hide resolved
sockeye/model.py Outdated Show resolved Hide resolved
@blchu
Copy link
Contributor Author

blchu commented May 27, 2020

Hm, depending on what you're testing, my best guess is that the weights may be initialized slightly differently by the xavier initializer since the key and value weights are now a single array. Could you clarify what you mean by undoing the merged input projections (i.e. what changes did you make to do that)? In my own tests I didn't observe any more negative impact in training after my fix, but I can check it again to make sure.

@fhieber
Copy link
Contributor

fhieber commented May 28, 2020

Here's a commit that still gives me different perplexity values when compared to sockeye_2 branch: 1676d6a
I might have missed something though.

@blchu
Copy link
Contributor Author

blchu commented Jun 2, 2020

Looking at the changes in your commit, it looks like the discrepancy is likely still there because of the way that you've combined the weights together. The interleaved_matmul_* ops don't assume that the key and value projections are simply concatenated, but rather are striped together over the head dimension (hence the name "interleaved").

You can reference the code changes in sockeye/inference.py from 6b25ad5 to see how the weights would need to be combined.

@fhieber
Copy link
Contributor

fhieber commented Jun 3, 2020

Looking at the changes in your commit, it looks like the discrepancy is likely still there because of the way that you've combined the weights together. The interleaved_matmul_* ops don't assume that the key and value projections are simply concatenated, but rather are striped together over the head dimension (hence the name "interleaved").

You can reference the code changes in sockeye/inference.py from 6b25ad5 to see how the weights would need to be combined.

Right, thanks for the clarification. I will try interleaving them instead of concatenating.

@blchu
Copy link
Contributor Author

blchu commented Jun 12, 2020

I tried testing the perplexity values myself, as it turns out since the model now transposes the batch and time dimensions, not only do the attention input arrays need to be interleaved together if the keys and values are initialized separately, but also that they need to be transposed for the dropout and layer norms to produce the same results. Accounting for all of this, I'm observing that the array values largely align with the main sockeye_2 branch.
There still is a small difference in perplexity values in the system test, and that's likely due to a ~1e-8 difference in outputs from the attention operator itself, since I believe its implementation uses a different GEMM than batch_dot does. I don't believe that this should result in any significant issue though.

I've pushed the changes I've made to run the tests onto a branch from my fork here https://github.com/blchu/sockeye/commits/sockeye_2_test_interleaved_perplex

@blchu blchu force-pushed the sockeye_2_interleaved_optim branch from e2f54a2 to 975c041 Compare June 19, 2020 02:23
@blchu blchu requested review from fhieber and removed request for davvil June 23, 2020 17:36
@mjdenkowski mjdenkowski changed the base branch from sockeye_2 to master September 2, 2020 17:20
@mjdenkowski
Copy link
Contributor

Now that we have a MXNet 1.7 branch, we can update/review this PR and merge it into that branch so it will be immediately usable for anyone who wants to manually install a MXNet pre-release. When MXNet 1.7 wheels are available, mx17 can merge into master and the improvements in this PR will be included automatically.

@mjdenkowski
Copy link
Contributor

It looks like there are some merge conflicts between this PR and the SSRU updates. @barbaradarques and @blchu, as the experts on these parts of the code, can you help us find the best way to resolve the conflicts?

@blchu
Copy link
Contributor Author

blchu commented Sep 8, 2020

Looking over the SSRU PR, there is significant overlap since both modify the decoder autoregressive layer. Most of the conflicts should be relatively straightforward to resolve, but one of my changes transposes the batch and time dimensions for the duration of the model to improve attention performance. @barbaradarques is that compatible with the SSRU layer, or can be made compatible? Otherwise I'll look for some workarounds.

@tdomhan
Copy link
Contributor

tdomhan commented Sep 9, 2020

Actually I think there's some positive synergy between the changes as the SSRU code currently temporarily transposes to time major in order to use the foreach operator: https://github.com/awslabs/sockeye/blob/master/sockeye/layers.py#L893
When merging we'd just to make sure to remove that transpose in favor your transpose for the model.

@barbaradarques
Copy link
Contributor

barbaradarques commented Sep 9, 2020

Hi @blchu , as @tdomhan mentioned, the transposition shouldn't be a problem for SSRU layers. On top of removing the internal transpose() operations, you'd also need to update get_state_shape(), but that should be it.

@mjdenkowski
Copy link
Contributor

We've officially updated Sockeye to MXNet 1.7, so we can merge this into master as soon as the conflicts are resolved.

@blchu
Copy link
Contributor Author

blchu commented Sep 14, 2020

I'm currently working on resolving the merge conflicts. Right now I have all tests passing; however, I'm seeing some strange performance and perplexity numbers in training, so I'm trying to get that resolved.

@barbaradarques
Copy link
Contributor

@blchu, feel free to commit your WIP, maybe we can help figuring out the inconsistencies.

@blchu
Copy link
Contributor Author

blchu commented Sep 16, 2020

With further testing, I did notice that while perplexity was noticeable worse during training than what I had measured before, this is also the case with the mx17 branch without any of the changes from the PR, so it may be due to something entirely unrelated. The BLEU scores at each training checkpoint were no different.

There's still an issue with the inference during a training checkpoint using an SSRU decoder, though I think I've narrowed the cause to a different state shape than in pure inference ([batch, dim] vs the expected [time, batch, dim]).

@fhieber
Copy link
Contributor

fhieber commented Sep 17, 2020

thanks @blchu though I think something went wrong with the merge/conflict resolution (the PR shows 49 changed files). The base branch of this PR has been changed back to master (which now includes the mx17 branch).

@blchu
Copy link
Contributor Author

blchu commented Sep 18, 2020

Ah hm, that might have been because the way I did the merge and the fact that I was originally building off of the sockeye_2 branch. I think the files changed is using a diff off of that branch, and not master perhaps? I'm not entirely sure how to fix that.

@blchu blchu changed the base branch from master to sockeye_2 September 18, 2020 06:10
@blchu blchu changed the base branch from sockeye_2 to master September 18, 2020 06:10
@barbaradarques
Copy link
Contributor

Hi @blchu , since a lot has changed since the original fork, it might help if you create a new branch from master and copy+paste on top of it the changes of those 6 files you actually changed.

@blchu
Copy link
Contributor Author

blchu commented Sep 23, 2020

I've just created a new PR by cherry-picking my commits to a new branch off of master, since I can't modify the source branch of this PR. The new PR can be seen here #884.

@blchu blchu closed this Sep 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants