Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi node deepspeed can't load_from_checkpoint #11947

Closed
thomas-happify opened this issue Feb 16, 2022 · 10 comments
Closed

Multi node deepspeed can't load_from_checkpoint #11947

thomas-happify opened this issue Feb 16, 2022 · 10 comments
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: deepspeed won't fix This will not be worked on

Comments

@thomas-happify
Copy link

thomas-happify commented Feb 16, 2022

馃悰 Bug

load_from_checkpoint() doesn't work under multi node training

Epoch 0: 100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2/2 [00:00<00:00, 62.84it/s, loss=-1.71, v_num=0] 
Processing zero checkpoint 'logs/last.ckpt/global_step1'
Traceback (most recent call last):
  File "boringmodel.py", line 109, in <module>
    convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
  File "/azureml-envs/azureml_7a1dbd612ddc63955eb40a628f0d1b4a/lib/python3.8/site-packages/pytorch_lightning/utilities/deepspeed.py", line 81, in convert_zero_checkpoint_to_fp32_state_dict
    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
  File "/azureml-envs/azureml_7a1dbd612ddc63955eb40a628f0d1b4a/lib/python3.8/site-packages/deepspeed/utils/zero_to_fp32.py", line 377, in get_fp32_state_dict_from_zero_checkpoint
    return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
  File "/azureml-envs/azureml_7a1dbd612ddc63955eb40a628f0d1b4a/lib/python3.8/site-packages/deepspeed/utils/zero_to_fp32.py", line 137, in _get_fp32_state_dict_from_zero_checkpoint
    zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
  File "/azureml-envs/azureml_7a1dbd612ddc63955eb40a628f0d1b4a/lib/python3.8/site-packages/deepspeed/utils/zero_to_fp32.py", line 93, in parse_optim_states
    raise ValueError(
ValueError: Expected 2 of '*_optim_states.pt' under 'logs/last.ckpt/global_step1' but found 1 files. Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes.
[2022-02-16T19:29:46.052376] Command finished with exit code 256 interpreted as return code 1

Screen Shot 2022-02-16 at 2 37 05 PM

looks like zero_pp_rank_1_mp_rank_00_optim_states.pt is stored on node 1 but trainer.is_global_zero doesn't have access to.

Here are the full logs
node_0_log.txt
node_1_log.txt

Please feel free to contact me if you need temp access to Azure computes and I will work something out.

Thanks!

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict


def set_environment_variables_for_nccl_backend(single_node=False, master_port=6105):
    if not single_node:
        master_node_params = os.environ["AZ_BATCH_MASTER_NODE"].split(":")
        os.environ["MASTER_ADDR"] = master_node_params[0]

        # Do not overwrite master port with that defined in AZ_BATCH_MASTER_NODE
        if "MASTER_PORT" not in os.environ:
            os.environ["MASTER_PORT"] = str(master_port)
    else:
        os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
        os.environ["MASTER_PORT"] = "54965"

    os.environ["NCCL_SOCKET_IFNAME"] = "^docker0,lo"
    os.environ["NODE_RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
    # additional variables
    os.environ["MASTER_ADDRESS"] = os.environ["MASTER_ADDR"]
    os.environ["LOCAL_RANK"] = os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]
    os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]

    print("NODE_RANK = {}".format(os.environ["NODE_RANK"]))
    print("WORLD_SIZE = {}".format(os.environ["WORLD_SIZE"]))
    print("MASTER_ADDR = {}".format(os.environ["MASTER_ADDR"]))
    print("MASTER_PORT = {}".format(os.environ["MASTER_PORT"]))
    print("NCCL_SOCKET_IFNAME new value = {}".format(os.environ["NCCL_SOCKET_IFNAME"]))


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


