Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DDP/FSDP a regular transform #122

Open
t-vi opened this issue Apr 3, 2024 · 5 comments
Open

Make DDP/FSDP a regular transform #122

t-vi opened this issue Apr 3, 2024 · 5 comments
Assignees
Labels
distributed enhancement New feature or request help wanted Extra attention is needed

Comments

@t-vi
Copy link
Collaborator

t-vi commented Apr 3, 2024

🚀 Feature

Make DDP/FSDP a regular transform (to a large part including making transforms flexible enough to support this).

Motivation

Currently DDP/FSDP is not a regular transform, leading to things like #94 and limiting composability / sequencing.
One of the key bits is that DDP/FSDP would need to do the adjustments we currently do to the prologue during tracing with DDP/FSDP in the transform, so we need to allow mutation of prologues through transforms. This is also in line with similar needs for other transforms (lora, quantization, but also value-and-grad-things) that change prologue signatures, so this generalization should happen.

cc @carmocca @awaelchli @crcrpar

@t-vi t-vi added enhancement New feature or request help wanted Extra attention is needed labels Apr 3, 2024
@t-vi t-vi self-assigned this Apr 3, 2024
@IvanYashchuk
Copy link
Collaborator

What is meant by making DDP/FSDP a regular transform? What are you planning to do?
Today it's not a transform at all, as I commented here #94 (comment). thunder.distributed.ddp/fsdp only annotate parameters for tracing. It's also described in the tutorial https://github.com/Lightning-AI/lightning-thunder/blob/main/notebooks/dev_tutorials/fsdp_tutorial.ipynb

I don't see any other way for sharding happen somewhen after the thunder.jit(model) call. What ideas do you have?
The current workflow is

  1. Shard the model
    Done with thunder.distributed.fsdp(model) or with torch.distributed.FullyShardedDataParallel in PyTorch
  2. Set up the optimizer using the sharded model so that the optimizer state is a shard
  3. Call thunder.jit(sharded_model) or torch.compile(sharded_model) in PyTorch.

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 11, 2024

I would like to move 3 up (for thunder.jit).

@IvanYashchuk
Copy link
Collaborator

Is the preferred order then 3 -> 1 -> 2?

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 22, 2024

So per discussions with @crcrpar and @IvanYashchuk (thank you!)

  • This needs to modify model & prologue and compute trace (and invalidate / modify old cached entries).
  • Needs to come before autograd and stay compatible.
  • We need a good way to represent the changes to model state, the current goal is to put this on the ThunderModule and leave the user modules intact. This is why Change prologue details to prepare for fsdp as a transform #228 simplifies parameter access and also we want to access the thunder module in the prologue (which has also come up before).

(obviously good ideas from Masaki and Ivan, not so good ones my own)

@mruberry
Copy link
Collaborator

triage review — let's start design review with draft PR to discuss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
distributed enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants