# Self CycleGAN Keras TPU

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/HighCWu/SelfGAN/blob/master/implementations/cyclegan/self_cyclegan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/HighCWu/SelfGAN/blob/master/implementations/cyclegan/self_cyclegan_keras_tpu.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

In [0]:
! pip install 'tensorflow>1.12,<2.0' -q

## Utils

In [0]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        
        if len(self.data) < self.max_size:
            ret = data
            self.data.append(data)
        else:
            ret = self.data.pop(0)
            self.data.append(data)
        return ret

## Datasets

In [0]:
import h5py
import sys
import time
import numpy as np
import tensorflow as tf
import threading
import glob
import random
import os
import numpy as np

from PIL import Image

from tensorflow.python.keras.utils import Sequence
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer

class ImageDataset:
  
    def __init__(self, root):

        self.files_A = sorted(glob.glob(os.path.join(root, '%sA' % 'train') + '/*.*'))
        self.files_B = sorted(glob.glob(os.path.join(root, '%sB' % 'train') + '/*.*'))

    def __getitem__(self, index):

        img = Image.open(self.files_A[index % len(self.files_A)]).convert('RGB' if channels==3 else 'L')
        w, h = img.size
        img = img.resize((img_rows, img_cols), Image.LANCZOS)
        img = np.asarray(img)
        img = img/255*2 - 1
        if channels==1:
            img = img.reshape(img.shape+(1,))
        imgA = img    
        
            
        img = Image.open(self.files_B[index % len(self.files_B)]).convert('RGB' if channels==3 else 'L')
        w, h = img.size
        img = img.resize((img_rows, img_cols), Image.LANCZOS)
        img = np.asarray(img)
        img = img/255*2 - 1
        if channels==1:
            img = img.reshape(img.shape+(1,))
        imgB = img

        return imgA, imgB

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))
      
class DataLoader:
    
    def __init__(self, dataset, batch_size, workers=1, max_queue_size=10, cache_filepath='tmp/cache.h5', pool_size=8000):
        self.idx = 0
        bsz = batch_size
        batch_pool = pool_size // bsz
        self.length = batch_pool
        
        os.makedirs(os.path.dirname(cache_filepath), exist_ok=True)
        cache_file = h5py.File(cache_filepath, 'w')
        counter = {'i': 0, 'full': False}
        def data_reader(counter, cache_file):
          idx = 0
          while True:
            try:
              X=[]
              Y=[]
              for i in range(bsz):
                x,y = dataset[idx]
                X.append(x)
                Y.append(y)
                idx = (idx + 1) % len(dataset)

              if not counter['full']:
                cache_file.create_dataset('x/{}'.format(counter['i']), data=X)
                cache_file.create_dataset('y/{}'.format(counter['i']), data=Y)
              else:
                cache_file['x/{}'.format(counter['i'])][...] = X
                cache_file['y/{}'.format(counter['i'])][...] = Y
              counter['i'] = (counter['i']+1)%(batch_pool)
              if counter['i'] == 0:
                counter['full'] = True
              time.sleep(0.5)
            except Exception as e:
              raise e

        dt = threading.Thread(target=data_reader, args=(counter, cache_file))
        dt.start()
        while counter['i'] < 100:
          sys.stdout.flush()
          if counter['i']%25==0:
            print('\rWaiting for enough cache.%d/100'%counter['i'], end='')
        print('')
        
        class generator(Sequence):
          
          def __len__(_self):
              return self.length
          
          def __getitem__(_self, index):
              return cache_file['x/{}'.format(index % (counter['i'] if not counter['full'] else batch_pool))].value.astype('float'), \
                     cache_file['y/{}'.format(index % (counter['i'] if not counter['full'] else batch_pool))].value.astype('float')
                
        enqueuer = OrderedEnqueuer(
                      generator(),
                      use_multiprocessing=False,
                      shuffle=False)
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        self.o_g = enqueuer.get()
        
    def __iter__(self):
        self.idx = -1
        return self
      
    def __next__(self):
        self.idx += 1
        if self.idx >= self.length:
            raise StopIteration()
        return next(self.o_g)
      
    def __len__(self):
        return self.length


## Prepare

In [0]:
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Lambda, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.python.keras.layers.advanced_activations import LeakyReLU
from tensorflow.python.keras.layers.convolutional import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import tensorflow as tf
import tensorflow.keras.backend as K

import matplotlib.pyplot as plt

import os
import sys

