In [0]:
import numpy as np
import matplotlib.pyplot as plt
import random

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.layers import Input, Dense, Activation, Dropout, LeakyReLU, BatchNormalization
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D, Flatten, Reshape, ZeroPadding2D


In [0]:
(train_x, train_y), (test_x, test_y) = keras.datasets.mnist.load_data()

train_x = train_x.reshape(60000, 28, 28, 1)
test_x = test_x.reshape(10000, 28, 28, 1)
train_x = train_x.astype('float32')/255
test_x = test_x.astype('float32')/255

z_dim = 100
width = 28
height = 28
channels = 1
img_shape = (width, height, channels)

In [0]:
# Generator
adam = Adam(lr=0.0002, beta_1=0.5)

generator = keras.Sequential(
[
    Dense(128 * 7 * 7, activation = 'relu', input_dim = z_dim),
    Reshape((7, 7, 128)),
    UpSampling2D(),
    Conv2D(128, kernel_size = 3, padding = 'same', activation = 'relu'),
    BatchNormalization(momentum = 0.8),
    UpSampling2D(),
    Conv2D(64, kernel_size = 3, padding = 'same', activation = 'relu'),
    BatchNormalization(momentum=0.8),
    Conv2D(1, kernel_size = 3, padding = 'same', activation = 'tanh')
])

# Discriminator
discriminator = keras.Sequential(
[
    Conv2D(32, kernel_size = 3, strides = 2, input_shape = img_shape, padding = 'same'),
    LeakyReLU(alpha = 0.2),
    Conv2D(64, kernel_size = 3, strides = 2, padding = 'same'),
    BatchNormalization(momentum = 0.8),
    LeakyReLU(alpha = 0.2),
    Conv2D(128, kernel_size = 3, strides = 2, padding = 'same'),
    BatchNormalization(momentum = 0.8),
    LeakyReLU(alpha = 0.2),
    Conv2D(256, kernel_size = 3, strides = 1, padding = 'same'),
    BatchNormalization(momentum = 0.8),
    LeakyReLU(alpha = 0.2),
    Flatten(),
    Dense(1, activation = 'sigmoid')
])

generator.compile(loss = 'binary_crossentropy', optimizer = adam, metrics = ['accuracy'])
discriminator.compile(loss = 'binary_crossentropy', optimizer = adam, metrics = ['accuracy'])

discriminator.trainable = False

inputs = Input(shape = (z_dim,))
hidden = generator(inputs)
output = discriminator(hidden)
gan = keras.Model(inputs, output)
gan.compile(loss = 'binary_crossentropy', optimizer = adam)

In [0]:
def plot_loss(losses):
    '''
    @losses.keys():
        0: loss
        1: accuracy
    '''
    d_loss = [v[0] for v in losses['D']]
    g_loss = [v[0] for v in losses['G']]
    
    plt.figure(figsize = (10,8))
    plt.plot(d_loss, label = 'Discriminator loss')
    plt.plot(g_loss, label = 'Generator loss')

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
def plot_generated(n_ex = 2, dim = (1, 2), figsize = (12, 2)):
    noise = np.random.normal(0, 1, size = (n_ex, z_dim))
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(generated_images.shape[0], width, height)
    plt.figure(figsize = figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i, :, :], interpolation = 'nearest', cmap = 'gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [0]:
losses = {"D":[], "G":[]}
samples = []

def train(epochs = 100, plt_frq = 10, batch_size = 32):
    batchCount = int(train_x.shape[0] / batch_size)
    print('Epochs:', epochs)
    print('Batch size:', batch_size)
    print('Batches per epoch:', batchCount)
    
    y1 = np.zeros(2 * batch_size)
    y1[:batch_size] = 1
    y2 = np.ones(batch_size)
    
    for e in range(1, epochs+1):
        if e == 1 or e%plt_frq == 0:
            print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in range(batchCount):
            
            image_batch = train_x[np.random.randint(0, train_x.shape[0], size = batch_size)]
#             image_batch = image_batch.reshape(image_batch.shape[0], image_batch.shape[1], image_batch.shape[2], 1)
            
            # Create noise vectors for the generator
            noise = np.random.normal(0, 1, size = (batch_size, z_dim))
            
            # Generate the images from the noise
            generated_images = generator.predict(noise)
            samples.append(generated_images)
            imgs = np.concatenate((image_batch, generated_images))            

            # Train discriminator on generated images
            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(imgs, y1)

            # Train generator
            noise = np.random.normal(0, 1, size = (batch_size, z_dim))
            
            discriminator.trainable = False
            g_loss = gan.train_on_batch(noise, y2)

        # Only store losses from final batch of epoch
        losses["D"].append(d_loss)
        losses["G"].append(g_loss)

        # Update the plots
        if e == 1 or e%plt_frq == 0:
            plot_generated()
    plot_loss(losses)

In [0]:
train()