# Self DCGAN Keras TPU

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/HighCWu/SelfGAN/blob/master/implementations/dcgan/self_dcgan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/HighCWu/SelfGAN/blob/master/implementations/dcgan/self_dcgan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [0]:
! pip install 'tensorflow>1.12,<2.0' -q

## Datasets

In [0]:
import h5py
import sys
import time
import numpy as np
import tensorflow as tf
import threading
import glob
import random
import os
import numpy as np

from PIL import Image

from tensorflow.python.keras.utils import Sequence
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer

class ImageDataset:
  
    def __init__(self, root):

        self.files = sorted(glob.glob(root + '/**/*.*', recursive=True))

    def __getitem__(self, index):

        img = Image.open(self.files[index % len(self.files)]).convert('RGB' if channels==3 else 'L')
        w, h = img.size
        img = img.resize((img_rows, img_cols), Image.LANCZOS)
        img = np.asarray(img)
        img = img/255*2 - 1
        if channels==1:
            img = img.reshape(img.shape+(1,))

        return img

    def __len__(self):
        return len(self.files)
      
class DataLoader:
    
    def __init__(self, dataset, batch_size, workers=1, max_queue_size=10, cache_filepath='tmp/cache.h5', pool_size=64000):
        self.idx = 0
        bsz = batch_size
        batch_pool = pool_size // bsz
        self.length = batch_pool
        
        os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
        cache_file = h5py.File(cache_filepath, 'w')
        counter = {'i': 0, 'full': False}
        def data_reader(counter, cache_file):
          idx = 0
          while True:
            try:
              X=[]
              for i in range(bsz):
                x = dataset[idx]
                X.append(x)
                idx = (idx + 1) % len(dataset)

              if not counter['full']:
                cache_file.create_dataset('x/{}'.format(counter['i']), data=X)
              else:
                cache_file['x/{}'.format(counter['i'])][...] = X
              counter['i'] = (counter['i']+1)%(batch_pool)
              if counter['i'] == 0:
                counter['full'] = True
              time.sleep(0.5)
            except Exception as e:
              raise e

        dt = threading.Thread(target=data_reader, args=(counter, cache_file))
        dt.start()
        while counter['i'] < 100:
          sys.stdout.flush()
          print('\rWaiting for enough cache.%d/100'%counter['i'], end='')
        print('')
        
        class generator(Sequence):
          
          def __len__(_self):
              return self.length
          
          def __getitem__(_self, index):
              return cache_file['x/{}'.format(index % (counter['i'] if not counter['full'] else batch_pool))].value.astype('float')
                
        enqueuer = OrderedEnqueuer(
                      generator(),
                      use_multiprocessing=False,
                      shuffle=False)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        self.o_g = enqueuer.get()
        
    def __iter__(self):
        self.idx = -1
        return self
      
    def __next__(self):
        self.idx += 1
        if self.idx >= self.length:
            raise StopIteration()
        return next(self.o_g)
      
    def __len__(self):
        return self.length


## Prepare

In [0]:
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Lambda
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import tensorflow as tf
import tensorflow.keras.backend as K

import matplotlib.pyplot as plt

import os
import sys

import numpy as np

os.makedirs('images', exist_ok=True)
os.makedirs('data/bedroom', exist_ok=True)

data_use = 'mnist' #@param ["bedroom", "mnist"] {allow-input: false}
if data_use=='mnist':
    img_rows = 28
    img_cols = 28
    channels = 1
else:
    img_rows = 64
    img_cols = 64
    channels = 3
img_shape = (img_rows, img_cols, channels)
latent_dim = 100
batch_size = 64
sample_interval = 200
epochs = 200

