From 287b934ba7170cdbbaf0550c354e2530f7c9fb46 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:14:02 +0100 Subject: [PATCH 1/8] Add functions to collate deepspeed zero 3 checkpoints --- .../utilities/collate_deepspeed_checkpoint.py | 104 ++++++++++++++++++ .../test_deepspeed_collate_checkpoint.py | 62 +++++++++++ 2 files changed, 166 insertions(+) create mode 100644 pytorch_lightning/utilities/collate_deepspeed_checkpoint.py create mode 100644 tests/utilities/test_deepspeed_collate_checkpoint.py diff --git a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py new file mode 100644 index 0000000000000..6d1cd15b3854f --- /dev/null +++ b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict +# for being able to run Model.load_from_checkpoint('...'). +# +# example usage within the lightning checkpoint directory where 'latest' is found: +# python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint . pytorch_model.ckpt + +import argparse +import os + +import torch + +from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE + +if _DEEPSPEED_AVAILABLE: + from deepspeed.utils.zero_to_fp32 import ( + get_fp32_state_dict_from_zero_checkpoint, + get_model_state_file, + get_optim_files, + ) + +device = torch.device("cpu") + + +def ds_checkpoint_dir(checkpoint_dir, tag=None): + if tag is None: + latest_path = os.path.join(checkpoint_dir, "latest") + if os.path.isfile(latest_path): + with open(latest_path) as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + directory = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(directory): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + return directory + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. + (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. + If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, + e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + # additional logic to ensure we keep the lightning state dict as well from rank 0. + deepspeed_states = [ + "module", + "optimizer", + "lr_scheduler", + "csr_tensor_module_names", + "skipped_steps", + "global_steps", + "dp_world_size", + "mp_world_size", + ] + checkpoint_dir = ds_checkpoint_dir(checkpoint_dir) + optim_files = get_optim_files(checkpoint_dir) + optim_state = torch.load(optim_files[0], map_location=device) + zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] + model_file = get_model_state_file(checkpoint_dir, zero_stage) + client_state = torch.load(model_file, map_location=device) + client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states} + # State dict keys will include reference to wrapper LightningDeepSpeedModule + # Delete `module` prefix before saving. + state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()} + client_state["state_dict"] = state_dict + + print(f"Saving fp32 state dict to {output_file}") + torch.save(client_state, output_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12" + ) + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)", + ) + parser.add_argument("-d", "--debug", action="store_true", help="enable debug") + args = parser.parse_args() + + # variable is used within DeepSpeed utilities + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py new file mode 100644 index 0000000000000..10d8191400190 --- /dev/null +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -0,0 +1,62 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess + +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins import DeepSpeedPlugin +from pytorch_lightning.utilities.collate_deepspeed_checkpoint import convert_zero_checkpoint_to_fp32_state_dict +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +@RunIf(min_gpus=2, deepspeed=True, special=False) +def test_deepspeed_collate_checkpoint(tmpdir): + """ + Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file. + """ + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, plugins=[DeepSpeedPlugin(stage=3)], gpus=2, fast_dev_run=True, precision=16 + ) + trainer.fit(model) + checkpoint_path = os.path.join(tmpdir, "model.pt") + checkpoint_path = trainer.accelerator.broadcast(checkpoint_path) + trainer.save_checkpoint(checkpoint_path) + trainer.accelerator.barrier() + if trainer.is_global_zero: + # ensure function call works + output_path = os.path.join(tmpdir, "single_model.pt") + convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, output_path) + _assert_checkpoint_equal(model, output_path) + + # ensure utility script work + output_path = os.path.join(tmpdir, "single_model_script.pt") + cmd = f"python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint {checkpoint_path} {output_path}" + exit_code = subprocess.call(cmd, shell=True) + assert exit_code == 0 + _assert_checkpoint_equal(model, output_path) + + +def _assert_checkpoint_equal(model, output_path): + assert os.path.exists(output_path) + single_output = torch.load(output_path) + state_dict = model.state_dict() + for orig_param, saved_model_param in zip(state_dict.values(), single_output["state_dict"].values()): + if model.dtype == torch.half: + # moved model to float32 for comparison with single fp32 saved weights + saved_model_param = saved_model_param.half() + assert torch.equal(orig_param.cpu(), saved_model_param) From 52195eeb540fefaf645e25297d241574acfd6c2a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:17:36 +0100 Subject: [PATCH 2/8] Add CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff4a53ed8ef62..d12bc82b8c804 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641)) -- +- Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701)) - From 97fcc56f4f291212087ccb3bf0be158b257dcbfb Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:23:50 +0100 Subject: [PATCH 3/8] Add licence --- pytorch_lightning/utilities/collate_deepspeed_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py index 6d1cd15b3854f..c3f5b64ee01b1 100644 --- a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py +++ b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py @@ -1,6 +1,8 @@ #!/usr/bin/env python +# Copyright (c) Microsoft Corporation +# Licensed under the MIT license. # Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py - +# # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in # the future. Once extracted, the weights don't require DeepSpeed and can be used in any From bd990367a12ba3d20b574990bafaf4946bafe049 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:30:40 +0100 Subject: [PATCH 4/8] Remove script --- .../utilities/collate_deepspeed_checkpoint.py | 20 ------------------- .../test_deepspeed_collate_checkpoint.py | 8 -------- 2 files changed, 28 deletions(-) diff --git a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py index c3f5b64ee01b1..c875cfc403c27 100644 --- a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py +++ b/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py @@ -12,7 +12,6 @@ # example usage within the lightning checkpoint directory where 'latest' is found: # python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint . pytorch_model.ckpt -import argparse import os import torch @@ -85,22 +84,3 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag= print(f"Saving fp32 state dict to {output_file}") torch.save(client_state, output_file) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12" - ) - parser.add_argument( - "output_file", - type=str, - help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)", - ) - parser.add_argument("-d", "--debug", action="store_true", help="enable debug") - args = parser.parse_args() - - # variable is used within DeepSpeed utilities - debug = args.debug - - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index 10d8191400190..aa0cd190f3872 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import subprocess import torch @@ -43,13 +42,6 @@ def test_deepspeed_collate_checkpoint(tmpdir): convert_zero_checkpoint_to_fp32_state_dict(checkpoint_path, output_path) _assert_checkpoint_equal(model, output_path) - # ensure utility script work - output_path = os.path.join(tmpdir, "single_model_script.pt") - cmd = f"python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint {checkpoint_path} {output_path}" - exit_code = subprocess.call(cmd, shell=True) - assert exit_code == 0 - _assert_checkpoint_equal(model, output_path) - def _assert_checkpoint_equal(model, output_path): assert os.path.exists(output_path) From 8c36e82eba66cdbaa7dc96a195b0ad91ae96304b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:31:53 +0100 Subject: [PATCH 5/8] Better name --- .../utilities/{collate_deepspeed_checkpoint.py => deepspeed.py} | 0 tests/utilities/test_deepspeed_collate_checkpoint.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename pytorch_lightning/utilities/{collate_deepspeed_checkpoint.py => deepspeed.py} (100%) diff --git a/pytorch_lightning/utilities/collate_deepspeed_checkpoint.py b/pytorch_lightning/utilities/deepspeed.py similarity index 100% rename from pytorch_lightning/utilities/collate_deepspeed_checkpoint.py rename to pytorch_lightning/utilities/deepspeed.py diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index aa0cd190f3872..9405d0bf2c827 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -17,7 +17,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DeepSpeedPlugin -from pytorch_lightning.utilities.collate_deepspeed_checkpoint import convert_zero_checkpoint_to_fp32_state_dict +from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf From 14054e9d0306d77d9c771238ca23ea25a008aa6c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 16:41:58 +0100 Subject: [PATCH 6/8] Add example usage --- pytorch_lightning/utilities/deepspeed.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/deepspeed.py b/pytorch_lightning/utilities/deepspeed.py index c875cfc403c27..7b9eda1fad47b 100644 --- a/pytorch_lightning/utilities/deepspeed.py +++ b/pytorch_lightning/utilities/deepspeed.py @@ -10,7 +10,13 @@ # for being able to run Model.load_from_checkpoint('...'). # # example usage within the lightning checkpoint directory where 'latest' is found: -# python -m pytorch_lightning.utilities.collate_deepspeed_checkpoint . pytorch_model.ckpt +# +# from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict +# +# lightning deepspeed has saved a directory instead of a file +# save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" +# output_path = "lightning_model.pt" +# convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) import os @@ -28,7 +34,7 @@ device = torch.device("cpu") -def ds_checkpoint_dir(checkpoint_dir, tag=None): +def ds_checkpoint_dir(checkpoint_dir: str, tag: str = None): if tag is None: latest_path = os.path.join(checkpoint_dir, "latest") if os.path.isfile(latest_path): @@ -44,7 +50,7 @@ def ds_checkpoint_dir(checkpoint_dir, tag=None): return directory -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir: str, output_file: str, tag: str = None): """ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. From 99122279ad430ef38790047699b4889be5f614a7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 3 Aug 2021 20:20:25 +0100 Subject: [PATCH 7/8] Update licence, address feedback --- pytorch_lightning/utilities/deepspeed.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/deepspeed.py b/pytorch_lightning/utilities/deepspeed.py index 7b9eda1fad47b..7688154c28257 100644 --- a/pytorch_lightning/utilities/deepspeed.py +++ b/pytorch_lightning/utilities/deepspeed.py @@ -1,6 +1,18 @@ #!/usr/bin/env python -# Copyright (c) Microsoft Corporation -# Licensed under the MIT license. +# Copyright 2020 The PyTorch Lightning team and Microsoft Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py # # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets @@ -31,7 +43,7 @@ get_optim_files, ) -device = torch.device("cpu") +CPU_DEVICE = torch.device("cpu") def ds_checkpoint_dir(checkpoint_dir: str, tag: str = None): @@ -78,10 +90,10 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir: str, output_file: ] checkpoint_dir = ds_checkpoint_dir(checkpoint_dir) optim_files = get_optim_files(checkpoint_dir) - optim_state = torch.load(optim_files[0], map_location=device) + optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE) zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] model_file = get_model_state_file(checkpoint_dir, zero_stage) - client_state = torch.load(model_file, map_location=device) + client_state = torch.load(model_file, map_location=CPU_DEVICE) client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states} # State dict keys will include reference to wrapper LightningDeepSpeedModule # Delete `module` prefix before saving. From 69072ecfb795604d6666ee15644526247bc39d3b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 4 Aug 2021 09:59:12 +0100 Subject: [PATCH 8/8] Set special to True --- tests/utilities/test_deepspeed_collate_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_deepspeed_collate_checkpoint.py b/tests/utilities/test_deepspeed_collate_checkpoint.py index 9405d0bf2c827..45c8f1a9a1d4f 100644 --- a/tests/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/utilities/test_deepspeed_collate_checkpoint.py @@ -22,7 +22,7 @@ from tests.helpers.runif import RunIf -@RunIf(min_gpus=2, deepspeed=True, special=False) +@RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_collate_checkpoint(tmpdir): """ Test to ensure that with DeepSpeed Stage 3 we can collate the sharded checkpoints into a single file.