From b41340fb8ac6f1425e443127ef9dc37c92d2b1f4 Mon Sep 17 00:00:00 2001 From: Carl Case Date: Wed, 3 May 2023 21:45:25 +0000 Subject: [PATCH] Fix for pytorch 2.0 compatibility --- deepspeed/runtime/utils.py | 6 +++++- deepspeed/runtime/zero/stage2.py | 3 +-- deepspeed/runtime/zero/stage3.py | 3 +-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 35df3422304e..e414484f9010 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -15,9 +15,13 @@ import torch import torch.distributed as dist -from torch._six import inf import torch.distributed as dist +try: + from torch._six import inf as inf +except ModuleNotFoundError: + from torch import inf as inf + from deepspeed.utils import logger from numpy import prod diff --git a/deepspeed/runtime/zero/stage2.py b/deepspeed/runtime/zero/stage2.py index e98a374902ce..2fc5aafeb18c 100755 --- a/deepspeed/runtime/zero/stage2.py +++ b/deepspeed/runtime/zero/stage2.py @@ -6,13 +6,12 @@ from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math -from torch._six import inf from torch.autograd import Variable import collections from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import inf, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.op_builder import UtilsBuilder diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2b6e12abd84b..6ee4d20a3f52 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -11,12 +11,11 @@ from torch.distributed.distributed_c10d import _get_global_rank import torch.distributed as dist import math -from torch._six import inf from torch.autograd import Variable from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import inf, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS