From bf59c126c5e18bb75413ccc2ddea8055abebcade Mon Sep 17 00:00:00 2001 From: Quentin-Anthony Date: Fri, 10 Mar 2023 02:57:48 +0000 Subject: [PATCH 1/4] Add remaining mup params from paper's list. Add prototype support for deferred init --- megatron/model/gpt2_model.py | 6 +++ megatron/model/transformer.py | 5 ++- megatron/neox_arguments/neox_args.py | 5 +++ megatron/training.py | 56 ++++++++++++++++++++++------ 4 files changed, 59 insertions(+), 13 deletions(-) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 296196f38..df0ead29e 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -222,6 +222,9 @@ def init_specs(self): heads=self.neox_args.num_attention_heads, ) + if self.neox_args.use_mup and self.neox_args.mup_input_temp is not None: + self.specs.append(lambda x: x * self.neox_args.mup_input_temp) + # Transformer layers for i in range(self.neox_args.num_layers): layer_type = self.neox_args.attention_config[i] @@ -260,6 +263,9 @@ def init_specs(self): LayerSpec(NormPipe, norm, self.neox_args.hidden_size, eps=eps) ) + if self.neox_args.use_mup and self.neox_args.output_temp is not None: + self.specs.append(lambda x: x * self.neox_args.mup_output_temp / self.neox_args.hidden_size) + # outputs are now a single tensor: hidden_states def _logits_helper(embedding, lm_output): diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e72ed91e5..291589e91 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -227,7 +227,10 @@ def __init__( ) coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if neox_args.use_mup: + self.norm_factor = self.hidden_size_per_attention_head / neox_args.mup_attn_temp + else: + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = max(1, self.layer_number) self.norm_factor *= coeff diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 95e6b6b8e..c8492b76d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1030,6 +1030,11 @@ class NeoXArgsTraining(NeoXArgsTemplate): What to scale width by when creating the delta model for mup """ + mup_deferred_init: bool = False + """ + Whether to fully initialize the base and delta models (set to true for big target models) + """ + @dataclass class NeoXArgsTextgen(NeoXArgsTemplate): diff --git a/megatron/training.py b/megatron/training.py index e2a6f8243..21bbf80bc 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -76,12 +76,28 @@ def save_base_shapes(neox_args, base_shapes, use_cache): # Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here neox_args.use_mup = False - base_model = GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache) + print(f'MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + if neox_args.mup_deferred_init: + try: + from torchdistx import deferred_init + except ModuleNotFoundError: + print("Please install torchdistx https://github.com/pytorch/torchdistx") + raise Exception + base_model = torchdistx.deferred_init.deferred_init(GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache)) + else: + base_model = GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache) + + print(f'MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') if not neox_args.is_pipe_parallel: base_model = base_model.to_sequential() @@ -99,12 +115,28 @@ def save_base_shapes(neox_args, base_shapes, use_cache): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale - delta_model = GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache) + print(f'MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + if neox_args.mup_deferred_init: + try: + from torchdistx import deferred_init + except ModuleNotFoundError: + print("Please install torchdistx https://github.com/pytorch/torchdistx") + raise Exception + delta_model = torchdistx.deferred_init.deferred_init(GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache)) + else: + delta_model = GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache) + + print(f'MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') if not neox_args.is_pipe_parallel: delta_model = delta_model.to_sequential() From 585efeab4ed06bf6408957c22fa03b1892905e72 Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 15 Mar 2023 18:22:59 +0000 Subject: [PATCH 2/4] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 0a7885be0..0493906dd 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = ebaeec1 + Default = 5f09348 current git hash of repository @@ -1548,6 +1548,14 @@ Training Arguments +- **mup_deferred_init**: bool + + Default = False + + Whether to fully initialize the base and delta models (set to true for big target models) + + + ## NeoXArgsDeepspeedConfig Args for deepspeed config From 4ab0c08abeb2367f834aea9b81a49bf5e4b93f0e Mon Sep 17 00:00:00 2001 From: curt-tigges Date: Wed, 12 Apr 2023 21:13:24 +0000 Subject: [PATCH 3/4] Fixed syntax errors in training.py and added requirements-mup file --- megatron/training.py | 48 +++++++++++++++++++------------ requirements/requirements-mup.txt | 1 + 2 files changed, 31 insertions(+), 18 deletions(-) create mode 100644 requirements/requirements-mup.txt diff --git a/megatron/training.py b/megatron/training.py index 55132d864..03ef6b721 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -76,20 +76,24 @@ def save_base_shapes(neox_args, base_shapes, use_cache): # Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here neox_args.use_mup = False - print(f'MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + print( + f"MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + ) if neox_args.mup_deferred_init: try: from torchdistx import deferred_init except ModuleNotFoundError: print("Please install torchdistx https://github.com/pytorch/torchdistx") raise Exception - base_model = torchdistx.deferred_init.deferred_init(GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) + base_model = torchdistx.deferred_init.deferred_init( + GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache, + ) + ) else: base_model = GPT2ModelPipe( neox_args=neox_args, @@ -99,7 +103,9 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - print(f'MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + print( + f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + ) if not neox_args.is_pipe_parallel: base_model = base_model.to_sequential() @@ -117,20 +123,24 @@ def save_base_shapes(neox_args, base_shapes, use_cache): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale - print(f'MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + print( + f"MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + ) if neox_args.mup_deferred_init: try: from torchdistx import deferred_init except ModuleNotFoundError: print("Please install torchdistx https://github.com/pytorch/torchdistx") raise Exception - delta_model = torchdistx.deferred_init.deferred_init(GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) + delta_model = torchdistx.deferred_init.deferred_init( + GPT2ModelPipe( + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache, + ) + ) else: delta_model = GPT2ModelPipe( neox_args=neox_args, @@ -140,7 +150,9 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - print(f'MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}') + print( + f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + ) if not neox_args.is_pipe_parallel: delta_model = delta_model.to_sequential() diff --git a/requirements/requirements-mup.txt b/requirements/requirements-mup.txt new file mode 100644 index 000000000..f6f475e84 --- /dev/null +++ b/requirements/requirements-mup.txt @@ -0,0 +1 @@ +mup==1.0.0 \ No newline at end of file From 54f82ca056d9d65a90310aa5cbdfeebc9e3bf708 Mon Sep 17 00:00:00 2001 From: curt-tigges Date: Mon, 24 Apr 2023 17:16:10 +0000 Subject: [PATCH 4/4] Fixed misc bugs in mup implementation --- megatron/training.py | 55 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 03ef6b721..1d01b0d8c 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -76,23 +76,22 @@ def save_base_shapes(neox_args, base_shapes, use_cache): # Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here neox_args.use_mup = False - print( - f"MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" - ) + # print( + # f"MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + # ) if neox_args.mup_deferred_init: try: from torchdistx import deferred_init except ModuleNotFoundError: print("Please install torchdistx https://github.com/pytorch/torchdistx") raise Exception - base_model = torchdistx.deferred_init.deferred_init( - GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) + base_model = deferred_init.deferred_init( + GPT2ModelPipe, + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache, ) else: base_model = GPT2ModelPipe( @@ -103,9 +102,9 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - print( - f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" - ) + # print( + # f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + # ) if not neox_args.is_pipe_parallel: base_model = base_model.to_sequential() @@ -123,23 +122,23 @@ def save_base_shapes(neox_args, base_shapes, use_cache): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale - print( - f"MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" - ) + # print( + # f"MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + # ) if neox_args.mup_deferred_init: + print("Using MUP deferred init") try: from torchdistx import deferred_init except ModuleNotFoundError: print("Please install torchdistx https://github.com/pytorch/torchdistx") raise Exception - delta_model = torchdistx.deferred_init.deferred_init( - GPT2ModelPipe( - neox_args=neox_args, - num_tokentypes=0, - parallel_output=True, - topology=mpu.get_topology(), - use_cache=use_cache, - ) + delta_model = deferred_init.deferred_init( + GPT2ModelPipe, + neox_args=neox_args, + num_tokentypes=0, + parallel_output=True, + topology=mpu.get_topology(), + use_cache=use_cache, ) else: delta_model = GPT2ModelPipe( @@ -150,9 +149,9 @@ def save_base_shapes(neox_args, base_shapes, use_cache): use_cache=use_cache, ) - print( - f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" - ) + # print( + # f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}" + # ) if not neox_args.is_pipe_parallel: delta_model = delta_model.to_sequential()