In [None]:
import tensorflow as tf
import matplotlib .pyplot as plt
import numpy as np
import tensorflow_datasets as tfds


gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
gpus

In [None]:
dataset = tfds.load("fashion_mnist", split = "train")

data_iterator = dataset.as_numpy_iterator()

# Since we are not loading the entire dataset, we are only loading batches of the data to prevent saturating our memory.
data_iterator.next()

plt.figure(figsize = (6,6))

for i in range(6):
    plt.subplot(3,3,i+1)
    sample_img = data_iterator.next()
    plt.imshow(np.squeeze(sample_img["image"]))
    plt.title(sample_img["label"])

plt.show()

In [None]:
# We then create a function to normalize the images from 0-255 to 0-1

def scale_imgs(data):
    image = data["image"]
    return image/255

# We can take the loaded data from tensorflow datasets and perfrom preprocessing.

dataset = dataset.map(scale_imgs)
dataset = dataset.cache()
dataset = dataset.shuffle(60000)
dataset = dataset.batch(128)
dataset = dataset.prefetch(64)

dataset.as_numpy_iterator().next().shape

In [None]:
from tensorflow.keras import layers

In [None]:
def generator_model():
    model = tf.keras.Sequential()
    # input_dim is the latent space in which the images are generated
    # the image is 7*7 * a random variable 128.
    model.add(tf.keras.layers.Dense(7*7*128, input_dim=128))
    model.add(tf.keras.layers.LeakyReLU(0.2))
    model.add(tf.keras.layers.Reshape((7,7,128)))

    # Upsampling block
    model.add(tf.keras.layers.UpSampling2D())
    model.add(tf.keras.layers.Conv2D(128, 5, padding = "same"))
    model.add(tf.keras.layers.LeakyReLU(0.2))

    # Upsampling block 2
    model.add(tf.keras.layers.UpSampling2D())
    model.add(tf.keras.layers.Conv2D(128, 5, padding = "same"))
    model.add(tf.keras.layers.LeakyReLU(0.2))

    # Convolutional block 1
    model.add(tf.keras.layers.Conv2D(128, 4, padding = "same"))
    model.add(tf.keras.layers.LeakyReLU(0.2))

    # Convolutional block 2
    model.add(tf.keras.layers.Conv2D(128, 4, padding = "same"))
    model.add(tf.keras.layers.LeakyReLU(0.2))

    # Add a convolutional layer to add the one  channel since image is 28,28,1.
    model.add(tf.keras.layers.Conv2D(1, 4, padding = "same", activation = "sigmoid"))

    return model

generator = generator_model()

generator.summary()

In [None]:
# Build the Discriminator model

def discriminator_model():
    model = tf.keras.Sequential()

    model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=3, activation="LeakyReLU", input_shape=[28,28,1]))
    model.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    model.add(tf.keras.layers.Dropout(0.4))

    model.add(tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation="LeakyReLU"))
    model.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    model.add(tf.keras.layers.Dropout(0.4))


    model.add(tf.keras.layers.Conv2D(filters=256, kernel_size=3, activation="LeakyReLU"))
    model.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    model.add(tf.keras.layers.Dropout(0.4))

    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(units=1, activation="sigmoid"))

    return model

discriminator = discriminator_model()
discriminator.summary()

In [None]:
from tensorflow .keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model

In [None]:
generator_optimizer = Adam(learning_rate=0.0001)
discriminator_optimizer = Adam(learning_rate=0.00001)

generator_loss = BinaryCrossentropy()
discriminator_loss = BinaryCrossentropy()

