Skip to content

Conversation

@FanhaiLu1
Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 commented May 24, 2024

This PR add below changes:

1: move torch_xla2.default_env() to function. jax_mode = torch_xla2.default_env() block jax multiple controller in init state
2: ray engine create is different than default run server one, it will have prefill and decode engines later
3: removed duplciated JetEngineEnvironment
4: Not support shard_on_batch and ragged attention in ray multiple for now

Copy link
Collaborator

@wang2yn84 wang2yn84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?

@wang2yn84
Copy link
Collaborator

Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?

Can you help me understand why is it?

@FanhaiLu1
Copy link
Collaborator Author

Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?

Can you help me understand why is it?

The is jax call under this function ( or deeper). For any jax function call, it will try to init the multiple controller env (though MPI barrier), which mean need to wait all the chips finished. So in ray multiple host, if there is a jax function call in head node, it will wait all the chips be ready, but only the head node chip is ready at this time, the the whole application will stuck there.

For current use case, it happens when Ray head load the class even, it call the jax and stuck there even before start execute main function.

@FanhaiLu1 FanhaiLu1 merged commit 2880904 into AI-Hypercomputer:main May 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants