diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 605817c330..69cd1d1e94 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -261,9 +261,6 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward # the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble, # and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay). -model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active. -# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step. -# False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass. pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration. # This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed # to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b9f243a073..e0fced3edf 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -1493,10 +1493,6 @@ class DerivedValues(BaseModel): None, description="Boolean flag indicating if pipeline parallelism is active across ICI or DCN.", ) - model_fsdp_ag_once: bool = Field( - False, - description="An alias for `pipeline_fsdp_ag_once` for backward compatibility.", - ) context_parallel_size: None | int = Field( None, @@ -1989,8 +1985,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ): self.logical_axis_rules.append(["aqt_amax_history", ("stage",)]) - self.model_fsdp_ag_once = self.pipeline_fsdp_ag_once # Backward compatibility alias - # H. RUN ALL CROSS-FIELD VALIDATIONS if self.load_parameters_path and self.load_full_state_path: raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.") diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 4f712d84be..1bbed113bd 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -717,11 +717,11 @@ def __call__( ) if cfg.using_pipeline_parallelism: if cfg.pipeline_fsdp_ag_once: - partition_spec = self.pipeline_module.get_weight_sharding( + logical_partition_spec = self.pipeline_module.get_weight_sharding( y, decoder_segment_ids, decoder_positions, deterministic, model_mode ) else: - partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. + logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." dense_layer = RemattedBlockLayers[0] @@ -750,9 +750,9 @@ def __call__( in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) else: # Not DeepSeek - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec) remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers if remaining_layers > 0: logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index d79e121412..331941e13c 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -35,7 +35,6 @@ from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen -from MaxText.maxtext_utils import all_gather_over_fsdp # ------------------------------------------------------------------------------ # The network: Transformer Definitions @@ -517,120 +516,3 @@ def __call__( return hidden_state, kv_caches return logits - - -class ZeroOneTransformer(nn.Module): - """ - A wrapper for the base Transformer model designed to implement the Zero-1 - FSDP optimization. - - The goal of this optimization is to reduce communication overhead. In the standard - FSDP implementation, an all-gather operation on the model weights is performed twice - for each gradient accumulation microbatch (once for the forward pass, once for the backward pass). - This class changes that behavior. When enabled, it performs the all-gather operation - only *once* per full gradient accumulation step. It gathers the full weights into - memory, runs all the microbatch forward and backward passes, and then releases the - full weights. This trades higher peak memory usage for significantly reduced - network communication, which can improve training speed if sufficient memory is - available. - """ - - config: Config - mesh: Mesh - quant: Quant - # Possible model_mode values can be found in MaxText.common_types. - # We generally use MaxText.common_types.MODEL_MODE_TRAIN or - # MaxText.common_types.MODEL_MODE_PREFILL for initializations here. - # TODO: Make model_mode required after confirming no users are affected. - model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ - - def setup(self): - """Sets up the underlying Transformer model. - - This method initializes the `self.model` attribute by calling the - `transformer_as_linen` factory function. - """ - self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode) - - def __call__( - self, - decoder_input_tokens: jnp.ndarray, - decoder_positions: jnp.ndarray, - decoder_segment_ids=None, - encoder_images: None | jnp.ndarray = None, - encoder_image_masks: None | jnp.ndarray = None, - enable_dropout=True, - model_mode=MODEL_MODE_TRAIN, - previous_chunk=None, - true_length: None | int = None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - partition_spec=None, - decoder_target_tokens: None | jnp.ndarray = None, - decoder_target_mask: None | jnp.ndarray = None, - nnx_method: str | None = None, - ): - """Applies the Zero-1 FSDP wrapped Transformer model. - - This method handles the all-gather operation for model weights before - applying the underlying Transformer model, and then releases them. - - Args: - decoder_input_tokens: Input tokens for the decoder. - decoder_positions: Positional encodings for the decoder inputs. - decoder_segment_ids: Segment IDs for the decoder inputs (optional). - encoder_images: Encoder images for multimodal models (optional). - enable_dropout: Whether to enable dropout. Defaults to True. - previous_chunk: Previous chunk for incremental decoding (optional). - true_length: True length of the prompt before padding (optional). - slot: An integer representing the decode batch index selected for this - request (optional). - page_state: Page state for paged attention (optional). - partition_spec: Partition specification for FSDP all-gather. - decoder_target_tokens: Target tokens for the decoder (optional, used in - MTP). - decoder_target_mask: Target mask for the decoder (optional, used in MTP). - nnx_method: Method to call on the NNX module (optional). - - Returns: - Logits from the Transformer model. - """ - if self.is_initializing(): - return self.model( - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_segment_ids=decoder_segment_ids, - encoder_images=encoder_images, - encoder_image_masks=encoder_image_masks, - enable_dropout=enable_dropout, - model_mode=model_mode, - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, - ) - all_model_weights = all_gather_over_fsdp( - self.model.variables, - partition_spec, - mesh=self.mesh, - logical_axis_rules=self.config.logical_axis_rules, - ) - - return self.model.apply( - all_model_weights, - decoder_input_tokens=decoder_input_tokens, - decoder_positions=decoder_positions, - decoder_segment_ids=decoder_segment_ids, - encoder_images=encoder_images, - encoder_image_masks=encoder_image_masks, - enable_dropout=enable_dropout, - model_mode=model_mode, - previous_chunk=previous_chunk, - true_length=true_length, - slot=slot, - page_state=page_state, - mutable=False, - decoder_target_tokens=decoder_target_tokens, - decoder_target_mask=decoder_target_mask, - nnx_method=nnx_method, - ) diff --git a/src/MaxText/layers/pipeline.py b/src/MaxText/layers/pipeline.py index c7284fb22c..795f677b9a 100644 --- a/src/MaxText/layers/pipeline.py +++ b/src/MaxText/layers/pipeline.py @@ -20,15 +20,22 @@ import numpy as np from jax import numpy as jnp -from jax.sharding import Mesh +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import jax import jax.ad_checkpoint from flax.core import meta from flax import linen as nn +from flax.linen.spmd import LogicallyPartitioned -from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT -from MaxText.sharding import all_gather_over_fsdp +from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT, ShardMode +from MaxText.sharding import ( + maybe_shard_with_logical, + maybe_shard_with_name, + create_sharding, + logical_to_mesh_axes, + logical_to_mesh, +) class Pipeline(nn.Module): @@ -68,6 +75,39 @@ def setup(self): # pylint: disable=missing-function-docstring self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" + # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2. + self.spmd_axis_name = "stage" if self.config.shard_mode == ShardMode.AUTO else None + + self.stages_in_logical = ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed") + self.stages_in_spec = logical_to_mesh_axes(self.stages_in_logical, self.mesh, rules=self.config.logical_axis_rules) + self.stages_in_sharding = ( + NamedSharding(self.mesh, self.stages_in_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None + ) + + self.state_io_logical = ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed") + self.state_io_spec = logical_to_mesh_axes(self.state_io_logical, self.mesh, rules=self.config.logical_axis_rules) + self.state_io_sharding = ( + NamedSharding(self.mesh, self.state_io_spec) if self.config.shard_mode == ShardMode.EXPLICIT else None + ) + self.input_sharding = ( + create_sharding( + self.mesh, + (None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + ) + if self.config.shard_mode == ShardMode.EXPLICIT + else None + ) + self.output_sharding = ( + create_sharding( + self.mesh, + (self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + rules=self.config.logical_axis_rules, + ) + if self.config.shard_mode == ShardMode.EXPLICIT + else None + ) + def need_circ_storage(self): return ( self.config.num_pipeline_repeats > 1 @@ -85,6 +125,20 @@ def iterations_to_complete_first_microbatch(self): + self.iterations_to_complete_first_microbatch_one_repeat() ) + def _maybe_shard_with_logical(self, inputs, logical_axes): + """Wrapper of maybe_shard_with_logical""" + return maybe_shard_with_logical( + inputs, + logical_axes, + shard_mode=self.config.shard_mode, + mesh=self.mesh, + rules=self.config.logical_axis_rules, + ) + + def _maybe_shard_with_name(self, inputs, sharding_name): + """Wrapper of maybe_shard_with_name""" + return maybe_shard_with_name(inputs, sharding_name, shard_mode=self.config.shard_mode) + def init_states(self, inputs): """Initialize components of state: state_io, shift, circular_storage and circular_storage_mover Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed] @@ -101,37 +155,23 @@ def init_states(self, inputs): # Shift is used to rotate the output of each pipeline into the input of the next # shift has shape [num_stages, micro_size, sequence, embed] shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - - shift = nn.with_logical_constraint( - shift, - ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, - ) + shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) # Prev outputs has the same shape of the output (and shift) if self.config.pipeline_delay_activation_forwarding: prev_outputs = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - prev_outputs = nn.with_logical_constraint( - prev_outputs, - ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, - ) + prev_outputs = self._maybe_shard_with_logical(prev_outputs, self.stages_in_logical) else: prev_outputs = None # state_io (state input output) at first holds all of the input batches, but also will hold the outputs # as the pipeline runs/finishes # state_io has shape [num_stages, microbatches/stages, micro_size, sequence, embed] - state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) - # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. - state_io = nn.with_logical_constraint( - state_io, - ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, + state_io = jnp.reshape( + inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:], out_sharding=self.state_io_sharding ) + # We shard the pipeline_microbatch_size axis by data/fsdp, not num_microbatches since those are looped over. + state_io = self._maybe_shard_with_logical(state_io, self.state_io_logical) # circ_storage is used to hold the final pipeline stage outputs before it is used for the next repeat. It is only # needed when num_microbatches > num_stages, else instead the final stage will immediately pass to the first without @@ -143,7 +183,7 @@ def init_states(self, inputs): # TP, DP (e.g. there are many devices that shard stage 0) # We may look into alternatives using less storage if this becomes an issue (ideas in b/347603101). if self.use_circ_storage: - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype, out_sharding=self.state_io_sharding) else: circ_storage = None @@ -175,6 +215,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # Setup potential input from state_io, which has a rotating microbatch index (size of microbatches_per_stage) state_io_batch_idx = loop_iteration % self.microbatches_per_stage state_io_slice = state_io[:, state_io_batch_idx] + shift = self._maybe_shard_with_logical(shift, self.stages_in_logical) if self.use_circ_storage: # Setup potential input from circ_storage, which also has a rotating index for microbatch, @@ -189,6 +230,7 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): # state_io we instead grab from the last stage's output (possibly buffered when num_microbatches > num_stages, e.g. # from circ_storage). first_stage_in = jnp.where(loop_iteration < self.config.num_pipeline_microbatches, state_io_slice, circular_stage_in) + first_stage_in = self._maybe_shard_with_logical(first_stage_in, self.stages_in_logical) # Note that first_stage_in may correspond to bubble computation during the last few iterations. # However, these bubble computation results remain in the shift buffer (do not make it back to state_io) and are @@ -198,26 +240,40 @@ def get_iteration_inputs(self, loop_iteration, state_io, circ_storage, shift): def select_state_or_input(first_stage_in, shift): # Selects input for stage 0, shift for other stages - return jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_stage_in, shift) + return jnp.where( + jax.lax.broadcasted_iota("int32", shift.shape, 0, out_sharding=self.stages_in_sharding) == 0, + first_stage_in, + shift, + ) # Selects input (from stream_io) for stage 0, other stages get from shift (the rotated previous output) stages_in = select_state_or_input(first_stage_in, shift) - stages_in = nn.with_logical_constraint( - stages_in, - ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), - rules=self.config.logical_axis_rules, - mesh=self.mesh, - ) + stages_in = self._maybe_shard_with_logical(stages_in, self.stages_in_logical) return stages_in - def shard_dim_by_stages(self, x, dim: int): - # Shards a dimension by stages. Currently, the sharding of other dimensions are left up the compiler, alternatively - # we may want to copy over the sharding from the other input axes. - dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim + def shard_dim_by_stages(self, x, dim: int, physical_partition_spec: P | None, is_stage_weight: bool = False): + """Shards x using the provided partition_spec, but adds the "stage" mesh axis to the existing sharding at + the specified dimension.""" + placeholder = None if self.config.shard_mode == ShardMode.EXPLICIT else P.UNCONSTRAINED + if physical_partition_spec is None: + dims_mapping = [placeholder] * x.ndim + else: + physical_partition_spec = self._remove_fsdp_from_physical_partition_spec(physical_partition_spec) + dims_mapping = list(physical_partition_spec) + # If not a stage weight, we handle the repeat dimension offset + if not is_stage_weight: + dims_mapping = [placeholder] * (dim + 1) + dims_mapping[dim:] # inflat one dimension for num_repeats dims_mapping[dim] = "stage" dims_mapping = tuple(dims_mapping) - sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) - return jax.lax.with_sharding_constraint(x, sharding) + # We add reduced rule only when pspec is given for a stage weight + if physical_partition_spec and is_stage_weight and self.config.shard_mode == ShardMode.EXPLICIT: + batch_mesh_axis = ["data", "fsdp"] + reduced_mark = [mesh_axis for mesh_axis in batch_mesh_axis if self.mesh.shape[mesh_axis] > 1] + pspec = P(*dims_mapping, reduced=set(reduced_mark)) + else: + pspec = P(*dims_mapping) + sharding = jax.sharding.NamedSharding(self.mesh, pspec) + return self._maybe_shard_with_name(x, sharding) def get_microbatch_and_repeat_ids(self, loop_iteration): """Gets the microbatch_ids and repeat_ids for all stages on this loop_iteration. Works for both circular and @@ -228,15 +284,19 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids - def vmap_parallel_gather(self, weights, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights): + def vmap_parallel_gather( + self, weights, physical_partition_spec, repeat_ids, repeat_dim_in_weights, stages_dim_in_weights + ): """Use vmap to implement a sharded parallel gather. Parallel gather means each stage has its own weights, and gets one slice from it. Args: weights: Per-stage data to be gathered from. + physical_partition_spec: Physical partition spec of the input weight. repeat_ids: Integer tensor of shape [num_stages], the repeats of the stages. repeat_dim_in_weights: The dimension in weights where repeat_ids are applied. The output will not have this dimension. stages_dim_in_weights: The dimension in weights that represents parallel stages. + Returns: The per-stage gathered values. The shape is weights.shape but with repeat_dim_in_weights removed. @@ -246,12 +306,18 @@ def _gather_one(x, repeat_id): return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, repeat_id, 1, repeat_dim_in_weights), repeat_dim_in_weights) gathered_weights_stage_dim = 0 - repeat_ids = self.shard_dim_by_stages(repeat_ids, 0) - weights = self.shard_dim_by_stages(weights, stages_dim_in_weights) + repeat_ids = self.shard_dim_by_stages(repeat_ids, 0, physical_partition_spec=None) + # num_repeats x num_stages x *param_dim + weights = self.shard_dim_by_stages( + weights, stages_dim_in_weights, physical_partition_spec=physical_partition_spec, is_stage_weight=False + ) stage_weights = jax.vmap(_gather_one, in_axes=(stages_dim_in_weights, 0), out_axes=gathered_weights_stage_dim)( weights, repeat_ids ) - stage_weights = self.shard_dim_by_stages(stage_weights, gathered_weights_stage_dim) + # num_stages x *param_dim + stage_weights = self.shard_dim_by_stages( + stage_weights, gathered_weights_stage_dim, physical_partition_spec=physical_partition_spec, is_stage_weight=True + ) return stage_weights def vmap_gather(self, xs, ids, ids_dim): @@ -271,11 +337,13 @@ def vmap_gather(self, xs, ids, ids_dim): """ def _gather_one(x, i): - return jnp.squeeze(jax.lax.dynamic_slice_in_dim(x, i, 1, ids_dim), ids_dim) + idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) + replicated_sharding = NamedSharding(self.mesh, P()) + return x.at[idx].get(out_sharding=replicated_sharding) - ids = self.shard_dim_by_stages(ids, 0) + ids = self.shard_dim_by_stages(ids, 0, physical_partition_spec=None) outs = jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) - return self.shard_dim_by_stages(outs, 0) + return self.shard_dim_by_stages(outs, 0, physical_partition_spec=None) def get_new_loop_state(self, output, loop_state): """ @@ -296,16 +364,31 @@ def get_new_loop_state(self, output, loop_state): loop_iteration = loop_state["loop_iteration"] old_prev_outputs = loop_state["prev_outputs"] + @jax.shard_map( + mesh=self.mesh, + in_specs=self.stages_in_spec, + out_specs=self.stages_in_spec, + check_vma=True, + ) def _rotate_right(arr): - # Use lax.slice to avoid generating a gather. - last = jax.lax.slice_in_dim(arr, self.num_stages - 1, self.num_stages, axis=0) - except_last = jax.lax.slice_in_dim(arr, 0, self.num_stages - 1, axis=0) - return jnp.concatenate([last, except_last], axis=0) + # we use +1 for right shifting + stage_size = jax.lax.axis_size("stage") + perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] + arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) + return arr + @jax.shard_map( + mesh=self.mesh, + in_specs=self.stages_in_spec, + out_specs=self.stages_in_spec, + check_vma=True, + ) def _shift_right(arr): - padding = [[1, 0]] + [[0, 0]] * (arr.ndim - 1) - # Use lax.slice to guarantee the gradient is a pad. - return jax.lax.slice(jnp.pad(arr, padding), [0] * arr.ndim, arr.shape) + stage_idx = jax.lax.axis_index("stage") + stage_size = jax.lax.axis_size("stage") + perm = [(i, (i + 1) % stage_size) for i in range(stage_size)] + arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) + return jnp.where(stage_idx == 0, jnp.zeros_like(arr), arr) # Shift either rotates or shifts depending on if the last stage immediately must send to first or not # For non-circular pipelines, the last stage does not need to send to first @@ -349,17 +432,30 @@ def _rotate_right_and_update(circ_storage_mover_in, circ_storage_in): stream_buf_idx = loop_iteration % self.microbatches_per_stage stream_slice = old_state_io[:, stream_buf_idx] - def _update_state_io(state_in, stream_slice, output): + def _rotate_left(arr, stage_size): + # we use -1 for left shifting + perm = [(i, (i - 1) % stage_size) for i in range(stage_size)] + arr = jax.lax.ppermute(arr, axis_name="stage", perm=perm) + return arr + + def _shift_left(arr, stage_size, output): + stage_idx = jax.lax.axis_index("stage") + arr = _rotate_left(arr, stage_size) + return jnp.where(stage_idx == stage_size - 1, output, arr) + + @jax.shard_map( + mesh=self.mesh, + in_specs=(self.state_io_spec, self.stages_in_spec, self.stages_in_spec, P()), + out_specs=self.state_io_spec, + ) + def _update_state_io(state_in, stream_slice, output, stream_buf_idx): # Shift the current slice to the left, then fill the last stage with the final output. - padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) - stream_slice = jax.lax.slice_in_dim(jnp.pad(stream_slice, padding), 1, stream_slice.shape[0] + 1, axis=0) - stream_slice = jnp.where( - jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice - ) + stage_size = jax.lax.axis_size("stage") + stream_slice = _shift_left(stream_slice, stage_size, output) stream_slice = jnp.expand_dims(stream_slice, 1) return jax.lax.dynamic_update_slice_in_dim(state_in, stream_slice, stream_buf_idx, axis=1) - new_state = _update_state_io(old_state_io, stream_slice, output) + new_state = _update_state_io(old_state_io, stream_slice, output, stream_buf_idx) new_loop_state = { "state_io": new_state, @@ -382,7 +478,7 @@ def permute_output_micro_per_stage_dim(self, output): output = output[:, permutation] return output - def get_current_stage_weights(self, pipeline_weights, loop_iteration): + def get_current_stage_weights(self, pipeline_weights, loop_iteration, physical_partition_spec=None): """ Gets the current weights used for one iteration. Outputs a pytree whose arrays have leading dimension of stages, e.g. {'mlp': 'wo': [stages, mlp, embed]}. Stage 0 will use the 0th index of this pytree, Stage 1 the 1st index, etc. @@ -390,22 +486,16 @@ def get_current_stage_weights(self, pipeline_weights, loop_iteration): for circular pipelines each stage grabs only the weights corresponding to the current repeat. """ if self.config.num_pipeline_repeats > 1: - return self.get_current_repeat_from_stages(pipeline_weights, loop_iteration) + return self.get_current_repeat_from_stages( + pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + ) else: return pipeline_weights - def get_current_repeat_from_stages(self, weights, loop_iteration): + def get_current_repeat_from_stages(self, weights, loop_iteration, physical_partition_spec=None): """get current repeat from stages""" _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - def gather_weights_for_stages_in(weights): - return jax.tree.map( - functools.partial( - self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 - ), - weights, - ) - circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", "sub_weight_split_dims_mapping": (None,), @@ -417,7 +507,21 @@ def gather_weights_for_stages_in(weights): weights, 0, circular_metadata_params ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one circular # entry per stage. - weights = gather_weights_for_stages_in(weights) + weights = self._remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, + repeat_ids=repeat_ids, + repeat_dim_in_weights=0, + stages_dim_in_weights=1, + physical_partition_spec=spec, + ) + + if physical_partition_spec is None: + weights = jax.tree.map(gather_weights_for_stages_in, weights) + else: + weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) return weights def get_vmap_func_for_init(self): @@ -430,7 +534,7 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi vmap_func = nn.vmap( func_to_vmap, in_axes=(0, 0, 0, None, None), - spmd_axis_name="stage", + spmd_axis_name=self.spmd_axis_name, # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2. variable_axes={"params": 0, "_overwrite_with_gradient": 0}, split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, metadata_params={ @@ -468,7 +572,7 @@ def func_to_vmap( vmap_func = nn.vmap( func_to_vmap, in_axes=(0, 0, 0, 0, None, None), - spmd_axis_name="stage", + spmd_axis_name=self.spmd_axis_name, # TODO(b/470167805): replace self.spmd_axis_name with "stage" when JAX >= 0.8.2. variable_axes={"params": 0}, split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout}, metadata_params={ @@ -481,7 +585,15 @@ def func_to_vmap( return vmap_func def run_one_iteration( - self, loop_state, pipeline_weights, positions, segment_ids, deterministic, model_mode, decoder_layer_instance + self, + loop_state, + pipeline_weights, + positions, + segment_ids, + deterministic, + model_mode, + decoder_layer_instance, + logical_partition_spec=None, ): """Run one loop iteration - gets weights and inputs for each stage, run the stages in parallel, and update the loop state.""" @@ -492,6 +604,9 @@ def run_one_iteration( microbatch_ids, _ = self.get_microbatch_and_repeat_ids(loop_iteration) + # Convert logical spec to physical spec + physical_partition_spec = logical_to_mesh(logical_partition_spec, self.mesh, rules=self.config.logical_axis_rules) + stages_inputs = self.get_iteration_inputs(loop_iteration, state_io, circ_storage, shift) # We checkpoint stages_inputs since we are grabbing only one slice of the state_io, don't need to save the entire # buffer. @@ -504,14 +619,7 @@ def run_one_iteration( if self.config.num_pipeline_repeats > 1: _, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration) - def prepare_vars_for_main_vmap(weights): - def gather_weights_for_stages_in(weights): - return jax.tree.map( - functools.partial( - self.vmap_parallel_gather, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1 - ), - weights, - ) + def prepare_vars_for_main_vmap(weights, physical_partition_spec=None): circular_metadata_params = { nn.PARTITION_NAME: "circular_repeats", @@ -524,17 +632,33 @@ def gather_weights_for_stages_in(weights): weights, 0, circular_metadata_params ) # Remove the circular metadata axis, this axis will be removed when passed to the main vmap, only one # circular entry per stage. - weights = gather_weights_for_stages_in(weights) + weights = self._remove_logically_partition(weights) + + def gather_weights_for_stages_in(w, spec=None): + return self.vmap_parallel_gather( + w, repeat_ids=repeat_ids, repeat_dim_in_weights=0, stages_dim_in_weights=1, physical_partition_spec=spec + ) + + if physical_partition_spec is None: + weights = jax.tree.map(gather_weights_for_stages_in, weights) + else: + weights = jax.tree.map(gather_weights_for_stages_in, weights, physical_partition_spec) return weights + prepare_vars_for_main_vmap_partial = functools.partial( + prepare_vars_for_main_vmap, physical_partition_spec=physical_partition_spec + ) vmap_func = nn.map_variables( vmap_func, mapped_collections=["params", "_overwrite_with_gradient", "non_trainable", "summaries", "intermediates"], mutable=True, - trans_in_fn=prepare_vars_for_main_vmap, + trans_in_fn=prepare_vars_for_main_vmap_partial, ) - stage_weights = self.get_current_stage_weights(pipeline_weights, loop_iteration) + stage_weights = self.get_current_stage_weights( + pipeline_weights, loop_iteration, physical_partition_spec=physical_partition_spec + ) + stages_output = vmap_func( decoder_layer_instance, stage_weights, @@ -582,51 +706,68 @@ def get_partition_spec_leaf(leaf): return partition_spec_tree partition_spec_with_extra_layer = get_partition_spec(weights) - partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} - return partition_spec + logical_partition_spec = {"params": partition_spec_with_extra_layer["params"]["layers"]} + return logical_partition_spec - def get_physical_spec_no_fsdp(self, full_logical): - """ - Get physical spec without fsdp. + @staticmethod + def get_logical_spec_repeats_removed(full_logical): + if full_logical is None: + return None - TODO: Remove the expert sharding on attention weights as well, since those act like fsdp. + def _remove_from_spec(spec): + return jax.sharding.PartitionSpec(*[dim for dim in spec if dim != "circular_repeats"]) - Args: - full_logical: original logical partition specs of all weights + return jax.tree.map(_remove_from_spec, full_logical) - Returns: - Modified physical spec with "fsdp" and "fsdp_transpose" removed - """ + # TODO(chengnuojin) Remove this function and its usage after pipeline nnx migration + @staticmethod + def _remove_logically_partition(weights): + def _remove_logically_partition_leaf(v): + return getattr(v, "value") if isinstance(v, LogicallyPartitioned) else v + + return jax.tree.map( + _remove_logically_partition_leaf, + weights, + is_leaf=lambda v: isinstance(v, LogicallyPartitioned), + ) - def remove_fsdp_sharding(sharding_tree): - def _remove_fsdp_from_partition_spec(named_sharding): - if isinstance(named_sharding, jax.sharding.NamedSharding): - new_spec = [] - for axis in named_sharding.spec: - if axis is None: - new_spec.append(None) - elif isinstance(axis, str): - if axis not in ("fsdp", "fsdp_transpose"): - new_spec.append(axis) - else: - new_spec.append(None) - elif isinstance(axis, (list, tuple)): - new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_spec.append(tuple(new_axis)) - else: - raise ValueError(f"Unsupported axis type: {type(axis)}") - return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec)) - return named_sharding - - return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) - - physical = nn.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules) - physical_no_fsdp = remove_fsdp_sharding(physical) - return physical_no_fsdp - - def all_gather_over_fsdp(self, sharding_info): - physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(sharding_info) - return jax.lax.with_sharding_constraint(self.layers.variables, physical_constraint_no_fsdp) + @staticmethod + def _remove_fsdp_from_physical_partition_spec(pps): + """Removes 'fsdp' and 'fsdp_transpose' from a physical PartitionSpec.""" + if isinstance(pps, P): + new_spec = [] + # Iterate through each axis in the original PartitionSpec. + for axis in pps: + if axis is None: + new_spec.append(None) + elif isinstance(axis, str): + # If the axis is 'fsdp', replace it with None to signify replication. + if axis not in ("fsdp", "fsdp_transpose"): + new_spec.append(axis) + else: + new_spec.append(None) + elif isinstance(axis, (list, tuple)): + # If the axis is a collection, filter out 'fsdp'. + new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_spec.append(tuple(new_axis)) + else: + raise ValueError(f"Unsupported_axis_type: {type(axis)}") + # Return a new sharding object with the modified spec. + return P(*new_spec) + return pps + + def all_gather_over_fsdp(self, variables, logical_partition_spec): + physical_partition_spec = logical_to_mesh( + logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) + physical_partition_spec_no_fsdp = jax.tree.map( + self._remove_fsdp_from_physical_partition_spec, physical_partition_spec + ) + return jax.tree.map( + lambda w, p: self._maybe_shard_with_name(w, NamedSharding(self.mesh, p)), + variables, + physical_partition_spec_no_fsdp, + ) @nn.compact def __call__( @@ -636,7 +777,7 @@ def __call__( positions: jnp.ndarray, deterministic: bool, model_mode=MODEL_MODE_TRAIN, - partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables) + logical_partition_spec=None, # Pytree of sharding specifications of the weights (aka self.layers.variables) ) -> jnp.ndarray: """The main method that maps the series of decoder layer inputs to final layer outputs. Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape @@ -649,14 +790,15 @@ def __call__( self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim, - ) + ), + out_sharding=self.input_sharding, ) + example_inputs = jax.lax.broadcast(inputs[0], [self.num_stages]) # dummy inputs fed to initialize the module - # weights. ag_sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(None, None)) if positions is not None: # AG positions - positions = jax.lax.with_sharding_constraint(positions, ag_sharding) + positions = self._maybe_shard_with_name(positions, ag_sharding) positions = positions.reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) @@ -667,7 +809,7 @@ def __call__( example_position = None position_idx = None if segment_ids is not None: - segment_ids = jax.lax.with_sharding_constraint(segment_ids, ag_sharding) + segment_ids = self._maybe_shard_with_name(segment_ids, ag_sharding) segment_ids = segment_ids.reshape( (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) ) @@ -731,6 +873,7 @@ def __call__( ) # We only need to run one set of stages to initialize the variables, instead of looping over all microbatches for # the full total_iterations. + example_inputs = self._maybe_shard_with_logical(example_inputs, (None, None, None, None)) stage_outputs = vmap_func( self.layers, example_inputs, example_segmentation, example_position, deterministic, model_mode ) @@ -744,24 +887,34 @@ def __call__( broadcasted_stage_outpus = jax.lax.broadcast( stage_outputs[0], [self.config.micro_batch_size_to_train_on // self.pipeline_microbatch_size] ) + return jnp.reshape( broadcasted_stage_outpus, [self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim], + out_sharding=self.output_sharding, ) if self.config.pipeline_fsdp_ag_once: - all_pipeline_weights = all_gather_over_fsdp( - self.layers.variables, partition_spec, mesh=self.mesh, logical_axis_rules=self.config.logical_axis_rules - ) + variables = self._remove_logically_partition(self.layers.variables) + all_pipeline_weights = self.all_gather_over_fsdp(variables, logical_partition_spec) else: all_pipeline_weights = self.layers.variables + logical_partition_spec = self.get_logical_spec_repeats_removed(logical_partition_spec) + def run_iteration_scannable(model, loop_state, xs): # flax transforms like nn.scan and nn.remat can only be applied to nn.module classes or nn.module instances, so we # explicitly wrap the run_one_iteration in this method - the 1st argument model (`self`) is a nn.module instance. return ( model.run_one_iteration( - loop_state, all_pipeline_weights, positions, segment_ids, deterministic, model_mode, model.layers + loop_state, + all_pipeline_weights, + positions, + segment_ids, + deterministic, + model_mode, + model.layers, + logical_partition_spec=logical_partition_spec, ), None, ) @@ -809,7 +962,9 @@ def run_iteration_scannable(model, loop_state, xs): # reshape outputs to match input shape of total batch instead of microbatches [batch, sequence, embed] final_output = jnp.reshape( - final_output, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) + final_output, + (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim), + out_sharding=self.output_sharding, ) return final_output diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index dd62544bc6..f6f3460a2f 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -76,11 +76,11 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) -def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules): +def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode): max_logging.log( "WARNING: Function maxtext_utils.all_gather_over_fsdp is deprecated. Please use sharding.all_gather_over_fsdp." ) - return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules) + return sharding.all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode) def get_functional_train_with_signature( diff --git a/src/MaxText/model_creation_utils.py b/src/MaxText/model_creation_utils.py index 816f83815f..dcb453caae 100644 --- a/src/MaxText/model_creation_utils.py +++ b/src/MaxText/model_creation_utils.py @@ -96,16 +96,10 @@ def from_config( def get_transformer_model(config, mesh, quant, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None): """Returns the transformer model based on the configuration.""" - if config.model_fsdp_ag_once: - if rngs is not None: - raise NotImplementedError - else: - return models.ZeroOneTransformer(config, mesh, quant=quant, model_mode=model_mode) + if rngs is not None: + return models.Transformer(config, mesh, quant=quant, rngs=rngs, model_mode=model_mode) else: - if rngs is not None: - return models.Transformer(config, mesh, quant=quant, rngs=rngs, model_mode=model_mode) - else: - return models.transformer_as_linen(config, mesh, quant=quant, model_mode=model_mode) + return models.transformer_as_linen(config, mesh, quant=quant, model_mode=model_mode) def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None): diff --git a/src/MaxText/sharding.py b/src/MaxText/sharding.py index 0126489740..40b2a379cf 100644 --- a/src/MaxText/sharding.py +++ b/src/MaxText/sharding.py @@ -83,6 +83,28 @@ def logical_to_mesh_axes(logical_names, mesh, rules=None): return remove_size_one_mesh_axis(tensor_spec, mesh) +def logical_to_mesh(tree, mesh, rules=None): + """Remove size one mesh axes given logical pspec pytree.""" + if tree is None: + return None + return jax.tree.map( + lambda x: logical_to_mesh_axes(x, mesh, rules=rules), + tree, + is_leaf=lambda x: isinstance(x, P), + ) + + +def logical_to_mesh_sharding(tree, mesh, rules=None): + """Return sharding pytree given logical specs pytree""" + if tree is None: + return None + return jax.tree.map( + lambda x: NamedSharding(mesh, x), + logical_to_mesh(tree, mesh, rules=rules), + is_leaf=lambda x: isinstance(x, P), + ) + + def create_sharding(mesh, logical_names, rules=None): """Create NamedSharding with given logical names.""" return NamedSharding(mesh, logical_to_mesh_axes(logical_names, mesh, rules=rules)) @@ -460,6 +482,36 @@ def get_formatted_sharding_annotations(params, mesh=None): return "\n".join(annotation_lines) +def remove_fsdp_sharding(sharding_tree): + """Recursively traverses the sharding tree to remove fsdp axes.""" + + def _remove_fsdp_from_partition_spec(named_sharding): + """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" + if isinstance(named_sharding, jax.sharding.NamedSharding): + new_spec = [] + # Iterate through each axis in the original PartitionSpec. + for axis in named_sharding.spec: + if axis is None: + new_spec.append(None) + elif isinstance(axis, str): + # If the axis is 'fsdp', replace it with None to signify replication. + if axis not in ("fsdp", "fsdp_transpose"): + new_spec.append(axis) + else: + new_spec.append(None) + elif isinstance(axis, (list, tuple)): + # If the axis is a collection, filter out 'fsdp'. + new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_spec.append(tuple(new_axis)) + else: + raise ValueError(f"Unsupported_axis_type: {type(axis)}") + # Return a new sharding object with the modified spec. + return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec)) + return named_sharding + + return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) + + def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): """ Generates a physical sharding spec for fully replicated weights. @@ -484,43 +536,14 @@ def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): mesh axis. """ - def remove_fsdp_sharding(sharding_tree): - """Recursively traverses the sharding tree to remove fsdp axes.""" - - def _remove_fsdp_from_partition_spec(named_sharding): - """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" - if isinstance(named_sharding, jax.sharding.NamedSharding): - new_spec = [] - # Iterate through each axis in the original PartitionSpec. - for axis in named_sharding.spec: - if axis is None: - new_spec.append(None) - elif isinstance(axis, str): - # If the axis is 'fsdp', replace it with None to signify replication. - if axis not in ("fsdp", "fsdp_transpose"): - new_spec.append(axis) - else: - new_spec.append(None) - elif isinstance(axis, (list, tuple)): - # If the axis is a collection, filter out 'fsdp'. - new_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_spec.append(tuple(new_axis)) - else: - raise ValueError(f"Unsupported_axis_type: {type(axis)}") - # Return a new sharding object with the modified spec. - return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec)) - return named_sharding - - return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree) - # Convert the high-level logical spec to a physical one using default rules. - physical = nn.logical_to_mesh_sharding(full_logical, mesh=mesh, rules=logical_axis_rules) + physical = logical_to_mesh_sharding(full_logical, mesh=mesh, rules=logical_axis_rules) # Apply the function to remove the FSDP sharding, defining our target layout. physical_no_fsdp = remove_fsdp_sharding(physical) return physical_no_fsdp -def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules): +def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules, shard_mode): """Performs an all-gather on FSDP-sharded variables via a sharding constraint. This function triggers an all-gather operation on the model's parameters. It does so by applying a sharding constraint that specifies a fully @@ -535,6 +558,7 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules): sharding_info: The logical partition spec of the currently sharded `variables`. mesh: The JAX device mesh. logical_axis_rules: Rules for converting logical axes to physical mesh axes. + shard_mode: auto or explicit shard mode. Returns: The model's variables with the all-gather operation applied, resulting @@ -544,4 +568,4 @@ def all_gather_over_fsdp(variables, sharding_info, mesh, logical_axis_rules): physical_constraint_no_fsdp = get_physical_spec_no_fsdp(sharding_info, mesh, logical_axis_rules) # Apply the constraint to the model's current variables. This tells JAX to # gather the weights into this layout. - return jax.lax.with_sharding_constraint(variables, physical_constraint_no_fsdp) + return maybe_shard_with_name(variables, physical_constraint_no_fsdp, shard_mode=shard_mode) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index f5e8cf377b..c98cd9b963 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -261,7 +261,11 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat ) extra_dpo_args = [reference_params] if config.shard_optimizer_over_data: - params = jax.tree.map(jax.lax.with_sharding_constraint, params, params_shardings) + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) diff --git a/src/MaxText/vocabulary_tiling.py b/src/MaxText/vocabulary_tiling.py index 61345ffe55..ceda62653e 100644 --- a/src/MaxText/vocabulary_tiling.py +++ b/src/MaxText/vocabulary_tiling.py @@ -99,7 +99,7 @@ def _reshape(inputs, out_shape, out_sharding): labels = _maybe_shard_with_name(labels, label_spec) segmentation = _maybe_shard_with_name(segmentation, label_spec) # TODO (chengnuojin) all gather only embedding table instead of all params after NNX module is enabled - gathered_params = all_gather_over_fsdp(params, param_spec, model.mesh, config.logical_axis_rules) + gathered_params = all_gather_over_fsdp(params, param_spec, model.mesh, config.logical_axis_rules, config.shard_mode) # Customized forward and backward maps for the embedding tiling @jax.custom_vjp diff --git a/tests/integration_tests/train_tests.py b/tests/integration_tests/train_tests.py index cc81a8e063..44f1234089 100644 --- a/tests/integration_tests/train_tests.py +++ b/tests/integration_tests/train_tests.py @@ -401,20 +401,10 @@ def test_gpu_cudnn_flash_jax(self): ] train_main(cudnn_flash_jax) - @pytest.mark.integration_test - @pytest.mark.tpu_only - def test_tpu_base_model_ag_once(self): - train_main(TrainTests.CONFIGS["base"] + ["model_fsdp_ag_once=True"]) - @pytest.mark.integration_test def test_base_model_shardy_false(self): train_main(TrainTests.CONFIGS["base"] + ["shardy=False"]) - @pytest.mark.integration_test - @pytest.mark.gpu_only - def test_gpu_synthetic_model_ag_once(self): - train_main(TrainTests.CONFIGS["synthetic"] + ["model_fsdp_ag_once=True"]) - @pytest.mark.integration_test @pytest.mark.gpu_only @pytest.mark.scheduled_only diff --git a/tests/pipeline_parallelism_test.py b/tests/pipeline_parallelism_test.py index 43efb62ca0..aab3aa2274 100644 --- a/tests/pipeline_parallelism_test.py +++ b/tests/pipeline_parallelism_test.py @@ -102,7 +102,7 @@ def get_inputs(batch_size, sequence, features): init_pipeline_params = my_pipeline.init( jax.random.PRNGKey(0), inputs, inputs_position, inputs_segmentation, deterministic, model_mode ) - partition_spec = my_pipeline.get_weight_sharding( + logical_partition_spec = my_pipeline.get_weight_sharding( inputs, inputs_position, inputs_segmentation, deterministic, model_mode ) @@ -115,16 +115,22 @@ def pipeline_parallelism_dummy_loss_extra( deterministic, model_mode, dummy_targets, - partition_spec=None, + logical_partition_spec=None, ): outputs = my_pipeline.apply( - params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode, partition_spec=partition_spec + params, + inputs, + inputs_position, + inputs_segmentation, + deterministic, + model_mode, + logical_partition_spec=logical_partition_spec, ) loss = jnp.linalg.norm(outputs - dummy_targets) return loss pipeline_parallelism_dummy_loss = functools.partial( - pipeline_parallelism_dummy_loss_extra, partition_spec=partition_spec + pipeline_parallelism_dummy_loss_extra, logical_partition_spec=logical_partition_spec ) def regular_sequential_layers(params, inputs, inputs_position, inputs_segmentation, deterministic, model_mode):