In [3]:
import os
import time
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.train import Checkpoint, CheckpointManager
from tensorflow.data import Dataset
from tensorflow.data.experimental import AUTOTUNE
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import Mean
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import TruncatedNormal, RandomNormal
from tensorflow.keras.layers import Input, Dense, Reshape, BatchNormalization, Conv2D, Conv2DTranspose, \
        LeakyReLU, Flatten, SpatialDropout2D, Dropout, MaxPool2D, GlobalAvgPool2D, Concatenate, LayerNormalization

from IPython import display

In [4]:
BASE_PATH = 'C:/Users/s4571730/Downloads/img_align_celeba'
RANDOM_STATE = 7
SHUFFLE_BUFFER = 32_000
IMAGE_SIZE = (178, 218)
BATCH_SIZE = 128
GEN_NOISE_SHAPE = (6, 5, 8)
PREDICT_COUNT = 9
GEN_LR = 4e-6
GEN_BETA_1 = 0.5
DISC_LR = 1e-6
DISC_BETA_1 = 0.9
GEN_RELU_ALPHA = 0.2
DISC_RELU_ALPHA = 0.3
EPOCHS = 50
DISC_LABEL_SMOOTHING = 0.25
PLOTS_DPI = 150

In [5]:
def generator_model():
    weight_init = TruncatedNormal(mean = 0.0, stddev = 0.02)
    model = Sequential(name='Generator')
    model.add(Flatten(input_shape=GEN_NOISE_SHAPE))
    model.add(Dense(6 * 5 * 512, use_bias = False, kernel_initializer = weight_init, 
                activation = LeakyReLU(GEN_RELU_ALPHA), name = 'Gen_Dense'))
    model.add(Reshape((6, 5, 512), name = 'Gen_Reshape'))
    model.add(SpatialDropout2D(0.2, name = 'Gen_SD_1'))
    model.add(Conv2DTranspose(512, (3, 3), padding='same', activation=LeakyReLU(GEN_RELU_ALPHA), use_bias = False,
                               kernel_initializer=weight_init, name='Gen_Conv_T_1'))
    model.add(Conv2DTranspose(256, (3, 3), padding = 'same', strides = (2, 2), use_bias = False,
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_2'))
    model.add(BatchNormalization(name = 'Gen_BN_1'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_1'))
    model.add(Conv2DTranspose(128, (4, 4), padding = 'same', activation = LeakyReLU(GEN_RELU_ALPHA), use_bias = False, 
                               kernel_initializer = weight_init, name='Gen_Conv_T_3'))
    model.add(Conv2DTranspose(64, (4, 4), padding = 'same', strides = (2, 2), use_bias = False, 
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_4'))
    model.add(BatchNormalization(name = 'Gen_BN_2'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_2'))
    model.add(SpatialDropout2D(0.15, name = 'Gen_SD_3'))
    model.add(Conv2DTranspose(8, (6, 6), padding = 'same', strides = (2, 2), use_bias = False, 
                               kernel_initializer = weight_init, name = 'Gen_Conv_T_6'))
    model.add(BatchNormalization(name = 'Gen_BN_4'))
    model.add(LeakyReLU(GEN_RELU_ALPHA, name = 'Gen_LR_4'))
    model.add(SpatialDropout2D(0.15, name = 'Gen_SD_4'))
    model.add(Conv2DTranspose(8, (7, 7), padding = 'same', activation = LeakyReLU(GEN_RELU_ALPHA), use_bias = False,
                               kernel_initializer = weight_init, strides = (2, 2), name = 'Gen_Conv_T_7'))
    model.add(Conv2DTranspose(3, (5, 5), padding = 'same', kernel_initializer = weight_init, use_bias = False,
                               activation = 'tanh', name = 'Gen_Conv_T_8'))
    return model
    
generator = generator_model()
generator.summary()

Model: "Generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 240)               0         
_________________________________________________________________
Gen_Dense (Dense)            (None, 15360)             3686400   
_________________________________________________________________
Gen_Reshape (Reshape)        (None, 6, 5, 512)         0         
_________________________________________________________________
Gen_SD_1 (SpatialDropout2D)  (None, 6, 5, 512)         0         
_________________________________________________________________
Gen_Conv_T_1 (Conv2DTranspos (None, 6, 5, 512)         2359296   
_________________________________________________________________
Gen_Conv_T_2 (Conv2DTranspos (None, 12, 10, 256)       1179648   
_________________________________________________________________
Gen_BN_1 (BatchNormalization (None, 12, 10, 256)       10