Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving AxoNN's memory consumption #95

Merged
merged 3 commits into from
Jul 9, 2024
Merged

Conversation

siddharth9820
Copy link
Collaborator

@siddharth9820 siddharth9820 commented Jul 9, 2024

When compared to FSDP, AxoNN seems to consume a lot more memory. I have identified and fixed those issues in this PR.

  1. Embedding layer - Our tensor parallelism did not parallelize embedding layers prior to this PR. During training, the parameters, gradients and optimizer states of the embedding layer can take up a lot of space. For example, in LLama-3 8B this amounts to 10GB!
  2. Weight caching / ZeRO-2-esque behavior without activation checkpointing - In depth TP, if the user isn't using activation checkpointing the entire parameter state is materialized in the GPU memory. People expect finetuning to work sans activation checkpointing since the batch sizes are much smaller than in pretraining. To alleviate this we only store the sharded weights during the forward pass, and do an extra all-gather in the backward pass which saves a lot of memory - (~16GB for llama-3 8B)

Here, I am showing the peak memory consumption, time and loss curves for IFT of llama-3-8B on 4 A100 GPUs, with a micro batch size of 2 and batch size of 8 on the alpaca dataset. We are using depth_tp=4 for AxoNN. Precision is bf16-mixed

Blue - FSDP for reference
Red (dotted) - AxoNN prior to this PR
Red (solid) - AxoNN after this PR

image

Observations -

  1. [first graph] This PR has dropped the peak memory usage of AxoNN from 70GB to 40.5 GB per GPU! This is 10% smaller than FSDP (44 GB)
  2. [second graph] There is a slight increase in the per iteration times (from 0.48s to 0.57s). This is expected since we are doing an extra-allgather. Even so, we are still 23% faster than FSDP. (Note that fabric.backward has been modified to overlap the extra all-gathers in the backward pass as well! )
  3. [third graph] Loss curves are near identical, which is good!

Another goal of this and future PRs is to make the pytorch lightning user experience identical to FSDP.

  • We are now offering optimize_communication as a flag in AxonnStrategy. The strategy places the forward and the backward passes under the appropriate context manager. The user doesn't need to do anything, apart from setting the boolean to true during intialization.

Will tackle these in a separate PR.

  • We should still offer the old setting (i.e. Z2 style parameter caching) as well. We should warn the user that it should be used with activation checkpointing
  • We need to offer a similar activation checkpointing API as FSDP, wherein the user just needs to pass the module to be checkpointed! FSDP has spoilt everyone :)
  • We need to offer an option to directly save a combined checkpoint from rank0.

@siddharth9820 siddharth9820 merged commit c5ef65d into develop Jul 9, 2024
4 checks passed
@siddharth9820 siddharth9820 deleted the mem-efficiency branch July 9, 2024 17:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant