Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2756,6 +2756,7 @@ def set_double_shard_weights_config(
batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"),
fsdp_axis_names: Union[str, Sequence[str]] = "fsdp",
tp_axis_names: Union[str, Sequence[str]] = "model",
seq_axis_names: Union[str, Sequence[str]] = "seq",
):
"""Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes.

Expand All @@ -2764,6 +2765,7 @@ def set_double_shard_weights_config(
batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors.
fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors.
tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors.
seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors.
"""

# pytype: disable=attribute-error
Expand All @@ -2780,8 +2782,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config):
ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names)
ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names)
# Encourage the right activation sharding.
ff_layer.linear1.output_partition_spec = (batch_axis_names, None, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, None, tp_axis_names)
ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)
ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names)

if not isinstance(cfg, Sequence):
cfg = [cfg]
Expand Down
18 changes: 14 additions & 4 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3300,6 +3300,7 @@ class ConfigHelperTest(TestCase):
batch_axis_names=("data", ("replica", "data", "fsdp")),
fsdp_axis_names=("fsdp",),
tp_axis_names=("model",),
seq_axis_names=("seq",),
)
def test_set_double_shard_weights_config(
self,
Expand All @@ -3308,6 +3309,7 @@ def test_set_double_shard_weights_config(
batch_axis_names,
fsdp_axis_names,
tp_axis_names,
seq_axis_names,
):
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(
cross_attention=cross_attention_cfg
Expand All @@ -3318,6 +3320,7 @@ def test_set_double_shard_weights_config(
batch_axis_names=batch_axis_names,
fsdp_axis_names=fsdp_axis_names,
tp_axis_names=tp_axis_names,
seq_axis_names=seq_axis_names,
)

ff_layer = cfg.feed_forward
Expand All @@ -3328,10 +3331,12 @@ def test_set_double_shard_weights_config(
ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names)
)
self.assertSequenceEqual(
ff_layer.linear1.output_partition_spec, (batch_axis_names, None, tp_axis_names)
ff_layer.linear1.output_partition_spec,
(batch_axis_names, seq_axis_names, tp_axis_names),
)
self.assertSequenceEqual(
ff_layer.linear2.output_partition_spec, (batch_axis_names, None, tp_axis_names)
ff_layer.linear2.output_partition_spec,
(batch_axis_names, seq_axis_names, tp_axis_names),
)

self_atten = cfg.self_attention.attention
Expand Down Expand Up @@ -3370,6 +3375,7 @@ def test_set_double_shard_weights_config(
batch_axis_names=("data", ("replica", "data", "fsdp")),
fsdp_axis_names=("fsdp",),
tp_axis_names=("model",),
seq_axis_names=("seq",),
)
def test_set_double_shard_weights_config_for_list_of_configs(
self,
Expand All @@ -3378,6 +3384,7 @@ def test_set_double_shard_weights_config_for_list_of_configs(
batch_axis_names,
fsdp_axis_names,
tp_axis_names,
seq_axis_names,
):
cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set(
cross_attention=cross_attention_cfg
Expand All @@ -3389,6 +3396,7 @@ def test_set_double_shard_weights_config_for_list_of_configs(
batch_axis_names=batch_axis_names,
fsdp_axis_names=fsdp_axis_names,
tp_axis_names=tp_axis_names,
seq_axis_names=seq_axis_names,
)

for cfg in cfg_layers:
Expand All @@ -3400,10 +3408,12 @@ def test_set_double_shard_weights_config_for_list_of_configs(
ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names)
)
self.assertSequenceEqual(
ff_layer.linear1.output_partition_spec, (batch_axis_names, None, tp_axis_names)
ff_layer.linear1.output_partition_spec,
(batch_axis_names, seq_axis_names, tp_axis_names),
)
self.assertSequenceEqual(
ff_layer.linear2.output_partition_spec, (batch_axis_names, None, tp_axis_names)
ff_layer.linear2.output_partition_spec,
(batch_axis_names, seq_axis_names, tp_axis_names),
)

self_atten = cfg.self_attention.attention
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
batch_axis_names[0]: 'data'
batch_axis_names[1]: 'expert'
batch_axis_names[2]: 'fsdp'
batch_axis_names[1]: 'seq'
batch_axis_names[2]: 'expert'
batch_axis_names[3]: 'fsdp'
checkpointer.gc_loop_interval_seconds: 60
checkpointer.keep_every_n_steps: 50000
checkpointer.keep_last_n: 3
Expand Down Expand Up @@ -106,25 +107,30 @@ learner.optimizer.args[1].weight_decay: 0.1
learner.optimizer.fn: 'axlearn.common.optimizers.chain'
max_step: 500000
mesh_axis_names[0]: 'data'
mesh_axis_names[1]: 'expert'
mesh_axis_names[2]: 'fsdp'
mesh_axis_names[3]: 'model'
mesh_axis_names[1]: 'seq'
mesh_axis_names[2]: 'expert'
mesh_axis_names[3]: 'fsdp'
mesh_axis_names[4]: 'model'
mesh_rules[0][0]: 'tpu-v4-(1024|2048)'
mesh_rules[0][1][0]: -1
mesh_rules[0][1][1]: 1
mesh_rules[0][1][2]: 16
mesh_rules[0][1][3]: 1
mesh_rules[0][1][2]: 1
mesh_rules[0][1][3]: 16
mesh_rules[0][1][4]: 1
mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)'
mesh_rules[1][1][0]: -1
mesh_rules[1][1][1]: 1
mesh_rules[1][1][2]: 8
mesh_rules[1][1][3]: 1
mesh_rules[1][1][2]: 1
mesh_rules[1][1][3]: 8
mesh_rules[1][1][4]: 1
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: -1
mesh_shape[3]: 1
mesh_shape[2]: 1
mesh_shape[3]: -1
mesh_shape[4]: 1
model.batch_axis_names[0]: 'data'
model.batch_axis_names[1]: 'fsdp'
model.batch_axis_names[1]: 'expert'
model.batch_axis_names[2]: 'fsdp'
model.decoder.attention_mask.klass: 'axlearn.common.attention.CausalAttentionLogitBiasLayer'
model.decoder.dim: 4096
model.decoder.dropout_rate: 0.0
Expand All @@ -141,8 +147,9 @@ model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'fsdp'
model.decoder.logits_partition_spec[1]: None
model.decoder.logits_partition_spec[0][1]: 'expert'
model.decoder.logits_partition_spec[0][2]: 'fsdp'
model.decoder.logits_partition_spec[1]: 'seq'
model.decoder.logits_partition_spec[2]: 'model'
model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
Expand All @@ -160,21 +167,25 @@ model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.Tr
model.decoder.transformer.layer.feed_forward.linear1.bias: False
model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: None
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.bias: False
model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: None
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'seq'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05
model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None
model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm'
Expand All @@ -195,8 +206,9 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.input_line
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear'
Expand All @@ -209,8 +221,9 @@ model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.
model.decoder.transformer.layer.self_attention.attention.num_heads: 32
model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False
model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
Expand All @@ -234,6 +247,7 @@ model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in'
model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.seq_axis_names: 'seq'
model.z_loss_scale: 0.0
name: 'gpt_trainer'
prune_empty_state_updates: True
Expand Down
Loading