<a href="https://colab.research.google.com/github/GarlandZhang/pg-toons/blob/master/cartoongan_impl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
!rm -r keras-contrib/ 
# !pip uninstall -y tensorflow-gpu==2.0.0-alpha0
# !pip uninstall -y tensorflow
# !pip install tensorflow-gpu==2.0.0-alpha0
!git clone https://www.github.com/keras-team/keras-contrib.git \
    && cd keras-contrib \
    && pip install git+https://www.github.com/keras-team/keras-contrib.git \
    && python convert_to_tf_keras.py \
    && USE_TF_KERAS=1 python setup.py install
# clear_output()
!pip install -q  --no-deps tensorflow-addons~=0.7

In [23]:
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer, InputSpec, DepthwiseConv2D, Conv2D, BatchNormalization, Add, ReLU, LeakyReLU, ZeroPadding2D, Activation
import tensorflow_addons as tfa

In [30]:
import numpy as np

In [6]:
"""
Not sure what this does
"""

class ReflectionPadding2D(Layer):
  def __init__(self, padding=(1, 1), **kwargs):
    super(ReflectionPadding2D, self).__init__(**kwargs)
    padding = tuple(padding)
    self.padding = ((0, 0), padding, padding, (0, 0))
    self.input_spec = [InputSpec(ndim=4)]

  def compute_output_shape(self, s):
    return s[0], s[1] + 2 * self.paddding[0], s[2] + 2 * self.padding[1], s[3]

  def call(self, x):
    return tf.pad(x, self.padding, "REFLECT")

In [7]:
def get_padding(pad_type, padding):
  if pad_type == 'reflect':
    return ReflectionPadding2D(padding)
  elif pad_type == 'constant':
    return ZeroPadding2D(padding)
  else:
    raise ValueError(f'Invalid padding type: {pad_type}')

In [24]:
def get_norm(norm_type):
  if norm_type == 'instance':
    return tfa.layers.InstanceNormalization()
  elif norm_type == 'batch':
    return BatchNormalization
  else:
    raise ValueError(f'Invalid norm type: {norm_type}')

In [34]:
class FlatConv(Model):
  def __init__(self,
               filters,
               kernel_size,
               norm_type='instance',
               pad_type='constant',
               **kwargs):
    super(FlatConv, self).__init__(name='FlatConv')
    """
    # assuming stride=1, remaining size will be (W - kernel_size) + 1 * (8 - kernel_size) + 1 so must add back (kernel_size - 1) // 2
    """
    padding = (kernel_size - 1) // 2
    padding = (padding, padding)

    self.model = Sequential([
                             get_padding(pad_type, padding),
                             Conv2D(filters, kernel_size),
                             get_norm(norm_type),
                             ReLU()
    ])

  def call(self, x, training=False):
    return self.model(x, training=training)


In [35]:
class ConvBlock(Model):
  def __init__(self,
               filters,
               kernel_size,
               stride=1,
               norm_type='instance',
               pad_type='constant',
               **kwargs):
    super(ConvBlock, self).__init__(name='ConvBlock')
  
    padding = (kernel_size - 1) // 2
    padding = (padding, padding)

    self.model = Sequential([
                              get_padding(pad_type, padding),
                              Conv2D(filters, kernel_size, strides=stride),
                              get_padding(pad_type, padding),
                              Conv2D(filters, kernel_size),
                              get_norm(norm_type),
                              ReLU()
    ])

  def call(self, x, training=False):
    return self.model(x, training=training)

In [16]:
class ResBlock(Model):
  def __init__(self,
               filters,
               kernel_size,
               norm_type='instance',
               pad_type='constant',
               **kwargs):
    super(ResBlock, self).__init__(name='ResBlock')
    
    padding = (kernel_size - 1) // 2
    padding = (padding, padding)

    self.model = Sequential([
                             get_padding(pad_type, padding),
                             Conv2D(filters, kernel_size),
                             get_norm(norm_type),
                             ReLU(),
                             get_padding(pad_type, padding),
                             Conv2D(filters, kernel_size),
                             get_norm(norm_type)
    ])

    self.add = Add()

  def call(self, x, training=False):
    return self.add([self.model(x, training=training), x])

In [39]:
class UpSampleConv(Model):
  def __init__(self,
               filters,
               kernel_size,
               norm_type='instance',
               pad_type='constant',
               **kwargs):
    super(UpSampleConv, self).__init__(name='UpSampleConv')
    self.model = ConvBlock(filters, kernel_size, 1, norm_type, pad_type)

  def call(self, x, training=False):
    """
    In downsampling, we can get away with reducing size in half by setting stride=2
    In upsample, we need to do directly resize the images
    """
    x = tf.keras.backend.resize_images(x, 2, 2, "channels_last", 'bilinear')
    return self.model(x, training=training)

In [41]:
class Generator(Model):
  def __init__(self,
               norm_type='instance',
               pad_type='constant',
               base_filters=64,
               num_resblocks=8):
    super(Generator, self).__init__(name='Generator')
    self.flat_conv = FlatConv(filters=base_filters, kernel_size=7, norm_type=norm_type, pad_type=pad_type)
    self.down_conv1 = ConvBlock(mid_filters=base_filters, filters=base_filters * 2, kernel_size=3, stride=2, norm_type=norm_type, pad_type=pad_type)
    self.down_conv2 = ConvBlock(mid_filters=base_filters, filters=base_filters * 4, kernel_size=3, stride=2, norm_type=norm_type, pad_type=pad_type)
    self.res_blocks = Sequential([ResBlock(filters=base_filters * 4, kernel_size=3, norm_type=norm_type, pad_type=pad_type) for _ in range(num_resblocks)])
    self.up_conv1 = UpSampleConv(filters=base_filters * 2, kernel_size=3, norm_type=norm_type, pad_type=pad_type)
    self.up_conv2 = UpSampleConv(filters=base_filters, kernel_size=3, norm_type=norm_type, pad_type=pad_type)

    padding = (3, 3)
    self.final_conv = Sequential([
                                  get_padding(pad_type, padding),
                                  Conv2D(filters=3, kernel_size=7),
                                  Activation('tanh')
    ])

  def call(self, x, training=False):
    x = self.flat_conv(x, training=training)
    x = self.down_conv1(x, training=training)
    x = self.down_conv2(x, training=training)
    x = self.res_blocks(x, training=training)
    x = self.up_conv1(x, training=training)
    x = self.up_conv2(x, training=training)
    x = self.final_conv(x, training=training)
    return x

  def compute_output_shape(self, input_shape):
    return tf.TensorShape(input_shape)


In [42]:
g = Generator()
shape = (1, 256, 256, 3)
nx = np.random.rand(*shape).astype(np.float32)
t = tf.keras.Input(shape=nx.shape[1:], batch_size=nx.shape[0])
out = g(t, training=False)
g.summary()
print(f'in: {nx.shape} vs out: {out.shape}')

Model: "Generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
FlatConv (FlatConv)          (1, 256, 256, 64)         9600      
_________________________________________________________________
ConvBlock (ConvBlock)        (1, 128, 128, 128)        221696    
_________________________________________________________________
ConvBlock (ConvBlock)        (1, 64, 64, 256)          885760    
_________________________________________________________________
sequential_115 (Sequential)  (1, 64, 64, 256)          9449472   
_________________________________________________________________
UpSampleConv (UpSampleConv)  (1, 128, 128, 128)        442880    
_________________________________________________________________
UpSampleConv (UpSampleConv)  (1, 256, 256, 64)         110848    
_________________________________________________________________
sequential_118 (Sequential)  (1, 256, 256, 3)          94