if __name__ == "__main__":
    num_nodes = 2
    
    set_environment_variables_for_nccl_backend(single_node=False if num_nodes>1 else True)

    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        strategy=DeepSpeedPlugin(stage=2),
        precision=16,
        gpus=1,
        num_nodes=num_nodes,
        callbacks=ModelCheckpoint(dirpath='logs', save_last=True)
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

    # once saved via the model checkpoint callback,
    # it saves a folder containing the deepspeed checkpoint rather than a single file
    checkpoint_path = "logs/last.ckpt/"

    if trainer.is_global_zero:
        single_ckpt_path = "single_model.pt"

        # magically converts the folder into a single lightning loadable pytorch file (for ZeRO 1,2 and 3)
        convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, single_ckpt_path)
        loaded_parameters = BoringModel.load_from_checkpoint(single_ckpt_path).parameters()

        model = model.cpu()
        # Assert model parameters are identical after loading
        for orig_param, saved_model_param in zip(model.parameters(), loaded_parameters):
            if model.dtype == torch.half:
                # moved model to float32 for comparison with single fp32 saved weights
                saved_model_param = saved_model_param.half()
            assert torch.equal(orig_param, saved_model_param)

Expected behavior

load_from_checkpoint() should be able to load deepspeed ckpt.

Environment

  • PyTorch Lightning Version: 1.5.4
  • PyTorch Version: 1.8.1
  • Python version : 3.8
  • OS: ubuntu 18.04
  • CUDA/cuDNN version: cuda11.1 cudnn8
  • GPU models and configuration: 1 * v100 gpu/node
  • How you installed PyTorch: pip

Additional context

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7 @SeanNaren @akihironitta

@thomas-happify thomas-happify added the bug Something isn't working label Feb 16, 2022
@thomas-happify thomas-happify changed the title Multi node deepspeed can't load_from_checkpoint Multi node deepspeed can't load_from_checkpoint Feb 16, 2022
@thomas-happify
Copy link
Author

thomas-happify commented Feb 16, 2022

However, DDP works perfectly.
DeepSpeed worked perfectly back in pytorch-lightningl==1.3.5, where save_full_weights argument in DeepSpeedPluginis still available.
Is there a reason why pl removed it?

@carmocca carmocca added strategy: deepspeed checkpointing Related to checkpointing labels Feb 16, 2022
@SeanNaren
Copy link
Contributor

hey @thomas-happify we removed save_full_weights in this PR: #8397 to try simplify the codebase and rely on DeepSpeed fully. The user then would need to consolidate the files into a single checkpoint offline if they so choose (as you have done in your example code).

To ensure I understand what the issue is, it seems that DeepSpeed is saving the corresponding optim states on the local drives that do not know of each other? or is it a shared drive?

@thomas-happify
Copy link
Author

thomas-happify commented Feb 17, 2022

@SeanNaren
Yes, it's saving to the local drives, so when load_from_checkpoint or convert_zero_checkpoint_to_fp32_state_dict on global zero, it can't access other nodes's optim states

@thomas-happify
Copy link
Author

@SeanNaren do u face this issue in multi-node setting? or is this Azure specific?

@SeanNaren
Copy link
Contributor

hey @thomas-happify I'll try to find time to think of what the right solution is. I think it makes sense that all processes save their own individual shard onto the disk, however in a perfect world it would be possible for all processes to save to a shared disk (thus making the checkpoint available on all machines). I think this is doable via an NFS drive but unsure how viable it is in your setup.

@thomas-happify
Copy link
Author

Thx @SeanNaren !
Let me know if you need help and I'd love to help

@stale
Copy link

stale bot commented Apr 17, 2022

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 17, 2022
@stale stale bot closed this as completed Apr 28, 2022
@ZeyiLiao
Copy link

Hi any update here? I also meet the similar question and only save rank0 for both optim states and model states.

@guozhiyao
Copy link

hi, have you solved the problem? I meet the same problem.

@yinweisu
Copy link

I encountered this issue too. Any updates?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working checkpointing Related to checkpointing strategy: deepspeed won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

6 participants