In [0]:
import os, zipfile
from google.colab import files
if data_use == 'bedroom':
    print('Please upload your kaggle api json.')
    files.upload()

    ! mkdir /root/.kaggle
    ! mv ./kaggle.json /root/.kaggle
    ! chmod 600 /root/.kaggle/kaggle.json
    ! kaggle datasets download -d jhoward/lsun_bedroom

    out_fname = 'lsun_bedroom.zip'
    zip_ref = zipfile.ZipFile(out_fname)
    zip_ref.extractall('./')
    zip_ref.close()
    os.remove(out_fname)

    out_fname = 'sample.zip'
    zip_ref = zipfile.ZipFile(out_fname)
    zip_ref.extractall('data/bedroom/')
    zip_ref.close()
    os.remove(out_fname)

In [0]:
# UpSample on TPU
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.utils import get_custom_objects
from tensorflow.python.keras.utils import conv_utils

def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(stride, int) and stride >= 1

    # Finalize filter kernel.
    f = np.array(f, dtype=np.float32)
    if f.ndim == 1:
        f = f[:, np.newaxis] * f[np.newaxis, :]
    assert f.ndim == 2
    if normalize:
        f /= np.sum(f)
    if flip:
        f = f[::-1, ::-1]
    f = f[:, :, np.newaxis, np.newaxis]
    f = np.tile(f, [1, 1, int(x.shape[1]), 1])

    # No-op => early exit.
    if f.shape == (1, 1) and f[0,0] == 1:
        return x

    # Convolve using depthwise_conv2d.
    orig_dtype = x.dtype
    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16
    f = tf.constant(f, dtype=x.dtype, name='filter')
    strides = [1, 1, stride, stride]
    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')
    x = tf.cast(x, orig_dtype)
    return x

def _upscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Upscale using tf.tile().
    s = x.shape
    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
    x = tf.tile(x, [1, 1, 1, factor, 1, factor])
    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
    return x

def _downscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # 2x2, float32 => downscale using _blur2d().
    if factor == 2 and x.dtype == tf.float32:
        f = [np.sqrt(gain) / factor] * factor
        return _blur2d(x, f=f, normalize=False, stride=factor)

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Large factor => downscale using tf.nn.avg_pool().
    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
    ksize = [1, 1, factor, factor]
    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')
  
def upscale2d(x, factor=2):
    with tf.variable_scope('Upscale2D'):
        @tf.custom_gradient
        def func(x):
            y = _upscale2d(x, factor)
            @tf.custom_gradient
            def grad(dy):
                dx = _downscale2d(dy, factor, gain=factor**2)
                return dx, lambda ddx: _upscale2d(ddx, factor)
            return y, grad
        return func(x)
      
