In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
import matplotlib.pyplot as plt

In [5]:
def discriminator(in_shape = (784,)):
    model  = Sequential()
    model.add(keras.Input(shape=in_shape))
    model.add(keras.layers.Dense(256,activation='leaky_relu'))
    model.add(keras.layers.Dropout(0.3))
    model.add(keras.layers.Dense(64,activation='leaky_relu'))
    model.add(keras.layers.Dropout(0.3))
    model.add(keras.layers.Dense(8,activation='leaky_relu'))
    model.add(keras.layers.Dense(1,activation='sigmoid'))

    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(),
                  metrics = ['accuracy'])

    return model

In [6]:
print(discriminator().summary())

None


In [27]:
def generator(rv_dim = 50):
    model = Sequential()
    model.add(keras.Input(shape=(rv_dim,)))
    model.add(keras.layers.Dense(64,activation='leaky_relu'))
    model.add(keras.layers.Dense(256,activation='leaky_relu'))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Dense(512,activation='leaky_relu'))
    model.add(keras.layers.Dense(784,activation='tanh'))

    model.compile(loss = 'mse',
                  optimizer=Adam(),
                  metrics=['accuracy'])

    return model
    

In [28]:
print(generator().summary())

None


In [11]:
def gan_model(g_model,d_model):
    model = Sequential()
    model.add(g_model)
    model.add(d_model)

    model.compile(loss='binary_crossentropy',
                  optimizer=Adam(),
                  metrics = ['accuracy'])
    return model

In [12]:
print(gan_model(generator(),discriminator()).summary())

None


In [13]:
def load_real_data():
    (X_train,_),(_,_) = load_data()
    X_train = X_train.reshape(-1,784).astype('f8')-127.5
    return X_train/127.5

In [20]:
def generate_real_samples(data,n_samples = 100):
    ix = np.random.randint(0,data.shape[0],n_samples)
    X_train = data[ix]
    y = np.ones(shape=(X_train.shape[0],1))
    return X_train,y

In [21]:
# data = load_real_data()
# generate_real_samples(data,1)

In [22]:
def generate_rv(rv_dim,n_sample=100):
    return np.random.randn(n_sample,rv_dim)

In [23]:
def generate_fake_images(g_model,rv_dim,n_samples = 100):
    rv = generate_rv(rv_dim,n_samples)
    fimg = g_model.predict(rv)
    y = np.zeros(shape = (n_samples,1))
    return fimg,y

In [29]:
def save_fig(g_model,rv_dim,epoch):
    n = 10
    rv = generate_rv(rv_dim,n*n)
    f_imgs = g_model.predict(rv)
    for i in range(n * n):
        plt.subplot(n, n, 1+i)
        plt.axis('off')
        plt.imshow(f_imgs[i].reshape((28,28)), interpolation='nearest',cmap = 'gray')
    filename = f'./GAN_output/generated_plot_e{epoch}.png'
    plt.savefig(filename)
    plt.close()

In [33]:
def train(data,g_model,d_model,gan_model,rv_dim,epochs = 51,batch_size = 256):
    nbatchs = data.shape[0]//batch_size
    half_batch = batch_size//2

    for e in range(epochs):
        for bn in range(nbatchs):
            x_real,y_real = generate_real_samples(data,half_batch)
            x_fake,y_fake = generate_fake_images(g_model,rv_dim,half_batch)

            d_model.trainable = True

            r_loss,_ = d_model.train_on_batch(x_real,y_real)
            f_loss,_ = d_model.train_on_batch(x_fake,y_fake)

            d_loss = 0.5*(r_loss+f_loss)

            d_model.trainable = False

            x_rv = generate_rv(rv_dim,batch_size)
            y = np.ones(shape=(x_rv.shape[0],1))

            g_loss,_ = gan_model.train_on_batch(x_rv,y)

        print(f'Epoch: {e+1}, d_loss: {d_loss}, g_loss: {g_loss}')
        if e%10 == 0:
            save_fig(g_model,rv_dim,e)

In [34]:
keras.utils.disable_interactive_logging()
rv_dim = 50
g_model = generator(rv_dim)
d_model = discriminator()
gan = gan_model(g_model,d_model)
data = load_real_data()
train(data,g_model,d_model,gan,rv_dim)
keras.utils.enable_interactive_logging()

Epoch: 1, d_loss: 0.16273373365402222, g_loss: 3.8132822513580322
Epoch: 2, d_loss: 0.19393290579319, g_loss: 4.101649284362793
Epoch: 3, d_loss: 0.23413152992725372, g_loss: 3.913792371749878
Epoch: 4, d_loss: 0.23362623155117035, g_loss: 3.821084499359131
Epoch: 5, d_loss: 0.24170395731925964, g_loss: 3.806110143661499
Epoch: 6, d_loss: 0.22938503324985504, g_loss: 3.8057026863098145
Epoch: 7, d_loss: 0.23645555973052979, g_loss: 3.7547895908355713
Epoch: 8, d_loss: 0.2443949282169342, g_loss: 3.624877452850342
Epoch: 9, d_loss: 0.25534456968307495, g_loss: 3.5069990158081055
Epoch: 10, d_loss: 0.2648019790649414, g_loss: 3.4007954597473145
Epoch: 11, d_loss: 0.27410632371902466, g_loss: 3.2989392280578613
Epoch: 12, d_loss: 0.28568947315216064, g_loss: 3.194089651107788
Epoch: 13, d_loss: 0.29776066541671753, g_loss: 3.077631950378418
Epoch: 14, d_loss: 0.3095737397670746, g_loss: 2.977318286895752
Epoch: 15, d_loss: 0.32097727060317993, g_loss: 2.886955976486206
Epoch: 16, d_loss: 