# Data Generation

## Generative Adversarial Networks

Resources:
* https://keras.io/examples/generative/conditional_gan/
  * See https://github.com/ipython/ipython/issues/10045 for last step (embedding gifs in notebook)

#### Imports

In [1]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import imageio

#### Constants and Hyperparameters

In [2]:
batch_size = 4 # 64
image_size = 224
num_channels = 3
num_classes = 3
latent_dim = 128

train_path="datasets/train"

#### Loading & Pre-processing Dataset

In [3]:
# %run common.py

def make_dataset_generator():
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        brightness_range=(0.5, 1.5),
        rotation_range=20,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True
    )

    train_generator = train_datagen.flow_from_directory(
        train_path,
        target_size=(image_size, image_size),
        batch_size=batch_size,
        class_mode='categorical'
    )

    return train_generator

# For testing
train_generator = make_dataset_generator()
x_train, y_train = next(train_generator)
print(x_train.shape, y_train.shape)

def make_dataset():
    dataset = tf.data.Dataset.from_generator(
        make_dataset_generator, 
        output_signature=(
            tf.TensorSpec(shape=(batch_size, image_size, image_size, num_channels)), 
            tf.TensorSpec(shape=(batch_size, num_classes))
            )
    )
    return dataset


#### Calculating number of input channels for generator & discriminator

i.e. the amount of noise required.

In [4]:
generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes

#### Creating the generator & discriminator

In [5]:
def make_generator():
    generator = keras.Sequential(
        [
            keras.layers.InputLayer((generator_in_channels,)),

            # We want to generate 128 + num_classes coefficients to reshape into a
            # 224x224x(128 + num_classes) map.
            layers.Dense(14 * 14 * generator_in_channels),
            layers.LeakyReLU(alpha=0.2),
            layers.Reshape((14, 14, generator_in_channels)),

            # Expand input to a 224x224x128 map.
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),

            # Create a 224x224xnum_channels image.
            layers.Conv2D(num_channels, (7, 7), padding="same", activation="sigmoid"),
        ],
        name="generator",
    )
    return generator


def make_discriminator():
    discriminator = keras.Sequential(
        [
            keras.layers.InputLayer((image_size, image_size, discriminator_in_channels)),

            layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),

            layers.GlobalMaxPooling2D(),
            layers.Dense(1),
        ],
        name="discriminator",
    )
    return discriminator


# Check dimensions
generator = make_generator()
generator.summary()
print("Input: ", generator.input.shape, "\n\n")

discriminator=make_discriminator()
discriminator.summary()
print("Input: ", discriminator.input.shape, "\n\n")

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 25676)             3389232   
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 25676)             0         
_________________________________________________________________
reshape (Reshape)            (None, 14, 14, 131)       0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 28, 28, 128)       268416    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 28, 28, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 56, 56, 128)       262272    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 56, 56, 128)       0 

#### Creating a `ConditionalGAN` model

In [6]:
class ConditionalGAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim):
        super(ConditionalGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss")
        self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")

    @property
    def metrics(self):
        return [self.gen_loss_tracker, self.disc_loss_tracker]
    
    def call(self, inputs):
        pass
    
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(ConditionalGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, data):
        # Unpack the data.
        real_images, one_hot_labels = data

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        # TODO: confirm this isn't necessary
        # image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(
            one_hot_labels, repeats=[image_size * image_size]
        )
        image_one_hot_labels = tf.reshape(
            image_one_hot_labels, (-1, image_size, image_size, num_classes)
        )

        # Generate random labels for the generator.
        # Sample random points in the latent space and concatenate the labels.
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Decode the noise (guided by labels) to fake images.
        generated_images = self.generator(random_vector_labels)

        # Combine them with real images. Note that we are concatenating the labels
        # with these images here.
        fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)
        real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)
        combined_images = tf.concat(
            [fake_image_and_labels, real_image_and_labels], axis=0
        )

        # Assemble labels discriminating real from fake images.
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Train the discriminator.
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space.
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_vector_labels = tf.concat(
            [random_latent_vectors, one_hot_labels], axis=1
        )

        # Assemble labels that say "all real images".
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_vector_labels)
            fake_image_and_labels = tf.concat([fake_images, image_one_hot_labels], -1)
            predictions = self.discriminator(fake_image_and_labels)
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Monitor loss.
        self.gen_loss_tracker.update_state(g_loss)
        self.disc_loss_tracker.update_state(d_loss)
        return {
            "g_loss": self.gen_loss_tracker.result(),
            "d_loss": self.disc_loss_tracker.result(),
        }
    
    def get_config(self):
        return {
            "latent_dim": self.latent_dim,
            "gen_loss_tracker": self.gen_loss_tracker.result().numpy(),
            "disc_loss_tracker": self.disc_loss_tracker.result().numpy()
        }