In [None]:
class my_gan(Model):
    def __init__(self, generator, discriminator, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.generator = generator
        self.discriminator = discriminator

    def compile(self, generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss, *args, **kwargs):
        super().compile(*args, **kwargs)

        self.generator_optimizer = generator_optimizer
        self.discriminator_optimizer = discriminator_optimizer
        self.generator_loss = generator_loss
        self.discriminator_loss = discriminator_loss

    def train_step(self, batch):
        # Get the batch images of 128
        real_images = batch # Here we get a batch of real images
        fake_images = self.generator(tf.random.normal((128,128,1)), training=False)

        # Lets first train the discriminator
        with tf.GradientTape() as d_tape:
            # This involves passing the real and fake images through the discriminator model
            yhat_real = self.discriminator(real_images, training=True) #Training=True so that the dropout layers are activated
            yhat_fake = self.discriminator(fake_images, training=True)
            yhat_realfake = tf.concat([yhat_real, yhat_fake], axis=0) # combine the two above into one set of outputs

            # Then we create labels( i.e zeros-->Real Images & ones--> Fake Images)
            # These will be labels from the discriminator output
            y_realfake = tf.concat([tf.zeros_like(yhat_real), tf.ones_like(yhat_fake)], axis=0)

            # Then add some noise to the outputs
            # The noise is injected into our TRUE outputs
            real_noise = 0.5*tf.random.uniform(tf.shape(yhat_real))
            fake_noise = -0.58*tf.random.uniform(tf.shape(yhat_fake))
            y_realfake += tf.concat([real_noise, fake_noise], axis=0)

            # Calculate the loss
            total_discriminator_loss = self.discriminator_loss(y_realfake, yhat_realfake)

        # backpropagation
        discriminator_grad = d_tape.gradient(total_discriminator_loss, self.discriminator.trainable_variables)
        self.discriminator_optimizer.apply_gradients(zip(discriminator_grad, self.discriminator.trainable_variables))


        # We can then train the generator
        with tf.GradientTape() as g_tape:
            # We first generate some new images
            generated_images = self.generator(tf.random.normal((128,128,1)), training=True)

            # We then create the predicted labels
            # Note in the discriminator, the real images are "Zeros". We want to trick the discriminator and provide "Zeros"
            #for the generated images. Hence this will also help in calculating the loss.
            predicted_labels = self.discriminator(generated_images, training=False)
            # Training is false because we do not want the discriminator to learn while training the generator.
            # The calculated loss is then given by:
            total_generator_loss = self.generator_loss(tf.zeros_like(predicted_labels), predicted_labels)
            # The calculated loss is actually the trickto training. This is because we are passing the generated labels as zeros
            # in order to confuse the discriminator. Everytime a generated image is passed as "real" by the discriminator, it is
            # rewarded.

        # backpropagation
        generator_grad = g_tape.gradient(total_generator_loss, self.generator.trainable_variables)
        self.generator_optimizer.apply_gradients(zip(generator_grad, self.generator.trainable_variables))

        return {"discriminator_loss":total_discriminator_loss, "generator_loss":total_generator_loss}

In [None]:
# Create an instance of the "my_gan" Class.
image_gan = my_gan(generator, discriminator)

# We then compile the model by passing through the losses and optimizers.
image_gan.compile(generator_optimizer, discriminator_optimizer, generator_loss, discriminator_loss)

In [None]:
from tensorflow.keras.preprocessing.image import array_to_img

In [None]:
hist = image_gan.fit(dataset, epochs=200)

In [None]:
hist.history

In [None]:
plt.suptitle("Total_Model_Loss")
plt.plot(hist.history["discriminator_loss"], label="discriminator_loss")
plt.plot(hist.history["generator_loss"], label="generator_loss")
plt.legend()
plt.show()

## Use trained generator to generate images

In [None]:
Test_image = generator.predict(tf.random.normal((15, 128, 1)))
Test_image

In [None]:
plt.figure(figsize = (6,6))

for i, Test_image in enumerate(Test_image):
    plt.subplot(3,3,i+1)
    plt.imshow(Test_image)
    plt.title(i)

plt.show()

## Save the models

In [None]:
# generator.save('generator.h5')
# discriminator.save('discriminator.h5')

## Load up saved model

In [None]:
# generator.load_weights(os.path.join('saved_model', 'generator.h5'))