In [6]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Activation, BatchNormalization, Conv2D, concatenate, Input, ZeroPadding2D
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras import activations, Sequential
import graphviz

In [7]:
class Discriminator:
    
    def __init__(self, shape_of_image):
        self.input_layer = Input(shape=shape_of_image, name='input_image')
        self.target_layer = Input(shape=shape_of_image, name='target_image')
        # size of new input -> (256, 256, 6)
        self.current_layer = concatenate([self.input_layer, self.target_layer])

        initializer = tf.random_normal_initializer(0., 0.02)
        
        d1 = self.downsample(64, 4)
        d2 = self.downsample(128, 4)
        d3 = self.downsample(256, 4)

        sequential = Sequential()
        zero_padding_1 = ZeroPadding2D()(d3)
        conv = Conv2D(filters=512, kernel_size=4, strides=1, kernel_initializer=initializer)(zero_padding_1)
        batch_norm = BatchNormalization()(conv)
        activation = Activation(activations.relu)(batch_norm)
        zero_padding_2 = ZeroPadding2D()(activation)

        self.output_layer = Conv2D(filters=1, kernel_size=4, strides=1, kernel_initializer=initializer)(zero_padding_2)
        self.model = keras.Model(inputs=[self.input_layer, self.target_layer], outputs=self.output_layer)
        self.optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)


    def downsample(self, filters, kernel_size):
        initializer = tf.random_normal_initializer(0., 0.02)
        
        sequential = Sequential()
        conv = Conv2D(filters=filters, kernel_size=kernel_size, strides=2, padding='same', kernel_initializer=initializer)
        norm = BatchNormalization()
        activation = Activation(activations.relu)
        
        sequential.add(conv)
        sequential.add(norm)
        sequential.add(activation)

        self.current_layer = sequential(self.current_layer)
        return self.current_layer

    def loss(self, real_output, generated_output):
        # discriminator_loss = real_loss + generated_los
        binary_cross_entropy = BinaryCrossentropy(from_logits=True)

        # binary cross entropy -> discriminator(real image) & 1
        real_loss = binary_cross_entropy(tf.ones_like(real_output), real_output)

        # binary cross entropy -> discriminator(generated image) & 0
        generated_loss = binary_cross_entropy(tf.zeros_like(generated_output), generated_output)
        
        discriminator_loss = real_loss + generated_loss
        return discriminator_loss

In [8]:
discriminator = Discriminator((256, 256, 3))
discriminator_model = discriminator.model
discriminator_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 target_image (InputLayer)      [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 concatenate_1 (Concatenate)    (None, 256, 256, 6)  0           ['input_image[0][0]',            
                                                                  'target_image[0][0]']     

In [9]:
keras.utils.plot_model(discriminator_model, show_shapes=True, dpi=64, to_file='../images/discriminatorModel.png')

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


In [10]:
discriminator.optimizer

<keras.optimizers.optimizer_v2.adam.Adam at 0x13687e2cf70>