-
Notifications
You must be signed in to change notification settings - Fork 344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient Accumulation #72
Comments
Style 1@parallelize
def train_step(data, params) -> grad
...
acc_grad = 0
for num_acc in range(5):
grad = train_step(data, params) # call all-reduce
acc_grad += grad We needs to support a new sharding specification both in XLA and Jax: PartialResult Style 2@parallelize
def train_step(data, params) -> grad
acc_grad = 0
for num_acc in range(5):
grad = train_step(data, params)
acc_grad += grad XLA needs to simplify Style 3@parallelize
def forward_backward_step(data, params) -> grad
pass
@parallelize
def gradient_update_step(grad, params) -> new_params
pass
acc_grad = 0
for num_acc in range(5):
grad = forward_backward_step(data, params) # call all-reduce
acc_grad += grad
params = gradient_update_step(acc_grad, params) We need to change our interface |
BackgroudTo enable efficient gradient accumulation in both SPMD and pipeline parallelism. We must generate at least two XLA executables: In SPMD mode, they are used as follows micro_batches = split(batch, num_micro_batches)
acc_grad = 0
for i in range(num_micro_batches):
sync = (i == num_micro_batches - 1) # Whether to use all-reduce to sync the local gradients
acc_grad = accumulate_grad(state, micro_batches[i], acc_grad, sync)
apply_grad(state, acc_grad) In Pipeline mode, they are used as follows # For each worker
acc_grad = 0
for i in range(num_micro_batches):
cur_micro_batch = recv_micro_batch_from_previsou_stage()
sync = (i == num_micro_batches - 1) # Whether to use all-reduce to sync the local gradients
acc_grad = accumulate_grad(state, cur_micro_batch, acc_grad, do_sync)
apply_grad(state, acc_grad) DiscussionsHow to get these two functions?
Where should the accumulation loop be?
How to enable in-place updates?The two functions
To share the same memory location, we have to set donate_invars (or alias) in XLA. How to handle the
|
Current gradient accumulation can introduce extra all-reduce overhead when combined with data parallel.
The text was updated successfully, but these errors were encountered: