In [28]:
import tensorflow as tf

from keras.layers import Input, Dense, LeakyReLU, Dropout, BatchNormalization
from keras.models import Model
from keras.optimizers import SGD, Adam
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [29]:
# load the data

mnist = tf.keras.datasets.mnist

(xtrain, ytrain) , (xtest, ytest) = mnist.load_data()
# center data between -1 and 1
xtrain, xtest = xtrain / 255.0 * 2 - 1, xtest / 255.0 * 2 - 1

In [30]:
# Flatten data

N, H, W = xtrain.shape
D = H * W
xtrain = xtrain.reshape(-1, D)
xtest = xtest.reshape(-1, D)

In [31]:
latent_dim = 100

In [32]:
# Build Generator
# latent dim is a hyperparameter

def build_gen(latent_dim):
    i = Input(shape=(latent_dim,))
    x = Dense(256, activation=LeakyReLU(alpha=0.2))(i)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(512, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(1024, activation=LeakyReLU(alpha=0.2))(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(D, activation='tanh')(x)
    model = Model(i, x)
    return model

In [33]:
# Build the discriminator
def build_disc(img_size):
    i = Input(shape=(img_size,))
    x = Dense(512, activation=LeakyReLU(alpha=0.2))(i)
    x = Dense(256, activation=LeakyReLU(alpha=0.2))(x)
    x = Dense(1, activation="sigmoid")(x)
    model = Model(i, x)
    return model

In [34]:
# compile the models

# making the discriminator
disc = build_disc(D)
disc.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5),
    metrics=["accuracy"]
)
# making the generator
gen = build_gen(latent_dim)

# input to represent the noise sample from the latent space
z = Input(shape=(latent_dim,))

# generate an image from the noise
img = gen(z)

# freeze the discriminators learning
disc.trainable = False

fakepred = disc(img)

combinedModel = Model(z, fakepred)

combinedModel.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5),
)

In [35]:
# training

batch_size = 32
epochs = 30000
sample_period = 200

ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

d_losses = []
g_losses = []



In [36]:
def sample_images(epoch):
    rows, colums = 5, 5
    noise = np.random.rand(rows*colums, latent_dim)
    imgs = gen.predict(noise)

    imgs = 0.5 * imgs + 0.5

    fig, axs = plt.subplots(rows, colums)
    idx = 0
    for i in range(rows):
        for j in range(colums):
            axs[i, j].imshow(imgs[idx].reshape(H, W), cmap='gray')
            axs[i, j].axis('off')
    fig.savefig(f"{epoch}")
    plt.close()

In [37]:
# training 

for epoch in range(epochs):

    idx = np.random.randint(0, xtrain.shape[0], batch_size)
    real_imgs = xtrain[idx]

    noise = np.random.rand(batch_size, latent_dim)
    fake_imgs = gen.predict(noise)

    d_loss_real, d_acc_real = disc.train_on_batch(real_imgs, ones)
    d_loss_fake, d_acc_fake = disc.train_on_batch(fake_imgs, zeros)

    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_acc = 0.5 * (d_acc_real + d_acc_fake)

    # train generator

    noise = np.random.rand(batch_size, latent_dim)
    g_loss = combinedModel.train_on_batch(noise, ones)

    d_losses.append(d_loss)
    g_losses.append(g_loss)

    if epoch % 100 == 0:
        print(f"epoch:{epoch+1}/{epochs}, d_loss:{d_loss:.2f}, d_acc:{d_acc:.2f}, g_loss:{g_loss:.2f} ")

    if epoch % sample_period == 0:
        sample_images(epoch)

epoch:1/30000, d_loss:1.06, d_acc:0.45, g_loss:0.98 
epoch:101/30000, d_loss:0.01, d_acc:1.00, g_loss:4.74 
epoch:201/30000, d_loss:0.03, d_acc:1.00, g_loss:5.32 
epoch:301/30000, d_loss:1.02, d_acc:0.53, g_loss:1.71 
epoch:401/30000, d_loss:0.66, d_acc:0.56, g_loss:1.06 
epoch:501/30000, d_loss:0.65, d_acc:0.52, g_loss:0.70 
epoch:601/30000, d_loss:0.65, d_acc:0.53, g_loss:0.77 
epoch:701/30000, d_loss:0.61, d_acc:0.70, g_loss:0.73 
epoch:801/30000, d_loss:0.61, d_acc:0.75, g_loss:0.79 
epoch:901/30000, d_loss:0.61, d_acc:0.70, g_loss:0.80 
epoch:1001/30000, d_loss:0.63, d_acc:0.69, g_loss:0.83 
epoch:1101/30000, d_loss:0.60, d_acc:0.67, g_loss:0.89 
epoch:1201/30000, d_loss:0.58, d_acc:0.73, g_loss:0.93 
epoch:1301/30000, d_loss:0.58, d_acc:0.77, g_loss:0.98 
epoch:1401/30000, d_loss:0.53, d_acc:0.81, g_loss:0.97 
epoch:1501/30000, d_loss:0.55, d_acc:0.72, g_loss:0.96 
epoch:1601/30000, d_loss:0.53, d_acc:0.78, g_loss:1.04 
epoch:1701/30000, d_loss:0.59, d_acc:0.64, g_loss:0.99 
epoc

KeyboardInterrupt: 