In [6]:
from tensorflow import keras
import tensorflow_datasets.public_api as tfds
import tensorflow as tf
import pathlib

Создаём датасет с цветочками

In [92]:
data_dir = tf.keras.utils.get_file(origin="C:\\Users\\samki\\Downloads\\flowers.zip",
                                   fname="flower_photos",
                                   untar=True)
data_dir = pathlib.Path(data_dir)

# Константы
batch_size = 20
img_height = img_width = 180

# Разбиваем датасет на тренировочную группу и группу валидации
train_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

validation_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

# Нормализуем в [0; 1], + убираем лейблы (т.к. у нас задача не распознавать изображения)
train_data = train_data.map(lambda x, y: (x / 255.0, x / 255.0))
validation_data = validation_data.map(lambda x, y: (x / 255.0, x / 255.0))

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.


Модель

In [100]:
from keras.layers import (Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization,
                          Dropout, Flatten, Reshape, Lambda, MaxPool2D, Concatenate)
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.layers import LeakyReLU
from keras.optimizers import RMSprop, Adam
from keras import backend as K

# Константы
latent_dim = 20
dropout = 0.0

# Энкодер
input_img = Input(shape=(img_height, img_width, 3))

x = Conv2D(latent_dim*10, (3, 3), activation="relu")(input_img)
x = Dropout(dropout)(BatchNormalization()(x))
x = MaxPool2D((2, 2))(x)

x = Conv2D(latent_dim*8, (3, 3), activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = MaxPool2D((2, 2))(x)

x = Conv2D(latent_dim*6, (3, 3), activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = MaxPool2D((2, 2))(x)

x = Conv2D(latent_dim*4, (3, 3), activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = MaxPool2D((2, 2))(x)

x = Conv2D(latent_dim*2, (3, 3), activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = MaxPool2D((2, 2))(x)

encoder_output = Flatten()(x)

# # Специальные промежуточные слои
# middle_mean = Dense(latent_dim, activation="sigmoid")(x)
# middle_log_var = Dense(latent_dim, activation="sigmoid")(x)
#
# # Сэмплирование из Q с трюком репараметризации (что бы это не значило)
# @tf.function
# def sampling(args):
#     middle_mean, middle_log_var = args
#     epsilon = tf.random.normal(shape=(latent_dim, ))  # Генерим нормальный шум
#     return middle_mean + tf.exp(middle_log_var / 2) * epsilon   # Что-то возвращаем
#
# # l ═ новый кастомный слой
# l = Lambda(sampling, output_shape=(latent_dim,))([middle_mean, middle_log_var])


# Декодер
decoder_input = Input(shape=(latent_dim, ))

# Делаем одновременно этот слой и входом и продолжением энкодера
# (т.к. как слой в который можно подать значения)
# vector_from_encoder = Concatenate()([decoder_input, encoder_output])

x = Dense(latent_dim, activation="relu")(encoder_output)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*2, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*4, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*6, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))
x = Dense(latent_dim*8, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(latent_dim, activation="relu")(x)
x = Dropout(dropout)(BatchNormalization()(x))

x = Dense(img_width * img_height * 3, activation="sigmoid")(x)
output_img = Reshape((img_height, img_width, 3))(x)

# Модели
encoder = Model(input_img, encoder_output, "encoder")
decoder = Model(decoder_input, output_img, "decoder")
vae = Model(input_img, output_img, name="vae")

# Создаём свою функцию ошибки (спиздил с инета)
@tf.function
def vae_loss(answer, decoded, middle_mean, middle_log_var):
    # Что-то делаем
    answer = K.reshape(answer, shape=(batch_size, img_width * img_height * 3))
    decoded = K.reshape(decoded, shape=(batch_size, img_width * img_height * 3))

    # Мутим какую-то муть
    xent_loss = img_width * img_height * 3 * binary_crossentropy(answer, decoded)
    kl_loss = -0.5 * K.sum(1 + middle_log_var - K.square(middle_mean) - K.exp(middle_log_var), axis=-1)

    # Что-то там вычисляем
    loss = (xent_loss + kl_loss) / (2 * img_width * img_height * 3)

    # Возвращаем скаляр
    return K.mean(loss)

# vae.add_loss(vae_loss(input_img, output_img, middle_mean, middle_log_var))
vae.compile(optimizer="adam", loss="binary_crossentropy")

vae.summary()
vae.fit(
    train_data,
    epochs=20,
    validation_data=validation_data,
)

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_170 (InputLayer)      [(None, 180, 180, 3)]     0         
                                                                 
 conv2d_280 (Conv2D)         (None, 178, 178, 200)     5600      
                                                                 
 batch_normalization_604 (Ba  (None, 178, 178, 200)    800       
 tchNormalization)                                               
                                                                 
 dropout_604 (Dropout)       (None, 178, 178, 200)     0         
                                                                 
 max_pooling2d_280 (MaxPooli  (None, 89, 89, 200)      0         
 ng2D)                                                           
                                                                 
 conv2d_281 (Conv2D)         (None, 87, 87, 160)       288160  

KeyboardInterrupt: 

In [94]:
import matplotlib.pyplot as plt
from random import randint

# Генерация изображений с помощью автоэнкодера
generated_images = vae.predict(validation_data)
num_images = 10

# Конвертируем в список
validation_img = generated_img = []
for img, _ in validation_data:
    validation_img.extend(img.numpy())

# Отображение сгенерированных изображений
plt.figure(figsize=(17, 4))
for i in range(num_images):
    random_num = randint(0, len(validation_img))
    # Исходное изображение
    plt.subplot(2, num_images, i + 1)
    plt.imshow(validation_img[random_num])
    plt.title("Original")
    plt.axis("off")

    # Сгенерированное изображение
    plt.subplot(2, num_images, i + num_images + 1)
    plt.imshow(generated_images[random_num])
    plt.title("Generated")
    plt.axis("off")

plt.tight_layout()
plt.show()

ValueError: in user code:

    File "C:\ProgramData\miniconda3\lib\site-packages\keras\engine\training.py", line 2041, in predict_function  *
        return step_function(self, iterator)
    File "C:\ProgramData\miniconda3\lib\site-packages\keras\engine\training.py", line 2027, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\ProgramData\miniconda3\lib\site-packages\keras\engine\training.py", line 2015, in run_step  **
        outputs = model.predict_step(data)
    File "C:\ProgramData\miniconda3\lib\site-packages\keras\engine\training.py", line 1983, in predict_step
        return self(x, training=False)
    File "C:\ProgramData\miniconda3\lib\site-packages\keras\utils\traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "C:\ProgramData\miniconda3\lib\site-packages\keras\engine\input_spec.py", line 216, in assert_input_compatibility
        raise ValueError(

    ValueError: Layer "vae" expects 2 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, 180, 180, 3) dtype=float32>]
