In [104]:
! pip install -q tensorflow-gpu

In [105]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
import time

In [106]:
device_name = "/device:GPU:0"

In [107]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [108]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [109]:
def make_generator_model(z_size=20,
                             output_size=(28, 28, 1),
                             n_filters=128,
                             n_blocks=2):
  size_factor = 2**n_blocks
  hidden_size = (
      output_size[0]//size_factor, 
      output_size[1]//size_factor
  )
  
  model = tf.keras.Sequential([
      tf.keras.layers.Input(shape=(z_size,)),
      
      tf.keras.layers.Dense(
          units=n_filters*np.prod(hidden_size), 
          use_bias=False),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU(),
      tf.keras.layers.Reshape(
          (hidden_size[0], hidden_size[1], n_filters)),
  
      tf.keras.layers.Conv2DTranspose(
          filters=n_filters, kernel_size=(5, 5), strides=(1, 1),
          padding='same', use_bias=False),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU()
  ])
      
  nf = n_filters
  for i in range(n_blocks):
      nf = nf // 2
      model.add(
          tf.keras.layers.Conv2DTranspose(
              filters=nf, kernel_size=(5, 5), strides=(2, 2),
              padding='same', use_bias=False))
      model.add(tf.keras.layers.BatchNormalization())
      model.add(tf.keras.layers.LeakyReLU())
              
  model.add(
      tf.keras.layers.Conv2DTranspose(
          filters=output_size[2], kernel_size=(5, 5), 
          strides=(1, 1), padding='same', use_bias=False, 
          activation='tanh'))
      
  return model

In [110]:
def make_discriminator_model(input_shape=(28, 28, 1),
                             n_filters=64,
                             n_blocks=2):
  model = tf.keras.Sequential([
      tf.keras.layers.Input(shape=input_shape),
      tf.keras.layers.Conv2D(
          filters=n_filters, kernel_size=5, 
          strides=(1, 1), padding='same'),
      tf.keras.layers.BatchNormalization(),
      tf.keras.layers.LeakyReLU()
  ])
  nf = n_filters
  for i in range(n_blocks):
      nf = nf*2
      model.add(
          tf.keras.layers.Conv2D(
              filters=nf, kernel_size=(5, 5), 
              strides=(2, 2),padding='same'))
      model.add(tf.keras.layers.BatchNormalization())
      model.add(tf.keras.layers.LeakyReLU())
      model.add(tf.keras.layers.Dropout(0.3))
      
  model.add(tf.keras.layers.Conv2D(
          filters=1, kernel_size=(7, 7), padding='valid'))

  model.add(tf.keras.layers.Reshape((1,)))

  return model


In [111]:
mnist_bldr = tfds.builder('mnist')
mnist_bldr.download_and_prepare()
mnist = mnist_bldr.as_dataset(shuffle_files=False)

In [112]:
def preprocess(ex, mode='uniform'):
  image = ex['image']
  image = tf.image.convert_image_dtype(image, tf.float32)

  image = image * 2 - 1.0
  if mode == 'uniform':
    input_z = tf.random.uniform(shape=(z_size,), minval=-1.0, maxval=1.0)
  elif mode == 'normal':
    input_z = tf.random.nomral(shape=(z_size,))
  return input_z, image

In [113]:
gen_model = make_generator_model()
gen_model.summary()

disc_model = make_discriminator_model()
disc_model.summary()

