From 9bd8178da1ff5955c47a6567d44ccf3410fb5f9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=8A=E9=9C=86?= Date: Wed, 10 Jan 2024 15:36:46 +0800 Subject: [PATCH] [Embedding] Backward compatibility with 2306. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 --- tensorflow/python/ops/kv_variable_ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index 1ef9550ef6d..840aadf2541 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -530,11 +530,16 @@ def _init_from_proto(self, variable_def, import_scope=None): cache_op = op elif self._initializer_op.type == "InitializeKvVariableOp": init_op = self._initializer_op - - self._init_op_for_restore = g.as_graph_element( + if variable_def.initialize_op_for_restore: + self._init_op_for_restore = g.as_graph_element( ops.prepend_name_scope( variable_def.initialize_op_for_restore, import_scope=import_scope)) + else: #Backward compatibility with 2306 + self._init_op_for_restore = g.as_graph_element( + ops.prepend_name_scope( + variable_def.initializer_name, + import_scope=import_scope)) self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: self._cached_value = g.as_graph_element(