Skip to content

Commit

Permalink
Create PrecisionPlugin for megatron_ckpt_to_nemo.py trainer (#7767)
Browse files Browse the repository at this point in the history
* Create PrecisionPlugin for megatron_ckpt_to_nemo.py trainer

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* Add ddp_find_unused_parameters_true for punctuation_capitalization_train_evaluate.py

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add '32-true' for precision values

Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com>

---------

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
athitten and pre-commit-ci[bot] committed Oct 20, 2023
1 parent 37c7c50 commit 0e273e0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
41 changes: 36 additions & 5 deletions examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch
from genericpath import isdir
from megatron.core import parallel_state
from omegaconf import open_dict
from omegaconf import OmegaConf, open_dict
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.trainer import Trainer

Expand All @@ -42,7 +42,12 @@
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
NLPDDPStrategy,
NLPSaveRestoreConnector,
PipelineMixedPrecisionPlugin,
)
from nemo.utils import AppState, logging
from nemo.utils.distributed import initialize_distributed
from nemo.utils.model_utils import inject_model_parallel_rank
Expand Down Expand Up @@ -92,6 +97,14 @@ def get_args():
)
parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1))
parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform")
parser.add_argument(
"--precision",
type=str,
required=False,
default='16-mixed',
choices=['32-true', '16-mixed', 'bf16-mixed'],
help="Precision value for the trainer that matches with precision of the ckpt",
)

args = parser.parse_args()
return args
Expand All @@ -109,9 +122,27 @@ def convert(local_rank, rank, world_size, args):
if args.model_type == 'gpt':
strategy = NLPDDPStrategy()

trainer = Trainer(
devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=plugins, strategy=strategy
)
cfg = {
'trainer': {
'devices': args.gpus_per_node,
'num_nodes': num_nodes,
'accelerator': 'gpu',
'precision': args.precision,
},
'model': {'native_amp_init_scale': 2 ** 32, 'native_amp_growth_interval': 1000, 'hysteresis': 2},
}
cfg = OmegaConf.create(cfg)

scaler = None
# If FP16 create a GradScaler as the build_model_parallel_config of MegatronBaseModel expects it
if cfg.trainer.precision == '16-mixed':
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)

app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = args.tensor_model_parallel_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@

@hydra_runner(config_path="conf", config_name="punctuation_capitalization_config")
def main(cfg: DictConfig) -> None:
# PTL 2.0 has find_unused_parameters as False by default, so its required to set it to True
# when there are unused parameters like here
if cfg.trainer.strategy == 'ddp':
cfg.trainer.strategy = "ddp_find_unused_parameters_true"
torch.manual_seed(42)
cfg = OmegaConf.merge(OmegaConf.structured(PunctuationCapitalizationConfig()), cfg)
trainer = pl.Trainer(**cfg.trainer)
Expand Down

0 comments on commit 0e273e0

Please sign in to comment.