#### Training the Conditional GAN

> I accidentally disconnected the Jupyter notebook halfway, which is why the output didn't get streamed back here.

In [None]:
cond_gan = ConditionalGAN(
    discriminator=discriminator, generator=generator, latent_dim=latent_dim
)
cond_gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

cond_gan.fit(make_dataset_generator(), epochs=20)

Found 274 images belonging to 3 classes.
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20

#### Saving the model

In [52]:
# %%capture output

print(cond_gan.get_config())

# cond_gan.save_weights("models/gan_1_weights/")
# cond_gan.discriminator.save("models/gan_1/discriminator/")
# cond_gan.generator.save("models/gan_1/generator/")
cond_gan.save("models/gan_1/", save_traces=False)

# output.show()

{'latent_dim': 128, 'gen_loss_tracker': 1.1581442, 'disc_loss_tracker': 0.57097614}
INFO:tensorflow:Assets written to: models/gan_1/assets


In [53]:
reloaded_model = tf.keras.models.load_model("models/gan_1")

reloaded_model.generator.get_config()

{'name': 'generator',
 'layers': [{'class_name': 'InputLayer',
   'config': {'batch_input_shape': (None, 131),
    'dtype': 'float32',
    'sparse': False,
    'ragged': False,
    'name': 'input_1'}},
  {'class_name': 'Dense',
   'config': {'name': 'dense',
    'trainable': True,
    'dtype': 'float32',
    'units': 25676,
    'activation': 'linear',
    'use_bias': True,
    'kernel_initializer': {'class_name': 'GlorotUniform',
     'config': {'seed': None}},
    'bias_initializer': {'class_name': 'Zeros', 'config': {}},
    'kernel_regularizer': None,
    'bias_regularizer': None,
    'activity_regularizer': None,
    'kernel_constraint': None,
    'bias_constraint': None}},
  {'class_name': 'LeakyReLU',
   'config': {'name': 'leaky_re_lu',
    'trainable': True,
    'dtype': 'float32',
    'alpha': 0.20000000298023224}},
  {'class_name': 'Reshape',
   'config': {'name': 'reshape',
    'trainable': True,
    'dtype': 'float32',
    'target_shape': (14, 14, 131)}},
  {'class_name': '

#### Verifying the trained generator

#### Next Steps

1. Verify
1. `ModelCheckpoint` (or some other means of saving the generator)


GAN from Lecture Notes, for reference.

```python
num_pixels = image_size * image_size * num_channels

def create_generator():
    generator = Sequential()
    generator.add(Dense(units = 256, input_dim = 100))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(units = 512))
    generator.add(LeakyReLU(0.2))
    generator.add(Dense(units = num_pixels, activation = 'tanh'))
    generator.compile(loss = 'binary_crossentropy', optimizer = adam_optimizer())

    return generator

def create_discriminator():
    discriminator = Sequential()
    discriminator.add(Dense(units = 1024, input_dim = num_pixels))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(units = 512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))
    
    discriminator.add(Dense(units = 256))
    discriminator.add(LeakyReLU(0.2))
    
    discriminator.add(Dense(units = 1, activation = 'sigmoid'))

    discriminator.compile(loss = 'binary_crossentropy', optimizer = adam_optimizer())
    return discriminator
```

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=ddd5f9cd-14c1-4462-aac1-a464a84065be' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>