In [1]:
import os, time, math, glob
import imageio
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

In [2]:
# Load the MNIST fashion dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

x_train.shape, x_test.shape

((60000, 28, 28), (10000, 28, 28))

In [3]:
# Preprocess data
def preprocess(data):
    # Add channel axis
    data = data[..., tf.newaxis]
    # Normalize between [-1, 1]; due to tanh activation used
    norm_factor = 255. / 2.
    return (data - norm_factor) / norm_factor

def deprocess(data):
    norm_factor = 255. / 2.
    return data * norm_factor + norm_factor

x_train, x_test = preprocess(x_train), preprocess(x_test)

x_train.shape, x_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

In [4]:
# Data constants definition
BUFFER_SIZE = 60000
BATCH_SIZE = 256

In [6]:
# Batch and shuffle dataset
train_ds = tf.data.Dataset.from_tensor_slices(x_train).shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(buffer_size=BUFFER_SIZE).cache()
test_ds = tf.data.Dataset.from_tensor_slices(x_test).shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(buffer_size=BUFFER_SIZE).cache()

In [7]:
# GAN model
class GAN():
    def __init__(self, latent_dims, input_shape, batch_size):
        self.latent_dims = latent_dims
        self.image_shape = input_shape
        self.batch_size = batch_size
        
        self.cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)
        
        self.generator_optimizer = keras.optimizers.Adam(1e-4)
        self.discriminator_optimizer = keras.optimizers.Adam(1e-4)
        
        self.generator = self.build_generator()
        self.discriminator = self.build_discriminator()
    
    def build_generator(self):
        model = keras.Sequential([
            keras.layers.Dense(7*7*256, input_shape=(self.latent_dims, )),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Reshape((7, 7, 256)),
            keras.layers.Conv2DTranspose(128, (4, 4), strides=2, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(64, (4, 4), strides=1, padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.LeakyReLU(),
            keras.layers.Conv2DTranspose(1, (4, 4), strides=2, padding='same', activation='tanh')
        ])
        return model
    
    def build_discriminator(self):
        model = keras.Sequential([
            keras.layers.Conv2D(64, (4, 4), strides=1, padding='same', input_shape=self.image_shape),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),
            keras.layers.Conv2D(128, (4, 4), strides=2, padding='same'),
            keras.layers.LeakyReLU(),
            keras.layers.Dropout(0.3),
            keras.layers.Flatten(),
            keras.layers.Dense(1)
        ])
        return model
    
    def generator_loss(self, fakes):
        return self.cross_entropy(tf.ones_like(fakes), fakes)
    
    def discriminator_loss(self, reals, fakes):
        real_loss = self.cross_entropy(tf.ones_like(reals), reals)
        fake_loss = self.cross_entropy(tf.zeros_like(fakes), fakes)
        return real_loss + fake_loss
    
    @tf.function
    def train_step(self, images):
        noise = tf.random.normal(shape=(self.batch_size, self.latent_dims))
        
        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
            generated_images = self.generator(noise, training=True)
            
            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)
            
            generator_loss = self.generator_loss(fakes=fake_output)
            discriminator_loss = self.discriminator_loss(reals=real_output, fakes=fake_output)
            
        generator_gradients = generator_tape.gradient(generator_loss, self.generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.discriminator.trainable_variables)
        
        self.generator_optimizer.apply_gradients(zip(generator_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.discriminator.trainable_variables))
    
    def train(self, dataset, epochs, seed):
        for epoch in range(1, epochs + 1):
            start = time.time()    
            for image_batch in dataset:
                self.train_step(image_batch)
                self.generate_and_save(epoch, seed)
            end = time.time()
            print(f'Epoch: {epoch}/{epochs} | Time: {math.ceil(end - start)} seconds')
        self.generate_and_save(epochs, seed, show=True)
        return True
    
    def generate_and_save(self, epoch, test, show=False):
        generated_images = self.generator(test, training=False)
        
        plt.figure(figsize=(4, 4), dpi=200)
        
        for i in range(generated_images.shape[0]):
            plt.subplot(4, 4, i + 1)
            plt.imshow(deprocess(generated_images[i, :, :, 0]), cmap='gray')
            plt.axis('off')
        
        plt.savefig(f'generated-images/image_at_epoch__{epoch}.png')
        if show:
            plt.show()

In [8]:
# Training constants definition
EPOCHS = 30
LATENT_DIMS = 100
INPUT_SHAPE = (28, 28, 1)
NUM_EXAMPLES = 16

In [9]:
# Create a directory for generated images
if not os.path.isdir('generated-images'):
    os.mkdir('generated-images')

In [10]:
# Random starting seed
seed = tf.random.normal([NUM_EXAMPLES, LATENT_DIMS])

In [11]:
# Initialize a GAN model
gan = GAN(latent_dims=LATENT_DIMS, input_shape=INPUT_SHAPE, batch_size=BATCH_SIZE)

In [12]:
gan.generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 12544)             1266944   
                                                                 
 batch_normalization (BatchN  (None, 12544)            50176     
 ormalization)                                                   
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 12544)             0         
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      524416    
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 14, 14, 128)      5

In [13]:
gan.discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 28, 28, 64)        1088      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 28, 28, 64)        0         
                                                                 
 dropout (Dropout)           (None, 28, 28, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 14, 14, 128)       131200    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 14, 14, 128)       0         
                                                                 
 dropout_1 (Dropout)         (None, 14, 14, 128)       0         
                                                                 
 flatten (Flatten)           (None, 25088)            

In [None]:
# Train gan
gan.train(dataset=train_ds, epochs=EPOCHS, seed=seed)

In [15]:
# Create a gif with generated images
gif_file = 'gan-fashion-mnist-generated.gif'

with imageio.get_writer(gif_file, mode='I') as writer:
  filenames = glob.glob('generated-images/image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

  image = imageio.imread(filename)
  image = imageio.imread(filename)
