Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support embedding & Test auto-sharding on the whole BERT model & Refine auto-sharding interface #49

Merged
merged 3 commits into from
Jul 10, 2021

Conversation

merrymercy
Copy link
Member

@merrymercy merrymercy commented Jul 6, 2021

  • Monkey patch the nn.Embed in flax to use one-hot + matmul instead of gather/scatter
  • Test auto-sharing solver on the whole BERT model (copied from huggingface). The result (transformer + embedding) has exactly the same partition strategy and communication cost as Megatron-LM's solution. I will do some benchmark in the next PR.
  • Refine the auto-sharding interface to better fit Combining Manual Pipeline Parallelism & Automatic SPMD Parallelism #46

Dependency:
need to update flax>=0.3.4

@merrymercy merrymercy changed the title Support embedding & Test auto-sharding on the whole BERT model Support embedding & Test auto-sharding on the whole BERT model & Refine auto-sharding interface Jul 10, 2021
@merrymercy merrymercy merged commit 47fb454 into master Jul 10, 2021
@merrymercy merrymercy deleted the whole-bert branch July 10, 2021 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant