# Image Generator
---

S.Yu. Papulin (papulin.study@yandex.ru)

### Contents

- [GAN Model](#GAN-Model)
    - [Preparing dataset](#Preparing-dataset)
    - [Generator](#Generator)
    - [Discriminator](#Discriminator)
    - [Image generator based on GAN](#Image-generator-based-on-GAN)
    - [Conditional GAN with classifier](#Conditional-GAN-with-classifier)
    - [Conditional GAN](#Conditional-GAN)
    - [Generating image based on text prompt](#Generating-image-based-on-text-prompt)
- [Stable Diffusion Model (Pretrained)](#Stable-Diffusion-Model-(Pretrained))
- [Sources](#Sources)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

In [None]:
import tensorflow as tf
from tensorflow.keras import (
    layers, 
    models, 
    Model, 
    utils, 
    losses, 
    optimizers, 
    metrics
)

In [None]:
from tensorflow.keras.datasets import mnist

In [None]:
from sklearn import datasets

In [None]:
RANDOM_STATE = 100

## GAN Model

Generative adversarial network - GAN

### Preparing dataset

In [None]:
def load_mnist_8x8():
    from sklearn import datasets
    digits = datasets.load_digits()
    X = digits.images
    y = digits.target
    X = X.astype('float32') / 16.0
    X_trainval, X_test, y_trainval, y_test = train_test_split(
        X, y, 
        test_size=0.2, 
        random_state=RANDOM_STATE
    )
    return (X_trainval, y_trainval), (X_test, y_test)


def load_mnist_28x28():
    from tensorflow.keras.datasets import mnist
    (X_trainval, y_trainval), (X_test, y_test) = mnist.load_data()
    return (X_trainval / 255.0, y_trainval), (X_test / 255.0, y_test)
    

In [None]:
(X_trainval, y_trainval), (X_test, y_test) = load_mnist_8x8()

In [None]:
X_trainval.shape, X_test.shape

In [None]:
X_trainval.min(), X_trainval.max()

In [None]:
targets, counts = np.unique(y_trainval, return_counts=True)
targets, counts

In [None]:
NUM_CLASSES = len(targets)
NUM_CLASSES

### Generator

In [None]:
# Dimensionality of latent noise vector
NOISE_DIM = 10

In [None]:
def build_simple_generator_model(noise_dim=NOISE_DIM):
    model = models.Sequential(name='generator')
    model.add(layers.Input(shape=(noise_dim,)))
    model.add(layers.Dense(16, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(32, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(8 * 8 * 1, activation='tanh'))
    model.add(layers.Reshape((8, 8, 1)))
    return model


def build_advanced_generator_model(noise_dim=NOISE_DIM):
    model = models.Sequential()
    model.add(layers.Input(shape=(noise_dim,)))
    # input embedding layer
    model.add(layers.Dense(4*4*64, activation='leaky_relu'))
    # reshape (like 4x4 per 64 filters)
    model.add(layers.Reshape(target_shape=(4, 4, 64)))
    # upscale
    # 4x4
    # model.add(layers.Conv2DTranspose(
    #     filters=32, 
    #     kernel_size=(2, 2), 
    #     strides=(1, 1), 
    #     padding='same'
    # ))
    # model.add(layers.LeakyReLU(0.02))
    # model.add(layers.BatchNormalization())
    # 8x8
    model.add(layers.Conv2DTranspose(
        filters=16, 
        kernel_size=(3, 3), 
        strides=(2, 2), 
        padding='same'
    ))
    model.add(layers.LeakyReLU(0.02))
    model.add(layers.BatchNormalization())
    # 8x8
    model.add(layers.Conv2DTranspose(
        filters=1, 
        kernel_size=(5, 5), 
        strides=(1, 1), 
        padding='same',
        activation='tanh'
    ))
    return model

# Upscaling block
#     model.add(layers.UpSampling2D((2, 2)))
#     model.add(layers.Conv2D(filters=16, kernel_size=3, padding='same'))
#     model.add(layers.LeakyReLU())
#     model.add(layers.BatchNormalization())


In [None]:
generator_model = build_advanced_generator_model()
generator_model.summary()

In [None]:
# Generate normal noise
sample_noise = tf.random.normal([1, 10])
# Generate image
output = generator_model(sample_noise)
output.shape, tf.reduce_min(output), tf.reduce_max(output)

In [None]:
# Show image
plt.figure(figsize=(2,2))
plt.imshow(output[0])
plt.axis("off")
plt.show()

### Discriminator

**Building dataset**

In [None]:
image_size = X_trainval.shape[1]
image_size

In [None]:
n_samples = len(X_trainval)

In [None]:
# True images
X_disc_true = X_trainval.reshape(-1, image_size, image_size, 1)
y_disc_true = np.ones(n_samples)

X_disc_true.shape, y_disc_true.shape

In [None]:
# Fake image
sample_noise = tf.random.normal([n_samples, NOISE_DIM])
X_disc_fake = generator_model(sample_noise)
y_disc_fake = np.zeros(n_samples)

X_disc_fake.shape, y_disc_fake.shape

In [None]:
# Build dataset by combining true and fake images
X_disc = np.r_[X_disc_true, X_disc_fake]
y_disc = np.r_[y_disc_true, y_disc_fake]

X_disc.shape, y_disc.shape

In [None]:
X_disc_trainval, X_disc_test, y_disc_trainval, y_disc_test = train_test_split(
    X_disc, y_disc, 
    test_size=0.2, 
    random_state=RANDOM_STATE
)
X_disc_trainval.shape, X_disc_test.shape

**Building discriminator**

In [None]:
def build_simple_discriminator_model():
    model = models.Sequential(name='discriminator')
    model.add(layers.Input(shape=(8, 8, 1)))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation='relu'))
    model.add(layers.Dropout(0.1))
    model.add(layers.Dense(1))
    return model


def build_advanced_discriminator_model():
    model = models.Sequential()
    model.add(layers.Input(shape=(8, 8, 1)))
    model.add(layers.Conv2D(16, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(32, activation='relu'))
    model.add(layers.Dense(1))
    return model

In [None]:
# Build discrimator model
discriminator_model = build_simple_discriminator_model()
discriminator_model.summary()

**Fitting**

In [None]:
discriminator_model.compile(
    optimizer=optimizers.Adam(1e-3),
    loss=losses.BinaryCrossentropy(from_logits=True),
    metrics=[metrics.BinaryAccuracy()]
)

In [None]:
discriminator_model.fit(
    x=X_disc_trainval,
    y=y_disc_trainval,
    validation_split=0.1,
    batch_size=64,
    epochs=10
)

**Evaluating**

In [None]:
# Output metrics on test set
discriminator_model.evaluate(X_disc_test, y_disc_test)

In [None]:
# Compute logits, pobabilities and predictions
logits = discriminator_model(X_disc_test[:10])
pobabilities = tf.nn.sigmoid(logits)
predictions = tf.round(pobabilities)

In [None]:
# Predictions
tf.reshape(predictions, shape=(-1)).numpy()

In [None]:
# True values
y_disc_test[:10]

### Image generator based on GAN

**Prepatring dataset**

In [None]:
def convert_to_tf_dataset(X, y, batch_size=64):
    return (
        tf.data.Dataset.from_tensor_slices((X, y))
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )


In [None]:
X_trainval.shape, y_trainval.shape

In [None]:
X_train, X_val, y_train, y_val = train_test_split(
    X_trainval, y_trainval, 
    test_size=0.1, 
    random_state=RANDOM_STATE
)

In [None]:
train_ds = convert_to_tf_dataset(X_train, y_train, batch_size=64)
val_ds = convert_to_tf_dataset(X_val, y_val, batch_size=64)
test_ds = convert_to_tf_dataset(X_test, y_test, batch_size=64)

#### Generating image similar to training set

In [None]:
class VanillaGANModel(Model):

    def __init__(self, generator_model, discriminator_model, **kwargs):
        super().__init__(**kwargs) 
        # models
        self.generator = generator_model
        self.discriminator = discriminator_model
        # optimizers
        self.generator_optimizer = optimizers.Adam(1e-4)
        self.discriminator_optimizer = optimizers.Adam(1e-4)
        # metrics
        self.train_generator_loss = metrics.Mean(name='train_generator_loss')
        self.train_discriminator_loss = metrics.Mean(name='train_discriminator_loss')
        self.test_generator_loss = metrics.Mean(name='test_generator_loss')
        self.test_discriminator_loss = metrics.Mean(name='test_discriminator_loss')
        self.train_cheater_accuracy = metrics.BinaryAccuracy(name='train_accuracy')
        self.test_cheater_accuracy = metrics.BinaryAccuracy(name='test_accuracy')

    def call(self, inputs, training=False):
        batch_size = inputs
        # generate noise
        noise = tf.random.normal([batch_size, NOISE_DIM])
        # generate fake images
        X_fake = self.generator(noise)
        return X_fake

    def compute_generator_loss(self, X_fake):
        # classify whether a fake image is real
        y_fake_logits = self.discriminator(X_fake)
        # penalty for generating unrealistic images
        loss = losses.binary_crossentropy(
            y_true=tf.ones_like(y_fake_logits), 
            y_pred=y_fake_logits, 
            from_logits=True
        )
        return tf.reduce_mean(loss)

    def compute_discriminator_loss(self, X_true, X_fake):
        """
        1 - real image
        0 - fake image
        """
        # classify whether a real image is real
        y_true_logits = self.discriminator(X_true)
        # classify whether a fake image is real
        y_fake_logits = self.discriminator(X_fake)
        # penalty for misclassification of real images
        loss_true = losses.binary_crossentropy(
            y_true=tf.ones_like(y_true_logits), 
            y_pred=y_true_logits, 
            from_logits=True
        )
        # penalty for misclassification of fake images
        loss_fake = losses.binary_crossentropy(
            y_true=tf.zeros_like(y_fake_logits), 
            y_pred=y_fake_logits,
            from_logits=True
        )
        total_loss = (loss_true + loss_fake) / 2.0
        return tf.reduce_mean(total_loss), y_fake_logits
    
    @tf.function
    def train_step(self, data):
        
        X_true, _ = data
        
        with (
            tf.GradientTape() as tape_gen, 
            tf.GradientTape() as tape_disc
        ):
            # number of images to generate
            batch_size = tf.shape(X_true)[0]
            # generate fake images
            X_fake = self.call(batch_size)
            # compute losses
            discriminator_loss, y_fake_logits = self.compute_discriminator_loss(X_true, X_fake)
            generator_loss = self.compute_generator_loss(X_fake)

        # compute gradients and update weights
        gradients_discriminator = tape_disc.gradient(
            target=discriminator_loss, 
            sources=self.discriminator.trainable_variables
        )
        self.discriminator_optimizer.apply_gradients(
            zip(gradients_discriminator, self.discriminator.trainable_variables)
        )
        
        gradients_generator = tape_gen.gradient(
            target=generator_loss, 
            sources=self.generator.trainable_variables
        )
        self.generator_optimizer.apply_gradients(
            zip(gradients_generator, self.generator.trainable_variables)
        )

        # update metrics
        self.train_discriminator_loss.update_state(discriminator_loss)
        self.train_generator_loss.update_state(generator_loss)
        self.train_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        
        return {
            'generator_loss': self.train_generator_loss.result(), 
            'discriminator_loss': self.train_discriminator_loss.result(),
            'cheater_accuracy': self.train_cheater_accuracy.result()
        }

    def test_step(self, data):
        X_true, _ = data
        batch_size = tf.shape(X_true)[0]
        # generate fake images
        X_fake = self.call(batch_size)
        # compute losses
        generator_loss = self.compute_generator_loss(X_fake)
        discriminator_loss, y_fake_logits = self.compute_discriminator_loss(X_true, X_fake)
        # update metrics
        self.test_discriminator_loss.update_state(discriminator_loss)
        self.test_generator_loss.update_state(generator_loss)
        self.test_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        return {
            'generator_loss': self.test_generator_loss.result(), 
            'discriminator_loss': self.test_discriminator_loss.result(),
            'cheater_accuracy': self.test_cheater_accuracy.result()
        }
        

In [None]:
EPOCHS = 1000
gan_model = VanillaGANModel(
    generator_model=build_simple_generator_model(),
    discriminator_model=build_simple_discriminator_model()
)
gan_model.summary()

In [None]:
# EPOCHS = 300
# gan_model = VanillaGANModel(
#     generator_model=build_advanced_generator_model(),
#     discriminator_model=build_advanced_discriminator_model()
# )
# gan_model.summary()

In [None]:
gan_model.compile()

In [None]:
# NUM_EPOCHS = 500

train_history = gan_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=NUM_EPOCHS,
    verbose=1 # change to 0 for the simple model
)
train_history

In [None]:
fake_image = gan_model(tf.constant(30))
fake_image.shape

In [None]:
def display_images(I):
    NUM_PER_ROW = 10
    num_images = I.shape[0]
    num_rows = -(-num_images // NUM_PER_ROW)
    plt.figure(figsize=[14, 1.5 * num_rows])
    for index, image in enumerate(I):
        plt.subplot(num_rows, NUM_PER_ROW, index+1)
        plt.imshow(image)
        plt.axis("off")
    
    plt.show()

In [None]:
display_images(fake_image)

In [None]:
# model.discriminator(fake_image).numpy() > 0

In [None]:
def display_smooth_plot(x, y, color, label, step = 10):
    # polyline
    coefs = np.polyfit(x, y, 5)
    poly_func = np.poly1d(coefs)
    x_trend = np.linspace(min(x), max(x), 100)
    y_trend = poly_func(x_trend)
    # mean by step
    y_lim = y[:len(y) // step * step]
    y_means = np.array(y_lim).reshape(-1, step).mean(axis=1)
    x_means = x[:-step:step] + (x[step] - x[0]) / 2.0
    plt.plot(x, y, color=color, alpha=0.3, linestyle='-')
    plt.plot(x_means, y_means, color=color, linestyle='-', alpha=0.5)
    plt.plot(x_trend, y_trend, color=color, linestyle='-', label=label)


def display_loss_plots(train_history):
    plt.figure(figsize=[14, 8])
    epochs = np.arange(1, len(train_history.history['generator_loss'])+1)
    plt.subplot(2,2,1)
    plt.title('Training')
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['generator_loss'][1:],
        color='g',
        label='generator'
    )
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['discriminator_loss'][1:],
        color='orange',
        label='discriminator'
    )
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.grid(True)
    plt.legend()

    plt.subplot(2,2,2)
    plt.title('Validation')
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['val_generator_loss'][1:],
        color='g',
        label='generator'
    )
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['val_discriminator_loss'][1:],
        color='orange',
        label='discriminator'
    )
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.grid(True)
    plt.legend()

    plt.subplot(2,2,3)
    plt.title('Training Accuracy')
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['cheater_accuracy'][1:],
        color='g',
        label='generator'
    )
    plt.xlabel('epochs')
    plt.ylabel('accuracy')
    plt.grid(True)
    plt.legend()

    plt.subplot(2,2,4)
    plt.title('Validation Accuracy')
    display_smooth_plot(
        x=epochs[1:],
        y=train_history.history['val_cheater_accuracy'][1:],
        color='g',
        label='generator'
    )
    plt.xlabel('epochs')
    plt.ylabel('accuracy')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    
    plt.show()

In [None]:
display_loss_plots(train_history)

### Conditional GAN with classifier

#### Classifier

In [None]:
def build_classifier_model():
    model = models.Sequential(name='classifier')
    model.add(layers.Input(shape=(8, 8, 1)))
    model.add(layers.Conv2D(16, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(32, (3, 3), activation='relu', padding='same'))
    model.add(layers.Flatten())
    model.add(layers.Dense(32, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    return model

In [None]:
classifier_model = build_classifier_model()
classifier_model.summary()

In [None]:
classifier_model.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3), 
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[metrics.SparseCategoricalAccuracy(),]
)

In [None]:
classifier_model.fit(
    train_ds,
    epochs=30
)

In [None]:
classifier_model.evaluate(test_ds)

In [None]:
# Turn off training
classifier_model.trainable = False

#### GAN generator

In [None]:
def build_simple_generator_model(noise_dim=NOISE_DIM):
    # inputs
    inputs_noise = layers.Input(shape=(noise_dim,))
    inputs_target = layers.Input(shape=(1,))

    # refactor inputs
    x_noise = layers.Dense(4*4*NUM_CLASSES, use_bias=False)(inputs_noise)
    x_target = layers.CategoryEncoding(
        num_tokens=NUM_CLASSES, 
        output_mode='one_hot'
    )(inputs_target)

    # combine flows from different inputs
    x = layers.Concatenate()([x_noise, x_target])

    # generator
    x = layers.Dense(16, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(32, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dense(8 * 8 * 1, activation='tanh')(x)

    # output
    output = layers.Reshape((8, 8, 1))(x)

    return models.Model(
        inputs=[inputs_noise, inputs_target], 
        outputs=output, 
        name='generator'
    )

In [None]:
def build_advanced_generator_model(noise_dim=NOISE_DIM):

    # inputs
    inputs_noise = layers.Input(shape=(noise_dim,))
    inputs_target = layers.Input(shape=(1,))
    
    # refactor inputs
    x_noise = layers.Dense(4*4*NUM_CLASSES, use_bias=False)(inputs_noise)
    x_target = layers.CategoryEncoding(num_tokens=NUM_CLASSES, output_mode="one_hot")(inputs_target)
    # x_target = layers.Dense(10, activation='relu')(x_target)
    x_target = layers.Dense(512, activation='relu')(x_target)
    x_target = layers.Dense(4*4*NUM_CLASSES, activation='relu')(x_target)
    x_target = layers.LayerNormalization()(x_target)
    
    # generator
    x = layers.Add()([x_noise, x_target])
    x = layers.Reshape(target_shape=(4, 4, NUM_CLASSES))(x)
    # upscale
    # 4x4
    # x = layers.Conv2DTranspose(
    #     filters=64, 
    #     kernel_size=(3, 3), 
    #     strides=(1, 1), 
    #     padding='same'
    # )(x)
    # x = layers.BatchNormalization()(x)         
    # x = layers.LeakyReLU()(x)
    # 8x8
    x = layers.Conv2DTranspose(
        filters=16, 
        kernel_size=(3, 3), 
        strides=(2, 2), 
        padding='same'
    )(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    
    # output image
    output = layers.Conv2DTranspose(
        filters=1, 
        kernel_size=(5, 5), 
        strides=(1, 1), 
        padding='same',
        activation='tanh'
    )(x)
    return Model(
        inputs=[inputs_noise, inputs_target], 
        outputs=output, 
        name='generator'
    )

#### Building conditional GAN with classifier

In [None]:
class ConditionalClassifierGANModel(Model):

    def __init__(self, generator_model, discriminator_model, classifier_model, **kwargs):
        super().__init__(**kwargs) 
        # models
        self.generator = generator_model
        self.discriminator = discriminator_model
        self.classifier = classifier_model
        # optimizers
        self.generator_optimizer = optimizers.Adam(1e-4)
        self.discriminator_optimizer = optimizers.Adam(1e-4)
        # metrics
        self.train_generator_loss = metrics.Mean(name='train_generator_loss')
        self.train_discriminator_loss = metrics.Mean(name='train_discriminator_loss')
        self.test_generator_loss = metrics.Mean(name='test_generator_loss')
        self.test_discriminator_loss = metrics.Mean(name='test_discriminator_loss')
        self.train_cheater_accuracy = metrics.BinaryAccuracy(name='train_accuracy')
        self.test_cheater_accuracy = metrics.BinaryAccuracy(name='test_accuracy')

    def call(self, inputs, training=False):
        y_fake = inputs
        noise = tf.random.normal([tf.shape(y_fake)[0], NOISE_DIM])
        X_fake = self.generator([noise, y_fake])
        return X_fake

    def compute_discriminator_loss(self, X_true, X_fake):
        y_true_logits = self.discriminator(X_true)
        y_fake_logits = self.discriminator(X_fake)
        loss_true = losses.binary_crossentropy(
            y_true=tf.ones_like(y_true_logits), 
            y_pred=y_true_logits, 
            from_logits=True
        )
        loss_fake = losses.binary_crossentropy(
            y_true=tf.zeros_like(y_fake_logits), 
            y_pred=y_fake_logits,
            from_logits=True
        )
        total_loss = (loss_true + loss_fake) / 2.0
        return tf.reduce_mean(total_loss), y_fake_logits

    def compute_generator_loss(self, X_fake, y_fake):
        y_fake_logits = self.discriminator(X_fake)
        y_fake_target_probs = self.classifier(X_fake)
        generator_loss = losses.binary_crossentropy(
            y_true=tf.ones_like(y_fake_logits), 
            y_pred=y_fake_logits, 
            from_logits=True
        )
        classifier_loss = losses.sparse_categorical_crossentropy(
            y_true=y_fake,
            y_pred=y_fake_target_probs
        )
        total_loss = generator_loss + 2.0 * classifier_loss
        return tf.reduce_mean(total_loss)

    def compute_generator_loss_without_classifier(self, X_fake, y_fake):
        y_fake_logits = self.discriminator(X_fake)
        y_fake_target_probs = self.classifier(X_fake)
        generator_loss = losses.binary_crossentropy(
            y_true=tf.ones_like(y_fake_logits), 
            y_pred=y_fake_logits, 
            from_logits=True
        )
        return tf.reduce_mean(generator_loss)
    
    @tf.function
    def train_step(self, data):

        # unpack input data
        X_true, _ = data
        
        with (
            tf.GradientTape() as tape_gen, 
            tf.GradientTape() as tape_disc
        ):
            # number of images to generate
            batch_size = tf.shape(X_true)[0]
            # generate y for fake images
            y_fake = tf.random.uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
            # generate fake images
            X_fake = self.call(y_fake, training=True)
            # compute losses
            discriminator_loss, y_fake_logits = self.compute_discriminator_loss(X_true, X_fake)
            generator_loss = self.compute_generator_loss(X_fake, y_fake)
            
        # compute gradients and update weights
        gradients_discriminator = tape_disc.gradient(
            target=discriminator_loss, 
            sources=self.discriminator.trainable_variables
        )
        self.discriminator_optimizer.apply_gradients(
            zip(gradients_discriminator, self.discriminator.trainable_variables)
        )
        
        gradients_generator = tape_gen.gradient(
            target=generator_loss, 
            sources=self.generator.trainable_variables
        )
        self.generator_optimizer.apply_gradients(
            zip(gradients_generator, self.generator.trainable_variables)
        )

        # update metrics
        self.train_discriminator_loss.update_state(discriminator_loss)
        self.train_generator_loss.update_state(generator_loss)
        self.train_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        
        return {
            'generator_loss': self.train_generator_loss.result(), 
            'discriminator_loss': self.train_discriminator_loss.result(),
            'cheater_accuracy': self.train_cheater_accuracy.result()
        }

    def test_step(self, data):
        X_true, _ = data
        batch_size = tf.shape(X_true)[0]
        y_fake = tf.random.uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
        # generate fake images
        X_fake = self.call(y_fake, training=False)
        # compute losses
        generator_loss = self.compute_generator_loss(X_fake, y_fake)
        discriminator_loss, y_fake_logits = self.compute_discriminator_loss(X_true, X_fake)
        # update metrics
        self.test_discriminator_loss.update_state(discriminator_loss)
        self.test_generator_loss.update_state(generator_loss)
        self.test_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        return {
            'generator_loss': self.test_generator_loss.result(), 
            'discriminator_loss': self.test_discriminator_loss.result(),
            'cheater_accuracy': self.test_cheater_accuracy.result()
        }


In [None]:
NUM_EPOCHS = 700

cc_gan_model = ConditionalClassifierGANModel(
    generator_model=build_simple_generator_model(),
    discriminator_model=build_simple_discriminator_model(),
    classifier_model=classifier_model
)
cc_gan_model.summary()

In [None]:
# NUM_EPOCHS = 300

# cc_gan_model = ConditionalClassifierGANModel(
#     generator_model=build_advanced_generator_model(),
#     discriminator_model=build_advanced_discriminator_model(),
#     classifier_model=classifier_model
# )
# cc_gan_model.summary()

In [None]:
cc_gan_model.compile()

In [None]:
# NUM_EPOCHS = 300

train_history = cc_gan_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=NUM_EPOCHS,
    verbose=1
)
train_history

In [None]:
# Generate images given targets
fake_targets = tf.constant([5, 5, 5, 9, 9, 9, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4])
fake_images = cc_gan_model(fake_targets)
fake_images = tf.clip_by_value(fake_image, 0, 1)
display_images(fake_images)

In [None]:
display_loss_plots(train_history)

### Conditional GAN

In [None]:
def build_simple_discriminator_model():
    # inputs
    inputs_image = layers.Input(shape=(8, 8, 1))
    inputs_target = layers.Input(shape=(1,))
    # discriminator
    x = layers.CategoryEncoding(num_tokens=NUM_CLASSES, output_mode='one_hot')(inputs_target)
    x = layers.Concatenate()([layers.Flatten()(inputs_image), x])
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.1)(x)
    # output
    output = layers.Dense(1)(x)
    return models.Model(
        inputs=[inputs_image, inputs_target], 
        outputs=output,
        name='discriminator'
    )
    

def build_advanced_discriminator_model():
    # inputs
    inputs_image = layers.Input(shape=(8, 8, 1))
    inputs_target = layers.Input(shape=(1,))
    # refactor target input
    x_target = layers.CategoryEncoding(num_tokens=10, output_mode="one_hot")(inputs_target)
    x_target = layers.Dense(32, activation='relu')(x_target)
    # discriminator
    x = layers.Conv2D(16, (3, 3), activation="relu", padding="same")(inputs_image)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Concatenate()([layers.Flatten()(x), x_target])
    x = layers.Dense(32, activation="relu")(x)
    x = layers.Dropout(0.1)(x)
    # output
    output = layers.Dense(1)(x)
    return models.Model(
        inputs=[inputs_image, inputs_target], 
        outputs=output,
        name='discriminator'
    )


In [None]:
class ConditionalGANModel(Model):

    def __init__(self, generator_model, discriminator_model, **kwargs):
        super().__init__(**kwargs) 
        # models
        self.generator = generator_model
        self.discriminator = discriminator_model
        # optimizers
        self.generator_optimizer = optimizers.Adam(1e-4)
        self.discriminator_optimizer = optimizers.Adam(1e-4)
        # metrics
        self.train_generator_loss = metrics.Mean(name='train_generator_loss')
        self.train_discriminator_loss = metrics.Mean(name='train_discriminator_loss')
        self.test_generator_loss = metrics.Mean(name='test_generator_loss')
        self.test_discriminator_loss = metrics.Mean(name='test_discriminator_loss')
        self.train_cheater_accuracy = metrics.BinaryAccuracy(name='train_accuracy')
        self.test_cheater_accuracy = metrics.BinaryAccuracy(name='test_accuracy')

    def call(self, inputs, training=False):
        y_fake = inputs
        noise = tf.random.normal([tf.shape(y_fake)[0], NOISE_DIM])
        X_fake = self.generator([noise, y_fake])
        return X_fake

    def compute_discriminator_loss(self, X_true, X_fake, y_true, y_fake):
        y_true_logits = self.discriminator([X_true, y_true])
        y_fake_logits = self.discriminator([X_fake, y_fake])
        loss_true = losses.binary_crossentropy(
            y_true=tf.ones_like(y_true_logits), 
            y_pred=y_true_logits, 
            from_logits=True
        )
        loss_fake = losses.binary_crossentropy(
            y_true=tf.zeros_like(y_fake_logits), 
            y_pred=y_fake_logits,
            from_logits=True
        )
        total_loss = (loss_true + loss_fake) / 2.0
        return tf.reduce_mean(total_loss), y_fake_logits
    
    def compute_generator_loss(self, X_fake, y_fake):
        y_fake_logits = self.discriminator([X_fake, y_fake])
        loss = losses.binary_crossentropy(
            y_true=tf.ones_like(y_fake_logits), 
            y_pred=y_fake_logits, 
            from_logits=True
        )
        return tf.reduce_mean(loss)

    
    @tf.function
    def train_step(self, data):
        
        # unpack input data
        X_true, y_true = data
        
        with (
            tf.GradientTape() as tape_gen, 
            tf.GradientTape() as tape_disc
        ):
            # number of images to generate
            batch_size = tf.shape(X_true)[0]
            # generate y for fake images
            y_fake = tf.random.uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
            # generate fake images
            X_fake = self.call(y_fake, training=True)
            # compute losses
            discriminator_loss, y_fake_logits = self.compute_discriminator_loss(
                X_true=X_true, 
                X_fake=X_fake, 
                y_true=y_true, 
                y_fake=y_fake
            )
            generator_loss = self.compute_generator_loss(X_fake, y_fake)

        # compute gradients and update weights
        gradients_discriminator = tape_disc.gradient(
            target=discriminator_loss, 
            sources=self.discriminator.trainable_variables
        )
        self.discriminator_optimizer.apply_gradients(
            zip(gradients_discriminator, self.discriminator.trainable_variables)
        )
        
        gradients_generator = tape_gen.gradient(
            target=generator_loss, 
            sources=self.generator.trainable_variables
        )
        self.generator_optimizer.apply_gradients(
            zip(gradients_generator, self.generator.trainable_variables)
        )

        # update metrics
        self.train_discriminator_loss.update_state(discriminator_loss)
        self.train_generator_loss.update_state(generator_loss)
        self.train_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        
        return {
            'generator_loss': self.train_generator_loss.result(), 
            'discriminator_loss': self.train_discriminator_loss.result(),
            'cheater_accuracy': self.train_cheater_accuracy.result()
        }
    
    def test_step(self, data):
        X_true, y_true = data
        batch_size = tf.shape(X_true)[0]
        y_fake = tf.random.uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
        # generate fake images
        X_fake = self.call(y_fake, training=False)
        # compute losses
        generator_loss = self.compute_generator_loss(X_fake, y_fake)
        discriminator_loss, y_fake_logits = self.compute_discriminator_loss(X_true, X_fake, y_true, y_fake)
        # update metrics
        self.test_discriminator_loss.update_state(discriminator_loss)
        self.test_generator_loss.update_state(generator_loss)
        self.test_cheater_accuracy.update_state(
            y_true=tf.ones_like(y_fake_logits),
            y_pred=tf.nn.sigmoid(y_fake_logits)
        )
        return {
            'generator_loss': self.test_generator_loss.result(), 
            'discriminator_loss': self.test_discriminator_loss.result(),
            'cheater_accuracy': self.test_cheater_accuracy.result()
        }


In [None]:
NUM_EPOCHS = 700

c_gan_model = ConditionalGANModel(
    generator_model=build_simple_generator_model(),
    discriminator_model=build_simple_discriminator_model()
)
c_gan_model.summary()

In [None]:
# NUM_EPOCHS = 500

# c_gan_model = ConditionalGANModel(
#     generator_model=build_advanced_generator_model(),
#     discriminator_model=build_advanced_discriminator_model()
# )
# c_gan_model.summary()

In [None]:
c_gan_model.compile()

In [None]:
# NUM_EPOCHS = 300

train_history = c_gan_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=NUM_EPOCHS,
    verbose=1 
)

In [None]:
# Generate images given targets
fake_targets = tf.constant([5, 5, 5, 9, 9, 9, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4])
fake_images = cc_gan_model(fake_targets)
fake_images = tf.clip_by_value(fake_image, 0, 1)
display_images(fake_images)

In [None]:
display_loss_plots(train_history)

Generated images are mostly identical within the given target. How to solve this problem?

### Generating image based on text prompt

#### Loading word embeddings

In [None]:
def load_vectors(path_to_file):
    """Load words and their weights from file."""
    words = list()
    embeddings = list()
    with open(path_to_file) as f:
        for line in f:
            word, coefs = line.split(maxsplit=1)
            coefs = np.fromstring(coefs, 'f', sep=' ')
            words.append(word)
            embeddings.append(coefs)
    return np.array(words), np.array(embeddings)

In [None]:
EMBEDDING_DIM = 100
FILEPATH = f'/media/sf_practice/data/debug_glove/glove.6B/glove.6B.{EMBEDDING_DIM}d.txt'

# Load words and their embeddings
words, embeddings = load_vectors(FILEPATH)
words[:5]

#### Embeddings for digits

In [None]:
labels = ['zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine']

In [None]:
E_digits = np.zeros((NUM_CLASSES, EMBEDDING_DIM), dtype='float32')
for index, label in enumerate(labels):
    E_digits[index] = embeddings[np.isin(words, label, assume_unique=True)]

E_digits.shape

#### Target decoder as layer

In [None]:
class TargetDecoder(layers.Layer):
    
    def __init__(self, E_digits):
        super().__init__()
        self.E_digits = self.add_weight(
            initializer=tf.constant_initializer(E_digits),
            trainable=False,
            dtype=tf.float32,
            shape=E_digits.shape,
        )
    
    def call(self, inputs):
        self.E_digits
        target = tf.math.argmax(
            tf.matmul(
                # inputs, 
                # E_digits, 
                tf.math.l2_normalize(inputs, axis=1), 
                tf.math.l2_normalize(E_digits, axis=1), 
                transpose_b=True
            ), axis=-1
        )
        return target

In [None]:
# Initialize target decoder
target_decoder = TargetDecoder(E_digits)

In [None]:
# Decode digit target from word
Q = embeddings[np.argmax(words == 'five')][np.newaxis, :]
target_decoder(Q)

#### Decoder block

In [None]:
MAX_TEXT_LENGTH = 20
NUM_FEATURES = len(words) + 2

In [None]:
# Weights of embedding laye
E = np.zeros((NUM_FEATURES, EMBEDDING_DIM))
# E[1] = np.random.normal(0, 0.1, EMBEDDING_DIM) # [UNK]
E[2:] = embeddings

In [None]:
def build_vectorizer_layer():
    # setup vectorizer layer
    vectorizer_layer = layers.TextVectorization(
        max_tokens=NUM_FEATURES, 
        output_sequence_length=MAX_TEXT_LENGTH,
        output_mode="int"
    )
    # set vocabulary
    vectorizer_layer.set_vocabulary(words)
    return vectorizer_layer


def build_embedding_layer():
    # setup embedding layer
    embedding_layer = layers.Embedding(
        input_dim=NUM_FEATURES,
        output_dim=EMBEDDING_DIM,
        trainable=False  # disable training
    )
    # initialize weights
    embedding_layer.build((1, ))
    # set weights
    embedding_layer.set_weights([E])
    return embedding_layer

In [None]:
def build_text_encoder_on_glove():
    model = models.Sequential()
    model.add(build_vectorizer_layer())
    model.add(build_embedding_layer())
    model.add(layers.GlobalAveragePooling1D())
    model.add(TargetDecoder(E_digits))
    return model

In [None]:
# Initialize text decoder
text_decoder = build_text_encoder_on_glove()

In [None]:
texts = tf.constant([
    'Generate digit two', 
    'five', 
    'First digit'
])

# Decode texts
targets = text_decoder(texts)
targets

In [None]:
# Generate images
generated_images = c_gan_model(targets)
display_images(generated_images)

In [None]:
class ImageDigitGenerator:

    def __init__(self, text_decoder_model, image_generator_model):
        self.text_decoder = text_decoder_model
        self.image_generator = image_generator_model

    def generate(self, texts):
        targets = self.text_decoder(texts)
        return self.image_generator(targets)

    def generate_and_display(self, texts):
        generated_images = self.generate(texts)
        display_images(generated_images)
        return generated_images


In [None]:
image_generator = ImageDigitGenerator(
    text_decoder_model=text_decoder,
    image_generator_model=c_gan_model
)

In [None]:
generated_images = image_generator.generate(texts)
display_images(generated_images)

In [None]:
generated_images = image_generator.generate_and_display(texts)
generated_images.shape

## Stable Diffusion Model (Pretrained)

In [None]:
# %pip install --upgrade keras-cv tensorflow-datasets

In [None]:
import keras_cv

In [None]:
# Stable diffusion v1.5 (it downloads weights on first call)
model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512
)

In [None]:
prompt = "photograph of an astronaut riding a horse"

**Pipeline**

In [None]:
# Generate images
images = model.text_to_image(
    prompt=prompt,
    batch_size=3,
    num_steps=20
)

In [None]:
images.shape

In [None]:
display_images(images)

**Separate steps**

In [None]:
prompt_encoded = model.encode_text(prompt)
prompt_encoded.shape

In [None]:
images = model.generate_image(
    encoded_text=prompt_encoded,
    batch_size=1,
    num_steps=20,
    unconditional_guidance_scale=7.5,
    negative_prompt=None
)
images.shape

In [None]:
plt.imshow(images[0])
plt.axis('off')
plt.show()

## Sources

- [Conditional GAN](https://keras.io/examples/generative/conditional_gan/)
- [High-performance image generation using Stable Diffusion in KerasCV](https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion)
- [Stable Diffusion 3 in KerasHub!](https://keras.io/keras_hub/guides/stable_diffusion_3_in_keras_hub/)