Model: "sequential_28"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_17 (Dense)             (None, 6272)              125440    
_________________________________________________________________
batch_normalization_97 (Batc (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu_97 (LeakyReLU)   (None, 6272)              0         
_________________________________________________________________
reshape_27 (Reshape)         (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_66 (Conv2DT (None, 7, 7, 128)         409600    
_________________________________________________________________
batch_normalization_98 (Batc (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_98 (LeakyReLU)   (None, 7, 7, 128)       

In [114]:
num_epochs = 100
batch_size = 128
image_size = (28, 28)
z_size = 20
mode_z = 'uniform' # options: 'uniform' and 'normal'
lambda_gp = 10.0

tf.random.set_seed(1)
np.random.seed(1)

In [115]:
mnist_trainset = mnist['train'].map(preprocess).shuffle(10000).batch(batch_size, drop_remainder=True)

In [116]:
def create_samples(g_model, input_z):
  g_output = g_model(input_z, training=False)
  images = tf.reshape(g_output, (batch_size, *image_size))
  return (images + 1) / 2.0

In [119]:
if mode_z == 'uniform':
    fixed_z = tf.random.uniform(
        shape=(batch_size, z_size),
        minval=-1, maxval=1)
elif mode_z == 'normal':
    fixed_z = tf.random.normal(
        shape=(batch_size, z_size))

In [120]:
with tf.device(device_name):
    gen_model = make_generator_model()
    gen_model.build(input_shape=(None, z_size))
    gen_model.summary()

    disc_model = make_discriminator_model()
    disc_model.build(input_shape=(None, np.prod(image_size)))
    disc_model.summary()

g_optimizer = tf.keras.optimizers.Adam(0.0002)
d_optimizer = tf.keras.optimizers.Adam(0.0002)

all_losses = []
epoch_samples = []

Model: "sequential_32"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_19 (Dense)             (None, 6272)              125440    
_________________________________________________________________
batch_normalization_111 (Bat (None, 6272)              25088     
_________________________________________________________________
leaky_re_lu_111 (LeakyReLU)  (None, 6272)              0         
_________________________________________________________________
reshape_31 (Reshape)         (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_74 (Conv2DT (None, 7, 7, 128)         409600    
_________________________________________________________________
batch_normalization_112 (Bat (None, 7, 7, 128)         512       
_________________________________________________________________
leaky_re_lu_112 (LeakyReLU)  (None, 7, 7, 128)       

In [None]:
start_time = time.time()
for epoch in range(1, num_epochs + 1):
  epoch_losses = []
  for i, (input_z, input_real) in enumerate(mnist_trainset):
    with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
      g_output = gen_model(input_z)
      d_critics_real = disc_model(input_real, training=True)
      d_critics_fake = disc_model(g_output, training=True)
      g_loss = -tf.math.reduce_mean(d_critics_fake)

      d_loss_real = -tf.math.reduce_mean(d_critics_real)
      d_loss_fake = tf.math.reduce_mean(d_critics_fake)
      d_loss = d_loss_real + d_loss_fake
      
      with tf.GradientTape() as gp_tape:
        alpha = tf.random.uniform(shape=[d_critics_real.shape[0], 1, 1, 1], minval=0.0, maxval=1.0)
        interpolated = (alpha * input_real + (1 - alpha) * g_output)
        gp_tape.watch(interpolated)
        d_critics_intp = disc_model(interpolated)
      
      grads_intp = gp_tape.gradient(d_critics_intp, [interpolated, ])[0]
      grads_intp_l2 = tf.sqrt(tf.reduce_sum(tf.square(grads_intp), axis=[1, 2, 3]))
      grad_penalty = tf.reduce_mean(tf.square(grads_intp_l2 - 1.0))
      d_loss += lambda_gp * grad_penalty

    d_grads = d_tape.gradient(d_loss, disc_model.trainable_variables)
    d_optimizer.apply_gradients(grads_and_vars=zip(d_grads, disc_model.trainable_variables))
    
    g_grads = g_tape.gradient(g_loss, gen_model.trainable_variables)
    g_optimizer.apply_gradients(grads_and_vars=zip(g_grads, gen_model.trainable_variables))

    epoch_losses.append((g_loss, d_loss,
                         d_loss_real.numpy(),
                         d_loss_fake.numpy()))
  all_losses.append(epoch_losses)

  print('Epoch {:03d} | ET {:.2f} min | Avg Losses >> G/D {:6.2f}/{:6.2f} [D-Real: {:6.2f} D-Fake: {:6.2f}]'
  .format(epoch, (time.time() - start_time)/60, *list(np.mean(all_losses[-1], axis=0))))
  epoch_samples.append(create_samples(gen_model, fixed_z).numpy())

Epoch 001 | ET 3.38 min | Avg Losses >> G/D 313.21/-537.45 [D-Real: -312.98 D-Fake: -313.21]
Epoch 002 | ET 6.76 min | Avg Losses >> G/D 535.42/-677.91 [D-Real: -518.04 D-Fake: -535.42]
Epoch 003 | ET 10.13 min | Avg Losses >> G/D 170.77/109.16 [D-Real: -113.76 D-Fake: -170.77]
Epoch 004 | ET 13.50 min | Avg Losses >> G/D 124.14/ 46.52 [D-Real: -57.32 D-Fake: -124.14]
Epoch 005 | ET 16.86 min | Avg Losses >> G/D 105.50/-21.52 [D-Real:  -1.31 D-Fake: -105.50]
Epoch 006 | ET 20.23 min | Avg Losses >> G/D  83.08/-36.46 [D-Real: -10.38 D-Fake: -83.08]
Epoch 007 | ET 23.60 min | Avg Losses >> G/D 122.31/-31.51 [D-Real:  46.68 D-Fake: -122.31]
Epoch 008 | ET 26.98 min | Avg Losses >> G/D 155.76/-43.46 [D-Real:  93.31 D-Fake: -155.76]
Epoch 009 | ET 30.35 min | Avg Losses >> G/D 109.14/-31.11 [D-Real:  53.23 D-Fake: -109.14]
Epoch 010 | ET 33.72 min | Avg Losses >> G/D 150.26/-57.91 [D-Real:  74.24 D-Fake: -150.26]
Epoch 011 | ET 37.09 min | Avg Losses >> G/D 146.68/-35.58 [D-Real:  80.99 D-F

In [None]:
fig = plt.figure(figsize=(8, 6))


ax = fig.add_subplot(1, 1, 1)
g_losses = [item[0] for item in itertools.chain(*all_losses)]
d_losses = [item[1] for item in itertools.chain(*all_losses)]
plt.plot(g_losses, label='Generator loss', alpha=0.95)
plt.plot(d_losses, label='Discriminator loss', alpha=0.95)
plt.legend(fontsize=20)
ax.set_xlabel('Iteration', size=15)
ax.set_ylabel('Loss', size=15)

epochs = np.arange(1, 101)
epoch2iter = lambda e: e*len(all_losses[-1])
epoch_ticks = [1, 20, 40, 60, 80, 100]
newpos   = [epoch2iter(e) for e in epoch_ticks]
ax2 = ax.twiny()
ax2.set_xticks(newpos)
ax2.set_xticklabels(epoch_ticks)
ax2.xaxis.set_ticks_position('bottom')
ax2.xaxis.set_label_position('bottom')
ax2.spines['bottom'].set_position(('outward', 60))
ax2.set_xlabel('Epoch', size=15)
ax2.set_xlim(ax.get_xlim())
ax.tick_params(axis='both', which='major', labelsize=15)
ax2.tick_params(axis='both', which='major', labelsize=15)

plt.show()

In [None]:
selected_epochs = [1, 2, 4, 10, 50, 100]
fig = plt.figure(figsize=(10, 14))
for i,e in enumerate(selected_epochs):
    for j in range(5):
        ax = fig.add_subplot(6, 5, i*5+j+1)
        ax.set_xticks([])
        ax.set_yticks([])
        if j == 0:
            ax.text(
                -0.06, 0.5, 'Epoch {}'.format(e),
                rotation=90, size=18, color='red',
                horizontalalignment='right',
                verticalalignment='center', 
                transform=ax.transAxes)
        
        image = epoch_samples[e-1][j]
        ax.imshow(image, cmap='gray_r')

plt.show()