In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# --- 1. define the GAN model ---
IMG_ROWS = 28
IMG_COLS = 28
CHANNELS = 1
IMG_SHAPE = (IMG_ROWS, IMG_COLS, CHANNELS)
NOISE_DIM = 100 # generator input noise dimension

In [None]:
# --- 2. Generator ---
#turn a 100-dim noise vector into a 28x28x1 image
def build_generator():
    model = Sequential(name="Generator")
    
    # start from a dense layer and reshape it to a small feature map
    model.add(Dense(7 * 7 * 256, input_dim=NOISE_DIM))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 256))) # reshape to 7x7x256 feature map

    # use Conv2DTranspose layers to upsample the feature map step by step
    # 7x7 -> 14x14
    model.add(Conv2DTranspose(128, kernel_size=5, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    # 14x14 -> 28x28
    model.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    # then output layer, use 'tanh' activation to get output in range [-1, 1]
    model.add(Conv2D(CHANNELS, kernel_size=5, padding='same', activation='tanh'))
    
    return model

In [None]:

# --- 3. Discriminator ---
# used to classify real vs fake images
def build_discriminator():
    model = Sequential(name="Discriminator")
    
    # downsample the input image step by step using Conv2D layers
    model.add(Conv2D(64, kernel_size=5, strides=2, padding='same', input_shape=IMG_SHAPE))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(128, kernel_size=5, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    # flatten and output a single probability with sigmoid activation
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid')) 
    
    return model

In [None]:
# --- 4. gan model ---
def build_gan(generator, discriminator):
    # set discriminator as non-trainable when training the GAN model
    discriminator.trainable = False
    
    model = Sequential(name="GAN")
    model.add(generator)
    model.add(discriminator)
    
    return model

In [None]:
# --- 5. main loop ---
import os
if not os.path.exists('gan_images'):
    print("Creating directory 'gan_images'...")
    os.makedirs('gan_images')

# build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

generator = build_generator()

# when training the GAN model, we want to freeze the discriminator
discriminator.trainable = False 
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# prepare the training data
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
# normalize to [-1, 1] because the generator uses 'tanh' activation
X_train = X_train / 127.5 - 1.0
X_train = np.expand_dims(X_train, axis=3)

# define training parameters
epochs = 12000
batch_size = 64
sample_interval = 1000 # generate and save images every 1000 epochs

# real and fake labels
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))

In [None]:

# start training
for epoch in range(epochs):
    
    # ---------------------
    #  train the discriminator
    # ---------------------
    
    # 1.get a batch of real images
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_imgs = X_train[idx]
    
    # 2. generate a batch of fake images
    noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
    fake_imgs = generator.predict(noise)
    
    # 3. train the discriminator
    # let the discriminator learn to classify real and fake images separately
    d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    #  train the generator
    # ---------------------
    
    # 4. generate new noise
    noise = np.random.normal(0, 1, (batch_size, NOISE_DIM))
    
    # 5. train the generator
    # use "real" labels to fool the discriminator
    g_loss = gan.train_on_batch(noise, real_labels)
    
    # print progress
    if (epoch + 1) % 100 == 0:
        print(f"{epoch + 1} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

    # every sample_interval epochs, generate and save images
    if (epoch + 1) % sample_interval == 0:
        noise = np.random.normal(0, 1, (16, NOISE_DIM))
        gen_imgs = generator.predict(noise)
        # rescale images to [0, 1]
        gen_imgs = 0.5 * gen_imgs + 0.5
        
        fig, axs = plt.subplots(4, 4)
        count = 0
        for i in range(4):
            for j in range(4):
                axs[i,j].imshow(gen_imgs[count, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                count += 1
        fig.savefig(f"06-gan_images/mnist_{epoch+1}.png")
        plt.close()



100 [D loss: 0.6701, acc.: 64.06%] [G loss: 0.7727]
200 [D loss: 0.6757, acc.: 63.28%] [G loss: 0.7671]
300 [D loss: 0.6885, acc.: 57.81%] [G loss: 0.7588]
400 [D loss: 0.6398, acc.: 65.62%] [G loss: 0.8241]
500 [D loss: 0.6231, acc.: 69.53%] [G loss: 0.8254]
600 [D loss: 0.6574, acc.: 62.50%] [G loss: 0.8102]
700 [D loss: 0.6594, acc.: 61.72%] [G loss: 0.7866]
800 [D loss: 0.6777, acc.: 60.94%] [G loss: 0.7642]
900 [D loss: 0.6944, acc.: 50.00%] [G loss: 0.7495]
1000 [D loss: 0.6851, acc.: 58.59%] [G loss: 0.7779]
1100 [D loss: 0.6677, acc.: 59.38%] [G loss: 0.7474]
1200 [D loss: 0.6859, acc.: 55.47%] [G loss: 0.6959]
1300 [D loss: 0.7002, acc.: 47.66%] [G loss: 0.7372]
1400 [D loss: 0.6770, acc.: 59.38%] [G loss: 0.7411]
1500 [D loss: 0.6822, acc.: 61.72%] [G loss: 0.7789]
1600 [D loss: 0.6889, acc.: 46.88%] [G loss: 0.7189]
1700 [D loss: 0.6885, acc.: 52.34%] [G loss: 0.6840]
1800 [D loss: 0.6839, acc.: 53.91%] [G loss: 0.8264]
1900 [D loss: 0.7036, acc.: 52.34%] [G loss: 0.6928]
20

KeyboardInterrupt: 