In [8]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, LeakyReLU
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import datetime
import matplotlib.pyplot as plt
import sys
import numpy as np

img_shape = (28, 28, 1)
# Load and preprocess the dataset
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# Define the generator network
def build_generator():

    noise_shape = (100,)

    model = Sequential()

    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))

    noise = Input(shape=noise_shape)
    img = model(noise)

    return Model(noise, img)

# Define the discriminator network
def build_discriminator():

    img_shape = (28, 28, 1)

    model = Sequential()

    model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(256, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))

    img1 = Input(shape=img_shape)
    img2 = Input(shape=img_shape)

    features1 = model(img1)
    features2 = model(img2)

    return Model([img1, img2], [features1, features2])

# Build and compile the GAN
def build_gan(generator, discriminator):

    optimizer = Adam(0.0002, 0.5)

    discriminator.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
                          optimizer=optimizer,
                          metrics=['accuracy'])

    noise = Input(shape=(100,))
    img1 = generator(noise)
    img2 = generator(noise)

    discriminator.trainable = False

    [features1, features2] = discriminator([img1, img2])

    combined = Model(noise, [features1, features2])
    combined.compile(loss=['binary_crossentropy', 'binary_crossentropy'],
                     optimizer=optimizer)

    return combined

# Define the training function
def train_gan(generator, discriminator, combined, epochs, batch_size, sample_interval):

    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train / 127.5 - 1.
    X_train = np.expand_dims(X_train, axis=3)
    # Rescale images to 0-1 range
    X_train = (X_train + 1.) / 2.

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in range(epochs):

        # Select a random batch of images
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        # Generate a batch of new images
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)

        # Train the discriminator
        d_loss_real = discriminator.train_on_batch([imgs, imgs], [valid, fake])
        d_loss_fake = discriminator.train_on_batch([gen_imgs, imgs], [fake, valid])
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train the generator
        g_loss = combined.train_on_batch(noise, [valid, valid])

        # Print the progress
        print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0]))

        # Save generated images at sample interval
        if epoch % sample_interval == 0:
            save_imgs(generator, epoch)

# Define a function to save generated images
def save_imgs(generator, epoch):
    r, c = 2, 2
    noise = np.random.normal(0, 1, (r*c, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("gan_mnist_%d.png" % epoch)
    plt.close()

# Build and train the GAN
generator = build_generator()
discriminator = build_discriminator()
combined = build_gan(generator, discriminator)
train_gan(generator, discriminator, combined, epochs=100, batch_size=32, sample_interval=200)



StagingError: ignored