In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

In [2]:
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0
X_train = np.expand_dims(X_train, axis=-1)

image_shape = (28, 28, 1)
latent_dim = 100

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [3]:
def build_generator(latent_dim):
    model = models.Sequential()
    model.add(layers.Dense(7 * 7 * 256, input_dim=latent_dim))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Reshape((7, 7, 256)))
    model.add(layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Conv2DTranspose(1, kernel_size=7, strides=1, padding='same', activation='tanh'))

    return model

In [4]:
def build_discriminator(image_shape):
    model = models.Sequential()

    model.add(layers.Conv2D(64, kernel_size=4, strides=2, padding='same', input_shape=image_shape))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Conv2D(128, kernel_size=4, strides=2, padding='same'))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Flatten())
    model.add(layers.Dense(1, activation='sigmoid'))  # Output: real or fake

    return model

In [5]:
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = models.Sequential()
    model.add(generator)
    model.add(discriminator)
    
    return model

In [6]:
discriminator = build_discriminator(image_shape)
discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy', metrics=['accuracy'])
generator = build_generator(latent_dim)

gan = build_gan(generator, discriminator)
gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [7]:
def train_gan(epochs, batch_size, latent_dim, dataset):
    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, dataset.shape[0], batch_size)
        real_images = dataset[idx]

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        fake_images = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_images, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, real_labels)  # We use real labels because we want to fool the discriminator
        if epochs % 10 == 0:
            print(f"{epoch}/{epochs} [D loss: {d_loss[0]} | D accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")

In [8]:
epochs = 2000
batch_size = 64

train_gan(epochs, batch_size, latent_dim, X_train)

[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step




0/2000 [D loss: 0.6944079399108887 | D accuracy: 33.59375] [G loss: [array(0.694335, dtype=float32), array(0.694335, dtype=float32), array(0.25, dtype=float32)]]
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 60ms/step
1/2000 [D loss: 0.6944798231124878 | D accuracy: 29.16666865348816] [G loss: [array(0.6948317, dtype=float32), array(0.6948317, dtype=float32), array(0.25, dtype=float32)]]
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step
2/2000 [D loss: 0.6953021883964539 | D accuracy: 25.78125] [G loss: [array(0.6957233, dtype=float32), array(0.6957233, dtype=float32), array(0.234375, dtype=float32)]]
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step
3/2000 [D loss: 0.6961241960525513 | D accuracy: 24.06529039144516] [G loss: [array(0.6966315, dtype=float32), array(0.6966315, dtype=float32), array(0.22460938, dtype=float32)]]
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 55ms/step
4/2000 [D loss: 0.696955204010