In [1]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU,BatchNormalization, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
gpus

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print(e)

1 Physical GPUs, 1 Logical GPUs


# Building our GAN Model

In [4]:
def build_generator():
    model = Sequential([
        Dense(256, input_shape=(100,)),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(28 * 28 * 1, activation='tanh'),  # This is the final Dense layer which will have image size of 28 by 28 and 1 channel representing grayscale
        Reshape((28, 28, 1))
    ])

    return model

In [5]:
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28,28,1)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])

    return model

In [6]:
generator = build_generator()
discriminator = build_discriminator()

discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

discriminator.trainable = False

In [7]:
gan_model = Sequential([generator, discriminator])
gan_model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

## Load our Dataset

In [8]:
# this is data with labels, since we don't require labels nor testing data we will ignore them with a _.
(X_train, _), (_, _) = mnist.load_data()

# Normalize data to be between -1 and 1 (matches our output range [-1, 1] of the 'tanh' activation)
X_train = X_train / 127.5 - 1.0

In [9]:
X_train.shape

(60000, 28, 28)

In [10]:
# Reshape the data to add a channel dimension (since we are using Conv2D layers)
X_train = np.expand_dims(X_train, axis=-1)

In [11]:
# Channel dimension is needed for our discriminator
X_train.shape

(60000, 28, 28, 1)

## Function to visualize our generated images

In [12]:
def save_imgs(epoch, generator):
    noise = np.random.normal(0, 1, (25, 100))
    gen_imgs = generator.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images to [0, 1]
    
    fig, axs = plt.subplots(5, 5, figsize=(5, 5))
    count = 0
    for i in range(5):
        for j in range(5):
            axs[i, j].imshow(gen_imgs[count, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            count += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

# Training our GAN Model

In [13]:
def train_gan(epochs, batch_size=128, save_interval=50, start_epoch=0):
    valid = np.ones((batch_size, 1))    # Real Label
    fake = np.zeros((batch_size, 1))    # Fake Label
    if not os.path.exists('gan_checkpoints'):
        os.makedirs('gan_checkpoints')

    for epoch in range(start_epoch, epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        print(f"Epoch: {epoch}")
        real_imgs = X_train[idx]

        # Generate fake images using our generator, we make 128 images each with 100 dimensions for noise
        noise = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(noise)
        
        # Train the discriminator on batches of real and fake images
        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan_model.train_on_batch(noise, valid)

        if epoch % save_interval == 0:
            print(f"{epoch} [D_loss: {d_loss[0]}, acc.: {100 * d_loss[1]}%] [G loss: {g_loss}]]")
            save_imgs(epoch, generator)
            # Save weights
            generator.save_weights(f'gan_checkpoints/generator_weights_epoch_{epoch}.weights.h5')
            discriminator.save_weights(f'gan_checkpoints/discriminator_weights_epoch_{epoch}.weights.h5')


In [14]:
train_gan(epochs=10001, batch_size=64, save_interval=1000)

Epoch: 0
0 [D_loss: 0.5335235595703125, acc.: 50.78125%] [G loss: 0.48786062002182007]]
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Epoch: 20
Epoch: 21
Epoch: 22
Epoch: 23
Epoch: 24
Epoch: 25
Epoch: 26
Epoch: 27
Epoch: 28
Epoch: 29
Epoch: 30
Epoch: 31
Epoch: 32
Epoch: 33
Epoch: 34
Epoch: 35
Epoch: 36
Epoch: 37
Epoch: 38
Epoch: 39
Epoch: 40
Epoch: 41
Epoch: 42
Epoch: 43
Epoch: 44
Epoch: 45
Epoch: 46
Epoch: 47
Epoch: 48
Epoch: 49
Epoch: 50
Epoch: 51
Epoch: 52
Epoch: 53
Epoch: 54
Epoch: 55
Epoch: 56
Epoch: 57
Epoch: 58
Epoch: 59
Epoch: 60
Epoch: 61
Epoch: 62
Epoch: 63
Epoch: 64
Epoch: 65
Epoch: 66
Epoch: 67
Epoch: 68
Epoch: 69
Epoch: 70
Epoch: 71
Epoch: 72
Epoch: 73
Epoch: 74
Epoch: 75
Epoch: 76
Epoch: 77
Epoch: 78
Epoch: 79
Epoch: 80
Epoch: 81
Epoch: 82
Epoch: 83
Epoch: 84
Epoch: 85
Epoch: 86
Epoch: 87
Epoch: 88
Epoch: 89
Epoch: 90
Epoch: 91
Epoch: 92
E