## Niestandardowa pętla uczenia modelu w Keras i TensorFlow

Poprzedni przykład pokazał, jak zbudować model GAN. Jednakże, aby zapewnić najwyższą wydajność i wykorzystać dobrodziejstwo kompilowania modelu przy pomocy `XLA`, musimy zaimplementować niestandardową pętlę uczenia. 

In [None]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf
from keras import layers, models, optimizers, losses, datasets, callbacks
import matplotlib.pyplot as plt
from typing import Tuple, Sequence


class FeatureMatchingLoss(losses.Loss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def call(self, y_true, y_pred):
        loss = 0.0
        # Ostatnie wyjście to wynik dlatego iterujemy do przedostatniego
        for i in range(len(y_pred) - 1):
            t, p = y_true[i], y_pred[i]
            loss += tf.reduce_mean(tf.abs(t - p))
        return loss


# Tworzenie generatora
def get_generator(noise_size: int = 64, classes: int = 10) -> models.Model:
    inputs = layers.Input(shape=[noise_size], dtype=tf.float32, name="noise")
    aux_inputs = layers.Input(shape=[1], dtype=tf.int32, name="category")

    # Embedding dla kategorii
    y = layers.Embedding(classes, noise_size)(aux_inputs)

    x = layers.Add()([inputs, y])
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU()(x)
    x = layers.Reshape([4, 4, -1])(x)  # -1 automatycznie dobiera wymiar

    for filters in [512, 256, 128]:
        x = layers.Conv2D(filters, 3, padding="same")(x)
        x = layers.GroupNormalization(groups=-1)(x)
        x = layers.LeakyReLU()(x)
        x = layers.UpSampling2D()(x)

    x = layers.Conv2D(64, 3, padding="same")(x)
    x = layers.LeakyReLU()(x)
    x = layers.Conv2D(3, 3, padding="same", activation="sigmoid")(x)
    return models.Model(inputs=[inputs, aux_inputs], outputs=x, name="generator")


# Tworzenie dyskryminatora
def get_discriminator(
    input_shape=(32, 32, 3), classes: int = 10, noise_size: int = 64
) -> models.Model:
    inputs = layers.Input(shape=input_shape, dtype=tf.float32, name="images")
    aux_inputs = layers.Input(shape=[1], dtype=tf.int32, name="category")

    y = layers.Embedding(classes, noise_size)(aux_inputs)

    x = inputs
    features: Sequence[layers.Layer] = []
    for filters in [64, 128, 256]:
        z = layers.Dense(filters, activation="relu")(y)
        z = layers.Reshape([1, 1, -1])(z)

        x = layers.Conv2D(filters, 3, padding="same")(x)
        x = layers.Add()([x, z])
        x = layers.LayerNormalization()(x)
        x = layers.LeakyReLU()(x)

        x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        features.append(x)

    x = layers.Conv2D(1, 3, padding="same")(x)  # brak aktywacji

    # Model ma teraz kilka wyjść
    return models.Model(
        inputs=[inputs, aux_inputs], outputs=features + [x], name="discriminator"
    )


class GAN(models.Model):
    def __init__(
        self,
        classes: int,
        noise_size: int,
        image_size: Tuple[int, int, int],
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)

        # Tworznie generatora i dyskryminatora w konstruktorze
        self.generator = get_generator(noise_size=noise_size, classes=classes)
        self.discriminator = get_discriminator(
            input_shape=image_size, classes=classes, noise_size=noise_size
        )

        # Budowanie modeli
        self.generator.build(input_shape=[(None, noise_size), (None, 1)])
        self.discriminator.build(input_shape=[(None, *image_size), (None, 1)])

    def compile(self, *args, **kwargs):
        super().compile(*args, **kwargs)
        # Kompilacja modeli
        self.generator.compile(
            optimizer=optimizers.Adam(0.0002, beta_1=0.0, beta_2=0.99),
            loss=FeatureMatchingLoss(),
        )
        self.discriminator.compile(
            optimizer=optimizers.Adam(0.0004, beta_1=0.0, beta_2=0.99),
            loss=losses.Hinge(),
        )

    def call(self, inputs, training=False):
        return self.generator(inputs, training=training)

    # Kluczowa metoda, która definiuje trenowanie modelu
    # Dodanie dekoratora `@tf.function` przyspiesza trenowanie
    # @tf.function(jit_compile=True)
    def train_step(self, data):
        real_images, real_labels = data

        bs = tf.shape(real_images)[0]

        noise = tf.random.normal((bs, noise_size))
        fake_images = self.generator([noise, real_labels])

        # Trenowanie dyskryminatora
        self.discriminator.trainable = True
        with tf.GradientTape() as tape:
            # Ostatnie wyjście to wynik dlatego [-1]
            real_output = self.discriminator([real_images, real_labels])[-1]
            fake_output = self.discriminator([fake_images, real_labels])[-1]

            hinge_real = tf.ones_like(real_output)
            hinge_fake = -tf.ones_like(fake_output)

            disc_loss_real = self.discriminator.loss(hinge_real, real_output)
            disc_loss_fake = self.discriminator.loss(hinge_fake, fake_output)
            total_d_loss = disc_loss_real + disc_loss_fake

        grads = tape.gradient(total_d_loss, self.discriminator.trainable_variables)
        self.discriminator.optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_variables)
        )

        # Trenowanie generatora
        self.discriminator.trainable = False
        with tf.GradientTape() as tape:
            noise = tf.random.normal((bs, noise_size))
            fake_images = self.generator([noise, real_labels])

            fake_output = self.discriminator([fake_images, real_labels])
            real_output = self.discriminator([real_images, real_labels])

            mean_loss = -tf.reduce_mean(fake_output[-1])
            fm_loss = self.generator.loss(real_output, fake_output)

            total_g_loss = mean_loss + fm_loss

        grads = tape.gradient(total_g_loss, self.generator.trainable_variables)
        self.generator.optimizer.apply_gradients(
            zip(grads, self.generator.trainable_variables)
        )

        return {"d_loss": total_d_loss, "g_loss": total_g_loss}


class PlotImagesCallback(callbacks.Callback):
    def __init__(self, noise_size: int, classes: int):
        self.noise = tf.random.normal((10, noise_size))
        self.labels = tf.constant([[i] for i in range(10)], dtype=tf.int32)
        self.classes = classes
        self.class_names = [
            "plane",
            "car",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
        ]

    def on_epoch_end(self, epoch, logs=None):
        images = self.model.generator([self.noise, self.labels])

        plt.figure(figsize=(10, 1))
        for i in range(10):
            plt.subplot(1, 10, i + 1)
            plt.imshow(images[i])
            plt.title(self.class_names[i])
            plt.axis("off")
        plt.show()


# Tworzenie modelu GAN
classes = 10
noise_size = 64
image_size = (32, 32, 3)


# Wczytanie cifar10
(train_images, train_labels), (_, _) = datasets.cifar10.load_data()
train_images = train_images / 255.0
train_labels = train_labels.reshape(-1, 1)

# Trenowanie modelu
batch_size = 128
epochs = 100

# Tworzenie zbioru tf.data.Dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))

gan = GAN(classes, noise_size, image_size)
gan.compile()

gan.fit(
    train_dataset.batch(batch_size),
    epochs=epochs,
    callbacks=[PlotImagesCallback(noise_size, classes)],
)


E0000 00:00:1733433414.318855 1037836 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733433414.321955 1037836 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
I0000 00:00:1733433417.505067 1037836 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 21769 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:06:00.0, compute capability: 8.6


Epoch 1/100


I0000 00:00:1733433426.149458 1037914 service.cc:148] XLA service 0x7f7e30002d80 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733433426.149497 1037914 service.cc:156]   StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
I0000 00:00:1733433427.293340 1037914 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1733433444.899059 1037914 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m390/391[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 147ms/step - d_loss: 2.0396 - g_loss: 0.5082