In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Conv2D, Activation, UpSampling2D, Dropout, BatchNormalization, Conv2DTranspose, Concatenate
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras import activations, Sequential
import graphviz

In [6]:
LAMBDA = 100

In [7]:
class Generator:
    
    def __init__(self, shape_of_image) :
        self.input_layer = Input(shape=shape_of_image)
        self.previous_layer = self.input_layer
        
        # downsampling
        d1 = self.downsample(64, 4)
        d2 = self.downsample(128, 4) 
        d3 = self.downsample(256, 4) 
        d4 = self.downsample(512, 4) 
        d5 = self.downsample(512, 4) 
        d6 = self.downsample(512, 4) 
        d7 = self.downsample(512, 4) 
        d8 = self.downsample(512, 4) 
        
        # upsampling
        self.upsample(512, 4, d7,  dropout_rate=0.5) 
        self.upsample(512, 4, d6,  dropout_rate=0.5) 
        self.upsample(512, 4, d5,  dropout_rate=0.5) 
        self.upsample(512, 4, d4) 
        self.upsample(256, 4, d3) 
        self.upsample(128, 4, d2) 
        self.upsample(64,  4, d1) 
        
        self.output_layer = Conv2DTranspose(3, 4, strides=2, padding='same', activation='tanh')(self.previous_layer)
        self.model = keras.Model(inputs=self.input_layer, outputs=self.output_layer, name="gen")
        self.optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
        
    
    def downsample(self, filters, kernel_size):
        sequential = Sequential()
        initializer = tf.random_normal_initializer(0., 0.02)
        conv = Conv2D(filters=filters, kernel_size=kernel_size, strides=2, padding='same', kernel_initializer=initializer)
        sequential.add(conv)
        norm = BatchNormalization()
        sequential.add(norm)
        relu = Activation(activations.relu)
        sequential.add(relu)
        self.previous_layer = sequential(self.previous_layer)
        return self.previous_layer
    
    def upsample(self, filters, kernel_size, skip_layer=None, dropout_rate=0.0):
        initializer = tf.random_normal_initializer(0., 0.02)
        sequential = Sequential()
        conv_transpose = Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=2, padding='same', kernel_initializer=initializer)
        sequential.add(conv_transpose)
        norm = BatchNormalization()
        sequential.add(norm)
        if (dropout_rate > 0) :
            dropout = Dropout(dropout_rate)
            sequential.add(dropout)
            
        relu = Activation(activations.relu)
        sequential.add(relu)
        self.previous_layer = sequential(self.previous_layer)
        if skip_layer is not None:
            self.previous_layer = tf.keras.layers.Concatenate()([self.previous_layer, skip_layer])

    def loss(self, discriminator_output, generator_output, target): 
        # generator_loss = cross_entropy_loss + LAMBDA * L1_loss
        
        binary_cross_entropy = BinaryCrossentropy(from_logits=True)
        # cross-entropy loss -> discriminator(generated image) & 1
        cross_entropy_loss = binary_cross_entropy(tf.ones_like(discriminator_output), discriminator_output)

        # mean absolute error -> original_image & generated image
        l1_loss = tf.reduce_mean(tf.abs(target - generator_output))

        generator_loss = cross_entropy_loss + (LAMBDA * l1_loss)
        return generator_loss, cross_entropy_loss, l1_loss

In [8]:
generator = Generator((256, 256, 3))
generator_model = generator.model

In [9]:
generator_model.summary()

Model: "gen"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 sequential_15 (Sequential)     (None, 128, 128, 64  3392        ['input_2[0][0]']                
                                )                                                                 
                                                                                                  
 sequential_16 (Sequential)     (None, 64, 64, 128)  131712      ['sequential_15[0][0]']          
                                                                                                

In [10]:
keras.utils.plot_model(generator_model, show_shapes=True, dpi=64, to_file='../images/generatorModel.png')

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


In [11]:
print(generator.optimizer)

<keras.optimizers.optimizer_v2.adam.Adam object at 0x000002C09ED34460>
