In [1]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers.activation import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
from keras.models import load_model

In [2]:
#Define input image dimensions
#Large images take too much time and resources.
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)

In [3]:
def train(epochs, batch_size=128, save_interval=50):

    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()

    # Convert to float and Rescale -1 to 1 (Can also do 0 to 1)
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5

#Add channels dimension. As the input to our gen and discr. has a shape 28x28x1.
    X_train = np.expand_dims(X_train, axis=3) 

    half_batch = int(batch_size / 2)

    for epoch in range(epochs):

        idx = np.random.randint(0, X_train.shape[0], half_batch)
        imgs = X_train[idx]
        noise = np.random.normal(0, 1, (half_batch, 100))
        gen_imgs = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 

        noise = np.random.normal(0, 1, (batch_size, 100)) 
        valid_y = np.array([1] * batch_size) #Creates an array of all ones of size=batch size
        g_loss = combined.train_on_batch(noise, valid_y)


        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # If at save interval => save generated image samples
        if epoch % save_interval == 0:
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("resume3/mnist_%d.png" % epoch)
    plt.close()

In [4]:
# load the discriminator 
discriminator = load_model("discriminator_model.h5")
optimizer_weights = np.load("discriminator_optimizer.npy", allow_pickle=True)
discriminator.optimizer.set_weights(optimizer_weights)

In [5]:
generator = load_model("generator_model.h5")

In [6]:
combined = load_model("gan_model.h5")

In [8]:
train(epochs=2000, batch_size=32, save_interval=100)

0 [D loss: 0.000905, acc.: 100.00%] [G loss: 0.000713]
1 [D loss: 0.001298, acc.: 100.00%] [G loss: 0.000696]
2 [D loss: 0.000609, acc.: 100.00%] [G loss: 0.000696]
3 [D loss: 0.000620, acc.: 100.00%] [G loss: 0.000708]
4 [D loss: 0.000389, acc.: 100.00%] [G loss: 0.000695]
5 [D loss: 0.000101, acc.: 100.00%] [G loss: 0.000696]
6 [D loss: 0.001250, acc.: 100.00%] [G loss: 0.000726]
7 [D loss: 0.004850, acc.: 100.00%] [G loss: 0.000721]
8 [D loss: 0.000576, acc.: 100.00%] [G loss: 0.000669]
9 [D loss: 0.000918, acc.: 100.00%] [G loss: 0.000688]
10 [D loss: 0.004434, acc.: 100.00%] [G loss: 0.000711]
11 [D loss: 0.000119, acc.: 100.00%] [G loss: 0.000713]
12 [D loss: 0.000808, acc.: 100.00%] [G loss: 0.000676]
13 [D loss: 0.018741, acc.: 100.00%] [G loss: 0.000689]
14 [D loss: 0.001162, acc.: 100.00%] [G loss: 0.000687]
15 [D loss: 0.001167, acc.: 100.00%] [G loss: 0.000696]
16 [D loss: 0.000943, acc.: 100.00%] [G loss: 0.000690]
17 [D loss: 0.007916, acc.: 100.00%] [G loss: 0.000710]
18

In [9]:
# save the generator, discriminator and gan model
discriminator.trainable = True
discriminator.save("discriminator_model.h5")
discriminator.save("discriminator_model")

optimizer_weights = discriminator.optimizer.get_weights()
np.save("discriminator_optimizer.npy", optimizer_weights)

INFO:tensorflow:Assets written to: discriminator_model\assets


  arr = np.asanyarray(arr)


In [10]:
generator.save("generator_model.h5")
generator.save("generator_model")

gen_optimizer_weights = generator.optimizer.get_weights()
np.save("generator_optimizer.npy", gen_optimizer_weights)



INFO:tensorflow:Assets written to: generator_model\assets


INFO:tensorflow:Assets written to: generator_model\assets


In [11]:
discriminator.trainable = False
combined.save("gan_model.h5")
combined.save("gan_model")

gan_optimizer_weights = combined.optimizer.get_weights()
np.save("gan_optimizer.npy", gan_optimizer_weights)



INFO:tensorflow:Assets written to: gan_model\assets


INFO:tensorflow:Assets written to: gan_model\assets
  arr = np.asanyarray(arr)
