In [None]:
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.data import Dataset
from tensorflow.random import normal
from tensorflow import concat, dtypes, float32

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

zeros = x_train[y_train == 0]

In [None]:
discriminator = Sequential()

discriminator.add(Flatten(input_shape=[28, 28]))

discriminator.add(Dense(150, activation="relu"))
discriminator.add(Dense(100, activation="relu"))

discriminator.add(Dense(1, activation="sigmoid"))

In [None]:
discriminator.compile(loss="binary_crossentropy", optimizer="adam")

In [None]:
coding_size = 100

generator = Sequential()

# like decoder part of an autoencoder
generator.add(Dense(100, activation="relu", input_shape=[coding_size]))
generator.add(Dense(150, activation="relu"))
generator.add(Dense(784, activation="relu"))
generator.add(Reshape([28, 28]))

# generator is not compiled because it is never trained on its own
# it is only trained together with the discriminator (full gan model)

In [None]:
gan = Sequential([generator, discriminator])

# when full gan is assembled the discriminator should not be trained
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="adam")

In [None]:
batch_size = 32

# buffer size is how many slices to load into memory at the same time
dataset_obj = Dataset.from_tensor_slices(zeros).shuffle(buffer_size=1000)
# prefetch allows later elements to be prepared while the current element is being processed
dataset = dataset_obj.batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
generator, discriminator = gan.layers

epochs = 1
n_batch = len(zeros) // batch_size

def train_discriminator(batch):
    discriminator.trainable = True

    noise = normal(shape=[batch_size, coding_size])
    gen_data = generator(noise)
    real_data = dtypes.cast(batch, float32)
    input_data = concat([gen_data, real_data], axis=0)
    y = np.concatenate([np.zeros(shape=(batch_size,1)), np.ones(shape=(batch_size, 1))])

    discriminator.train_on_batch(input_data, y)        


def train_generator():
    discriminator.trainable = False

    noise = normal(shape=[batch_size, coding_size])
    y = np.ones(shape=(batch_size, 1))

    gan.train_on_batch(noise, y)


for epoch in range(epochs):
    print(f"epoch = {epoch + 1}")
    for i, batch in enumerate(dataset, start=1):
        print(f"batch = {i}/{n_batch}")
        train_discriminator(batch) # phase 1
        train_generator() # phase 2
        

In [None]:
noise = normal(shape=[1, coding_size])
plt.imshow(noise)

In [None]:
images = generator(noise)
plt.imshow(images[0])