Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Accumulation #72

Closed
zhuohan123 opened this issue Aug 18, 2021 · 3 comments
Closed

Gradient Accumulation #72

zhuohan123 opened this issue Aug 18, 2021 · 3 comments

Comments

@zhuohan123
Copy link
Member

Current gradient accumulation can introduce extra all-reduce overhead when combined with data parallel.

@merrymercy
Copy link
Member

merrymercy commented Aug 18, 2021

Style 1

@parallelize
def train_step(data, params) -> grad
    ...

acc_grad = 0
for num_acc in range(5):
    grad = train_step(data, params)   # call all-reduce
    acc_grad += grad

We needs to support a new sharding specification both in XLA and Jax: PartialResult

Style 2

@parallelize
def train_step(data, params) -> grad
    acc_grad = 0
    for num_acc in range(5):
        grad = train_step(data, params)
        acc_grad += grad

XLA needs to simplify allreduce(x) + allreduce(y) to allreduce(x+y)

Style 3

@parallelize
def forward_backward_step(data, params) -> grad
    pass

@parallelize
def gradient_update_step(grad, params) -> new_params
    pass

acc_grad = 0
for num_acc in range(5):
    grad = forward_backward_step(data, params)   # call all-reduce
    acc_grad += grad
params = gradient_update_step(acc_grad, params)

We need to change our interface

@zhisbug zhisbug mentioned this issue Aug 18, 2021
3 tasks
@merrymercy
Copy link
Member

merrymercy commented Aug 21, 2021

Backgroud

To enable efficient gradient accumulation in both SPMD and pipeline parallelism. We must generate at least two XLA executables: accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad and apply_grad(state, grad) -> new_state

In SPMD mode, they are used as follows

micro_batches = split(batch, num_micro_batches)
acc_grad = 0

for i in range(num_micro_batches):
    sync = (i == num_micro_batches - 1)  # Whether to use all-reduce to sync the local gradients
    acc_grad = accumulate_grad(state, micro_batches[i], acc_grad, sync)

apply_grad(state, acc_grad)

In Pipeline mode, they are used as follows

# For each worker
acc_grad = 0

for i in range(num_micro_batches):
    cur_micro_batch = recv_micro_batch_from_previsou_stage()
    sync = (i == num_micro_batches - 1)  # Whether to use all-reduce to sync the local gradients
    acc_grad = accumulate_grad(state, cur_micro_batch, acc_grad, do_sync)

apply_grad(state, acc_grad)

Discussions

How to get these two functions?

  • A1: Force users to provide two functions.
    • A1.1: Users define
      compute_grad(state, batch) -> grad
      apply_grad(state, grad) -> new_state
      We then derive accumulate_grad from compute_grad
    • A1.2: Users define
      forward(state, batch) -> loss
      apply_grad(state, grad) -> new_state
      We then derive accumulate_grad from forward. This method is less general but works better for pipeline_marker.
    • Pros: By forcing users to provide these functions, we can make fewer assumptions and guesses.
    • Cons: Not compatible with existing jax/flax programs
  • A2: Use parax.grad to replace jax.grad:
    @parallelize
    def func(optimizer, batch):
         def loss_func(params):
              return ...
    
        grads = parax.grad(loss_func)(optimizer.target)
        new_optimizer = optimizer.apply_gradient(grads)
        return new_optimizer
    parax.grad inserts a separator after gradient computation. We can reuse pipeline_marker for this separator. This separator partitions the original jaxpr into compute_grad and apply_grad. We then derive accumulate_grad from compute_grad.
  • A3: We use static analysis (batch dimension propagation) to separate a whole computational graph into compute_grad, apply_grad
    • Pros: Compatible with existing jax/flax programs
    • Cons: Does not work for pipeline_marker. The static analysis is hard to be robust.

Where should the accumulation loop be?

  • B1: Put the loop in HLO
    • Pros
      • Less python runtime overhead
    • Cons
      • Does not work in pipeline mode due to our current existing implementation
      • Has to deal with while loop in HLO passes
  • B2: Put the loop in our python runtime
    The pros and cons are the opposites of the above ones.

How to enable in-place updates?

The two functions accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad and apply_grad(state, grad) -> new_state should be compiled with the following constraints:

  1. old_grad, new_grad and their corresponding parameters in state should share the same sharding spec.
  2. old_grad and new_grad should share the same memory location.
  3. new_state and state should share the same memory location and the sharding spec.

To share the same memory location, we have to set donate_invars (or alias) in XLA.
To share the same sharding spec, we have to pass these constraints to the ILP solver.

How to handle the sync argument of accumulate_grad?

  • C1: Compile two versions of accumulate_grad, one does sync and the other does not
  • C2: Compile one executable and use two branches in XLA.
  • C3: Compile two executables: accmulate_grad and sync_grad, we then dispatched them in our python runtime loop. However, this makes it impossible to overlap all-reduce and computation.

How to handle other optimizations?

We want the memory of gradients to be continuous. This can benefit all-reduce.

@zhuohan123
Copy link
Member Author

#87 #90

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants