## **Making New Layers and Models via Subclassing**

#### **Setup**

In [12]:
import tensorflow as tf
from tensorflow import keras

#### **The `Layer` Class: combination of state(weights) and computation**
###### **A layer encapsulates both a state(layer's weight) and a transformation from inputs to outputs. Let's have a look a densely-connected layer. It has a state: the variables `w` and `b`.**

In [13]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value = w_init(shape=(input_dim, units), dtype="float32"),
            trainable = True
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value = b_init(shape=(units,), dtype="float32"),
            trainable = True
        )
    def call(self, inputs):
        return(tf.matmul(inputs, self.w) + self.b)

###### **We could use the layer by calling it on some tensor input(s):**

In [14]:
x = tf.ones((2,2))
linear_layer = Linear(4, 2)     # The column size(in this case 2), must be equal to the row size of x)
y = linear_layer(x)
print(y)

tf.Tensor(
[[ 0.03454653 -0.0629256   0.04774752 -0.13105682]
 [ 0.03454653 -0.0629256   0.04774752 -0.13105682]], shape=(2, 4), dtype=float32)


###### **[NB] The weights `w` and `b` can be automatically tracked by the layer being set as layer attributes:**

In [4]:
assert linear_layer.weights == [linear_layer.w, linear_layer.b]

###### **We can also add weights to a layer using the `add_weights()` method.**

In [15]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units),
            initializer="random_normal",
            trainable=True
        )
        self.b = self.add_weight(shape=(units,),
            initializer="zeros",
            trainable=True
        )
    def call(self, inputs):
        return(tf.matmul(inputs, self.w) + self.b)

x = tf.ones((2,2))
linear_layer = Linear(4, 2)     # The column size(in this case 2), must be equal to the row size of x)
y = linear_layer(x)
print(y)

tf.Tensor(
[[-0.04173263 -0.12612289 -0.10119284 -0.00206137]
 [-0.04173263 -0.12612289 -0.10119284 -0.00206137]], shape=(2, 4), dtype=float32)


#### **Layers can have Non-Trainable Weights**
###### **Besides trainable weights, we can add non-trainable weights to a layer as well. These weights are meant not to be taken into account during backpropagation, when we are training the layer.<br>Here's how we can add and use a non-trainable weights:**

In [16]:
class ComputeSum(keras.layers.Layer):
    def __init__(self, input_dim):
        super(ComputeSum, self).__init__()
        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)
    
    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return(self.total)

x = tf.ones((2,2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())

[2. 2.]
[4. 4.]


###### **It's part of `layer.weights`, but it gets categorized as a non-trainable weight:**

In [7]:
print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))
print("trainable weights:", len(my_sum.trainable_weights))

weights: 1
non-trainable weights: 1
trainable weights: 0


#### **Unknown Input: Deferring(delay/postpone) weight creation until the shape of the inputs is known**
###### **Our `Linear` layer above took an `input_dim` arguments that was used to compute the shape of the weights `w` and `b` in `__init__()`.<br>In many cases, we may not know in advance the size of the inputs and we would like to lazily create weights when that value becomes known, some time after instantiating the layer.<br>In the Keras API, it recommends creating layer weights in the `build(self, inputs_shape)` method of the layer:**

In [17]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units
    
    def build(self, input_shape):
        self.w = self.add_weight(
            shape = (input_shape[-1], self.units),      # ???
            initializer = "random_normal",
            trainable = True
        )
        self.b = self.add_weight(
            shape = (self.units,),
            initializer = "random_normal",
            trainable = True
        )
    
    def call(self, inputs):
        return(tf.matmul(inputs, self.w) + self.b)

###### **The `__call__()` method of our layer will automatically run `build()` the first time it is called. We now have a lazy layer and thus easier to use.**

In [9]:
x = tf.ones((2,2))
# At instantiation, we don't know on what inputs this is going to be called
linear_layer = Linear(32)
# The Layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)
print(y)

tf.Tensor(
[[-0.00219322  0.00839042 -0.04551635 -0.0311306   0.09423123  0.01757101
  -0.06232182 -0.16253401  0.07216877  0.14174268 -0.13624954 -0.06687343
  -0.15582576 -0.02411817  0.00293126 -0.16645454  0.02450774 -0.08350323
  -0.01096236 -0.06474096  0.16049442  0.0244691  -0.07178485 -0.12176354
   0.01975246  0.13974911 -0.05417458 -0.15414973  0.03189974 -0.02327025
  -0.04507712  0.03664723]
 [-0.00219322  0.00839042 -0.04551635 -0.0311306   0.09423123  0.01757101
  -0.06232182 -0.16253401  0.07216877  0.14174268 -0.13624954 -0.06687343
  -0.15582576 -0.02411817  0.00293126 -0.16645454  0.02450774 -0.08350323
  -0.01096236 -0.06474096  0.16049442  0.0244691  -0.07178485 -0.12176354
   0.01975246  0.13974911 -0.05417458 -0.15414973  0.03189974 -0.02327025
  -0.04507712  0.03664723]], shape=(2, 32), dtype=float32)


###### **Implementing `build()` separately as shown above nicely seperates creating weights only once from using weights in every call. Layer implementers are allowed to defer weight creation to the first `__call__()`, but need to take care that, later calls use the same weights. In addition, since `__call__()` is likely to be executed for the first time inside a `tf.function`, any variable creation that takes place in `__call__()` should be wrapped in a `tf.init_scope`.**

#### **Layers are Recursively Composable(Writeable)**
###### **If we assign a layer instance as an attribute of another layer, the outer layer will start tracking the weights created by the inner layer. Keras recommend creating such sublayers in the `__init__()` method and leave it to the first `__call__()` to trigger building their weights.**

In [18]:
class MLPBlock(keras.layers.Layer):
    def __init__(self):
        super(MLPBlock, self).__init__()
        self.linear1 = Linear(32)
        self.linear2 = Linear(32)
        self.linear3 = Linear(1)
    def call(self, inputs):
        x = self.linear1(inputs)
        x = tf.nn.relu(x)
        x = self.linear2(x)
        x = tf.nn.relu(x)
        return(self.linear3(x))
    
mlp = MLPBlock()
y = mlp(tf.ones(shape=(3,64)))      # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable_weights:", len(mlp.trainable_weights))

weights: 6
trainable_weights: 6


#### **The `add_loss()` Method**
###### **While writing the call method, we can create loss tensors. That loss tensors, we will want to use later while writing our training loop. This is doable by calling `self.add_loss(value)`.**

In [19]:
# A layer that creates an activity regularixation loss
class ActivityRegularizationLayer(keras.layers.Layer):
    def __init__(self, rate=1e-2):
        super(ActivityRegularizationLayer, self).__init__()
        self.rate = rate
    def call(self, inputs):
        self.add_loss(self.rate * tf.reduce_sum(inputs))
        return(inputs)

###### **These losses(including those created by any inner layer) can be retrived via `layer.losses`. This property is reset at the start of every `__call__()` to the top-level layer, so that `layer,losses` always contains the loss values created during the last forward pass.**

In [20]:
class OuterLayer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayer, self).__init__()
        self.activity_reg = ActivityRegularizationLayer(1e-2)
    def call(self, inputs):
        return(self.activity_reg(inputs))

layer = OuterLayer()
assert len(layer.losses) == 0       # No losses yet since the layer has never been called

_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1       # We created one loss value

# `layer.losses` gets reset at the start of each __call__
_ = layer(tf.zeros(1, 1))
assert len(layer.losses) == 1       # This is the loss created during the call above

###### **In addition, the `loss` property also contains regularization losses created for the weights of any inner layer:**

In [21]:
class OuterLayerWithKernelRegularizer(keras.layers.Layer):
    def __init__(self):
        super(OuterLayerWithKernelRegularizer, self).__init__()
        self.dense = keras.layers.Dense(32, kernel_regularizer = tf.keras.regularizers.L2(1e-3))
    def call(self, inputs):
        return(self.dense(inputs))

layer = OuterLayerWithKernelRegularizer()
_ = layer(tf.zeros((1,1)))

# kernel_regularizer uses this formula "1e-3 * sum(layer.dense.kernel ** 2)"
print(layer.losses)

[<tf.Tensor: shape=(), dtype=float32, numpy=0.0019959928>]


###### **These losses are meant to be taken when writing training loops:**

In [22]:
# Instantiate an optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = keras.Model(optimizer, loss_fn)

# Iterate over the batches of a dateset
for x_batch_train, y_batch_train in train_dataset:
    with tf.GradientTape() as tape:
        logits = layer(x_batch_train)       # Logits for this minibatch
        loss_value = loss_fn(y_batch_train, logits)     # Loss value for this minibatch
        # Add the extra losses created during this forward pass:
        loss_value += sum(model.losses)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))

ValueError: Found unexpected instance while processing input tensors for keras functional model. Expecting KerasTensor which is from tf.keras.Input() or output from keras layer call(). Got: <keras.optimizer_v2.gradient_descent.SGD object at 0x000002660934FCD0>

###### **These losses also work seemlessly with `fit()` (they get automatically summed and added to the main loss, if any):**

In [14]:
import numpy as np

inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)

# If there is a loss passed in `compile`, the regularization losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2,3)), np.random.random((2,3)))

# It's also possible not to pass any loss in `compile`, since the model already has a loss to minimize, via the `add_loss` call during the forward pass.
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))



<keras.callbacks.History at 0x268a182aa90>

#### **The `add_metric()` Method**
###### **Like `add_loss()`, there also has the `add_metric()` method- used for tracking the moving average of a quantity during training.<br>Consider a layer: a `logistic endpoint` layer - takes predictions and targets as input, computes the loss tracked via `add_loss()`, and then computes an accuracy scalar, which is tracks via `add_metric()`.**

In [23]:
class LogisticEndpoint(keras.layers.Layer):
    def __init__(self, name=None):
        super(LogisticEndpoint, self).__init__(name=name)
        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy_fn = keras.metrics.BinaryAccuracy()
    
    def call(self, targets, logits, sample_weights=None):
        # Compute the training-time loss value and add it to the layer using `self.add_loss()`
        loss = self.loss_fn(targets, logits, sample_weights)
        self.add_loss(loss)
        # Copute the log accuracy as ametric and add it to the layer using `self.add_metric()`
        acc = self.accuracy_fn(targets, logits, sample_weights)
        self.add_metric(acc, name="accuracy")
    
        # Return the inference-time prediction temsor (for `.prediction()`)
        return(tf.nn.softmax(logits))

###### **Metrics tracked in this way are accessible via `layer.metrics`:**

In [24]:
layer = LogisticEndpoint()

targets = tf.ones((2,2))
logits = tf.ones((2,2))
y = layer(targets, logits)

print("layer_metrics:", layer.metrics)
print("current_accuracy_value", float(layer.metrics[0].result()))

layer_metrics: [<keras.metrics.BinaryAccuracy object at 0x0000026620384370>]
current_accuracy_value 1.0


###### **Just like `add_loss()`, these metrics are tracked by `fit()`:**

In [17]:
from turtle import shape


inputs = keras.Input(shape=(3,), name="inputs")
targets = keras.Input(shape=(10,), name="targets")
logits = keras.layers.Dense(10)(inputs)
predictions = LogisticEndpoint(name="predictions")(logits, targets)

model = keras.Model(inputs=[inputs, targets], outputs=predictions)
model.compile(optimizer="adam")

data = {
    "inputs": np.random.random((3,3)),
    "targets": np.random.random((3,10)),
}
model.fit(data)



<keras.callbacks.History at 0x268a2b05760>

#### **We can Optionally Enable Serialization on our Layers**
###### **If we need our custom layers to be serializable as part of a `Functional Model`, we can optionally implement a `get_onfig()` method:**

In [25]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape = (input_shape[-1], self.units),
            initializer = "random_normal",
            trainable = True
        )
        self.b = self.add_weight(
            shape = (self.units,),
            initializer = "random_normal",
            trainable = True
        )

    def call(self, inputs):
        return(tf.matmul(inputs, self.w) + self.b)

    def get_config(self):
        return({"units": self.units})

# Now we can recreate the layer from its config
layer = Linear()
config = layer.get_config()
print(config)
my_layer = Linear.from_config(config)

{'units': 32}


###### `[NB]` **The `__init__()` method of the base `Layer` class takes some keywords arguments, in particular a `name` and a `dtypes`. It's good practice to pass these arguments to the parent class in `__Init__()` and to include them in the layer config:**