class UpSampling2D(keras.layers.Layer):
  """Upsampling layer for 2D inputs.
  Repeats the rows and columns of the data
  by size[0] and size[1] respectively.
  Arguments:
      size: int, or tuple of 2 integers.
          The upsampling factors for rows and columns.
      data_format: A string,
          one of `channels_last` (default) or `channels_first`.
          The ordering of the dimensions in the inputs.
          `channels_last` corresponds to inputs with shape
          `(batch, height, width, channels)` while `channels_first`
          corresponds to inputs with shape
          `(batch, channels, height, width)`.
          It defaults to the `image_data_format` value found in your
          Keras config file at `~/.keras/keras.json`.
          If you never set it, then it will be "channels_last".
  Input shape:
      4D tensor with shape:
      - If `data_format` is `"channels_last"`:
          `(batch, rows, cols, channels)`
      - If `data_format` is `"channels_first"`:
          `(batch, channels, rows, cols)`
  Output shape:
      4D tensor with shape:
      - If `data_format` is `"channels_last"`:
          `(batch, upsampled_rows, upsampled_cols, channels)`
      - If `data_format` is `"channels_first"`:
          `(batch, channels, upsampled_rows, upsampled_cols)`
  """

  def __init__(self, size=(2, 2), data_format=None, **kwargs):
    super(UpSampling2D, self).__init__(**kwargs)
    self.data_format = conv_utils.normalize_data_format(data_format)
    self.size = conv_utils.normalize_tuple(size, 2, 'size')
    self.input_spec = InputSpec(ndim=4)

  def call(self, inputs):
    if self.data_format == 'channels_first':
      return upscale2d(inputs, self.size[0])
    else:
      T = tf.transpose(inputs, [0,3,1,2])
      up_T = upscale2d(T, self.size[0])
      return tf.transpose(up_T, [0,2,3,1])

  def get_config(self):
    config = {'size': self.size, 'data_format': self.data_format}
    base_config = super(UpSampling2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
    
get_custom_objects().update({'UpSampling2D': UpSampling2D})

In [0]:
class Generator:
  
    def __init__(self):
        self.layers = []
        model = self.layers
        model.append(Dense(128 * img_rows * img_cols // (4**2), activation="relu", input_dim=latent_dim))
        model.append(Reshape((img_rows//4, img_cols//4, 128)))
        model.append(UpSampling2D())
        model.append(Conv2D(128, kernel_size=3, padding="same"))
        model.append(BatchNormalization(momentum=0.8))
        model.append(Activation("relu"))
        model.append(UpSampling2D())
        model.append(Conv2D(64, kernel_size=3, padding="same"))
        model.append(BatchNormalization(momentum=0.8))
        model.append(Activation("relu"))
        model.append(Conv2D(channels, kernel_size=3, padding="same"))
        model.append(Activation("tanh", name='output'))
        
    def __call__(self, x):
        y = x
        for layer in self.layers:
            y = layer(y)
        
        return y

class Discriminator:
  
    def __init__(self):
        self.layers = []
        model = self.layers
        model.append(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.append(ZeroPadding2D(padding=((0,1),(0,1))))
        model.append(BatchNormalization(momentum=0.8))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.append(BatchNormalization(momentum=0.8))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.append(BatchNormalization(momentum=0.8))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Flatten())
        model.append(Dense(1, activation='sigmoid'))

    def __call__(self, x):
        y = x
        for layer in self.layers:
            y = layer(y)
            
        return y
  
def SelfGAN():
    
    generator = Generator()
    discriminator = Discriminator()
    
    real_img = Input(shape=img_shape)
    fake_img = Input(shape=img_shape)
    
    noise = Input(shape=(latent_dim,))
    gen_img = generator(noise)
    validity_gen = discriminator(gen_img)
    validity_real = discriminator(real_img)
    validity_fake = discriminator(fake_img)
    
    # compute loss
    adversarial_loss = Lambda(lambda x: keras.losses.binary_crossentropy(x[0], x[1]))
    
    
    valid = Input(shape=(1,))
    fake = Input(shape=(1,))
    gen_loss = adversarial_loss([validity_gen, valid])
    real_loss = adversarial_loss([validity_real, valid])
    fake_loss = adversarial_loss([validity_fake, fake])
    gen_loss = Lambda(lambda x: x*1.0, name='gen_loss')(gen_loss)
    real_loss = Lambda(lambda x: x*1.0, name='real_loss')(real_loss)
    fake_loss = Lambda(lambda x: x*1.0, name='fake_loss')(fake_loss)
    
    v_g = Lambda(lambda x: 1 - K.mean(x))(validity_gen)
    v_r = Lambda(lambda x: 1 - K.mean(x))(validity_real)
    v_f = Lambda(lambda x: K.mean(x))(validity_fake)
    v_sum = Lambda(lambda x: x[0]+x[1]+x[2])([v_g,v_r,v_f])
    s_loss = Lambda(lambda x: x[2]*x[1]/x[0] \
                            + x[4]*x[3]/x[0] \
                            + x[6]*x[5]/x[0])([v_sum, v_r, real_loss, v_g, gen_loss, v_f, fake_loss])
    
    return Model([noise, real_img, fake_img, valid, fake], [s_loss])
  
def sample_images(model, epoch):
    r = 5
    c = 5
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    gen_imgs = model.predict([noise, last_imgs, last_imgs, valid, fake])[-1][:r*c]

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            if channels==1:
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            else:
                axs[i,j].imshow(gen_imgs[cnt, :,:,:])
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%d.png" % epoch)
    plt.close()


In [0]:
# Function override
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUFunction
from tensorflow.keras.models import Model
from tensorflow.python.estimator import model_fn as model_fn_lib 
ModeKeys = model_fn_lib.ModeKeys

def extra_outputs(self):
  outputs = []
  for layer in self.layers:
    if 'loss' in layer.name:
      outputs.append(layer.output)
  for layer in self.layers:
    if 'output' in layer.name:
      outputs.append(layer.output)
  return outputs

def _make_predict_function(self):
  if not hasattr(self, 'predict_function'):
    self.predict_function = None
  if self.predict_function is None:
    inputs = self._feed_inputs
    # Gets network outputs. Does not update weights.
    # Does update the network states.
    kwargs = getattr(self, '_function_kwargs', {})
    with K.name_scope(ModeKeys.PREDICT):
      self.predict_function = K.function(
          inputs,
          self.outputs+extra_outputs(self),
          updates=self.state_updates,
          name='predict_function',
          **kwargs)
      
def _make_fit_function(self):
  metrics_tensors = [
      self._all_stateful_metrics_tensors[m] for m in self.metrics_names[1:]
  ]
  self._make_train_function_helper(
      '_fit_function', [self.total_loss] + metrics_tensors + extra_outputs(self))
  
Model._make_predict_function = _make_predict_function
Model._make_fit_function = _make_fit_function

def _process_outputs(self, outfeed_outputs):
    """Processes the outputs of a model function execution.
    Args:
      outfeed_outputs: The sharded outputs of the TPU computation.
    Returns:
      The aggregated outputs of the TPU computation to be used in the rest of
      the model execution.
    """
    # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
    if self.execution_mode == ModeKeys.PREDICT:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])

      return [np.concatenate(group) for group in outputs]
    else:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])
      
      ret = []
      for group in outputs:
        if len(group[0].shape) > 0:
          ret.append(np.concatenate(group))
        else:
          ret.append(group[0])
      return ret
    
TPUFunction._process_outputs = _process_outputs

In [0]:
tf.keras.backend.clear_session()

optimizer = Adam(0.0002, 0.5)
model = SelfGAN()
model.compile(loss='mae',optimizer=optimizer)

TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

def initialize_uninitialized_variables():
    sess = K.get_session()
    uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])
    init_op = tf.variables_initializer(
        [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]
    )
    sess.run(init_op)
