In [None]:
from tensorflow import keras
import tensorflow_datasets.public_api as tfds
import tensorflow as tf
import pathlib

Создаём датасет с цветочками

In [None]:
data_dir = tf.keras.utils.get_file(origin="C:\\Users\\samki\\Downloads\\flowers.zip",
                                   fname="flower_photos",
                                   untar=True)
data_dir = pathlib.Path(data_dir)

# Константы
batch_size = 20
img_height = img_width = 128

# Разбиваем датасет на тренировочную группу и группу валидации
train_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="training",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

validation_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.1,
    subset="validation",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

# Нормализуем в [0; 1], + убираем лейблы (т.к. у нас задача не распознавать изображения)
train_data = train_data.map(lambda x, y: (x / 255.0, x / 255.0))
validation_data = validation_data.map(lambda x, y: (x / 255.0, x / 255.0))

Модель

In [None]:
from keras.layers import (Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization, UpSampling2D,
                          Dropout, Flatten, Reshape, Lambda, MaxPool2D, Concatenate)
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.layers import LeakyReLU
from keras.optimizers import RMSprop, Adam
from keras import backend as K

# Константы
filters = 64
hidden_units = 1024
dropout = 0.0

# Энкодер
input_img = Input(shape=(img_height, img_width, 3))

def encode_layer(x):
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))

    x = MaxPool2D((2, 2), padding="same")(x)
    return x

x = encode_layer(input_img)
x = encode_layer(x)
x = encode_layer(x)
x = encode_layer(x)
x = encode_layer(x)
x = Flatten()(x)
encoder_output = Dense(hidden_units)(x)

# Декодер
def decode_layer(x):
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))
    x = Conv2D(filters, (3, 3), activation="relu", padding="same", kernel_initializer="he_normal")(x)
    x = Dropout(dropout)(BatchNormalization()(x))

    x = UpSampling2D()(x)
    return x

x = Reshape((32, 32, 1))(encoder_output)
x = decode_layer(x)
x = decode_layer(x)
x = decode_layer(x)
x = decode_layer(x)
x = decode_layer(x)

x = Conv2D(3, (11, 11), activation="sigmoid", padding="same", kernel_initializer="he_normal")(x)
# x = Flatten()(x)
# x = Dense(img_height * img_width * 3, activation="sigmoid", kernel_initializer="he_normal")(x)
output_img = Reshape((img_height, img_width, 3))(x)

# Модели
vae = Model(input_img, output_img, name="vae")

vae.compile(
    optimizer=Adam(3e-3),
    loss="binary_crossentropy",
    # loss_weights=[100],
)

vae.summary()

In [None]:
vae.fit(
    train_data,
    epochs=20,
    validation_data=validation_data,
)

In [None]:
import matplotlib.pyplot as plt
from random import randint
import numpy as np

# Генерация изображений с помощью автоэнкодера
# validation_data = train_data

# Конвертируем в список
validation_img = generated_img = []
for img, _ in validation_data:
    validation_img.extend(img.numpy())

generated_images = vae.predict(np.array(validation_img))
num_images = 10

# Отображение сгенерированных изображений
plt.figure(figsize=(18, 5))
for i in range(num_images):
    random_num = randint(0, len(validation_img))
    # Исходное изображение
    plt.subplot(2, num_images, i + 1)
    plt.imshow(validation_img[random_num])
    plt.title("Original")
    plt.axis("off")

    # Сгенерированное изображение
    plt.subplot(2, num_images, i + num_images + 1)
    plt.imshow(generated_images[random_num])
    plt.title("Generated")
    plt.axis("off")

plt.tight_layout()
plt.show()