Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve our mup implementation #837

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = ebaeec1
Default = 5f09348

current git hash of repository

Expand Down Expand Up @@ -1548,6 +1548,14 @@ Training Arguments



- **mup_deferred_init**: bool

Default = False

Whether to fully initialize the base and delta models (set to true for big target models)



## NeoXArgsDeepspeedConfig

Args for deepspeed config
Expand Down
6 changes: 6 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def init_specs(self):
heads=self.neox_args.num_attention_heads,
)

if self.neox_args.use_mup and self.neox_args.mup_input_temp is not None:
self.specs.append(lambda x: x * self.neox_args.mup_input_temp)

# Transformer layers
for i in range(self.neox_args.num_layers):
layer_type = self.neox_args.attention_config[i]
Expand Down Expand Up @@ -260,6 +263,9 @@ def init_specs(self):
LayerSpec(NormPipe, norm, self.neox_args.hidden_size, eps=eps)
)

if self.neox_args.use_mup and self.neox_args.output_temp is not None:
self.specs.append(lambda x: x * self.neox_args.mup_output_temp / self.neox_args.hidden_size)

# outputs are now a single tensor: hidden_states

def _logits_helper(embedding, lm_output):
Expand Down
5 changes: 4 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ def __init__(
)

coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if neox_args.use_mup:
self.norm_factor = self.hidden_size_per_attention_head / neox_args.mup_attn_temp
else:
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = max(1, self.layer_number)
self.norm_factor *= coeff
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,11 @@ class NeoXArgsTraining(NeoXArgsTemplate):
What to scale width by when creating the delta model for mup
"""

mup_deferred_init: bool = False
"""
Whether to fully initialize the base and delta models (set to true for big target models)
"""


@dataclass
class NeoXArgsTextgen(NeoXArgsTemplate):
Expand Down
73 changes: 59 additions & 14 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,35 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
# Instantiation of the base model fails in the init function (init_functions.py) because we haven't called set_base_shapes on it at this point, so disable it temporarily here
neox_args.use_mup = False

base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)
# print(
# f"MEM BEFORE BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}"
# )
if neox_args.mup_deferred_init:
try:
from torchdistx import deferred_init
except ModuleNotFoundError:
print("Please install torchdistx https://github.com/pytorch/torchdistx")
raise Exception
base_model = deferred_init.deferred_init(
GPT2ModelPipe,
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)
else:
base_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

# print(
# f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}"
# )

if not neox_args.is_pipe_parallel:
base_model = base_model.to_sequential()
Expand All @@ -100,13 +122,36 @@ def save_base_shapes(neox_args, base_shapes, use_cache):
old_hidden_size = neox_args.hidden_size
neox_args.hidden_size = neox_args.hidden_size * neox_args.mup_width_scale

delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)
# print(
# f"MEM BEFORE DELTA MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}"
# )
if neox_args.mup_deferred_init:
print("Using MUP deferred init")
try:
from torchdistx import deferred_init
except ModuleNotFoundError:
print("Please install torchdistx https://github.com/pytorch/torchdistx")
raise Exception
delta_model = deferred_init.deferred_init(
GPT2ModelPipe,
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)
else:
delta_model = GPT2ModelPipe(
neox_args=neox_args,
num_tokentypes=0,
parallel_output=True,
topology=mpu.get_topology(),
use_cache=use_cache,
)

# print(
# f"MEM AFTER BASE MUP: {torch.cuda.memory_allocated(device_index)} on rank {torch.distributed.get_rank()}"
# )

if not neox_args.is_pipe_parallel:
delta_model = delta_model.to_sequential()
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-mup.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mup==1.0.0