In [1]:
from tensorflow import keras
import tensorflow_datasets.public_api as tfds
import tensorflow as tf
import pathlib

Создаём датасет с цветочками

In [8]:
# (лучше загрузить свой с инета)
data_dir = tf.keras.utils.get_file(origin="C:\\Users\\samki\\Downloads\\flowers.zip",
                                   fname="flower_photos",
                                   untar=True)
data_dir = pathlib.Path(data_dir)

# Константы
batch_size = 20
img_height = img_width = 192

# Разбиваем датасет на тренировочную группу и группу валидации
train_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

validation_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)


train_data = train_data.map(lambda x, y: (x / 255.0, x / 255.0))
validation_data = validation_data.map(lambda x, y: (x / 255.0, x / 255.0))

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.


Модель

In [66]:
from keras.layers import (Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization,
                          Dropout, Flatten, Reshape, Lambda, MaxPool2D, Concatenate)
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.layers import LeakyReLU
from keras.optimizers import RMSprop, Adam
from keras import backend as K

latent_dim = 16
dropout = 0.1

# Энкодер
input_img = Input(shape=(img_height, img_width, 3))

x = Dense(latent_dim*16, activation="relu")(input_img)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*8, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*4, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*2, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))

z_mean = Dense(latent_dim, activation="sigmoid")(x)
z_log_var = Dense(latent_dim, activation="sigmoid")(x)

# Сэмплирование из Q с трюком репараметризации
@tf.function
def sampling(args):
    z_mean, z_log_var = args
    epsilon = tf.random.normal(shape=tf.shape(z_mean))
    return z_mean + tf.exp(z_log_var / 2) * epsilon

l = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

encoder = Model(input_img, l, "encoder")

# Декодер
z = Input(shape=(latent_dim, ))

x = Dense(latent_dim*2)(z)
x = LeakyReLU()(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(latent_dim*4)(x)
x = LeakyReLU()(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(latent_dim*8)(x)
x = LeakyReLU()(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(latent_dim*16)(x)
x = LeakyReLU()(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(img_width * img_height * 3, activation="sigmoid")(x)
output_img = Reshape((img_height, img_width, 3))(x)

decoder = Model(z, output_img, name="decoder")

# VAE модель
vae = Model(input_img, decoder(l), name="vae")

def vae_loss(x, decoded):
    x = K.reshape(x, shape=(batch_size, img_width * img_height * 3))
    decoded = K.reshape(decoded, shape=(batch_size, img_width * img_height * 3))

    xent_loss = img_width * img_height * 3 * binary_crossentropy(x, decoded)
    kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return (xent_loss + kl_loss) / 2 / img_width / img_height / 3

vae.compile(optimizer="adam", loss="binary_crossentropy")
vae.summary()



ValueError: Exception encountered when calling layer "decoder" (type Functional).

Input 0 of layer "batch_normalization_494" is incompatible with the layer: expected ndim=2, found ndim=4. Full shape received: (None, 192, 192, 32)

Call arguments received by layer "decoder" (type Functional):
  • inputs=tf.Tensor(shape=(None, 192, 192, 16), dtype=float32)
  • training=False
  • mask=None

In [24]:
vae.fit(
    train_data,
    epochs=10,
    validation_data=validation_data,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
  1/147 [..............................] - ETA: 1:01 - loss: 0.5627

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# Генерация изображений с помощью автоэнкодера
generated_images = vae.predict(validation_data)
num_images = 10

# Конвертируем в список
validation_img = generated_img = []
for img, _ in validation_data:
    validation_img.extend(img.numpy())

# Отображение сгенерированных изображений
plt.figure(figsize=(17, 4))
for i in range(num_images):
    # Исходное изображение
    plt.subplot(2, num_images, i + 1)
    plt.imshow(validation_img[i])
    plt.title("Original")
    plt.axis("off")

    # Сгенерированное изображение
    plt.subplot(2, num_images, i + num_images + 1)
    plt.imshow(generated_images[i])
    plt.title("Generated")
    plt.axis("off")

plt.tight_layout()
plt.show()