initialize_uninitialized_variables()

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

last_imgs = np.zeros((batch_size,)+img_shape)
s_loss_zeros = np.zeros((batch_size,))

In [0]:
# Init tpu model
noise = np.random.normal(0, 1, (batch_size, latent_dim))
train_outputs = model.train_on_batch([noise, last_imgs, last_imgs, valid, fake], [s_loss_zeros])
predict_output = model.predict([noise, last_imgs, last_imgs, valid, fake])

dataLoader = DataLoader(ImageDataset('data/bedroom') if data_use == 'bedroom' else 
                            mnist.load_data()[0][0].reshape((-1,28,28,1))/127.5-1, batch_size)

In [0]:
for epoch in range(epochs):
  
    for i, imgs in enumerate(dataLoader):

        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        # Generate a batch of new images

        outputs = model.train_on_batch([noise, imgs, last_imgs, valid, fake], [s_loss_zeros])
        s_loss = outputs[0]/8
        gen_loss = np.mean(outputs[2])/(batch_size/8)
        real_loss = np.mean(outputs[1])/(batch_size/8)
        fake_loss = np.mean(outputs[3])/(batch_size/8)
        last_imgs = outputs[-1]

        # Plot the progress
        if i % 25 == 0:
            sys.stdout.flush()
            print ("\r[Epoch %d/%d] [Batch %d/%d]  [S loss: %f  G loss: %f R loss: %f  F loss: %f]" % (epoch, epochs, i,
                                                                            len(dataLoader), s_loss,
                                                                            gen_loss, real_loss, fake_loss),end='')

        # If at save interval => save generated image samples
        if (epoch*len(dataLoader) + i) % sample_interval == 0:
            sample_images(model, epoch*len(dataLoader) + i)
            