In [1]:
import tensorflow as tf
import tensorflow.keras as keras


The model architecture used in this notebook is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:

* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).
* The [CycleGAN](https://arxiv.org/abs/1703.10593) paper uses a modified resnet based generator. This notebook is using a modified unet generator for simplicity.
There are 2 generators (G and F) and 2 discriminators (X and Y) being trained here.

### Instace Normalization.
First we will create an ``InstanceNormalization``Layer from this [paper](https://arxiv.org/abs/1607.08022).

In [16]:
test_image = tf.random.normal([256, 256, 3])
test_image.shape

TensorShape([256, 256, 3])

In [2]:
class InstanceNormalization(keras.layers.Layer):
  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name="scale",
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable = True
    )
    self.offset = self.add_weight(
        name="offset",
        shape=input_shape[-1:],
        initializer = "zeros",
        trainable = True
    )
  
  def call(self, x):
    mean, varience = tf.nn.moments(
        x, axes=[1, 2], keepdims=True
    )
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset


### Encoder, (Downsampler)
We are going to create a class that will downsample the image from higher resolution to a lower resolution. The structure of the downsampler is as follows:

```py
[ Conv2D ] => [ Batchnorm ] => [ LeakyRelu ]
```

Args:

```
  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_norm: If True, adds the batchnorm layer
```

Return:
```
Downsample Keras Layer
```

In [27]:
class Encoder(keras.layers.Layer):
  def __init__(self, in_features, kernel_size,
               norm_type="batchnorm", apply_norm=True):
    super(Encoder, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)

    self.apply_norm = apply_norm

    self.conv = keras.layers.Conv2D(
        in_features, kernel_size = kernel_size, strides=2,
        padding="same",
        kernel_initializer=initializer, use_bias=False
    )

    self.norm = None
    if  norm_type.lower() == 'batchnorm':
      self.norm = keras.layers.BatchNormalization()
    elif norm_type.lower() == 'instancenorm':
       self.norm = InstanceNormalization()

    self.l_relu = keras.layers.LeakyReLU()
    
  def call(self, x):
    x = self.conv(x)
    if self.apply_norm:
      x = self.norm(x)
    return self.l_relu(x)

In [29]:
down_model = Encoder(3, 4)
down_model(tf.expand_dims(test_image, 0)).shape

TensorShape([1, 128, 128, 3])

### Decoder (Upsampler).

In the ``Decoder`` or the `Upsampler` we are going to use the ``Conv2DTranspose()`` you can also use the UpSampling2D together with ``Conv2D`` which is what we will do later on. The structure of this layer looks as follows:

```
[ Conv2DTranspose ] => [ Batchnorm ] => [ Dropout ] => [ ReLu ]
```

Args:
```
  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer
```

Returns:
```
Returns:
    Upsample Keras Layer.
```



In [30]:
class Decoder(keras.layers.Layer):
  def __init__(self, in_features, kernel_size, 
               norm_type="batchnorm", apply_dropout=False):
    super(Decoder, self).__init__()
    initializer = tf.random_normal_initializer(0., 0.02)

    self.norm = None
    self.apply_dropout = apply_dropout
    self.conv_2d_transpose = keras.layers.Conv2DTranspose(
      in_features, kernel_size=kernel_size, strides=2,
      padding="same",
      kernel_initializer= initializer,
      use_bias = False
    )
    if norm_type.lower() == 'batchnorm':
      self.norm = keras.layers.BatchNormalization()
    elif norm_type.lower() == 'instancenorm':
      self.norm = InstanceNormalization()

    self.dropout = keras.layers.Dropout(.5, name="decoder_dropout")
    self.relu = keras.layers.ReLU()

  def call(self, x):
    x = self.conv_2d_transpose(x)
    x = self.norm(x)
    if self.apply_dropout:
      x = self.dropout(x)
    return self.relu(x)

### The ``discriminator`` model.

We are going to us the [PatchGan](https://arxiv.org/abs/1611.07004) discriminator. which take the following args:

```
Args:
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
    target: Bool, indicating whether target image is an input or not.
```
and returns:

```
Returns:
  Discriminator model
```


In [8]:
IMG_HEIGHT = IMG_WIDTH = 256

In [31]:
def discriminator(norm_type='batchnorm', target=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 3], name='input_image')
  x = inp

  if target:
    tar = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, 3], name='target_image')
    x = keras.layers.concatenate([inp, tar], name="concatenated_inputs") # (batch_size, 256, 256, channels*2)
 
  down_1 = Encoder(64, 4, norm_type,  False)(x) # (batch_size, 128, 128, 64)
  down_2 = Encoder(128, 4, norm_type)(down_1)  # (batch_size, 64, 64, 128)
  down_3 = Encoder(256, 4, norm_type)(down_2) # (batch_size, 32, 32, 256)
  zero_pad1 = keras.layers.ZeroPadding2D(name="zero_padding_layer_1")(down_3) # (batch_size, 34, 34, 256)
  
  conv = keras.layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer,
                             use_bias=False,
                             name="conv_layer")(zero_pad1)
  norm = None
  if norm_type.lower() == "batchnorm":
    norm = keras.layers.BatchNormalization(name="batch_norm")(conv)
  elif norm_type.lower() == 'instancenorm':
    norm = InstanceNormalization()(conv)
  
  leaky_relu = keras.layers.LeakyReLU(name="leaky_relu")(norm)

  zero_pad2 = keras.layers.ZeroPadding2D(name="zero_padding_layer_2")(leaky_relu) # (batch_size, 33, 33, 512)
  last = keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer,
                             name="output_layer"
                             )(zero_pad2)  # (batch_size, 30, 30, 1)

  if target:
    return keras.Model(inputs=[inp, tar], outputs=last, name="discriminator_model")
  else:
    return keras.Model(inputs=inp, outputs=last, name="discriminator_model")


In [32]:
discriminator().summary()

Model: "discriminator_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_image (InputLayer)        [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
target_image (InputLayer)       [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
concatenated_inputs (Concatenat (None, 256, 256, 6)  0           input_image[0][0]                
                                                                 target_image[0][0]               
__________________________________________________________________________________________________
encoder_15 (Encoder)            (None, 128, 128, 64) 6144        concatenated_in

### The ``generator`` model.

We are going to build a `unet_generator` based on [this paper](https://arxiv.org/abs/1611.07004) which takes the following args.

Args:
```
Args:
  output_channels: Output channels
  norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
```

Return:
```
  Returns:
    Generator model
```