Skip to content

Commit

Permalink
feat: update args, configs, and requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
DayOfThePenguin committed May 6, 2024
1 parent 9c66895 commit 3388c51
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 240 deletions.
58 changes: 20 additions & 38 deletions configs/125M-dmoe.yml
Original file line number Diff line number Diff line change
@@ -1,49 +1,39 @@
# GPT-2 pretraining setup
{
# See README for MoE config docs!
"moe_type": "megablocks",
"moe_token_dropping": false,
# Have 4 experts per layer (every 2 layers by default)
"moe_num_experts": 4,
# parallelism settings
"enable_expert_tensor_parallelism": true,
"pipe_parallel_size": 1, # not yet supported for MoE
"model_parallel_size": 1,
"moe_expert_parallel_size": 1,
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 2, # MoE support PP
"model_parallel_size": 2, # MoE uses model parallel group to split both experts and attention weights

# model settings
"num_layers": 12,
"hidden_size": 768,
"num_attention_heads": 12,
"hidden_size": 1024,
"num_attention_heads": 16,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# moe settings
"moe_num_experts": 8,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",
"layernorm_fusion": false,


# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.95],
"betas": [0.9, 0.999],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 0,
Expand All @@ -58,6 +48,7 @@
# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",
"split": "949,50,1",

# activation checkpointing
"checkpoint_activations": true,
Expand All @@ -67,35 +58,26 @@

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"weight_decay": 0.0,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,

# precision settings
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"precision": "bfloat16",

"fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32
# misc. training settings
"train_iters": 320000,
"train_iters": 5,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"min_lr": 0.0006,
"warmup": 0.0,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"log_interval": 1,
"steps_per_print": 1,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,

# networking
"hostfile": "/mock_path"
}
101 changes: 0 additions & 101 deletions configs/125M-moe.yml

This file was deleted.

147 changes: 52 additions & 95 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -1056,14 +1056,6 @@ Parallelism Arguments



- **expert_interval**: int

Default = 2

Have one MoE layer every expert_interval layers



## NeoXArgsTemplate

NeoXArgsTemplate()
Expand Down Expand Up @@ -1185,93 +1177,6 @@ Text Generation arguments



- **moe_top_k**: int

Default = 1

Activate top K experts in MoE



- **use_tutel**: bool

Default = False

Use Tutel optimizations in MoE



- **num_experts**: int

Default = 1

Number of MoE experts



- **moe_loss_coeff**: float

Default = 0.1

Coefficient for MoE loss



- **moe_train_capacity_factor**: float

Default = 1.0

The capacity of the expert at train time



- **moe_eval_capacity_factor**: float

Default = 1.0

The capacity of the expert at eval time



- **moe_min_capacity**: int

Default = 4

The minimum capacity per expert regardless of the capacity_factor



- **moe_token_dropping**: bool

Default = True

Whether to drop tokens when exceeding capacity



- **create_moe_param_group**: bool

Default = True

Whether to create a separate parameter group for MoE parameters



- **moe_use_residual**: bool

Default = True

Whether to use residual in MoE



- **moe_expert_parallel_size**: int

Default = 1

Number of parallel experts in MoE



## NeoXArgsTokenizer

Expand Down Expand Up @@ -2304,3 +2209,55 @@ Args for deepspeed runner (deepspeed.launcher.runner).

Adds a `--account` to the DeepSpeed launch command. In DeeperSpeed this is passed on to the SlurmLauncher as well. Sometimes necessary for cluster rules, or so I've heard.

## NeoXArgsMoE

Args for Mixture of Experts configuration


- **moe_num_experts**: int

Default = 1

The number of experts in MoE layers. MoE
layers not used if set to 1



- **moe_expert_interval**: int

Default = 1

Have one MoE layer every expert_interval layers


- **moe_top_k**: int

Default = 1

The number of experts each token is routed to
in MoE layers.



- **moe_router_type**: typing.Literal['sinkhorn', 'topk']

Default = 'sinkhorn'

What token routing algorithm to use.



- **moe_lbl_in_fp32**: bool

Default = 0.1

Whether to compute the load balancing loss in fp32.



- **moe_jitter_eps**: float

Default = None

Coefficient for MoE routing jitter. Jitter is
not used if set to None
8 changes: 2 additions & 6 deletions megatron/model/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,9 @@ def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor):
# self.activation_func
# )

llama_x_w1T = gg.ops.gmm(
x, w1, grouped_gemm_batch_sizes, trans_b=True
)
llama_x_w1T = gg.ops.gmm(x, w1, grouped_gemm_batch_sizes, trans_b=True)

llama_x_w3T = gg.ops.gmm(
x, w3, grouped_gemm_batch_sizes, trans_b=True
)
llama_x_w3T = gg.ops.gmm(x, w3, grouped_gemm_batch_sizes, trans_b=True)

llama_act_x_w1T = self.activation_func(llama_x_w1T)

Expand Down
Loading

0 comments on commit 3388c51

Please sign in to comment.