-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Combining Manual Pipeline Parallelism & Automatic SPMD Parallelism #46
Conversation
@@ -94,7 +94,7 @@ def loss_func(params): | |||
|
|||
hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) | |||
attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32) | |||
label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) | |||
label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23.0 * np.arange(hidden_size)[None, None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this?
@@ -0,0 +1,11 @@ | |||
# XLA Pipeline Marker Custom Call |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a custom call is not as simple as I thought. How can we simplify this? For example, can putting the code to tensorflow-parax simplify the compilation process?
|
||
xla_computation = xla_client.XlaComputation(hlo_proto) | ||
num_devices = np.prod(strategy_config.logical_mesh_shape) | ||
assert num_devices == len(self.backend.devices()) | ||
|
||
compiled = compile_with_given_strategy( | ||
compiled = compile_without_auto_sharding( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not always correct. When not using 3d parallel, we will pass an unoptimized HLO Proto. In this case, we need to call compile_with_given_strategy
. I can fix this for you later.
No description provided.