Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move load state dict after initialize parallel state in nlp_model #9382

Merged
merged 4 commits into from
Jun 15, 2024

Conversation

ryxli
Copy link
Contributor

@ryxli ryxli commented Jun 5, 2024

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Move loading sharded state dict after initialize parallel state.

Addresses #8460

Notes:

Traceback (most recent call last):
  File "train_sft.py", line 64, in main
    ptl_model, updated_cfg = load_megatron_model(GPTSFTModel, cfg, trainer, additional_keys=additional_keys)
  File "/workspace/NeMo/nemo/collections/nlp/models/nlp_model.py", line 397, in load_from_checkpoint
    checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir)
  File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 133, in load
    validate_sharding_integrity(nested_values(sharded_state_dict))
  File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 437, in validate_sharding_integrity
    _validate_sharding_for_key(shardings)
  File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/serialization.py", line 475, in _validate_sharding_for_key
    raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
megatron.core.dist_checkpointing.core.CheckpointingException: Invalid access pattern for ShardedTensor(key='model.embedding.word_embeddings.weight')

Collection: [Note which collection this PR will affect]
nlp/language_modeling

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below

For the nemo aligner SFT use case:

  1. change the implementation of nemo_aligner.models.nlp.gpt.gpt_sft_model.GPTSFTModel like so
from omegaconf import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo_aligner.models.nlp.gpt.gpt_sft_model import GPTSFTModel

from torch.nn.modules.module import Module


class _GPTSFTModel(GPTSFTModel):
    """
    Megatron GPT Supervised Fine-Tuning, support load from mcore distributed ckpt.
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer):
        super().__init__(cfg, trainer=trainer)

    def load_state_dict(self, state_dict, strict: bool = True):
        if self.use_peft and self.setup_complete:
            super().load_state_dict(state_dict, strict=strict)
        else:
            if self.mcore_gpt and not self.use_fsdp:
                # mcore distributed checkpoint
                # we defer loading the state_dict until the class has been initialized
                return Module.load_state_dict(self, state_dict, strict=False)
            else:
                super().load_state_dict(state_dict, strict=strict)
  1. Load the checkpoint overriding the files like so
from fsspec.implementations.local import LocalFileSystem


def load_from_checkpoint_dir(cfg, trainer, modify_confg_fn, checkpoint_logger=None):
    cls = _GPTSFTModel # from above
    app_state = AppState()
    if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1:
        app_state.model_parallel_size = (
            cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size * cfg.get("expert_model_parallel_size", 1)
        )
        app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size
        app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size
        app_state.expert_model_parallel_size = cfg.get("expert_model_parallel_size", 1)
        (
            app_state.tensor_model_parallel_rank,
            app_state.pipeline_model_parallel_rank,
            app_state.expert_model_parallel_rank,
            app_state.model_parallel_size,
            app_state.data_parallel_size,
            app_state.pipeline_model_parallel_split_rank,
            app_state.virtual_pipeline_model_parallel_rank,
        ) = fake_initialize_model_parallel(
            world_size=app_state.model_parallel_size,
            rank=trainer.global_rank,
            tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size,
            pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size,
            pipeline_model_parallel_split_rank_=cfg.model.get("pipeline_model_parallel_split_rank", None),
            expert_model_parallel_size_=cfg.get("expert_model_parallel_size", 1),
        )
    checkpoint_path = os.path.join(
        cfg.model.pretrained_checkpoint.checkpoint_dir,
        cfg.model.pretrained_checkpoint.checkpoint_name,
    )

    if not os.path.isdir(checkpoint_path):
        # legacy checkpoint needs model parallel rank injection
        checkpoint_path = inject_model_parallel_rank(checkpoint_path)
    fs = LocalFileSystem()
    with fs.open(cfg.model.pretrained_checkpoint.hparams_file, "r") as fin:
        hparams_file = OmegaConf.load(fin)
    gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True)
    with tempfile.NamedTemporaryFile(suffix=".yaml") as f:
        OmegaConf.save(config=gpt_cfg, f=f.name)
        model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name)
        return model, gpt_cfg
  1. in the config, resume the training from a dist ckpt dir:
model:
  pretrained_checkpoint:
    checkpoint_dir: /experiments/example_gpt_model_A/checkpoints/
    checkpoint_name: megatron_gpt--val_loss=1.23-step=5000-consumed_samples=100000.0
    hparams_file: /hparams.yaml
  • this hparams file, had to manually create as by default hparams.yaml is not saved with dist ckpt format last I checked. It seems like the behavior is to only create and store with .nemo or the legacy pp tp partition format with the PTL callback. Unrelated, but it might be nice to also save the hparams.yaml file when saving a distributed ckpt (not with .nemo)

GitHub Actions CI

The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.

The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • [ x] Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@mikolajblaz
Copy link
Collaborator

Look good to me 👍

Can you apply a similar sharded_state_dict delay here? https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/parts/nlp_overrides.py#L947

mikolajblaz
mikolajblaz previously approved these changes Jun 7, 2024
Ryan Li added 2 commits June 7, 2024 22:11
Signed-off-by: Ryan Li <rynli@amazon.com>
Signed-off-by: Ryan Li <rynli@amazon.com>
@ryxli
Copy link
Contributor Author

ryxli commented Jun 7, 2024

@mikolajblaz

Look good to me 👍

Can you apply a similar sharded_state_dict delay here? https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/parts/nlp_overrides.py#L947

Done, thanks for review.

@ryxli ryxli requested a review from mikolajblaz June 12, 2024 00:18
@ryxli
Copy link
Contributor Author

ryxli commented Jun 13, 2024

@ericharper or @mikolajblaz

Could you add the run CICD label?

@ericharper ericharper merged commit ec0eb59 into NVIDIA:main Jun 15, 2024
108 of 109 checks passed
@ryxli ryxli deleted the fix_load_dist_ckpt branch June 15, 2024 03:46
JesusPaz pushed a commit to JesusPaz/NeMo that referenced this pull request Jun 18, 2024
…IDIA#9382)

* move load state dict after initialize parallel state

Signed-off-by: Ryan Li <rynli@amazon.com>

* delay sharded_state_dict in save_to

Signed-off-by: Ryan Li <rynli@amazon.com>

---------

Signed-off-by: Ryan Li <rynli@amazon.com>
Co-authored-by: Ryan Li <rynli@amazon.com>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
…IDIA#9382)

* move load state dict after initialize parallel state

Signed-off-by: Ryan Li <rynli@amazon.com>

* delay sharded_state_dict in save_to

Signed-off-by: Ryan Li <rynli@amazon.com>

---------

Signed-off-by: Ryan Li <rynli@amazon.com>
Co-authored-by: Ryan Li <rynli@amazon.com>
@ko3n1g ko3n1g mentioned this pull request Jul 18, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants