In [None]:
from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow as tf

In [None]:
# Константы
img_side = 128

# Разбиваем датасет на тренировочную группу и группу валидации
def init_data_with_batch_size(batch_size):
    global train_data
    train_data = tf.keras.utils.image_dataset_from_directory(
        "flowers",
        image_size=(img_side, img_side),
        shuffle=True,
        batch_size=batch_size,
    )

    # Убираем лейблы (т.к. у нас задача не распознавать изображения)
    train_data = train_data.map(lambda x, y: (x/255.0, x/255.0))

    return train_data

Модель

In [11]:
from keras.layers import (Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization, UpSampling2D,
                          Dropout, Flatten, Reshape, Lambda, MaxPool2D, Concatenate, add, Cropping2D)
from keras.models import Model
from keras.optimizers import RMSprop, Adam
from math import log2

# Константы
filters = 64
hidden_units = 64  # Размер "смыслового вектора"
core_size = (3, 3)

input_img = Input(shape=(img_side, img_side, 3))

"""Энкодер"""
# Это надо чтобы первый слой мог сложиться со следующим (при помощи add([x_temp, x]) )
x_temp = BatchNormalization()(input_img)
x_temp = Conv2D(filters, core_size, activation="relu", padding="same")(x_temp)

for i in range(5):
    x = BatchNormalization()(x_temp)
    x = Conv2D(filters, core_size, activation="relu", padding="same")(x)
    x = Conv2D(filters, core_size, activation="relu", padding="same")(x)

    x_temp = add([x_temp, x])
    x_temp = MaxPool2D((2, 2))(x_temp)


# Превращаем сжатую картинку в "смысловой вектор"
x = BatchNormalization()(x_temp)
x = Flatten()(x)
x = Dense(hidden_units, activation="relu")(x)
x_temp = Reshape((int(hidden_units**.5), int(hidden_units**.5), 1))(x)


"""Декодер"""
num = 0
def base_decode_layer(x_input):
    """Просто декодирующие слои с постепенным уменьшением количества фильтров"""
    global num
    x = BatchNormalization()(x_input)
    x = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same")(x)
    x = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same")(x)
    x = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same")(x)

    num += 0
    return x

def decoding_layer(x_temp):
    """Делаем остаточное обучение"""
    x = base_decode_layer(x_temp)

    x_temp = add([x_temp, x])  # Остаточное обучение
    x_temp = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same", strides=2)(x_temp)
    # x_temp = UpSampling2D()(x_temp)

    return x_temp

# Тут сочетаем Dense (для увеличения резкости), остаточное обучение и расширяем картинку
for _ in range(2):  # <= 3
    x = Conv2D(1, core_size, activation="relu", padding="same")(x_temp)
    x = BatchNormalization()(x)
    shape = x.shape[1]
    x = Flatten()(x)
    x = Dense(shape**2, activation="relu")(x)
    x = Reshape((shape, shape, 1))(x)
    x = BatchNormalization()(x)

    x_temp = decoding_layer(x_temp)
    x_temp = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same", strides=2)(x_temp)

# Расширяем карту признаков, увеличиваем картинку, и остаточное обучение
# (Dense ут нет, т.к. слишком много параметров жрёт)
for _ in range(2):
    x = decoding_layer(x)
    x_temp = Conv2DTranspose(filters//2**num, core_size, activation="relu", padding="same", strides=2)(x_temp)

# Просто добавляем признаки, не меняя картинку, но уменьшая колчество фильтров до 1
num = 0
for i in range(int(log2(filters)) +1):
    x = base_decode_layer(x)


# Добавляем цвета
x = Conv2D(3, core_size, activation="sigmoid", padding="same")(x)
output_img = Reshape((img_side, img_side, 3))(x)

# Модель
vae = Model(input_img, output_img, name="vae")
vae.compile(
    optimizer=keras.optimizers.Adam(3e-3),
    loss="binary_crossentropy",
    loss_weights=[1000],
)

vae.summary()

Model: "vae"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 batch_normalization_104 (Batch  (None, 128, 128, 3)  12         ['input_6[0][0]']                
 Normalization)                                                                                   
                                                                                                  
 conv2d_69 (Conv2D)             (None, 128, 128, 64  1792        ['batch_normalization_104[0][0]']
                                )                                                               

In [None]:
train_data = init_data_with_batch_size(4)
vae.fit(
    train_data,
    epochs=50,
)

Found 1671 files belonging to 5 classes.
Epoch 1/50
Epoch 2/50

In [None]:
import matplotlib.pyplot as plt
from random import randint
import numpy as np

data = np.array([i[0][0] for count, i in enumerate(init_data_with_batch_size(1))
                 if count < 16])
generated_images = vae.predict(data, verbose=False)

num_images = 5

plt.figure(figsize=(20, 10))

for _ in range(num_images):
    random_num = randint(0, 16-1)

    # Оригинальное изображение
    plt.subplot(2, num_images, _ + 1)
    plt.imshow(data[random_num])
    plt.title("Original")
    plt.axis("off")

    # Сгенерированное изображение
    plt.subplot(2, num_images, _ + num_images + 1)
    plt.imshow(generated_images[random_num])
    plt.title("Generated")
    plt.axis("off")
plt.tight_layout()
plt.show()

In [None]:
"""Выводим Архитектуру"""
img_file = "architecture.png"
tf.keras.utils.plot_model(vae, to_file=img_file, show_shapes=True, show_layer_names=False)