<a href="https://colab.research.google.com/github/aju22/pix2pix/blob/main/Pix2Pix_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow.keras.layers as layers

# Model Architecture

The Pix2Pix GAN is a general approach for image-to-image translation. It is based on the conditional generative adversarial network, where a target image is generated, conditional on a given input image. In this case, the Pix2Pix GAN changes the loss function so that the generated image is both plausible in the content of the target domain, and is a plausible translation of the input image.


---


![Model](https://paper-attachments.dropbox.com/s_84D9D849F786EC83B26BF2A0F74F0C33230682E8BA1D41AD8C3F3D770D23236A_1566175741579_dlhacks-perceptual-adversarial-networks-for-imagetoimage-transformation-7-638.jpg)


---


### Generator

A U-Net model architecture is used for the generator, instead of the common encoder-decoder model.

### Discriminator

Unlike the traditional GAN model that uses a deep convolutional neural network to classify images, the Pix2Pix model uses a PatchGAN. This is a deep convolutional neural network designed to classify patches of an input image as real or fake, rather than the entire image.



In [None]:
class CNNBlock(layers.Layer):
  def __init__(self, filters, size, apply_batchnorm = True, apply_dropout = False, downsample = True):
    super().__init__()

    self.apply_batchnorm = apply_batchnorm if downsample else True
    self.apply_dropout = apply_dropout if not downsample else False
    self.conv = layers.Conv2D(filters, size, strides=2, padding='same', use_bias=False) if downsample else tf.keras.layers.Conv2DTranspose(filters, size, strides=2, padding='same',use_bias=False) 
    self.bn = layers.BatchNormalization()
    self.leaky_relu = layers.LeakyReLU()
    self.dropout = layers.Dropout(0.5)
  
  def call(self, x):
    x = self.conv(x)
    
    if self.apply_batchnorm:
      x = self.bn(x)
    
    x = self.leaky_relu(x)
    
    if self.apply_dropout:
      x = self.droput(x)  

    return x

In [None]:
class Generator(tf.keras.Model):
  
  def __init__(self, out_channels):

    super().__init__()

    self.downsample = tf.keras.Sequential([
                                           CNNBlock(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
                                           CNNBlock(128, 4),  # (batch_size, 64, 64, 128)
                                           CNNBlock(256, 4),  # (batch_size, 32, 32, 256)
                                           CNNBlock(512, 4),  # (batch_size, 16, 16, 512)
                                           CNNBlock(512, 4),  # (batch_size, 8, 8, 512)
                                           CNNBlock(512, 4),  # (batch_size, 4, 4, 512)
                                           CNNBlock(512, 4),  # (batch_size, 2, 2, 512)
                                           CNNBlock(512, 4),  # (batch_size, 1, 1, 512)
                                         ])
  
    self.upsample = tf.keras.Sequential([
                                         CNNBlock(512, 4, apply_dropout=True, downsample = False),  # (batch_size, 2, 2, 1024)
                                         CNNBlock(512, 4, apply_dropout=True, downsample = False),  # (batch_size, 4, 4, 1024)
                                         CNNBlock(512, 4, apply_dropout=True, downsample = False),  # (batch_size, 8, 8, 1024)
                                         CNNBlock(512, 4, downsample = False),  # (batch_size, 16, 16, 1024)
                                         CNNBlock(256, 4, downsample = False),  # (batch_size, 32, 32, 512)
                                         CNNBlock(128, 4, downsample = False),  # (batch_size, 64, 64, 256)
                                         CNNBlock(64, 4, downsample = False),   # (batch_size, 128, 128, 128)
                                        ])
    
    self.final = tf.keras.layers.Conv2DTranspose(out_channels, 4, strides=2, padding='same', activation='tanh')


  def call(self, x):

    skips = []
    for down in self.downsample:
      x = down(x)
      skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(self.upsample, skips):
      x = up(x)
      x = tf.keras.layers.Concatenate()([x, skip])

    x = self.final(x)

    return x  

In [None]:
class Discriminator(tf.keras.Model):
  def __init__(self):
    super().__init__()

    self.concat = layers.concatenate
    self.down1 = CNNBlock(64, 4, False)
    self.down2 = CNNBlock(128, 4) 
    self.down3 = CNNBlock(256, 4)
    self.zero_pad1 = layers.ZeroPadding2D()
    self.conv = layers.Conv2D(512, 4, strides=1, use_bias=False)
    self.bn1 = layers.BatchNormalization()
    self.leaky_relu = layers.LeakyReLU()
    self.zero_pad2 = layers.ZeroPadding2D()
    self.final = layers.Conv2D(1, 4, strides=1)

  def call(self, input_tensor):

    #here input_tensor = [inp, tar]

    x = self.concat([input_tensor[0], input_tensor[1]])
    x  = self.down1(x)
    x  = self.down2(x)
    x = self.down3(x)
    x = self.zero_pad1(x)
    x = self.conv(x)
    x = self.bn1(x)
    x = self.leaky_relu(x)
    x = self.zero_pad2(x)
    
    return self.final(x)

In [None]:
gen_model = Generator(out_channels = 3)

In [None]:
disc_model = Discriminator()