Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to collate deepspeed zero 3 checkpoints #8701

Merged
merged 8 commits into from Aug 4, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641))


-
- Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701))


-
Expand Down
104 changes: 104 additions & 0 deletions pytorch_lightning/utilities/collate_deepspeed_checkpoint.py
@@ -0,0 +1,104 @@
#!/usr/bin/env python
# Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
# application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict
# for being able to run Model.load_from_checkpoint('...').
#
# example usage within the lightning checkpoint directory where 'latest' is found:
# python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint . pytorch_model.ckpt

import argparse
import os

import torch

from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE

if _DEEPSPEED_AVAILABLE:
from deepspeed.utils.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
get_model_state_file,
get_optim_files,
)

device = torch.device("cpu")


def ds_checkpoint_dir(checkpoint_dir, tag=None):
if tag is None:
latest_path = os.path.join(checkpoint_dir, "latest")
if os.path.isfile(latest_path):
with open(latest_path) as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")

directory = os.path.join(checkpoint_dir, tag)

if not os.path.isdir(directory):
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
return directory


def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
"""
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
Args:
- ``checkpoint_dir``: path to the desired checkpoint folder.
(one that contains the tag-folder, like ``global_step14``)
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
- ``tag``: checkpoint tag used as a unique identifier for checkpoint.
If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder,
e.g., ``global_step14``
"""

state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)

# additional logic to ensure we keep the lightning state dict as well from rank 0.
deepspeed_states = [
"module",
"optimizer",
"lr_scheduler",
"csr_tensor_module_names",
"skipped_steps",
"global_steps",
"dp_world_size",
"mp_world_size",
]
checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
optim_files = get_optim_files(checkpoint_dir)
optim_state = torch.load(optim_files[0], map_location=device)
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
model_file = get_model_state_file(checkpoint_dir, zero_stage)
client_state = torch.load(model_file, map_location=device)
client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
# State dict keys will include reference to wrapper LightningDeepSpeedModule
# Delete `module` prefix before saving.
state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
client_state["state_dict"] = state_dict

print(f"Saving fp32 state dict to {output_file}")
torch.save(client_state, output_file)


if __name__ == "__main__":
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
parser = argparse.ArgumentParser()
parser.add_argument(
"checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12"
)
parser.add_argument(
"output_file",
type=str,
help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)",
)
parser.add_argument("-d", "--debug", action="store_true", help="enable debug")
args = parser.parse_args()

# variable is used within DeepSpeed utilities
debug = args.debug

convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
62 changes: 62 additions & 0 deletions tests/utilities/test_deepspeed_collate_checkpoint.py
@@ -0,0 +1,62 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess

import torch

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin
from pytorch_lightning.utilities.collate_deepspeed_checkpoint import convert_zero_checkpoint_to_fp32_state_dict
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf


@RunIf(min_gpus=2, deepspeed=True, special=False)
def test_deepspeed_collate_checkpoint(tmpdir):
"""
Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file.
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
checkpoint_path = trainer.accelerator.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.accelerator.barrier()
if trainer.is_global_zero:
# ensure function call works
output_path = os.path.join(tmpdir, "single_model.pt")
convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, output_path)
_assert_checkpoint_equal(model, output_path)

# ensure utility script work
output_path = os.path.join(tmpdir, "single_model_script.pt")
cmd = f"python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint {checkpoint_path} {output_path}"
exit_code = subprocess.call(cmd, shell=True)
assert exit_code == 0
_assert_checkpoint_equal(model, output_path)


def _assert_checkpoint_equal(model, output_path):
assert os.path.exists(output_path)
single_output = torch.load(output_path)
state_dict = model.state_dict()
for orig_param, saved_model_param in zip(state_dict.values(), single_output["state_dict"].values()):
if model.dtype == torch.half:
# moved model to float32 for comparison with single fp32 saved weights
saved_model_param = saved_model_param.half()
assert torch.equal(orig_param.cpu(), saved_model_param)