Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] show how to not use a global mesh. #549

Closed
wants to merge 1 commit into from

Conversation

nouiz
Copy link
Collaborator

@nouiz nouiz commented Dec 1, 2023

The TE/JAX back-end currently has a notion of global mesh.
But JAX doesn't have that.

I think I found a way to remove this notion of global mesh and I hope it would simplify the code.

@denera @mingxu1067 What do you think of that?
The impl() lambda can capture the mesh information instead of having a global mesh.

@mingxu1067
Copy link
Collaborator

This might be a good solution to avoid on getting mesh from a global variable.
Usually, all_reduce_sum_along_dp_fsdp and all_reduce_max_along_all_axes_except_PP are invoked in partition of custom_partitioning, and it is passed with the current used mesh, then we could pass that mesh to p* functions.

@nouiz nouiz added the jax label Feb 14, 2024
@ksivaman
Copy link
Member

After a discussion with @nouiz, we will re-open a fresh PR.

@nouiz
Copy link
Collaborator Author

nouiz commented Aug 16, 2024

New version of that PR: #1112

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants