In [None]:
# імпорт бібліотек
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

# Загрузка датасета MNIST
(x_train, _), (_, _) = mnist.load_data()

# Нормалізація данных
x_train = x_train.astype('float32') / 255.

# Створення моделі варіаційного автокодувальника
input_layer = tf.keras.Input(shape=(28, 28, 1))

# Энкодер
encoded_mean = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_layer)
encoded_mean = layers.MaxPooling2D((2, 2), padding='same')(encoded_mean)
encoded_mean = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded_mean)
encoded_mean = layers.MaxPooling2D((2, 2), padding='same')(encoded_mean)
encoded_mean = layers.Flatten()(encoded_mean)

encoded_log_var = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_layer)
encoded_log_var = layers.MaxPooling2D((2, 2), padding='same')(encoded_log_var)
encoded_log_var = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded_log_var)
encoded_log_var = layers.MaxPooling2D((2, 2), padding='same')(encoded_log_var)
encoded_log_var = layers.Flatten()(encoded_log_var)


def sampling(args):
    mean, log_var = args
    epsilon = tf.random.normal(shape=tf.shape(mean))
    return mean + tf.exp(0.5 * log_var) * epsilon

latent_layer = layers.Lambda(sampling)([encoded_mean, encoded_log_var])

# Декодер
decoded = layers.Dense(128, activation='relu')(latent_layer)
decoded = layers.Reshape((7, 7, 1))(decoded)
decoded = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(decoded)
decoded = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(decoded)
decoded = layers.Conv2D(1, (3, 3), activation='sigmoid', padding='same')(decoded)

vae = tf.keras.Model(input_layer, decoded)

# Функція втрат
def vae_loss(inputs, outputs):
    reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(inputs, outputs))
    kl_divergence = 0.5 * tf.reduce_mean(tf.exp(encoded_log_var) + tf.square(encoded_mean) - 1. - encoded_log_var)
    return reconstruction_loss + kl_divergence


vae.compile(optimizer='adam', loss=vae_loss)


vae.fit(x_train, x_train,
                epochs=10,
                batch_size=128,
                shuffle=True)


import matplotlib.pyplot as plt

encoded_digits = vae.encoder(x_train).numpy()
plt.scatter(encoded_digits[:, 0], encoded_digits[:, 1], c=x_train.argmax(axis=1))
plt.colorbar()
plt.show()