# Self GAN Keras TPU

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/HighCWu/SelfGAN/blob/master/implementations/gan/self_gan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/HighCWu/SelfGAN/blob/master/implementations/gan/self_gan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [0]:
! pip install 'tensorflow>1.12,<2.0' -q

## Prepare

In [0]:
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Lambda
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import tensorflow as tf
import tensorflow.keras.backend as K

import matplotlib.pyplot as plt

import os
import sys

import numpy as np

os.makedirs('images', exist_ok=True)

img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100
batch_size = 64
sample_interval = 1000
epochs = 200000

In [0]:
class Generator:
  
    def __init__(self):
        self.layers = []
        model = self.layers
        model.append(Dense(256, input_dim=latent_dim))
        model.append(LeakyReLU(alpha=0.2))
        model.append(BatchNormalization(momentum=0.8))
        model.append(Dense(512))
        model.append(LeakyReLU(alpha=0.2))
        model.append(BatchNormalization(momentum=0.8))
        model.append(Dense(1024))
        model.append(LeakyReLU(alpha=0.2))
        model.append(BatchNormalization(momentum=0.8))
        model.append(Dense(np.prod(img_shape), activation='tanh'))
        model.append(Reshape(img_shape, name='output'))
        
    def __call__(self, x):
        y = x
        for layer in self.layers:
            y = layer(y)
        
        return y

class Discriminator:
  
    def __init__(self):
        self.layers = []
        model = self.layers
        model.append(Flatten(input_shape=img_shape))
        model.append(Dense(512))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dense(256))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dense(1, activation='sigmoid'))

    def __call__(self, x):
        y = x
        for layer in self.layers:
            y = layer(y)
            
        return y
  
def SelfGAN():
    
    generator = Generator()
    discriminator = Discriminator()
    
    real_img = Input(shape=img_shape)
    fake_img = Input(shape=img_shape)
    
    noise = Input(shape=(latent_dim,))
    gen_img = generator(noise)
    
    validity_gen = discriminator(gen_img)
    validity_real = discriminator(real_img)
    validity_fake = discriminator(fake_img)
    
    # compute loss
    adversarial_loss = Lambda(lambda x: keras.losses.binary_crossentropy(x[0], x[1]))
    
    
    valid = Input(shape=(1,))
    fake = Input(shape=(1,))
    gen_loss = adversarial_loss([validity_gen, valid])
    real_loss = adversarial_loss([validity_real, valid])
    fake_loss = adversarial_loss([validity_fake, fake])
    gen_loss = Lambda(lambda x: x*1.0, name='gen_loss')(gen_loss)
    real_loss = Lambda(lambda x: x*1.0, name='real_loss')(real_loss)
    fake_loss = Lambda(lambda x: x*1.0, name='fake_loss')(fake_loss)
    
    v_g = Lambda(lambda x: 1 - K.mean(x))(validity_gen)
    v_r = Lambda(lambda x: 1 - K.mean(x))(validity_real)
    v_f = Lambda(lambda x: K.mean(x))(validity_fake)
    v_sum = Lambda(lambda x: x[0]+x[1]+x[2])([v_g,v_r,v_f])
    s_loss = Lambda(lambda x: x[2]*x[1]/x[0] \
                            + x[4]*x[3]/x[0] \
                            + x[6]*x[5]/x[0])([v_sum, v_r, real_loss, v_g, gen_loss, v_f, fake_loss])
    
    return Model([noise, real_img, fake_img, valid, fake], [s_loss])
  
def sample_images(model, epoch):
    r = 5
    c = 5
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    gen_imgs = model.predict([noise, last_imgs, last_imgs, valid, fake])[-1][:r*c]

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    plt.close()


In [0]:
# Function override
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUFunction
from tensorflow.keras.models import Model
from tensorflow.python.estimator import model_fn as model_fn_lib 
ModeKeys = model_fn_lib.ModeKeys

def extra_outputs(self):
  outputs = []
  for layer in self.layers:
    if 'loss' in layer.name:
      outputs.append(layer.output)
  for layer in self.layers:
    if 'output' in layer.name:
      outputs.append(layer.output)
  return outputs

def _make_predict_function(self):
  if not hasattr(self, 'predict_function'):
    self.predict_function = None
  if self.predict_function is None:
    inputs = self._feed_inputs
    # Gets network outputs. Does not update weights.
    # Does update the network states.
    kwargs = getattr(self, '_function_kwargs', {})
    with K.name_scope(ModeKeys.PREDICT):
      self.predict_function = K.function(
          inputs,
          self.outputs+extra_outputs(self),
          updates=self.state_updates,
          name='predict_function',
          **kwargs)
      
def _make_fit_function(self):
  metrics_tensors = [
      self._all_stateful_metrics_tensors[m] for m in self.metrics_names[1:]
  ]
  self._make_train_function_helper(
      '_fit_function', [self.total_loss] + metrics_tensors + extra_outputs(self))
  
Model._make_predict_function = _make_predict_function
Model._make_fit_function = _make_fit_function

def _process_outputs(self, outfeed_outputs):
    """Processes the outputs of a model function execution.
    Args:
      outfeed_outputs: The sharded outputs of the TPU computation.
    Returns:
      The aggregated outputs of the TPU computation to be used in the rest of
      the model execution.
    """
    # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
    if self.execution_mode == ModeKeys.PREDICT:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])

      return [np.concatenate(group) for group in outputs]
    else:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])
      
      ret = []
      for group in outputs:
        if len(group[0].shape) > 0:
          ret.append(np.concatenate(group))
        else:
          ret.append(group[0])
      return ret
    
TPUFunction._process_outputs = _process_outputs

In [0]:
tf.keras.backend.clear_session()

optimizer = Adam(0.0002, 0.5)
model = SelfGAN()
model.compile(loss='mae',optimizer=optimizer)

TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

# Load the dataset
(X_train, _), (_, _) = mnist.load_data()

# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

last_imgs = np.zeros((batch_size,)+img_shape)
s_loss_zeros = np.zeros((batch_size,))

In [0]:
def initialize_uninitialized_variables():
    sess = K.get_session()
    uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])
    init_op = tf.variables_initializer(
        [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]
    )
    sess.run(init_op)
initialize_uninitialized_variables()

# Init tpu model
noise = np.random.normal(0, 1, (batch_size, latent_dim))
train_outputs = model.train_on_batch([noise, last_imgs, last_imgs, valid, fake], [s_loss_zeros])
predict_output = model.predict([noise, last_imgs, last_imgs, valid, fake])

In [0]:
for epoch in range(epochs):

    # ---------------------
    #  Train Discriminator
    # ---------------------

    # Select a random batch of images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs = X_train[idx]

    noise = np.random.normal(0, 1, (batch_size, latent_dim))

    # Generate a batch of new images
    
    outputs = model.train_on_batch([noise, imgs, last_imgs, valid, fake], [s_loss_zeros])
    s_loss = outputs[0]/8
    gen_loss = np.mean(outputs[2])/(batch_size/8)
    real_loss = np.mean(outputs[1])/(batch_size/8)
    fake_loss = np.mean(outputs[3])/(batch_size/8)
    last_imgs = outputs[-1]
    
    # Plot the progress
    if epoch % 200 == 0:
        sys.stdout.flush()
        print ("\r%d [S loss: %f  G loss: %f R loss: %f  F loss: %f]" % (epoch, s_loss,
                                                                        gen_loss, real_loss, fake_loss),end='')

    # If at save interval => save generated image samples
    if epoch % sample_interval == 0:
        sample_images(model, epoch)
