In [1]:
import numpy as np
import utils as ut
import log
import matplotlib.pyplot as plt
from keras.datasets import fashion_mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, LeakyReLU, BatchNormalization, Input
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam

Init Plugin
Init Graph Optimizer
Init Kernel


In [2]:
logger = log.get_logger(__name__)

In [7]:
@ut.timer
def load_data():
    # Load the Fashion-MNIST dataset
    (X_train, _), (_, _) = fashion_mnist.load_data()
    X_train = X_train / 127.5 - 1.0 # Normalize the images to [-1, 1]
    return np.expand_dims(X_train, axis=3)


# Generator
@ut.timer
def create_generator():
    generator = Sequential([
        Dense(128, input_dim=100),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(256),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        Dense(28 * 28 * 1, activation='tanh'),
        Reshape((28, 28, 1))
    ])

    return generator


# Discriminator
@ut.timer
def create_discriminator():
    discriminator = Sequential([
        Input(shape=(28, 28, 1)),
        Flatten(),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])

    return discriminator


# GAN
@ut.timer
def create_gan(discriminator, generator):
    discriminator.trainable = False
    gan_input = Input(shape=(100,))
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(gan_input, gan_output)
    return gan


@ut.timer
def compile_models(discriminator, generator):
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
    gan = create_gan(discriminator, generator)
    gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return gan


@ut.timer
def sample_images(generator, epoch, img_out_path, datetime):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images from [-1, 1] to [0, 1]
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    fig.savefig(f"{img_out_path}/{datetime}_epoch_{epoch}.png")
    plt.close()


# Train GAN
@ut.timer
def train_gan(X_train, generator, discriminator, gan, epochs, batch_size, sample_interval, img_out_path, datetime):
    for epoch in range(epochs):
        # Train discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        real_y = np.ones((batch_size, 1))
        fake_y = np.zeros((batch_size, 1))
        
        d_loss_real = discriminator.train_on_batch(real_imgs, real_y)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake_y)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        real_y = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(noise, real_y)
        
        if epoch % sample_interval == 0:
            logger.info(f"Epoch {epoch}, D-Loss: {d_loss[0]}, G-Loss: {g_loss}")
            sample_images(generator, epoch, img_out_path, datetime)
    return generator, discriminator, gan


def save_models(generator, discriminator, gan, model_out_path, datetime):
    generator.save(f"{model_out_path}{datetime}_generator.h5")
    discriminator.save(f"{model_out_path}{datetime}_discriminator.h5")
    gan.save(f"{model_out_path}{datetime}_gan.h5")
    logger.info("Models saved successfully")

In [8]:
# load config
conf = ut.load_config()
X_train = load_data()
generator = create_generator()
discriminator = create_discriminator()
gan = compile_models(discriminator, generator)
dt = ut.get_datetime()
# Set parameters and train the GAN
generator, discriminator, gan = train_gan(X_train, generator, discriminator, gan, conf.a3.gan_params.epochs, conf.a3.gan_params.batch_size, conf.a3.gan_params.sample_interval, conf.a3.paths.training_inspection_plots, dt)
# Save models
save_models(generator, discriminator, gan, conf.a3.paths.model, dt)

06-May-23 14:36:55 - INFO - Starting 'load_config'.
06-May-23 14:36:56 - INFO - Finished 'load_config' in 0.0515 secs.
06-May-23 14:36:56 - INFO - Starting 'load_data'.
06-May-23 14:36:56 - INFO - Finished 'load_data' in 0.4083 secs.
06-May-23 14:36:56 - INFO - Starting 'create_generator'.
06-May-23 14:36:56 - INFO - Finished 'create_generator' in 0.0397 secs.
06-May-23 14:36:56 - INFO - Starting 'create_discriminator'.




06-May-23 14:36:56 - INFO - Finished 'create_discriminator' in 0.0141 secs.
06-May-23 14:36:56 - INFO - Starting 'compile_models'.
06-May-23 14:36:56 - INFO - Starting 'create_gan'.
06-May-23 14:36:56 - INFO - Finished 'create_gan' in 0.0232 secs.
06-May-23 14:36:56 - INFO - Finished 'compile_models' in 0.0295 secs.
06-May-23 14:36:56 - INFO - Starting 'train_gan'.
2023-05-06 14:36:56.547798: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2023-05-06 14:36:56.692833: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.




2023-05-06 14:36:57.041730: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
06-May-23 14:36:57 - INFO - Epoch 0, D-Loss: 0.7250036001205444, G-Loss: 0.657907247543335
06-May-23 14:36:57 - INFO - Starting 'sample_images'.
2023-05-06 14:36:57.261960: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
06-May-23 14:36:57 - INFO - Finished 'sample_images' in 0.3681 secs.
06-May-23 14:36:57 - INFO - Epoch 4, D-Loss: 0.6821849551051855, G-Loss: 0.16276264190673828
06-May-23 14:36:57 - INFO - Starting 'sample_images'.
06-May-23 14:36:57 - INFO - Finished 'sample_images' in 0.1584 secs.
06-May-23 14:36:58 - INFO - Epoch 8, D-Loss: 0.6865379419177771, G-Loss: 0.41157853603363037
06-May-23 14:36:58 - INFO - Starting 'sample_images'.
06-May-23 14:36:58 - INFO - Finished 'sample_images' in 0.2436 secs.
06-May-23 14:36:58 - INFO - Epoch 12, D-Loss