Since the critic is much larger than in previous version, it needs to be trained slower to avoid generator training issues. 

In [None]:
from tensorflow.keras import datasets, callbacks, optimizers

from model import WassersteinGANGradientPenalty, build_discriminator_large, build_generator_large
from custom_callbacks import ImageGenerator
from prepare_data import prepare_fashion_mnist_data

import logging
logging.basicConfig(
    level=logging.INFO, format="[%(asctime)s][%(levelname)s] %(message)s"
)


N_Z = 64
GP_WEIGHT = 10
LEARNING_RATE_G = 0.0001
LEARNING_RATE_D = 0.0001
ADAM_BETA_1 = 0.5
EPOCHS = 400
BATCH_SIZE = 512


logging.info("Preparing the data...")
x_train, x_test = prepare_fashion_mnist_data()

logging.info("Build generator...")
generator = build_generator_large(N_Z)

logging.info("Build discriminator...")
discriminator = build_discriminator_large()

logging.info("Build GAN...")
gan = WassersteinGANGradientPenalty(
    critic=discriminator, generator=generator, latent_dim=N_Z, gp_weight=GP_WEIGHT)

gan.compile(
    c_optimizer=optimizers.RMSprop(
        learning_rate=LEARNING_RATE_D
    ),
    g_optimizer=optimizers.Adam(
        learning_rate=LEARNING_RATE_G, beta_1=ADAM_BETA_1
    ),
)


model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath="./checkpoint/checkpoint.ckpt",
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)
tensorboard_callback = callbacks.TensorBoard(log_dir="./logs")

logging.info("Fit GAN...")
gan.fit(
    x_train,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=[
        model_checkpoint_callback,
        tensorboard_callback,
        ImageGenerator(num_img=10, latent_dim=N_Z),
    ],
)
