## GradientTape - Serce TensorFlow

`GradientTape` to główny mechanizm TensorFlow, który pozwala na obliczanie gradientów funkcji definiowanych przy pomocy Kerasa lub TensorFlow. Dopiero pełne zrozumienie `GradientTape` pozwala na pełne wykorzystanie potencjału TensorFlow.

In [None]:
import tensorflow as tf
from keras import layers, models, optimizers, losses, datasets
import matplotlib.pyplot as plt
import tqdm
# Generacja cifar10 przy użyciu GANów


# Tworzenie generatora
def get_generator(noise_size: int = 64, classes: int = 10) -> models.Model:
    inputs = layers.Input(shape=[noise_size], dtype=tf.float32, name="noise")
    aux_inputs = layers.Input(shape=[1], dtype=tf.int32, name="category")

    # Embedding dla kategorii
    y = layers.Embedding(classes, noise_size)(aux_inputs)

    x = layers.Add()([inputs, y])
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU()(x)
    x = layers.Reshape([4, 4, -1])(x)  # -1 automatycznie dobiera wymiar

    for filters in [512, 256, 128]:
        x = layers.Conv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.UpSampling2D()(x)

    x = layers.Conv2D(3, 3, padding="same", activation="sigmoid")(x)
    return models.Model(inputs=[inputs, aux_inputs], outputs=x, name="generator")


# Tworzenie dyskryminatora
def get_discriminator(
    input_shape=(32, 32, 3), classes: int = 10, noise_size: int = 64
) -> models.Model:
    inputs = layers.Input(shape=input_shape, dtype=tf.float32, name="images")
    aux_inputs = layers.Input(shape=[1], dtype=tf.int32, name="category")

    y = layers.Embedding(classes, noise_size)(aux_inputs)

    x = inputs
    for filters in [64, 128, 256]:
        z = layers.Dense(filters, activation="relu")(y)
        z = layers.Reshape([1, 1, -1])(z)

        x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
        x = layers.Add()([x, z])
        x = layers.Conv2D(filters, 3, padding="same", activation="relu")(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        x = layers.MaxPooling2D()(x)

    x = layers.Conv2D(256, 3, padding="same", activation="relu")(x)
    x = layers.Conv2D(1, 3, padding="same")(x)  # brak aktywacji

    return models.Model(inputs=[inputs, aux_inputs], outputs=x, name="discriminator")


# Tworzenie modelu GAN
classes = 10
noise_size = 64
image_size = (32, 32, 3)

generator = get_generator(noise_size=noise_size, classes=classes)
disciminator = get_discriminator(
    input_shape=image_size, classes=classes, noise_size=noise_size
)

disciminator.build(input_shape=[(None, *image_size), (None, 1)])
generator.build(input_shape=[(None, noise_size), (None, 1)])

disciminator.compile(
    optimizer=optimizers.Adam(0.0001, beta_1=0.0, beta_2=0.99),
    loss=losses.Hinge(),
)
generator.compile(
    optimizer=optimizers.Adam(0.0001, beta_1=0.0, beta_2=0.99),
    loss=[losses.BinaryCrossentropy(from_logits=True), losses.MeanSquaredError()],
)

# disciminator.summary()
# generator.summary()

# Wczytanie cifar10
(train_images, train_labels), (_, _) = datasets.cifar10.load_data()
train_images = train_images / 255.0
train_labels = train_labels.reshape(-1, 1)

# Trenowanie modelu
batch_size = 128
epochs = 100

print(train_images.shape, train_labels.shape)

for epoch in range(epochs):
    discriminator_losses = []
    generator_losses = []

    with tqdm.tqdm(total=len(train_images)) as pbar:
        for i in range(0, len(train_images), batch_size):
            # Pobieranie batcha
            real_labels = tf.convert_to_tensor(
                train_labels[i : i + batch_size], dtype=tf.int32
            )
            real_images = tf.convert_to_tensor(
                train_images[i : i + batch_size], dtype=tf.float32
            )
            bs = tf.shape(real_images)[0]

            # Przygotowanie danych dla generatora
            noise = tf.random.normal((bs, noise_size))
            # Nie losujemy kategorii, aby ułatwić pracę dyskryminatorowi
            fake_images = generator([noise, real_labels])

            # Trenowanie dyskryminatora przy pomocy `GradientTape`
            # Włączamy trenowanie dyskryminatora
            disciminator.trainable = True
            with tf.GradientTape() as tape:
                # Dyskyminator ocenia obrazy prawdziwe i fałszywe
                real_output = disciminator([real_images, real_labels])
                fake_output = disciminator([fake_images, real_labels])

                # Definiowanie etykiet dla funkcji straty Hinge
                hinge_real = tf.ones_like(real_output)  # 1
                hinge_fake = -tf.ones_like(fake_output)  # -1

                # Obliczanie straty
                disc_loss_real = disciminator.loss(hinge_real, real_output)
                disc_loss_fake = disciminator.loss(hinge_fake, fake_output)
                total_loss = disc_loss_real + disc_loss_fake

            discriminator_losses.append(total_loss.numpy())
            # Obliczanie gradientów
            grads = tape.gradient(total_loss, disciminator.trainable_variables)
            # Aktualizacja wag
            disciminator.optimizer.apply_gradients(
                zip(grads, disciminator.trainable_variables)
            )

            # Trenowanie generatora
            # Wyłączamy trenowanie dyskryminatora
            disciminator.trainable = False

            with tf.GradientTape() as tape:
                # Generujemy obrazy ponownie
                noise = tf.random.normal((bs, noise_size))
                fake_images = generator([noise, real_labels])

                # Oceniamy obrazy przez dyskryminator
                fake_output = disciminator([fake_images, real_labels])

                # Obliczanie straty
                # Tworzenie etykiet dla funkcji strat BinaryCrossentropy
                bce_real = tf.ones_like(fake_output)

                # Obliczanie straty dla generatora przez dyskryminator
                bce_gen_loss = generator.loss[0](bce_real, fake_output)
                # Obliczanie straty dla generatora bez dyskryminatora
                mse_gen_loss = generator.loss[1](real_images, fake_images)
                total_loss = bce_gen_loss + mse_gen_loss

            generator_losses.append(total_loss.numpy())
            # Obliczanie gradientów
            grads = tape.gradient(total_loss, generator.trainable_variables)
            # Aktualizacja wag
            generator.optimizer.apply_gradients(
                zip(grads, generator.trainable_variables)
            )

            pbar.set_description(
                f"Epoch {epoch + 1}/{epochs}, "
                f"Discriminator loss: {sum(discriminator_losses) / len(discriminator_losses):.4f}, "
                f"Generator loss: {sum(generator_losses) / len(generator_losses):.4f}"
            )
            pbar.update(batch_size)

    # Wyświetlanie przykładowych obrazów
    noise = tf.random.normal((10, noise_size))
    labels = tf.constant([[i] for i in range(10)], dtype=tf.int32)
    images = generator([noise, labels])

    plt.figure(figsize=(10, 1))
    for i in range(10):
        plt.subplot(1, 10, i + 1)
        plt.imshow(images[i])
        plt.axis("off")
    plt.show()
