From c2b350017b48435f827b6bc2097cdfe02a160142 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=8A=E9=9C=86?= Date: Wed, 8 Nov 2023 11:23:20 +0800 Subject: [PATCH] [Op] Prevent inconsistent number of Ops and devices during distributed training. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 --- tensorflow/python/training/saver.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 981d01dd7be..acc9723c183 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -550,8 +550,12 @@ def _GroupByDevices(self, saveables): """ per_device = collections.defaultdict(lambda: []) for saveable in saveables: - canonical_device = set( - pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) + canonical_device = set() + for spec in saveable.specs: + device_name = pydev.canonical_name(spec.tensor.device) + device_spec = pydev.DeviceSpec.from_string(device_name) + device_spec.device_type = "CPU" + canonical_device.add(device_spec.to_string()) if len(canonical_device) != 1: raise ValueError("All tensors of a saveable object must be " "on the same device: %s" % saveable.name)