In [26]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units
    
    def build(self, input_shape):
        self.w = self.add_weight(
            shape = (input_shape[-1], self.units),
            initializer = "random_normal",
            trainable = True
        )
        self.b = self.add_weight(
            shape = (self.units,),
            initializer = "random_normal",
            trainable = True
        )
    
    def call(self, inputs):
        return(tf.matmul(inputs, self.w) + self.b)

    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units": self.units})
        return(config)

layer = Linear(64)
config = layer.get_config()
print(config)
my_layer = Linear.from_config(config)

{'name': 'linear_7', 'trainable': True, 'dtype': 'float32', 'units': 64}


###### **If we need more flexibility when deserializing the layer from its config, we can also override the `from_config()` class method. Following is the base implementation of `from_config()`:**

In [27]:
def from_config(cls, config):
    return(cls(**config))

#### **Privileged `training` Argument in the `call()` Method**
###### **Some layers, in particular the `BatchNormalization` layer and the `Dropout` layer, have different behaviors during training and inference. For such layers, it is the best practice to expose `training` (boolean) argument in the `call()` method.<br>By exposing this argument in `call()`, we enable the built-in training and evaluation loops(e.g. `fit()`) to correctly use the layer in training and inference.**

In [28]:
class CustomDropout(keras.layers.Layer):
    def __init__(self, rate, **kwargs):
        super(CustomDropout, self).__init__(**kwargs)
        self.rate = rate
    
    def call(self, inputs, training=None):
        if training: return(tf.nn.dropout(inputs, rate=self.rate))
        return(inputs)

#### **Privileged `mask` Argument in the `call()` Method**
###### **The other privileged argument supported by `call()` is the `mask` method.<br>A mask is a boolean tensor used to skip certain input timesteps when processing timeseries data. We will find it in all Keras RNN layers.<br>Keras will automatically pass the correct `mask` arguments to `__call__()` for layers that support it, when a mask is generated by a prior layer. Mask-generating layers are the `Embedding` layer configured with `mask_zero=True` and the `Masking` layer.**

#### **The `Model` Class**
###### **in general, we will use the `Layer` class to define the inner computation blocks, and will use the `Model` class to define the outer model. For instance, in a ResNet50 model, we would have several ResNet blocks subclassing `Layer`, and a single `Model` enclosing the entire ResNet network.**
###### **The `Model` class has the same API as `Layer`, with the following differences:**
* *It exposes built-in training, evaluation and prediction loops(`model.fit()`, `model.evaluate()`, `model.predict()`).*
* *It exposes the list of its inner layers, via the `model.layers` property.*
* *It exposes saving and serialization APIs(save(), save_weights()...)*
###### **Meanwhile, the `Layer` class corresponds to what we refer to in the literature as a "layer"(as in `convolutional layer`, or `recurrent layer`) or as a "block"(as in `DNN`). And the `Model` class corresponds to what is referred to in the literature as a "model"(as in `deep learning model`) or as a "network"(as in `DNN`)**
###### **For instance, we could take our mini-resnet example above, and use it to build a `Model` that we could train with `fit()`, and that we could save with `save_weights()`:**


In [None]:
class ResNet(tf.keras.Model):
    def __init__(self, num_classes=1000):
        super(ResNet, self).__init__()
        self.block1 = ResNetBlock()
        self.block2 = ResNetBlock()
        self.global_pool = layers.GlobalAveragePooling2D()
        self.classifier = Dense(num_classes)
    def call(self, inputs):
        x = self.block1(inputs)
        x = self.block2(x)
        x = self.global_pool(x)
        return(self.classifier(x))

resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save(filepath)

#### **Putting all Together: an End-to-End Example**
###### **Here's what we've learned so far:**
* *A `Layer` encapsulate a state(created in `__init__()` or `build()`) and some computation(defined in `call()`).*
* *Layers can be recursively nested to create new, bigger computation blocks.*
* *Layers can create and track losses(typically regularization losses) as well as metrics, via `add_loss()` and `add_metric()`.*
* *The outer container, the thing we want to train, is a `Model`. A `Model` is just like a `Layer`, but with added training and serialization utilities.*
###### **Let's put all of these things together into an end-to-end example: we're going to implement a Variational AutoEncoder(VAE) and train it on MNIST utilities.<br>Our VAE will be a subclass of `Model`, built as a nested composition of layers that subclass `Layer`. It will feature a regularization loss(KL divergence).**

In [29]:
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
    
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return(z_mean + tf.exp(0.5 * z_log_var) * epsilon)
    
class Encoder(layers.Layer):
    """Maps MNIST digits to a triplet(z_mean, z_log_var, z)."""
    
    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name="name", **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()
    
    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return(z_mean, z_log_var, z)

class Decoder(layers.Layer):
    """Converts z, the encoded digit vector, back into a readable digit."""

    def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name="name", **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")
    
    def call(self, inputs):
        x = self.dense_proj(inputs)
        return(self.dense_output(x))
    
class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""

    def __init__(self, original_dim, intermediate_dim=64, latent_dim=32, name="autoencoder", **kwargs):
        super(VariationalAutoEncoder, self).__init__(name="name", **kwargs)
        self.original_dim = original_dim
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
    
    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        # Add KL divergence regularization loss
        kl_loss = -0.5 * tf.reduce_mean(z_log_var-tf.square(z_mean)-tf.exp(z_log_var)+1)
        self.add_loss(kl_loss)
        return(reconstructed)

###### **Now, let's write a simple training loop on MNIST:**

In [7]:
original_dim = 784
vae = VariationalAutoEncoder(original_dim, 64, 32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32")/255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            # Compute reconstruction loss
            loss = loss_fn(x_batch_train, reconstructed)
            loss += sum(vae.losses)     # Add KLD regularization loss
        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))
        loss_metric(loss)
        if(step%100 == 0):
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

Start of epoch 0
step 0: mean loss = 0.3184
step 100: mean loss = 0.1257
step 200: mean loss = 0.0992
step 300: mean loss = 0.0892
step 400: mean loss = 0.0843
step 500: mean loss = 0.0809
step 600: mean loss = 0.0788
step 700: mean loss = 0.0772
step 800: mean loss = 0.0760
step 900: mean loss = 0.0750
Start of epoch 1
step 0: mean loss = 0.0747
step 100: mean loss = 0.0740
step 200: mean loss = 0.0735
step 300: mean loss = 0.0731
step 400: mean loss = 0.0727
step 500: mean loss = 0.0723
step 600: mean loss = 0.0720
step 700: mean loss = 0.0717
step 800: mean loss = 0.0715
step 900: mean loss = 0.0712


###### **Note that, since the VAE is subclassing `Model`, it features built-in training loops. So we could also have trained it like this:**

In [8]:
vae = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss = tf.keras.losses.MeanSquaredError()

vae.compile(optimizer, loss)
vae.fit(x_train, x_train, epochs=2, batch_size=64)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x2661f128b50>

#### **The `Functional API`: Beyond Object-Oriented Development**
###### **We can also build models using the [Functional API](https://github.com/abs-sayem/deep_learning/blob/main/keras/functional_api/readme.md). Importantly, choosing one style or onother, from object-oriented or functional-api, doesn't prevent us from leveraging components written in the other style: we can always mix and match.<br>For Instance, the Functional API example below reuses the same `Sampling` layer we defined in the example above:**

In [30]:
import tensorflow as tf
from tensorflow import keras
from os import devnull
from unicodedata import name


original_dim = 784
intermediate_dim = 64
latent_dim = 32

# Define Encoder Model
original_inputs = tf.keras.Input(shape=(original_dim,), name="encoder_input")
x = layers.Dense(intermediate_dim, activation="relu")(original_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()((z_mean, z_log_var))
encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name="encoder")

# Define Decoder Model
latent_inputs = tf.keras.Input(shape=(latent_dim,), name="z_sampling")
x = layers.Dense(intermediate_dim, activation="relu")(latent_inputs)
outputs = layers.Dense(original_dim, activation="sigmoid")(x)
decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name="decoder")

# Define VariationalAutoEncoder(VAE) Model
outputs = decoder(z)
vae = tf.keras.Model(inputs=original_inputs, outputs=outputs, name="vae")

# Add KL Divergence Regularization Loss
kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)      #????????
vae.add_loss(kl_loss)

# Train
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss = tf.keras.losses.MeanSquaredError()
vae.compile(optimizer, loss=loss)
vae.fit(x_train, x_train, epochs=3, batch_size=64)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x266214a7160>