Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,6 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward
# the communication and compute in each iteration are now independent. However this comes at the cost of doubling the pipeline bubble,
# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay).

model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active.
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
# False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass.
pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration.
# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed
# to only one stage's worth, however we only execute one all-gather and reduce across per repeat, as opposed
Expand Down
6 changes: 0 additions & 6 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,10 +1493,6 @@ class DerivedValues(BaseModel):
None,
description="Boolean flag indicating if pipeline parallelism is active across ICI or DCN.",
)
model_fsdp_ag_once: bool = Field(
False,
description="An alias for `pipeline_fsdp_ag_once` for backward compatibility.",
)

context_parallel_size: None | int = Field(
None,
Expand Down Expand Up @@ -1989,8 +1985,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
):
self.logical_axis_rules.append(["aqt_amax_history", ("stage",)])

self.model_fsdp_ag_once = self.pipeline_fsdp_ag_once # Backward compatibility alias

# H. RUN ALL CROSS-FIELD VALIDATIONS
if self.load_parameters_path and self.load_full_state_path:
raise ValueError("At most one of `load_parameters_path` or `load_full_state_path` should be set.")
Expand Down
8 changes: 4 additions & 4 deletions src/MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,11 @@ def __call__(
)
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
Comment thread
NuojCheng marked this conversation as resolved.
partition_spec = self.pipeline_module.get_weight_sharding(
logical_partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
)
else:
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
logical_partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
dense_layer = RemattedBlockLayers[0]
Expand Down Expand Up @@ -750,9 +750,9 @@ def __call__(
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
Comment thread
NuojCheng marked this conversation as resolved.
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
Comment thread
NuojCheng marked this conversation as resolved.
else: # Not DeepSeek
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
y = self.pipeline_module(y, *broadcast_args, logical_partition_spec=logical_partition_spec)
remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers
if remaining_layers > 0:
logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules)
Expand Down
118 changes: 0 additions & 118 deletions src/MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from MaxText.layers.encoders import VisionEncoder, vision_encoder_as_linen
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers.multi_token_prediction import multi_token_prediction_block_as_linen
from MaxText.maxtext_utils import all_gather_over_fsdp

# ------------------------------------------------------------------------------
# The network: Transformer Definitions
Expand Down Expand Up @@ -517,120 +516,3 @@ def __call__(
return hidden_state, kv_caches

return logits


class ZeroOneTransformer(nn.Module):
"""
A wrapper for the base Transformer model designed to implement the Zero-1
FSDP optimization.

The goal of this optimization is to reduce communication overhead. In the standard
FSDP implementation, an all-gather operation on the model weights is performed twice
for each gradient accumulation microbatch (once for the forward pass, once for the backward pass).
This class changes that behavior. When enabled, it performs the all-gather operation
only *once* per full gradient accumulation step. It gathers the full weights into
memory, runs all the microbatch forward and backward passes, and then releases the
full weights. This trades higher peak memory usage for significantly reduced
network communication, which can improve training speed if sufficient memory is
available.
"""

config: Config
mesh: Mesh
quant: Quant
# Possible model_mode values can be found in MaxText.common_types.
# We generally use MaxText.common_types.MODEL_MODE_TRAIN or
# MaxText.common_types.MODEL_MODE_PREFILL for initializations here.
# TODO: Make model_mode required after confirming no users are affected.
model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__

def setup(self):
"""Sets up the underlying Transformer model.

This method initializes the `self.model` attribute by calling the
`transformer_as_linen` factory function.
"""
self.model = transformer_as_linen(self.config, self.mesh, self.quant, self.model_mode)

def __call__(
self,
decoder_input_tokens: jnp.ndarray,
decoder_positions: jnp.ndarray,
decoder_segment_ids=None,
encoder_images: None | jnp.ndarray = None,
encoder_image_masks: None | jnp.ndarray = None,
enable_dropout=True,
model_mode=MODEL_MODE_TRAIN,
previous_chunk=None,
true_length: None | int = None,
slot: None | int = None,
page_state: None | page_manager.PageState = None,
partition_spec=None,
decoder_target_tokens: None | jnp.ndarray = None,
decoder_target_mask: None | jnp.ndarray = None,
nnx_method: str | None = None,
):
"""Applies the Zero-1 FSDP wrapped Transformer model.

This method handles the all-gather operation for model weights before
applying the underlying Transformer model, and then releases them.

Args:
decoder_input_tokens: Input tokens for the decoder.
decoder_positions: Positional encodings for the decoder inputs.
decoder_segment_ids: Segment IDs for the decoder inputs (optional).
encoder_images: Encoder images for multimodal models (optional).
enable_dropout: Whether to enable dropout. Defaults to True.
previous_chunk: Previous chunk for incremental decoding (optional).
true_length: True length of the prompt before padding (optional).
slot: An integer representing the decode batch index selected for this
request (optional).
page_state: Page state for paged attention (optional).
partition_spec: Partition specification for FSDP all-gather.
decoder_target_tokens: Target tokens for the decoder (optional, used in
MTP).
decoder_target_mask: Target mask for the decoder (optional, used in MTP).
nnx_method: Method to call on the NNX module (optional).

Returns:
Logits from the Transformer model.
"""
if self.is_initializing():
return self.model(
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
encoder_images=encoder_images,
encoder_image_masks=encoder_image_masks,
enable_dropout=enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
true_length=true_length,
slot=slot,
page_state=page_state,
)
all_model_weights = all_gather_over_fsdp(
self.model.variables,
partition_spec,
mesh=self.mesh,
logical_axis_rules=self.config.logical_axis_rules,
)

return self.model.apply(
all_model_weights,
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
encoder_images=encoder_images,
encoder_image_masks=encoder_image_masks,
enable_dropout=enable_dropout,
model_mode=model_mode,
previous_chunk=previous_chunk,
true_length=true_length,
slot=slot,
page_state=page_state,
mutable=False,
decoder_target_tokens=decoder_target_tokens,
decoder_target_mask=decoder_target_mask,
nnx_method=nnx_method,
)
Loading
Loading