In [1]:
from tensorflow import keras
import tensorflow_datasets.public_api as tfds
import tensorflow as tf
import pathlib

Создаём датасет с цветочками

In [8]:
# (лучше загрузить свой с инета)
data_dir = tf.keras.utils.get_file(origin="C:\\Users\\samki\\Downloads\\flowers.zip",
                                   fname="flower_photos",
                                   untar=True)
data_dir = pathlib.Path(data_dir)

# Константы
batch_size = 20
img_height = img_width = 192

# Разбиваем датасет на тренировочную группу и группу валидации
train_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)

validation_data = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    image_size=(img_height, img_width),
    batch_size=batch_size,
    seed=5432,
)


train_data = train_data.map(lambda x, y: (x / 255.0, x / 255.0))
validation_data = validation_data.map(lambda x, y: (x / 255.0, x / 255.0))

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.


Модель

In [23]:
from keras.layers import Input, Dense, Conv2D, Conv2DTranspose, BatchNormalization, Dropout, Flatten, Reshape, Lambda, MaxPool2D
from keras.models import Model
from keras.losses import binary_crossentropy
from keras.layers import LeakyReLU
from keras.optimizers import RMSprop, Adam
from keras import backend as K

def multiply(arr: list[int]) -> int:
    num = 1
    for i in arr:
        num *= i
    return num

latent_dim = 64

# Энкодер
input_img = Input(shape=(img_height, img_width, 3))

x = Conv2D(32, (3, 3), activation="relu", padding="same")(input_img)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Conv2D(32, (3, 3), activation="relu", padding="same")(x)
x = BatchNormalization()(x)
x = MaxPool2D((2, 2), padding="same")(x)

x = Flatten()(x)

x = Dense(latent_dim, activation="relu")(x)
x = BatchNormalization()(x)

middle_layer = Dense(latent_dim)(x)
encoder = Model(input_img, middle_layer, "encoder")

# Декодер
z = Input(shape=(latent_dim, ))
dec_input_shape = (3, 3, 64)

x = Dense(multiply(dec_input_shape), activation="relu")(z)
x = LeakyReLU()(x)
x = BatchNormalization()(x)

x = Dense(multiply(dec_input_shape), activation="relu")(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)

x = Reshape(dec_input_shape)(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
x = Conv2DTranspose(128, (3, 3), activation="relu", strides=2, padding="same")(x)
x = BatchNormalization()(x)
decoded = Conv2DTranspose(3, (3, 3), activation="sigmoid", padding="same")(x)

decoder = Model(z, decoded, name="decoder")

# VAE модель
vae = Model(input_img, decoder(middle_layer), name="vae")

def vae_loss(x, decoded):
    x = K.reshape(x, shape=(batch_size, img_height * img_width * 3))
    decoded = K.reshape(decoded, shape=(batch_size, img_height * img_width * 3))

    xent_loss = img_height * img_width * 3 * binary_crossentropy(x, decoded)
    kl_loss = -0.5 * K.sum(
        1 + z_log_var - K.square(middle_layer) - K.exp(middle_layer),
        axis=-1,
    )

    return (xent_loss + kl_loss) / (2 * img_height * img_width * 3)

vae.compile(optimizer=Adam(1e-3, beta_1=0.8, beta_2=0.7), loss="mae")
vae.summary()

Model: "vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_35 (InputLayer)       [(None, 192, 192, 3)]     0         
                                                                 
 conv2d_136 (Conv2D)         (None, 192, 192, 32)      896       
                                                                 
 batch_normalization_281 (Ba  (None, 192, 192, 32)     128       
 tchNormalization)                                               
                                                                 
 max_pooling2d_136 (MaxPooli  (None, 96, 96, 32)       0         
 ng2D)                                                           
                                                                 
 conv2d_137 (Conv2D)         (None, 96, 96, 32)        9248      
                                                                 
 batch_normalization_282 (Ba  (None, 96, 96, 32)       128     

In [None]:
vae.fit(
    train_data,
    epochs=10,
    validation_data=validation_data,
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10

In [None]:
import matplotlib.pyplot as plt

# Генерация изображений с помощью автоэнкодера
generated_images = vae.predict(validation_data)
num_images = 10

# Конвертируем в список
validation_img = generated_img = []
for img, _ in validation_data:
    validation_img.extend(img.numpy())

# Отображение сгенерированных изображений
plt.figure(figsize=(17, 4))
for i in range(num_images):
    # Исходное изображение
    plt.subplot(2, num_images, i + 1)
    plt.imshow(validation_img[i])
    plt.title("Original")
    plt.axis("off")

    # Сгенерированное изображение
    plt.subplot(2, num_images, i + num_images + 1)
    plt.imshow(generated_images[i])
    plt.title("Generated")
    plt.axis("off")

plt.tight_layout()
plt.show()