import numpy as np

os.makedirs('images', exist_ok=True)
os.makedirs('data', exist_ok=True)

img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)
batch_size = 16
gf = 64
df = 64
sample_interval = 100
epochs = 200
lambda_cyc = 10
lambda_id = lambda_cyc*0.1

dataset_name = 'horse2zebra'

In [0]:
! pip install wget -q
import wget, zipfile
import os

dataset_url = 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/{}.zip'.format(dataset_name)
out_fname = '{}.zip'.format(dataset_name)
wget.download(dataset_url, out=out_fname)

zip_ref = zipfile.ZipFile(out_fname)
zip_ref.extractall('data/')
zip_ref.close()

os.remove(out_fname)

In [0]:
# UpSample on TPU
from tensorflow.python.keras.utils import get_custom_objects
from tensorflow.python.keras.utils import conv_utils
from tensorflow.keras.layers import Layer, InputSpec
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import backend as K

def _blur2d(x, f=[1,2,1], normalize=True, flip=False, stride=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(stride, int) and stride >= 1

    # Finalize filter kernel.
    f = np.array(f, dtype=np.float32)
    if f.ndim == 1:
        f = f[:, np.newaxis] * f[np.newaxis, :]
    assert f.ndim == 2
    if normalize:
        f /= np.sum(f)
    if flip:
        f = f[::-1, ::-1]
    f = f[:, :, np.newaxis, np.newaxis]
    f = np.tile(f, [1, 1, int(x.shape[1]), 1])

    # No-op => early exit.
    if f.shape == (1, 1) and f[0,0] == 1:
        return x

    # Convolve using depthwise_conv2d.
    orig_dtype = x.dtype
    x = tf.cast(x, tf.float32)  # tf.nn.depthwise_conv2d() doesn't support fp16
    f = tf.constant(f, dtype=x.dtype, name='filter')
    strides = [1, 1, stride, stride]
    x = tf.nn.depthwise_conv2d(x, f, strides=strides, padding='SAME', data_format='NCHW')
    x = tf.cast(x, orig_dtype)
    return x

def _upscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Upscale using tf.tile().
    s = x.shape
    x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
    x = tf.tile(x, [1, 1, 1, factor, 1, factor])
    x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
    return x

def _downscale2d(x, factor=2, gain=1):
    assert x.shape.ndims == 4 and all(dim.value is not None for dim in x.shape[1:])
    assert isinstance(factor, int) and factor >= 1

    # 2x2, float32 => downscale using _blur2d().
    if factor == 2 and x.dtype == tf.float32:
        f = [np.sqrt(gain) / factor] * factor
        return _blur2d(x, f=f, normalize=False, stride=factor)

    # Apply gain.
    if gain != 1:
        x *= gain

    # No-op => early exit.
    if factor == 1:
        return x

    # Large factor => downscale using tf.nn.avg_pool().
    # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
    ksize = [1, 1, factor, factor]
    return tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding='VALID', data_format='NCHW')
  
def upscale2d(x, factor=2):
    with tf.variable_scope('Upscale2D'):
        @tf.custom_gradient
        def func(x):
            y = _upscale2d(x, factor)
            @tf.custom_gradient
            def grad(dy):
                dx = _downscale2d(dy, factor, gain=factor**2)
                return dx, lambda ddx: _upscale2d(ddx, factor)
            return y, grad
        return func(x)
      
class UpSampling2D(keras.layers.Layer):
  """Upsampling layer for 2D inputs.
  Repeats the rows and columns of the data
  by size[0] and size[1] respectively.
  Arguments:
      size: int, or tuple of 2 integers.
          The upsampling factors for rows and columns.
      data_format: A string,
          one of `channels_last` (default) or `channels_first`.
          The ordering of the dimensions in the inputs.
          `channels_last` corresponds to inputs with shape
          `(batch, height, width, channels)` while `channels_first`
          corresponds to inputs with shape
          `(batch, channels, height, width)`.
          It defaults to the `image_data_format` value found in your
          Keras config file at `~/.keras/keras.json`.
          If you never set it, then it will be "channels_last".
  Input shape:
      4D tensor with shape:
      - If `data_format` is `"channels_last"`:
          `(batch, rows, cols, channels)`
      - If `data_format` is `"channels_first"`:
          `(batch, channels, rows, cols)`
  Output shape:
      4D tensor with shape:
      - If `data_format` is `"channels_last"`:
          `(batch, upsampled_rows, upsampled_cols, channels)`
      - If `data_format` is `"channels_first"`:
          `(batch, channels, upsampled_rows, upsampled_cols)`
  """

  def __init__(self, size=(2, 2), data_format=None, **kwargs):
    super(UpSampling2D, self).__init__(**kwargs)
    self.data_format = conv_utils.normalize_data_format(data_format)
    self.size = conv_utils.normalize_tuple(size, 2, 'size')
    self.input_spec = InputSpec(ndim=4)

  def call(self, inputs):
    if self.data_format == 'channels_first':
      return upscale2d(inputs, self.size[0])
    else:
      T = tf.transpose(inputs, [0,3,1,2])
      up_T = upscale2d(T, self.size[0])
      return tf.transpose(up_T, [0,2,3,1])

  def get_config(self):
    config = {'size': self.size, 'data_format': self.data_format}
    base_config = super(UpSampling2D, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))
    
get_custom_objects().update({'UpSampling2D': UpSampling2D})

class InstanceNormalization(Layer):
    """Instance normalization layer.
    Normalize the activations of the previous layer at each step,
    i.e. applies a transformation that maintains the mean activation
    close to 0 and the activation standard deviation close to 1.
    # Arguments
        axis: Integer, the axis that should be normalized
            (typically the features axis).
            For instance, after a `Conv2D` layer with
            `data_format="channels_first"`,
            set `axis=1` in `InstanceNormalization`.
            Setting `axis=None` will normalize all values in each
            instance of the batch.
            Axis 0 is the batch dimension. `axis` cannot be set to 0 to avoid errors.
        epsilon: Small float added to variance to avoid dividing by zero.
        center: If True, add offset of `beta` to normalized tensor.
            If False, `beta` is ignored.
        scale: If True, multiply by `gamma`.
            If False, `gamma` is not used.
            When the next layer is linear (also e.g. `nn.relu`),
            this can be disabled since the scaling
            will be done by the next layer.
        beta_initializer: Initializer for the beta weight.
        gamma_initializer: Initializer for the gamma weight.
        beta_regularizer: Optional regularizer for the beta weight.
        gamma_regularizer: Optional regularizer for the gamma weight.
        beta_constraint: Optional constraint for the beta weight.
        gamma_constraint: Optional constraint for the gamma weight.
    # Input shape
        Arbitrary. Use the keyword argument `input_shape`
        (tuple of integers, does not include the samples axis)
        when using this layer as the first layer in a Sequential model.
    # Output shape
        Same shape as input.
    # References
        - [Layer Normalization](https://arxiv.org/abs/1607.06450)
        - [Instance Normalization: The Missing Ingredient for Fast Stylization](
        https://arxiv.org/abs/1607.08022)
    """
    def __init__(self,
                 axis=None,
                 epsilon=1e-3,
                 center=True,
                 scale=True,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 beta_regularizer=None,
                 gamma_regularizer=None,
                 beta_constraint=None,
                 gamma_constraint=None,
                 **kwargs):
        super(InstanceNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.axis = axis
        self.epsilon = epsilon
        self.center = center
        self.scale = scale
        self.beta_initializer = initializers.get(beta_initializer)
        self.gamma_initializer = initializers.get(gamma_initializer)
        self.beta_regularizer = regularizers.get(beta_regularizer)
        self.gamma_regularizer = regularizers.get(gamma_regularizer)
        self.beta_constraint = constraints.get(beta_constraint)
        self.gamma_constraint = constraints.get(gamma_constraint)

    def build(self, input_shape):
        ndim = len(input_shape)
        if self.axis == 0:
            raise ValueError('Axis cannot be zero')

        if (self.axis is not None) and (ndim == 2):
            raise ValueError('Cannot specify axis for rank 1 tensor')

        self.input_spec = InputSpec(ndim=ndim)

        if self.axis is None:
            shape = (1,)
        else:
            shape = (input_shape[self.axis],)

        if self.scale:
            self.gamma = self.add_weight(shape=shape,
                                         name='gamma',
                                         initializer=self.gamma_initializer,
                                         regularizer=self.gamma_regularizer,
                                         constraint=self.gamma_constraint)
        else:
            self.gamma = None
        if self.center:
            self.beta = self.add_weight(shape=shape,
                                        name='beta',
                                        initializer=self.beta_initializer,
                                        regularizer=self.beta_regularizer,
                                        constraint=self.beta_constraint)
        else:
            self.beta = None
        self.built = True

    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))

        if self.axis is not None:
            del reduction_axes[self.axis]

        del reduction_axes[0]

        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev

        broadcast_shape = [1] * len(input_shape)
        if self.axis is not None:
            broadcast_shape[self.axis] = input_shape[self.axis]

        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed = normed * broadcast_gamma
        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed = normed + broadcast_beta
        return normed

    def get_config(self):
        config = {
            'axis': self.axis,
            'epsilon': self.epsilon,
            'center': self.center,
            'scale': self.scale,
            'beta_initializer': initializers.serialize(self.beta_initializer),
            'gamma_initializer': initializers.serialize(self.gamma_initializer),
            'beta_regularizer': regularizers.serialize(self.beta_regularizer),
            'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
            'beta_constraint': constraints.serialize(self.beta_constraint),
            'gamma_constraint': constraints.serialize(self.gamma_constraint)
        }
        base_config = super(InstanceNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


get_custom_objects().update({'InstanceNormalization': InstanceNormalization})

In [0]:
class Generator:
    """U-Net Generator"""
    def __init__(self):
      
        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = InstanceNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=img_shape)

        # Downsampling
        d1 = conv2d(d0, gf)
        d2 = conv2d(d1, gf*2)
        d3 = conv2d(d2, gf*4)
        d4 = conv2d(d3, gf*8)

        # Upsampling
        u1 = deconv2d(d4, d3, gf*4)
        u2 = deconv2d(u1, d2, gf*2)
        u3 = deconv2d(u2, d1, gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        self.model =  Model(d0, output_img)
        
    def __call__(self, x):
        
        layers = {}
        layers[self.model.layers[0].name] = x
        y = x
        for layer in self.model.layers[1:]:
          if isinstance(layer.input, list):
            y = []
            for inp in layer.input:
              name = inp.name.split(':')[0].split('/')[0]
              y.append(layers[name])
          else:
            name = layer.input.name.split(':')[0].split('/')[0]
            y = layers[name]
          y = layer(y)
          layers[layer.name] = y

        return y
        


class Discriminator:
  
    def __init__(self):
        self.layers = []
        model = self.layers
        model.append(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.append(ZeroPadding2D(padding=((0,1),(0,1))))
        model.append(InstanceNormalization())
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.append(InstanceNormalization())
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Conv2D(256, kernel_size=3, strides=1, padding="same"))
        model.append(InstanceNormalization())
        model.append(LeakyReLU(alpha=0.2))
        model.append(Dropout(0.25))
        model.append(Flatten())
        model.append(Dense(1, activation='sigmoid'))

    def __call__(self, x):
        y = x
        for layer in self.layers:
            y = layer(y)
            
        return y
  
def SelfGAN():
    
    G_AB = Generator()
    G_BA = Generator()
    D_A = Discriminator()
    D_B = Discriminator()
    
    realA = Input(shape=img_shape)
    realB = Input(shape=img_shape)
    fakeA = Input(shape=img_shape)
    fakeB = Input(shape=img_shape)
    
    # Identity gen
    idenA = G_BA(realA)
    idenB = G_AB(realB)
    
    # GAN validity
    genA = G_BA(realB)
    genB = G_AB(realA)
    genA = Lambda(lambda x: x*1.0, name='genA')(genA)
    genB = Lambda(lambda x: x*1.0, name='genB')(genB)
    validity_genA = D_A(genA)
    validity_realA = D_A(realA)
    validity_fakeA = D_A(fakeA)
    validity_genB = D_B(genB)
    validity_realB = D_B(realB)
    validity_fakeB = D_B(fakeB)

    # Cycle gen
    recA = G_BA(genB)
    recB = G_AB(genA)
    
    # compute loss
    criterion_GAN = Lambda(lambda x: keras.losses.mean_squared_error(x[0], x[1]))
    criterion_cycle = Lambda(lambda x: keras.losses.mean_absolute_error(x[0], x[1]))
    criterion_identity = Lambda(lambda x: keras.losses.mean_absolute_error(x[0], x[1]))
    
    # Identity loss
    loss_id_A = criterion_identity([Flatten()(idenA), Flatten()(realA)])
    loss_id_B = criterion_identity([Flatten()(idenB), Flatten()(realB)])

    loss_identity = Lambda(lambda x: (x[0]+x[1])/2, name='identity_loss')([loss_id_A, loss_id_B])
    
    valid = Input(shape=(1,))
    fake = Input(shape=(1,))
    # Self GAN loss
    
    # loss A2B
    gen_loss = criterion_GAN([validity_genA, valid])
    real_loss = criterion_GAN([validity_realA, valid])
    fake_loss = criterion_GAN([validity_fakeA, fake])
    
    v_g = Lambda(lambda x: K.abs(1 - K.mean(x)))(validity_genA)
    v_r = Lambda(lambda x: K.abs(1 - K.mean(x)))(validity_realA)
    v_f = Lambda(lambda x: K.mean(x))(validity_fakeA)
    v_sum = Lambda(lambda x: x[0]+x[1]+x[2])([v_g,v_r,v_f])
    s_lossA = Lambda(lambda x: x[2]*x[1]/x[0] \
                            + x[4]*x[3]/x[0] \
                            + x[6]*x[5]/x[0])([v_sum, v_r, real_loss, v_g, gen_loss, v_f, fake_loss])
    
    # loss B2A
    gen_loss = criterion_GAN([validity_genB, valid])
    real_loss = criterion_GAN([validity_realB, valid])
    fake_loss = criterion_GAN([validity_fakeB, fake])
    
    v_g = Lambda(lambda x: K.abs(1 - K.mean(x)))(validity_genB)
    v_r = Lambda(lambda x: K.abs(1 - K.mean(x)))(validity_realB)
    v_f = Lambda(lambda x: K.mean(x))(validity_fakeB)
    v_sum = Lambda(lambda x: x[0]+x[1]+x[2])([v_g,v_r,v_f])
    s_lossB = Lambda(lambda x: x[2]*x[1]/x[0] \
                            + x[4]*x[3]/x[0] \
                            + x[6]*x[5]/x[0])([v_sum, v_r, real_loss, v_g, gen_loss, v_f, fake_loss])
    
    s_loss = Lambda(lambda x: (x[0]+x[1])/2, name='self_loss')([s_lossA, s_lossB])
    
    # Cycle loss
    loss_cycle_A = criterion_cycle([Flatten()(recA), Flatten()(realA)])
    loss_cycle_B = criterion_cycle([Flatten()(recB), Flatten()(realB)])

    loss_cycle = Lambda(lambda x: (x[0]+x[1])/2, name='cycle_loss')([loss_cycle_A , loss_cycle_B])
    
    def loss_All(x, lambda_cyc=10, lambda_id=1):
      loss_s, loss_cycle, loss_identity = x
      return loss_s + \
             lambda_cyc * loss_cycle + \
             lambda_id * loss_identity
    
    all_loss = Lambda(loss_All,
                      arguments={'lambda_cyc':lambda_cyc, 
                                 'lambda_id':lambda_id})([s_loss, 
                                                          loss_cycle, 
                                                          loss_identity])
    
    return Model([realA, realB, fakeA, fakeB, valid, fake], [all_loss])
  
def sample_images(model, epoch, imgA, imgB, last_imgA, last_imgB, valid, fake):
  
    ret = model.predict([imgA, imgB, last_imgA, last_imgB, valid, fake])
    imgA = imgA[:6].transpose(1,0,2,3).reshape((imgA.shape[1],-1,imgA.shape[-1]))
    imgB = imgB[:6].transpose(1,0,2,3).reshape((imgB.shape[1],-1,imgB.shape[-1]))
    fakeB = ret[-1][:6].transpose(1,0,2,3).reshape((ret[-1].shape[1],-1,ret[-1].shape[-1]))
    fakeA = ret[-2][:6].transpose(1,0,2,3).reshape((ret[-2].shape[1],-1,ret[-2].shape[-1]))
    
    gen_imgs = np.concatenate([imgA, fakeB, imgB, fakeA])

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(1, 1)
    cnt = 0
    axs.imshow(gen_imgs)
    axs.axis('off')
    fig.savefig("images/%d.png" % epoch)
    plt.close()


In [0]:
# Function override
from tensorflow.contrib.tpu.python.tpu.keras_support import TPUFunction
from tensorflow.keras.models import Model
from tensorflow.python.estimator import model_fn as model_fn_lib 
ModeKeys = model_fn_lib.ModeKeys

def extra_outputs(self):
  outputs = []
  sort_names = ['self_loss', 'cycle_loss', 'genA', 'genB']
  for name in sort_names:
    for layer in self.layers:
      if name in layer.name:
        outputs.append(layer.output)
  return outputs

def _make_predict_function(self):
  if not hasattr(self, 'predict_function'):
    self.predict_function = None
  if self.predict_function is None:
    inputs = self._feed_inputs
    # Gets network outputs. Does not update weights.
    # Does update the network states.
    kwargs = getattr(self, '_function_kwargs', {})
    with K.name_scope(ModeKeys.PREDICT):
      self.predict_function = K.function(
          inputs,
          self.outputs+extra_outputs(self),
          updates=self.state_updates,
          name='predict_function',
          **kwargs)
      
def _make_fit_function(self):
  metrics_tensors = [
      self._all_stateful_metrics_tensors[m] for m in self.metrics_names[1:]
  ]
  self._make_train_function_helper(
      '_fit_function', [self.total_loss] + metrics_tensors + extra_outputs(self))
  
Model._make_predict_function = _make_predict_function
Model._make_fit_function = _make_fit_function

def _process_outputs(self, outfeed_outputs):
    """Processes the outputs of a model function execution.
    Args:
      outfeed_outputs: The sharded outputs of the TPU computation.
    Returns:
      The aggregated outputs of the TPU computation to be used in the rest of
      the model execution.
    """
    # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
    if self.execution_mode == ModeKeys.PREDICT:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])

      return [np.concatenate(group) for group in outputs]
    else:
      outputs = [[] for _ in range(len(self._outfeed_spec))]
      outputs_per_replica = len(self._outfeed_spec)

      for i in range(self._tpu_assignment.num_towers):
        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
                                       outputs_per_replica]
        for j in range(outputs_per_replica):
          outputs[j].append(output_group[j])
      
      ret = []
      for group in outputs:
        if len(group[0].shape) > 0:
          ret.append(np.concatenate(group))
        else:
          ret.append(group[0])
      return ret
    
