Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing MuP #1061

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,14 @@ def __init__(
coeff = max(1, self.layer_number)
self.norm_factor *= coeff

if neox_args.use_mup:
self.norm_factor = self.hidden_size_per_attention_head
# TODO
#right now there's no way to correctly set use_mup here, possible options:
#- refactor model init (hard)
#- do this via another config argument, e.g. "mup_norm_factor" (probably easy)
#- ignore, this never changed anything in my experiments
#
#if neox_args.use_mup:
# self.norm_factor = self.hidden_size_per_attention_head

self.rpe = rpe

Expand Down
10 changes: 6 additions & 4 deletions megatron/model/word_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
self.use_mup = neox_args.use_mup
self.use_mup = neox_args.use_mup # TODO: as of now this will always be false
self.mup_embedding_mult = neox_args.mup_embedding_mult
self.mup_rp_embedding_mult = neox_args.mup_rp_embedding_mult

Expand Down Expand Up @@ -155,9 +155,11 @@ def forward(self, input_ids, position_ids, tokentype_ids=None):
# Dropout.
embeddings = self.embedding_dropout(embeddings)

if self.use_mup:
with torch.no_grad():
embeddings.mul_(self.mup_embedding_mult)
# TODO:
# not only this always false because of the way the model is initialized, but this also throws an error
# if self.use_mup:
# with torch.no_grad():
# embeddings.mul_(self.mup_embedding_mult)

return embeddings

Expand Down
13 changes: 8 additions & 5 deletions megatron/mpu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def __init__(
self.init_method = init_method
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.use_mup = False

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down Expand Up @@ -539,6 +539,7 @@ def mup_reinitialize_weights(self, neox_args):
partition_dim=0,
stride=self.stride,
)
self.use_mup = True

def set_parallel_output(self, value: bool):
assert isinstance(value, bool)
Expand All @@ -547,8 +548,9 @@ def set_parallel_output(self, value: bool):
) # if gather_output is True, parallel output is False, so we set the opposite

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
if self.mup_rescale_parameters:
if hasattr(self.weight, "infshape"):
input_ /= self.weight.infshape.width_mult()
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
Expand Down Expand Up @@ -623,7 +625,7 @@ def __init__(
self.stride = stride
self.keep_master_weight_for_test = keep_master_weight_for_test
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
self.use_mup = False

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down Expand Up @@ -728,13 +730,14 @@ def mup_reinitialize_weights(self, neox_args):
partition_dim=1,
stride=self.stride,
)
self.use_mup = True

def set_parallel_output(self, parallel_output: bool):
assert isinstance(parallel_output, bool)
self.parallel_output = parallel_output

def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
if self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
if self.input_is_parallel:
Expand Down
15 changes: 10 additions & 5 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ def mup_weights_reinit(neox_args, model):
def has_method(o, name):
return callable(getattr(o, name, None))

# HACK: it uses the mother class name to avoid re-initializing the output layer, highly prone to future bugs
# HACK: only works with non-tied input-output layers

previous = ""
for layer in model.modules():
# This normally would happen in set_base_shapes if we actually were able to use the MuReadout class
if hasattr(layer, "mup_rescale_parameters") and layer.mup_rescale_parameters:
layer._rescale_parameters()

if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
if previous != "ParallelLinearPipe":
if has_method(layer, "mup_reinitialize_weights"):
layer.mup_reinitialize_weights(neox_args)
previous = layer.__class__.__name__


def save_base_shapes(neox_args, base_shapes, use_cache):
Expand Down Expand Up @@ -530,9 +535,9 @@ def get_optimizer(model, neox_args):
# Use Adam
if neox_args.use_mup:
try:
from mup import MuAdam
from mup import MuAdamW # TODO: was there any particular reason for not using MuAdamW?

adam_optimizer = MuAdam
adam_optimizer = MuAdamW
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
Expand Down
9 changes: 9 additions & 0 deletions mup/CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Microsoft Open Source Code of Conduct

This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).

Resources:

- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
589 changes: 589 additions & 0 deletions mup/CoordCheck.ipynb

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions mup/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) Microsoft Corporation.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
Loading