In [1]:
from tensorflow import keras
import tensorflow_datasets.public_api as tfds
import tensorflow as tf

Создаём датасет с цветочками

In [2]:
# Константы
img_side = 192

# Разбиваем датасет на тренировочную группу и группу валидации
def init_data_with_batch_size(batch_size):
    global train_data, validation_data
    train_data = tf.keras.utils.image_dataset_from_directory(
        "flowers",
        validation_split=0.1,
        subset="training",
        image_size=(img_side, img_side),
        seed=5432,
        shuffle=True,
        batch_size=batch_size,
    )

    val_data = tf.keras.utils.image_dataset_from_directory(
        "flowers",
        validation_split=0.1,
        subset="validation",
        image_size=(img_side, img_side),
        seed=5432,
        shuffle=True,
        batch_size=batch_size,
    )

    # Убираем лейблы (т.к. у нас задача не распознавать изображения)
    train_data = train_data.map(lambda x, y: (x/255.0, x/255.0))
    val_data = val_data.map(lambda x, y: (x/255.0, x/255.0))

    return train_data, val_data

Модель

In [3]:
from keras.layers import (Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization, UpSampling2D,
                          Dropout, Flatten, Reshape, Lambda, MaxPool2D, Concatenate)
from keras.models import Model
from keras.optimizers import RMSprop, Adam
from keras import backend as K

# Константы
filters = 32
hidden_units = 128

# Энкодер
input_img = Input(shape=(img_side, img_side, 3))
x = input_img

for _ in range(6):
    x = Conv2D(filters, (5, 5), activation="relu", padding="same")(x)
    x = BatchNormalization()(x)

    x = MaxPool2D((2, 2), padding="same")(x)

# x = Flatten()(x)  больше нахуй никогда
# encoder_output = Dense(hidden_units)(x)

# Декодер
def decode_layer(x, upsampling=(2, 2)):
    x = Conv2D(filters, (3, 3), activation="relu", padding="same")(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters, (3, 3), activation="relu", padding="same")(x)
    x = BatchNormalization()(x)

    x = UpSampling2D(upsampling)(x)
    x = BatchNormalization()(x)
    return x

# Минимум 4 на 4
# x = Reshape((4, 4, hidden_units//16))(encoder_output)

x = decode_layer(x)
x = decode_layer(x)

x = Flatten()(x)
x = Dense(hidden_units, activation="tanh")(x)
x = BatchNormalization()(x)

# Повышает чёткость изображения при помощи Dense
# (да, съедает кучу параметров, но зато картинка не размазанная)
x = Dense(img_side**2 // 4, activation="tanh")(x)  # Отрезаем четвертинку картинки
x = Reshape((img_side//2, img_side//2, 1))(x)

x = decode_layer(x)

x = Conv2D(3, (25, 25), activation="sigmoid", padding="same")(x)
output_img = Reshape((img_side, img_side, 3))(x)

# Модели
vae = Model(input_img, output_img, name="vae")

vae.compile(
    optimizer=keras.optimizers.Adamax(),
    loss=keras.losses.binary_crossentropy,
    loss_weights=[1000],
)

vae.summary()

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 192, 192, 3)]     0         
                                                                 
 conv2d (Conv2D)             (None, 192, 192, 32)      2432      
                                                                 
 batch_normalization (BatchN  (None, 192, 192, 32)     128       
 ormalization)                                                   
                                                                 
 max_pooling2d (MaxPooling2D  (None, 96, 96, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 96, 96, 32)        25632     
                                                                 
 batch_normalization_1 (Batc  (None, 96, 96, 32)       128     

In [None]:
train_data, val_data = init_data_with_batch_size(4)
vae.fit(
    train_data,
    epochs=30,
    validation_data=val_data,
)

Found 1671 files belonging to 5 classes.
Using 1504 files for training.
Found 1671 files belonging to 5 classes.
Using 167 files for validation.
Epoch 1/30


In [None]:
# from keras.models import save_model, load_model
# vae.load_weights("123")

import matplotlib.pyplot as plt
from random import randint
import numpy as np

# Иначе очень велика вероятость ООМ (недостаток памяти)
_, val_data = init_data_with_batch_size(1)

# Конвертируем в список
validation_img = generated_img = []
for img, _ in val_data:
    validation_img.extend(img.numpy())

generated_images = vae.predict(np.array(validation_img))
num_images = 5

plt.figure(figsize=(20, 10))

def print_images():
    for _ in range(num_images):
        random_num = randint(0, len(validation_img))
        # Исходное изображение
        plt.subplot(2, num_images, _ + 1)
        plt.imshow(validation_img[random_num])
        plt.title("Original")
        plt.axis("off")

        # Сгенерированное изображение
        plt.subplot(2, num_images, _ + num_images + 1)
        plt.imshow(generated_images[random_num])
        plt.title("Generated")
        plt.axis("off")
    plt.tight_layout()
    plt.show()

print_images()