TPUFunction._process_outputs = _process_outputs

In [0]:
tf.keras.backend.clear_session()

optimizer = Adam(0.0002, 0.5)
model = SelfGAN()
model.compile(loss='mae',optimizer=optimizer)

TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

def initialize_uninitialized_variables():
    sess = K.get_session()
    uninitialized_variables = set([i.decode('ascii') for i in sess.run(tf.report_uninitialized_variables())])
    init_op = tf.variables_initializer(
        [v for v in tf.global_variables() if v.name.split(':')[0] in uninitialized_variables]
    )
    sess.run(init_op)
initialize_uninitialized_variables()

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))


bufferA = ReplayBuffer()
bufferB = ReplayBuffer()
last_imgA = np.ones((batch_size,)+img_shape)
last_imgB = -np.ones((batch_size,)+img_shape)
loss_zeros = np.zeros((batch_size,))

In [0]:
# Init tpu model
train_outputs = model.train_on_batch([last_imgA, last_imgB, last_imgA, last_imgB, valid, fake], [loss_zeros])
predict_output = model.predict([last_imgA, last_imgB, last_imgA, last_imgB, valid, fake])

dataLoader = DataLoader(ImageDataset('data/horse2zebra'), batch_size)

In [0]:
for epoch in range(epochs):
  
    for i, (imgA, imgB) in enumerate(dataLoader):

        # ---------------------
        #  Train Discriminator
        # ---------------------

        # Generate a batch of new images

        outputs = model.train_on_batch([imgA, imgB, last_imgA, last_imgB, valid, fake], [loss_zeros])
        all_loss = outputs[0]/8
        self_loss = np.mean(outputs[1])/(batch_size/8)
        cycle_loss = np.mean(outputs[2])/(batch_size/8)
        identity_loss = np.mean(outputs[3])/(batch_size/8)
        last_imgA = bufferA.push_and_pop(outputs[-2])
        last_imgB = bufferB.push_and_pop(outputs[-1])
        # Plot the progress
        if i % 25 == 0:
            sys.stdout.flush()
            print ("\r[Epoch %d/%d] [Batch %d/%d]  [All loss: %f  self loss: %f cycle loss: %f  iden loss: %f]" % (epoch, epochs, i,
                                                                            len(dataLoader), all_loss,
                                                                            self_loss, cycle_loss, identity_loss),end='')

        # If at save interval => save generated image samples
        if (epoch*len(dataLoader) + i) % sample_interval == 0:
            sample_images(model, epoch*len(dataLoader) + i, imgA, imgB, last_imgA, last_imgB, valid, fake)
            