Skip to content

Commit

Permalink
TF mostly fixed an issue so remove workaround for ReconstructingRegre…
Browse files Browse the repository at this point in the history
…ssor
  • Loading branch information
ageron committed Oct 7, 2021
1 parent 7d5649f commit ff86dfd
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions 12_custom_models_and_training_with_tensorflow.ipynb
Expand Up @@ -3467,7 +3467,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: due to an issue introduced in TF 2.2 ([#46858](https://github.com/tensorflow/tensorflow/issues/46858)), it is currently not possible to use `add_loss()` along with the `build()` method. So the following code differs from the book: I create the `reconstruct` layer in the constructor instead of the `build()` method. Unfortunately, this means that the number of units in this layer must be hard-coded (alternatively, it could be passed as an argument to the constructor)."
"**Note**: the following code has two differences with the code in the book:\n",
"1. It creates a `keras.metrics.Mean()` metric in the constructor and uses it in the `call()` method to track the mean reconstruction loss. Since we only want to do this during training, we add a `training` argument to the `call()` method, and if `training` is `True`, then we update `reconstruction_mean` and we call `self.add_metric()` to ensure it's displayed properly.\n",
"2. Due to an issue introduced in TF 2.2 ([#46858](https://github.com/tensorflow/tensorflow/issues/46858)), we must not call `super().build()` inside the `build()` method."
]
},
{
Expand All @@ -3476,21 +3478,19 @@
"metadata": {},
"outputs": [],
"source": [
"class ReconstructingRegressor(keras.models.Model):\n",
"class ReconstructingRegressor(keras.Model):\n",
" def __init__(self, output_dim, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.hidden = [keras.layers.Dense(30, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\")\n",
" for _ in range(5)]\n",
" self.out = keras.layers.Dense(output_dim)\n",
" self.reconstruct = keras.layers.Dense(8) # workaround for TF issue #46858\n",
" self.reconstruction_mean = keras.metrics.Mean(name=\"reconstruction_error\")\n",
"\n",
" #Commented out due to TF issue #46858, see the note above\n",
" #def build(self, batch_input_shape):\n",
" # n_inputs = batch_input_shape[-1]\n",
" # self.reconstruct = keras.layers.Dense(n_inputs)\n",
" # super().build(batch_input_shape)\n",
" def build(self, batch_input_shape):\n",
" n_inputs = batch_input_shape[-1]\n",
" self.reconstruct = keras.layers.Dense(n_inputs)\n",
" #super().build(batch_input_shape)\n",
"\n",
" def call(self, inputs, training=None):\n",
" Z = inputs\n",
Expand Down Expand Up @@ -3526,9 +3526,9 @@
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"363/363 [==============================] - 1s 761us/step - loss: 1.6313 - reconstruction_error: 1.0474\n",
"363/363 [==============================] - 1s 810us/step - loss: 1.6313 - reconstruction_error: 1.0474\n",
"Epoch 2/2\n",
"363/363 [==============================] - 0s 667us/step - loss: 0.4536 - reconstruction_error: 0.4022\n"
"363/363 [==============================] - 0s 683us/step - loss: 0.4536 - reconstruction_error: 0.4022\n"
]
}
],
Expand Down

0 comments on commit ff86dfd

Please sign in to comment.