In [15]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU,BatchNormalization, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
import os

In [16]:
def build_generator():
    model = Sequential([
        Dense(256, input_shape=(100,)),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(28 * 28 * 1, activation='tanh'),  # This is the final Dense layer which will have image size of 28 by 28 and 1 channel representing grayscale
        Reshape((28, 28, 1))
    ])

    return model

In [17]:
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28,28,1)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])

    return model

In [18]:
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

discriminator.trainable = False

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(**kwargs)


In [19]:
gan_model = Sequential([generator, discriminator])
gan_model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

In [20]:
# this is data with labels, since we don't require labels nor testing data we will ignore them with a _.
(X_train, _), (_, _) = mnist.load_data()

# Normalize data to be between -1 and 1 (matches our output range [-1, 1] of the 'tanh' activation)
X_train = X_train / 127.5 - 1.0

In [21]:
X_train.shape

(60000, 28, 28)

In [22]:
# Reshape the data to add a channel dimension (since we are using Conv2D layers)
X_train = np.expand_dims(X_train, axis=-1)

In [24]:
# Channel dimension is needed for our discriminator
X_train.shape

(60000, 28, 28, 1)

In [27]:
def save_imgs(epoch, generator):
    noise = np.random.normal(0, 1, (25, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images to [0, 1]
    
    fig, axs = plt.subplots(5, 5, figsize=(5, 5))
    count = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

In [29]:
def train_gan(epochs, batch_size=128, save_interval=50, start_epoch=0):
    valid = np.ones((batch_size, 1))    # Real Label
    fake = np.zeros((batch_size, 1))    # Fake Label

    if not os.path.exists('gan_checkpoints'):
        os.makedirs('gan_checkpoints')

    for epoch in range(start_epoch, epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]

        # Generate fake images using our generator, we make 128 images each with 100 dimensions for noise
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        # Train the discriminator on batches of real and fake images
        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan_model.train_on_batch(noise, valid)

        if epoch % save_interval == 0:
            print(f"{epoch} [D_loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {g_loss}]]")
            save_imgs(epoch, generator)
            # Save weights
            generator.save_weights('gan_checkpoints/generator_weights_epoch_{epoch}.h5')
            discriminator.save_weights('gan_checkpoints/discriminator_weights_epoch_{epoch}.h5')


In [14]:
train_gan(epochs=10000, batch_size=64, save_interval=1000)

[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step  




0 [D_loss: 0.4981812834739685, acc.: 80.859375%] [G loss: [array(0.57944167, dtype=float32), array(0.57944167, dtype=float32), array(0.6171875, dtype=float32)]]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 109ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

KeyboardInterrupt: 