From 359ee9bed7684eb1b7f37397a6eaa9cf2ccef134 Mon Sep 17 00:00:00 2001 From: Tom Gunter Date: Fri, 8 Dec 2023 12:51:30 -0800 Subject: [PATCH] set_double_shard_weights_config(...) now supports a seq_axis_names arg. --- axlearn/common/attention.py | 6 +- axlearn/common/attention_test.py | 18 +++-- .../fuji-7B-single.txt | 66 +++++++++++-------- .../fuji-7B.txt | 66 +++++++++++-------- .../fuji-test.txt | 52 +++++++++------ axlearn/experiments/text/gpt/common.py | 28 ++++---- 6 files changed, 145 insertions(+), 91 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index a1688ccbe..c02b69265 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -2756,6 +2756,7 @@ def set_double_shard_weights_config( batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), fsdp_axis_names: Union[str, Sequence[str]] = "fsdp", tp_axis_names: Union[str, Sequence[str]] = "model", + seq_axis_names: Union[str, Sequence[str]] = "seq", ): """Sets `cfg` to shard FFN and attention weights over both fsdp and tp axes. @@ -2764,6 +2765,7 @@ def set_double_shard_weights_config( batch_axis_names: Axis name(s) over which we shard the batch dimension of output tensors. fsdp_axis_names: Axis name(s) over which we shard fully-sharded-data-parallel tensors. tp_axis_names: Axis name(s) over which we shard tensor-parallel tensors. + seq_axis_names: Axis name(s) over which we shard sequence-parallel tensors. """ # pytype: disable=attribute-error @@ -2780,8 +2782,8 @@ def set_ffn_partition_specs(ff_layer: TransformerFeedForwardLayer.Config): ff_layer.linear1.param_partition_spec = (fsdp_axis_names, tp_axis_names) ff_layer.linear2.param_partition_spec = (tp_axis_names, fsdp_axis_names) # Encourage the right activation sharding. - ff_layer.linear1.output_partition_spec = (batch_axis_names, None, tp_axis_names) - ff_layer.linear2.output_partition_spec = (batch_axis_names, None, tp_axis_names) + ff_layer.linear1.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) + ff_layer.linear2.output_partition_spec = (batch_axis_names, seq_axis_names, tp_axis_names) if not isinstance(cfg, Sequence): cfg = [cfg] diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 9f8fed410..ead34b37a 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -3300,6 +3300,7 @@ class ConfigHelperTest(TestCase): batch_axis_names=("data", ("replica", "data", "fsdp")), fsdp_axis_names=("fsdp",), tp_axis_names=("model",), + seq_axis_names=("seq",), ) def test_set_double_shard_weights_config( self, @@ -3308,6 +3309,7 @@ def test_set_double_shard_weights_config( batch_axis_names, fsdp_axis_names, tp_axis_names, + seq_axis_names, ): cfg: TransformerLayer.Config = TransformerLayer.default_config().set( cross_attention=cross_attention_cfg @@ -3318,6 +3320,7 @@ def test_set_double_shard_weights_config( batch_axis_names=batch_axis_names, fsdp_axis_names=fsdp_axis_names, tp_axis_names=tp_axis_names, + seq_axis_names=seq_axis_names, ) ff_layer = cfg.feed_forward @@ -3328,10 +3331,12 @@ def test_set_double_shard_weights_config( ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) ) self.assertSequenceEqual( - ff_layer.linear1.output_partition_spec, (batch_axis_names, None, tp_axis_names) + ff_layer.linear1.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), ) self.assertSequenceEqual( - ff_layer.linear2.output_partition_spec, (batch_axis_names, None, tp_axis_names) + ff_layer.linear2.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), ) self_atten = cfg.self_attention.attention @@ -3370,6 +3375,7 @@ def test_set_double_shard_weights_config( batch_axis_names=("data", ("replica", "data", "fsdp")), fsdp_axis_names=("fsdp",), tp_axis_names=("model",), + seq_axis_names=("seq",), ) def test_set_double_shard_weights_config_for_list_of_configs( self, @@ -3378,6 +3384,7 @@ def test_set_double_shard_weights_config_for_list_of_configs( batch_axis_names, fsdp_axis_names, tp_axis_names, + seq_axis_names, ): cfg_layer: TransformerLayer.Config = TransformerLayer.default_config().set( cross_attention=cross_attention_cfg @@ -3389,6 +3396,7 @@ def test_set_double_shard_weights_config_for_list_of_configs( batch_axis_names=batch_axis_names, fsdp_axis_names=fsdp_axis_names, tp_axis_names=tp_axis_names, + seq_axis_names=seq_axis_names, ) for cfg in cfg_layers: @@ -3400,10 +3408,12 @@ def test_set_double_shard_weights_config_for_list_of_configs( ff_layer.linear2.param_partition_spec, (tp_axis_names, fsdp_axis_names) ) self.assertSequenceEqual( - ff_layer.linear1.output_partition_spec, (batch_axis_names, None, tp_axis_names) + ff_layer.linear1.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), ) self.assertSequenceEqual( - ff_layer.linear2.output_partition_spec, (batch_axis_names, None, tp_axis_names) + ff_layer.linear2.output_partition_spec, + (batch_axis_names, seq_axis_names, tp_axis_names), ) self_atten = cfg.self_attention.attention diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-single.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-single.txt index 63c72805d..56e31f468 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-single.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-single.txt @@ -1,6 +1,7 @@ batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' +batch_axis_names[1]: 'seq' +batch_axis_names[2]: 'expert' +batch_axis_names[3]: 'fsdp' checkpointer.gc_loop_interval_seconds: 60 checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 @@ -106,25 +107,30 @@ learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' max_step: 500000 mesh_axis_names[0]: 'data' -mesh_axis_names[1]: 'expert' -mesh_axis_names[2]: 'fsdp' -mesh_axis_names[3]: 'model' +mesh_axis_names[1]: 'seq' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'model' mesh_rules[0][0]: 'tpu-v4-(1024|2048)' mesh_rules[0][1][0]: -1 mesh_rules[0][1][1]: 1 -mesh_rules[0][1][2]: 16 -mesh_rules[0][1][3]: 1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)' mesh_rules[1][1][0]: -1 mesh_rules[1][1][1]: 1 -mesh_rules[1][1][2]: 8 -mesh_rules[1][1][3]: 1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 8 +mesh_rules[1][1][4]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 -mesh_shape[2]: -1 -mesh_shape[3]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'fsdp' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask.klass: 'axlearn.common.attention.CausalAttentionLogitBiasLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 @@ -141,8 +147,9 @@ model.decoder.emb.token_emb.param_partition_spec[1]: 'model' model.decoder.eos_token_id: 1 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'fsdp' -model.decoder.logits_partition_spec[1]: None +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' model.decoder.logits_partition_spec[2]: 'model' model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 @@ -160,21 +167,25 @@ model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.Tr model.decoder.transformer.layer.feed_forward.linear1.bias: False model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' model.decoder.transformer.layer.feed_forward.linear2.bias: False model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' @@ -195,8 +206,9 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.input_line model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' @@ -209,8 +221,9 @@ model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common. model.decoder.transformer.layer.self_attention.attention.num_heads: 32 model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' @@ -234,6 +247,7 @@ model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' model.param_init.init_by_param_name['.*weight$'].scale: 1.0 model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names: 'seq' model.z_loss_scale: 0.0 name: 'gpt_trainer' prune_empty_state_updates: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B.txt index 4e74f1735..fc517aa65 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B.txt @@ -1,6 +1,7 @@ batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' +batch_axis_names[1]: 'seq' +batch_axis_names[2]: 'expert' +batch_axis_names[3]: 'fsdp' checkpointer.gc_loop_interval_seconds: 60 checkpointer.keep_every_n_steps: 50000 checkpointer.keep_last_n: 3 @@ -106,25 +107,30 @@ learner.optimizer.args[1].weight_decay: 0.1 learner.optimizer.fn: 'axlearn.common.optimizers.chain' max_step: 500000 mesh_axis_names[0]: 'data' -mesh_axis_names[1]: 'expert' -mesh_axis_names[2]: 'fsdp' -mesh_axis_names[3]: 'model' +mesh_axis_names[1]: 'seq' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'model' mesh_rules[0][0]: 'tpu-v4-(1024|2048)' mesh_rules[0][1][0]: -1 mesh_rules[0][1][1]: 1 -mesh_rules[0][1][2]: 16 -mesh_rules[0][1][3]: 1 +mesh_rules[0][1][2]: 1 +mesh_rules[0][1][3]: 16 +mesh_rules[0][1][4]: 1 mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(256|512|1024)' mesh_rules[1][1][0]: -1 mesh_rules[1][1][1]: 1 -mesh_rules[1][1][2]: 8 -mesh_rules[1][1][3]: 1 +mesh_rules[1][1][2]: 1 +mesh_rules[1][1][3]: 8 +mesh_rules[1][1][4]: 1 mesh_shape[0]: 1 mesh_shape[1]: 1 -mesh_shape[2]: -1 -mesh_shape[3]: 1 +mesh_shape[2]: 1 +mesh_shape[3]: -1 +mesh_shape[4]: 1 model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'fsdp' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask.klass: 'axlearn.common.attention.CausalAttentionLogitBiasLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 @@ -141,8 +147,9 @@ model.decoder.emb.token_emb.param_partition_spec[1]: 'model' model.decoder.eos_token_id: 1 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'fsdp' -model.decoder.logits_partition_spec[1]: None +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' model.decoder.logits_partition_spec[2]: 'model' model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 @@ -160,21 +167,25 @@ model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.Tr model.decoder.transformer.layer.feed_forward.linear1.bias: False model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' model.decoder.transformer.layer.feed_forward.linear2.bias: False model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' @@ -195,8 +206,9 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.input_line model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' @@ -209,8 +221,9 @@ model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common. model.decoder.transformer.layer.self_attention.attention.num_heads: 32 model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' @@ -234,6 +247,7 @@ model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' model.param_init.init_by_param_name['.*weight$'].scale: 1.0 model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names: 'seq' model.z_loss_scale: 0.0 name: 'gpt_trainer' prune_empty_state_updates: True diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test.txt index f83ca1068..3539e9931 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test.txt @@ -1,6 +1,7 @@ batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' +batch_axis_names[1]: 'seq' +batch_axis_names[2]: 'expert' +batch_axis_names[3]: 'fsdp' checkpointer.gc_loop_interval_seconds: 60 checkpointer.keep_every_n_steps: 3000 checkpointer.keep_last_n: 3 @@ -106,15 +107,18 @@ learner.optimizer.args[1].weight_decay: 0.01 learner.optimizer.fn: 'axlearn.common.optimizers.chain' max_step: 3000 mesh_axis_names[0]: 'data' -mesh_axis_names[1]: 'expert' -mesh_axis_names[2]: 'fsdp' -mesh_axis_names[3]: 'model' +mesh_axis_names[1]: 'seq' +mesh_axis_names[2]: 'expert' +mesh_axis_names[3]: 'fsdp' +mesh_axis_names[4]: 'model' mesh_shape[0]: 1 mesh_shape[1]: 1 mesh_shape[2]: 1 mesh_shape[3]: 1 +mesh_shape[4]: 1 model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'fsdp' +model.batch_axis_names[1]: 'expert' +model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask.klass: 'axlearn.common.attention.CausalAttentionLogitBiasLayer' model.decoder.dim: 8 model.decoder.dropout_rate: 0.0 @@ -131,8 +135,9 @@ model.decoder.emb.token_emb.param_partition_spec[1]: 'model' model.decoder.eos_token_id: 1 model.decoder.klass: 'axlearn.common.decoder.Decoder' model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'fsdp' -model.decoder.logits_partition_spec[1]: None +model.decoder.logits_partition_spec[0][1]: 'expert' +model.decoder.logits_partition_spec[0][2]: 'fsdp' +model.decoder.logits_partition_spec[1]: 'seq' model.decoder.logits_partition_spec[2]: 'model' model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.output_norm.eps: 1e-05 @@ -150,21 +155,25 @@ model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.Tr model.decoder.transformer.layer.feed_forward.linear1.bias: False model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' model.decoder.transformer.layer.feed_forward.linear2.bias: False model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: None +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1][0]: 'seq' model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'seq' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'expert' +model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'fsdp' model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' @@ -185,8 +194,9 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.input_line model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' @@ -199,8 +209,9 @@ model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common. model.decoder.transformer.layer.self_attention.attention.num_heads: 4 model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'seq' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'expert' +model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'fsdp' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' @@ -224,6 +235,7 @@ model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' model.param_init.init_by_param_name['.*weight$'].scale: 1.0 model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' +model.seq_axis_names: 'seq' model.z_loss_scale: 0.0 name: 'gpt_trainer' prune_empty_state_updates: True diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index cfee48fc4..78e6409f6 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -68,7 +68,7 @@ # The default mesh-axis names for LM training, from least to most communication intensive. # See mesh_shape_from_axes() docstring for more details. -MESH_AXIS_NAMES = ("data", "expert", "fsdp", "model") +MESH_AXIS_NAMES = ("data", "seq", "expert", "fsdp", "model") def scaled_hidden_dim(scale: float, *, round_up_to_multiples_of: int = 256) -> FunctionConfigBase: @@ -131,13 +131,15 @@ def tfds_input( def mesh_shape_from_axes( - *, data: int = 1, expert: int = 1, fsdp: int = 1, model: int = 1 + *, data: int = 1, seq: int = 1, expert: int = 1, fsdp: int = 1, model: int = 1 ) -> Tuple[int, int, int, int]: - """Builds a 4D logical mesh from the provided spec. + """Builds a 5D logical mesh from the provided spec. Args: data: For data-paralellism. Expect model state to be fully replicated over this axis. Useful for e.g. multi-slice/granule partitioning with slow networking between granules. + seq: Used for sequence-parallelism. Typically this means sharding the activation sequence + dimension, and possibly a subset of the weights. expert: Designed to be used for partitioning "experts" in mixture-of-expert models. E.g. . fsdp: Fully-sharded-data-parallelism a.k.a. async-with-compute model-parallelism. @@ -148,9 +150,9 @@ def mesh_shape_from_axes( Returns: A tuple describing the logical mesh shape (from least to most communication intensive). """ - assert MESH_AXIS_NAMES == ("data", "expert", "fsdp", "model") + assert MESH_AXIS_NAMES == ("data", "seq", "expert", "fsdp", "model") # We set the minimum size for a mesh axis to 1 as anything lower is degenerate, except -1. - return tuple((max(x, 1) if x != -1 else -1 for x in [data, expert, fsdp, model])) + return tuple((max(x, 1) if x != -1 else -1 for x in [data, seq, expert, fsdp, model])) def model_config( @@ -233,22 +235,23 @@ def model_config( ) } ) - batch_axis_names = ("data", "fsdp") + batch_axis_names = ("data", "expert", "fsdp") cfg = causal_lm.Model.default_config().set( decoder=decoder_cfg, param_init=model_param_init, batch_axis_names=batch_axis_names, - seq_axis_names=None, + seq_axis_names="seq", ) cfg.dtype = jnp.float32 - # Shard some FFN and attention weights over both FSDP and model axes. + # Shard some FFN and attention weights over multiple axes. set_double_shard_weights_config( cfg.decoder.transformer.layer, batch_axis_names=batch_axis_names, - fsdp_axis_names=("expert", "fsdp"), + fsdp_axis_names=("seq", "expert", "fsdp"), tp_axis_names="model", + seq_axis_names=("seq",), ) - cfg.decoder.logits_partition_spec = (batch_axis_names, None, "model") + cfg.decoder.logits_partition_spec = (batch_axis_names, "seq", "model") set_bias_recursively(cfg, False) set_norm_recursively(cfg, normalization) cfg.z_loss_scale = z_loss_scale @@ -536,9 +539,8 @@ def config_fn() -> InstantiableConfig: ), f"Len mismatch: {mesh_axis_names} vs. {mesh_shape}" cfg.mesh_axis_names = mesh_axis_names cfg.mesh_shape = mesh_shape - # Set batch sharding spec to be all but the last axis (assumed for tensor-parallelism). - assert mesh_axis_names[-1] == "model" - cfg.batch_axis_names = mesh_axis_names[:-1] + # Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism). + cfg.batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") cfg.mesh_rules = mesh_rules # Maybe load state. if init_state_builder: