In [1]:
from tensorflow.keras.datasets.mnist import load_data
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Conv2D,LeakyReLU,Dropout,Flatten,Conv2DTranspose,Conv2DTranspose,Reshape
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def discriminator(input_shape):
    model = Sequential()
    model.add(Conv2D(64,(3,3),strides=(2,2),padding="same",input_shape=input_shape))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dropout(.4))
    model.add(Conv2D(64,(3,3),strides=(2,2),padding="same"))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dropout(.4))
    model.add(Flatten())
    model.add(Dense(1,activation="sigmoid"))
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss="binary_crossentropy",optimizer=opt,metrics=["accuracy"])
    return model

In [3]:
def load_real_samples():
    (X,_),(_,_) = load_data()
    X = np.expand_dims(X,axis=-1)
    X = X.astype("float32")
    X = X / 255.0
    return X
    
def generate_real_samples(dataset,batch_size):
    ix = np.random.randint(0,len(dataset),batch_size)
    x = dataset[ix]
    y = np.ones((batch_size,1))
    return x,y

def generate_fake_samples(batch_size):
    x = np.random.rand(28*28*batch_size)
    x = x.reshape((batch_size,28,28,1))
    y = np.zeros((batch_size,1))
    return x,y

def train_discriminator(model,dataset,iteration,steps):
    half_batch = steps//2
    for i in range(iteration):
        x_real,y_real = generate_real_samples(dataset,steps)
        _,real_accuracy  = model.train_on_batch(x_real,y_real)
        
        x_fake,y_fake = generate_fake_samples(steps)
        _,fake_accuracy = model.train_on_batch(x_fake,y_fake)
        print('>%d real=%.0f%% fake=%.0f%%' % (i+1, real_accuracy*100, fake_accuracy*100))


In [5]:
def generator(latent_dim):
    model = Sequential()
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    # upsample to 14x14
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    # upsample to 28x28
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

In [6]:
def generate_latent_points(latent_dim,batch_size):
    latent_points = np.random.randn(batch_size*latent_dim)
    latent_points = latent_points.reshape((batch_size,latent_dim))
    return latent_points

def generate_fake_images_and_labels(g_model,latent_dim,batch_size):
    x_input = generate_latent_points(latent_dim,batch_size)
    x = g_model.predict(x_input)
    y = np.zeros((batch_size,1))
    return x,y

In [7]:
def GanModel(g_model,d_model):
    d_model.trainable = False
    model = Sequential()
    model.add(g_model)
    model.add(d_model)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

In [8]:
def train_gan(gan_model, latent_dim, n_epochs=100, n_batch=256):
    for i in range(n_epochs):
        x_gan = generate_latent_points(latent_dim, n_batch)
        y_gan = np.ones((n_batch, 1))
        gan_model.train_on_batch(x_gan, y_gan)

In [9]:
# create and save a plot of generated images (reversed grayscale)
def save_plot(examples, epoch, n=10):
    # plot images
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    # save plot to file
    filename = 'generated_plot_e%03d.png' % (epoch+1)
    plt.savefig(filename)
    plt.close()
 
# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    # prepare real samples
    X_real, y_real = generate_real_samples(dataset, n_samples)
    # evaluate discriminator on real examples
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    # prepare fake examples
    x_fake, y_fake = generate_fake_images_and_labels(g_model, latent_dim, n_samples)
    # evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate(x_fake, y_fake, verbose=0)
    # summarize discriminator performancetra
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    # save plot
    save_plot(x_fake, epoch)
    # save the generator model tile file
    filename = 'generator_model_%03d.h5' % (epoch + 1)
    g_model.save(filename)

# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs, n_batch):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    # manually enumerate epochs
    for i in range(n_epochs):
        # enumerate batches over the training set
        for j in range(bat_per_epo):
            # get randomly selected 'real' samples
            X_real, y_real = generate_real_samples(dataset, half_batch)
            # generate 'fake' examples
            #generate fake image with generator
            X_fake, y_fake = generate_fake_images_and_labels(g_model, latent_dim, half_batch)
            X, y = np.vstack((X_real, X_fake)), np.vstack((y_real, y_fake))
            d_loss, _ = d_model.train_on_batch(X, y)
            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = np.ones((n_batch, 1))
            # update the generator via the discriminator's error
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            # summarize loss on this batch
            print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, bat_per_epo, d_loss, g_loss))
        # evaluate the model performance, sometimes
        if (i+1) % 10 == 0:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)

In [None]:
latent_dim = 100
epochs = 11
batch = 256
# create the discriminator
d_model = discriminator(input_shape=(28,28,1))
# create the generator
g_model = generator(latent_dim)
# create the gan
gan_model = GanModel(g_model, d_model)
# load image data
dataset = load_real_samples()
# train model
train(g_model, d_model, gan_model, dataset, latent_dim,epochs,batch)

>1, 1/234, d=0.689, g=0.743
>1, 2/234, d=0.685, g=0.759
>1, 3/234, d=0.673, g=0.780
>1, 4/234, d=0.670, g=0.794
>1, 5/234, d=0.660, g=0.810
>1, 6/234, d=0.653, g=0.829
>1, 7/234, d=0.651, g=0.837
>1, 8/234, d=0.648, g=0.844
>1, 9/234, d=0.646, g=0.845
>1, 10/234, d=0.650, g=0.834
>1, 11/234, d=0.649, g=0.815
>1, 12/234, d=0.656, g=0.790
>1, 13/234, d=0.655, g=0.766
>1, 14/234, d=0.664, g=0.747
>1, 15/234, d=0.657, g=0.732
>1, 16/234, d=0.659, g=0.723
>1, 17/234, d=0.659, g=0.714
>1, 18/234, d=0.652, g=0.709
>1, 19/234, d=0.648, g=0.705
>1, 20/234, d=0.638, g=0.702
>1, 21/234, d=0.634, g=0.700
>1, 22/234, d=0.621, g=0.700
>1, 23/234, d=0.615, g=0.699
>1, 24/234, d=0.606, g=0.699
>1, 25/234, d=0.597, g=0.699
>1, 26/234, d=0.589, g=0.699
>1, 27/234, d=0.580, g=0.699
>1, 28/234, d=0.576, g=0.700
>1, 29/234, d=0.560, g=0.701
>1, 30/234, d=0.552, g=0.702
>1, 31/234, d=0.539, g=0.703
>1, 32/234, d=0.527, g=0.704
>1, 33/234, d=0.514, g=0.705
>1, 34/234, d=0.500, g=0.706
>1, 35/234, d=0.497, g=