Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,26 +1296,26 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
expert_assignments=selected_experts,
)
wi_tile_size = (
self.config.wi_tile_fwd_batch_seq,
self.config.wi_tile_fwd_embed_dim,
self.config.wi_tile_fwd_mlp_dim,
self.config.wi_tile_dlhs_batch_seq,
self.config.wi_tile_dlhs_embed_dim,
self.config.wi_tile_dlhs_mlp_dim,
self.config.wi_tile_drhs_batch_seq,
self.config.wi_tile_drhs_embed_dim,
self.config.wi_tile_drhs_mlp_dim,
self.config.wi_tile_fwd_batch_seq, # m (LHS batch)
self.config.wi_tile_fwd_embed_dim, # k (contracting)
self.config.wi_tile_fwd_mlp_dim, # n (RHS batch)
self.config.wi_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wi_tile_dlhs_mlp_dim, # k (contracting)
self.config.wi_tile_dlhs_embed_dim, # n (RHS batch)
self.config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
self.config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is RHS batch dim
)
wo_tile_size = (
self.config.wo_tile_fwd_batch_seq,
self.config.wo_tile_fwd_embed_dim,
self.config.wo_tile_fwd_mlp_dim,
self.config.wo_tile_dlhs_batch_seq,
self.config.wo_tile_dlhs_embed_dim,
self.config.wo_tile_dlhs_mlp_dim,
self.config.wo_tile_drhs_batch_seq,
self.config.wo_tile_drhs_embed_dim,
self.config.wo_tile_drhs_mlp_dim,
self.config.wo_tile_fwd_batch_seq, # m (LHS batch)
self.config.wo_tile_fwd_mlp_dim, # k (contracting)
self.config.wo_tile_fwd_embed_dim, # n (RHS batch)
self.config.wo_tile_dlhs_batch_seq, # m (LHS batch)
self.config.wo_tile_dlhs_embed_dim, # k (contracting)
self.config.wo_tile_dlhs_mlp_dim, # n (RHS)
self.config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
self.config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)

layer_w0 = gmm_fn(
Expand Down
37 changes: 19 additions & 18 deletions src/maxtext/models/deepseek_batchsplit_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,26 +977,27 @@ def gmm(
wo_gather_axes = []

wi_tile_size = (
config.wi_tile_fwd_batch_seq,
config.wi_tile_fwd_embed_dim,
config.wi_tile_fwd_mlp_dim,
config.wi_tile_dlhs_batch_seq,
config.wi_tile_dlhs_embed_dim,
config.wi_tile_dlhs_mlp_dim,
config.wi_tile_drhs_batch_seq,
config.wi_tile_drhs_embed_dim,
config.wi_tile_drhs_mlp_dim,
config.wi_tile_fwd_batch_seq, # m (LHS batch)
config.wi_tile_fwd_embed_dim, # k (contracting)
config.wi_tile_fwd_mlp_dim, # n (RHS batch)
config.wi_tile_dlhs_batch_seq, # m (LHS batch)
config.wi_tile_dlhs_mlp_dim, # k (contracting)
config.wi_tile_dlhs_embed_dim, # n (RHS batch)
config.wi_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
config.wi_tile_drhs_embed_dim, # Called k in megablox, but this is LHS batch dim
config.wi_tile_drhs_mlp_dim, # Called n in megablox, and indeed is the RHS batch dim
)

wo_tile_size = (
config.wo_tile_fwd_batch_seq,
config.wo_tile_fwd_embed_dim,
config.wo_tile_fwd_mlp_dim,
config.wo_tile_dlhs_batch_seq,
config.wo_tile_dlhs_embed_dim,
config.wo_tile_dlhs_mlp_dim,
config.wo_tile_drhs_batch_seq,
config.wo_tile_drhs_embed_dim,
config.wo_tile_drhs_mlp_dim,
config.wo_tile_fwd_batch_seq, # m (LHS batch)
config.wo_tile_fwd_mlp_dim, # k (contracting)
config.wo_tile_fwd_embed_dim, # n (RHS batch)
config.wo_tile_dlhs_batch_seq, # m (LHS batch)
config.wo_tile_dlhs_embed_dim, # k (contracting)
config.wo_tile_dlhs_mlp_dim, # n (RHS)
config.wo_tile_drhs_batch_seq, # Called m in megablox, but this is contracting
config.wo_tile_drhs_mlp_dim, # Called k in megablox, but this is LHS batch dim
config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)

if config.use_qwix_quantization:
Expand Down
Loading