From 492cbf15af0b1fb26b9f1e1ce16b0de7e3bd31e8 Mon Sep 17 00:00:00 2001 From: Francois Chollet <francois.chollet@gmail.com> Date: Mon, 4 Dec 2023 13:44:13 -0800 Subject: [PATCH] Reduce memory consumption for load_model (WiP) --- keras/backend/common/stateless_scope.py | 11 +++- keras/backend/common/variables.py | 56 +++++++++++-------- keras/backend/common/variables_test.py | 8 --- keras/backend/tensorflow/core.py | 7 ++- keras/export/export_lib.py | 1 + keras/initializers/__init__.py | 5 +- .../initializers/random_initializers_test.py | 2 +- keras/layers/layer.py | 2 +- keras/optimizers/base_optimizer.py | 2 +- keras/saving/saving_lib.py | 31 ++++++---- 10 files changed, 73 insertions(+), 52 deletions(-) diff --git a/keras/backend/common/stateless_scope.py b/keras/backend/common/stateless_scope.py index b5b9d58573fc..dafd867e9978 100644 --- a/keras/backend/common/stateless_scope.py +++ b/keras/backend/common/stateless_scope.py @@ -37,6 +37,7 @@ def __init__( state_mapping=None, collect_losses=False, initialize_variables=True, + allow_variable_creation=False, ): from keras import backend from keras.backend.common.variables import KerasVariable @@ -44,6 +45,7 @@ def __init__( self.collect_losses = collect_losses self.initialize_variables = initialize_variables self.losses = [] + self.allow_variable_creation = allow_variable_creation self.state_mapping = {} state_mapping = state_mapping or {} for k, v in state_mapping: @@ -77,18 +79,21 @@ def add_update(self, update): self.state_mapping[id(variable)] = value def get_current_value(self, variable): - return self.state_mapping.get(id(variable), None) + value = self.state_mapping.get(id(variable), None) + if value is None: + value = variable._value + return value def __exit__(self, *args, **kwargs): global_state.set_global_attribute( "stateless_scope", self.original_scope ) - if self.original_scope is None and self.initialize_variables: + is_eager_friendly = self.original_scope is None or self.original_scope.allow_variable_creation + if self.initialize_variables and is_eager_friendly: # We're back in eager scope; # if any variables were created within the stateless # scope, we initialize them here. from keras.backend.common.variables import initialize_all_variables - initialize_all_variables() diff --git a/keras/backend/common/variables.py b/keras/backend/common/variables.py index 21eac4c4f293..5e9ea0a4ab3e 100644 --- a/keras/backend/common/variables.py +++ b/keras/backend/common/variables.py @@ -45,49 +45,61 @@ def __init__( ) if in_stateless_scope(): + scope = get_stateless_scope() if callable(initializer): self._value = None self._initializer = initializer self._shape = self._validate_shape(shape) register_uninitialized_variable(self) else: - raise ValueError( - "You are attempting to create a variable " - "while in a stateless scope. This is disallowed. " - "Make sure that all variables are created " - "before you start using your layer/model objects.\n\n" - "In some cases, you might be seeing this error " - "because you need to " - "implement a `def build(self, input_shape)` method " - "on your layer/model, which will " - "create its variables.\n\n" - "In some other cases, you might be seeing this error " - "because you are instantiating a `Variable` and " - "assigning it to a layer without going through " - "self.add_variable()/self.add_weight(). Always prefer " - "using these methods " - "(with a `shape` and `initializer` argument)." - ) + if not scope.allow_variable_creation: + raise ValueError( + "You are attempting to create a variable " + "with an eager value " + "while in a stateless scope. This is disallowed.\n\n" + f"Variable: {self} with init value {initializer}\n\n" + "Make sure that all variables are created " + "before you start using your layer/model objects.\n\n" + "In some cases, you might be seeing this error " + "because you need to " + "implement a `def build(self, input_shape)` method " + "on your layer/model, which will " + "create its variables.\n\n" + "In some other cases, you might be seeing this error " + "because you are instantiating a `Variable` and " + "assigning it to a layer without going through " + "self.add_variable()/self.add_weight(). Always prefer " + "using these methods " + "(with a `shape` and `initializer` argument)." + ) + # Special case where the stateless scope allows + # eager variable creation. Used for model loading. + self._initialize(initializer) + self._shape = tuple(self._value.shape) else: if callable(initializer): shape = self._validate_shape(shape) value = initializer(shape, dtype=dtype) else: value = initializer + self._initialize(value) self._shape = tuple(self._value.shape) self._ndim = len(self._shape) def _deferred_initialize(self): if self._value is not None: - raise ValueError(f"Variable {self.path} is already initialized.") + # Variables are allowed to already be force-initialized + # (e.g. this is what model loading does). + return - if in_stateless_scope(): + if in_stateless_scope() and not get_stateless_scope().allow_variable_creation: raise ValueError( "You are attempting to initialize a variable " "while in a stateless scope. This is disallowed. " "Make sure that all variables are initialized " - "before you start using your layer/model objects." + "before you start using your layer/model objects. " + f"Variable: {self}" ) value = self._initializer(self._shape, dtype=self._dtype) self._initialize(value) @@ -128,7 +140,7 @@ def value(self): ) return self._maybe_autocast(self._value) - def assign(self, value): + def assign(self, value, force=False): value = self._convert_to_tensor(value, dtype=self.dtype) if not shape_equal(value.shape, self.shape): raise ValueError( @@ -139,7 +151,7 @@ def assign(self, value): f"Received: value.shape={value.shape}. " f"Target variable: {self}" ) - if in_stateless_scope(): + if in_stateless_scope() and not force: scope = get_stateless_scope() scope.add_update((self, value)) else: diff --git a/keras/backend/common/variables_test.py b/keras/backend/common/variables_test.py index 42bd4d45524d..a6e7392e3e3c 100644 --- a/keras/backend/common/variables_test.py +++ b/keras/backend/common/variables_test.py @@ -56,14 +56,6 @@ def test_variable_initialization_without_shape(self): ): backend.Variable(initializer=initializers.RandomNormal()) - def test_deferred_initialize_already_initialized(self): - """Test deferred init on an already initialized variable.""" - v = backend.Variable(initializer=np.ones((2, 2))) - with self.assertRaisesRegex( - ValueError, f"Variable {v.path} is already initialized." - ): - v._deferred_initialize() - def test_variable_initialize(self): """Test initializing a variable.""" v = backend.Variable(initializer=np.array([1, 2, 3])) diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 9385ceac28d2..9c49271063f4 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -11,6 +11,7 @@ from keras.backend.common.name_scope import name_scope as base_name_scope from keras.backend.common.stateless_scope import StatelessScope from keras.backend.common.stateless_scope import in_stateless_scope +from keras.backend.common.stateless_scope import get_stateless_scope from keras.utils.naming import auto_name SUPPORTS_SPARSE_TENSORS = True @@ -34,9 +35,11 @@ def _initialize(self, value): def _deferred_initialize(self): if self._value is not None: - raise ValueError(f"Variable {self.path} is already initialized.") + # Variables are allowed to already be force-initialized + # (e.g. this is what model loading does). + return - if in_stateless_scope(): + if in_stateless_scope() and not get_stateless_scope().allow_variable_creation: raise ValueError( "You are attempting to initialize a variable " "while in a stateless scope. This is disallowed. " diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index 8796d8e3c5e4..3387a58f1cf4 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -610,6 +610,7 @@ def __init__( def _add_existing_weight(self, weight): """Tracks an existing weight.""" + # TODO: wrap in KerasVariable self._track_variable(weight) def call(self, inputs, training=False, **kwargs): diff --git a/keras/initializers/__init__.py b/keras/initializers/__init__.py index e9019a2d40f8..132f5f13f341 100644 --- a/keras/initializers/__init__.py +++ b/keras/initializers/__init__.py @@ -109,9 +109,10 @@ def get(identifier): else: obj = identifier + if inspect.isclass(obj): + obj = obj() + if callable(obj): - if inspect.isclass(obj): - obj = obj() return obj else: raise ValueError( diff --git a/keras/initializers/random_initializers_test.py b/keras/initializers/random_initializers_test.py index d47d412a7515..cfacfaec631e 100644 --- a/keras/initializers/random_initializers_test.py +++ b/keras/initializers/random_initializers_test.py @@ -149,7 +149,7 @@ def test_orthogonal_initializer(self): def test_get_method(self): obj = initializers.get("glorot_normal") - self.assertTrue(obj, initializers.GlorotNormal) + self.assertTrue(isinstance(obj, initializers.GlorotNormal)) obj = initializers.get(None) self.assertEqual(obj, None) diff --git a/keras/layers/layer.py b/keras/layers/layer.py index 39fd87e59714..6af4f8c724d3 100644 --- a/keras/layers/layer.py +++ b/keras/layers/layer.py @@ -1154,7 +1154,7 @@ def load_own_variables(self, store): f"Expected: {[v.name for v in all_vars]}" ) for i, v in enumerate(all_vars): - v.assign(store[f"{i}"]) + v.assign(store[f"{i}"], force=True) def _track_variable(self, variable): if variable.trainable: diff --git a/keras/optimizers/base_optimizer.py b/keras/optimizers/base_optimizer.py index 53eb8240e906..1d62b9937caf 100644 --- a/keras/optimizers/base_optimizer.py +++ b/keras/optimizers/base_optimizer.py @@ -445,7 +445,7 @@ def load_own_variables(self, store): warnings.warn(msg, stacklevel=2) return for i, variable in enumerate(self.variables): - variable.assign(store[str(i)]) + variable.assign(store[str(i)], force=True) def _get_current_learning_rate(self): if isinstance( diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 7342bcc8f22b..a54b026dd585 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -10,6 +10,7 @@ import numpy as np from keras.backend.common import global_state +from keras.backend.common.stateless_scope import StatelessScope from keras.layers.layer import Layer from keras.losses.loss import Loss from keras.metrics.metric import Metric @@ -148,11 +149,6 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): if not compile: # Disable compilation config_dict["compile_config"] = None - # Construct the model from the configuration file in the archive. - with ObjectSharingScope(): - model = deserialize_keras_object( - config_dict, custom_objects, safe_mode=safe_mode - ) all_filenames = zf.namelist() if _VARS_FNAME + ".h5" in all_filenames: @@ -171,13 +167,24 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): else: asset_store = None - _load_state( - model, - weights_store=weights_store, - assets_store=asset_store, - inner_path="", - visited_trackables=set(), - ) + # We use a stateless scope to prevent variable initialization + # (since the values would be discarded at loading time). + with StatelessScope( + allow_variable_creation=True + ): + with ObjectSharingScope(): + # Construct the model from the configuration file. + model = deserialize_keras_object( + config_dict, custom_objects, safe_mode=safe_mode + ) + # Populate variable values. + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_trackables=set(), + ) weights_store.close() if asset_store: asset_store.close()