In [None]:
import tensorflow as tf 
from tensorflow.keras import layers

import numpy as np
from functools import partial

In [None]:
g_loss_metrics = tf.metrics.Mean(name='g_loss')
d_loss_metrics = tf.metrics.Mean(name='d_loss')
total_loss_metrics = tf.metrics.Mean(name='total_loss')

In [None]:
ITERATION = 10000
Z_DIM = 100
BATCH_SIZE = 512
BUFFER_SIZE = 60000
G_LR = 0.0004
D_LR = 0.0004
GP_WEIGHT = 10.0
IMAGE_SHAPE = (28, 28, 1)
RANDOM_SEED = 42

np.random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)

In [None]:
test_z = tf.random.normal([36, Z_DIM])

In [None]:
def get_random_z(z_dim, batch_size):
    return tf.random.uniform([batch_size, z_dim], minval=-1, maxval=1)

In [None]:
def make_discriminaor(input_shape):
    return tf.keras.Sequential([
        layers.Conv2D(64, 5, strides=2, padding='same',
                      input_shape=input_shape),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, 5, strides=2, padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1)
    ])

In [None]:
def make_generator(input_shape):
    return tf.keras.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=input_shape),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 256)),
        layers.Conv2DTranspose(
            128, 5, strides=1, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(
            64, 5, strides=2, padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(
            1, 5, strides=2, padding='same', use_bias=False, activation='tanh')
    ])

In [None]:
def get_loss_fn():
    def d_loss_fn(real_logits, fake_logits):
        return tf.reduce_mean(fake_logits) - tf.reduce_mean(real_logits)

    def g_loss_fn(fake_logits):
        return -tf.reduce_mean(fake_logits)

    return d_loss_fn, g_loss_fn

In [None]:
def gradient_penalty(generator, real_images, fake_images):
    real_images = tf.cast(real_images, tf.float32)
    fake_images = tf.cast(fake_images, tf.float32)
    alpha = tf.random.uniform([BATCH_SIZE, 1, 1, 1], 0., 1.)
    diff = fake_images - real_images
    inter = real_images + (alpha * diff)
    with tf.GradientTape() as tape:
        tape.watch(inter)
        predictions = generator(inter)
    gradients = tape.gradient(predictions, [inter])[0]
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    return tf.reduce_mean((slopes - 1.) ** 2)

In [None]:
(train_x, _), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
train_x = train_x.reshape(train_x.shape[0], 28, 28, 1)
train_x = (train_x - 127.5) / 127.5
train_ds = (
    tf.data.Dataset.from_tensor_slices(train_x)
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE, drop_remainder=True)
    .repeat()
)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [None]:
G = make_generator((Z_DIM,))
D = make_discriminaor(IMAGE_SHAPE)

In [None]:
g_optim = tf.keras.optimizers.Adam(G_LR, beta_1=0.5, beta_2=0.999)
d_optim = tf.keras.optimizers.Adam(D_LR, beta_1=0.5, beta_2=0.999)

In [None]:
d_loss_fn, g_loss_fn = get_loss_fn()

In [None]:
@tf.function
def train_step(real_images):
    z = get_random_z(Z_DIM, BATCH_SIZE)
    with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
        fake_images = G(z, training=True)

        fake_logits = D(fake_images, training=True)
        real_logits = D(real_images, training=True)

        d_loss = d_loss_fn(real_logits, fake_logits)
        g_loss = g_loss_fn(fake_logits)

        gp = gradient_penalty(partial(D, training=True),
                              real_images, fake_images)
        d_loss += gp * GP_WEIGHT

    d_gradients = d_tape.gradient(d_loss, D.trainable_variables)
    g_gradients = g_tape.gradient(g_loss, G.trainable_variables)

    d_optim.apply_gradients(zip(d_gradients, D.trainable_variables))
    g_optim.apply_gradients(zip(g_gradients, G.trainable_variables))

    return g_loss, d_loss

In [None]:
# training loop
def train(ds, log_freq=20):  
    ds = iter(ds)
    for step in range(ITERATION):
        images = next(ds)
        g_loss, d_loss = train_step(images)

        g_loss_metrics(g_loss)
        d_loss_metrics(d_loss)
        total_loss_metrics(g_loss + d_loss)
        if step % log_freq == 0:
            template = '[{}/{}] D_loss={:.5f} G_loss={:.5f} Total_loss={:.5f}'
            print(template.format(step, ITERATION, d_loss_metrics.result(),
                                  g_loss_metrics.result(), total_loss_metrics.result()))
            g_loss_metrics.reset_states()
            d_loss_metrics.reset_states()
            total_loss_metrics.reset_states()

In [None]:
if __name__ == "__main__":
    train(train_ds)

[0/10000] D_loss=7.44371 G_loss=0.07432 Total_loss=7.51803
[20/10000] D_loss=4.76322 G_loss=-5.70243 Total_loss=-0.93921
[40/10000] D_loss=2.81742 G_loss=-12.15286 Total_loss=-9.33544
[60/10000] D_loss=2.04043 G_loss=-11.12871 Total_loss=-9.08827
[80/10000] D_loss=1.86185 G_loss=-9.90501 Total_loss=-8.04316
[100/10000] D_loss=1.26050 G_loss=-9.54142 Total_loss=-8.28092
[120/10000] D_loss=0.52433 G_loss=-6.88027 Total_loss=-6.35594
[140/10000] D_loss=1.00488 G_loss=-7.24561 Total_loss=-6.24073
[160/10000] D_loss=0.70468 G_loss=-3.80791 Total_loss=-3.10324
[180/10000] D_loss=0.34709 G_loss=-3.17050 Total_loss=-2.82341
[200/10000] D_loss=1.15397 G_loss=-5.76687 Total_loss=-4.61291
[220/10000] D_loss=0.63578 G_loss=-4.68936 Total_loss=-4.05359
[240/10000] D_loss=-1.45479 G_loss=-3.99399 Total_loss=-5.44879
[260/10000] D_loss=0.76380 G_loss=-4.99020 Total_loss=-4.22640
[280/10000] D_loss=0.70176 G_loss=-4.96434 Total_loss=-4.26258
[300/10000] D_loss=-0.06895 G_loss=-3.54576 Total_loss=-3.61