In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (Input, Dense, Reshape, Flatten, LeakyReLU,
                                     Conv2D, UpSampling2D, Dropout, LayerNormalization,
                                     multiply)
from tensorflow.keras.optimizers import Adam

In [2]:
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)  # (28,28,1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [3]:
def build_mapping(latent_dim=100, dlatent_dim=100):
    z = Input(shape=(latent_dim,))
    x = Dense(128, activation="relu")(z)
    x = Dense(dlatent_dim)(x)
    return Model(z, x, name="MappingNetwork")

In [4]:
def build_generator(latent_dim=100, dlatent_dim=100):
    style_input = Input(shape=(dlatent_dim,))
    x = Dense(7*7*128)(style_input)
    x = Reshape((7,7,128))(x)

    # Style modulation (simplified)
    for filters in [128, 64]:
        x = UpSampling2D()(x)
        x = Conv2D(filters, kernel_size=3, padding="same")(x)
        x = LayerNormalization()(x)
        x = LeakyReLU(0.2)(x)
        style = Dense(filters, activation="linear")(style_input)
        style = Reshape((1,1,filters))(style)
        x = multiply([x, style])

    img = Conv2D(1, kernel_size=3, padding="same", activation="tanh")(x)
    return Model(style_input, img, name="Generator")

In [5]:
def build_discriminator(img_shape=(28,28,1)):
    img = Input(shape=img_shape)
    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = LeakyReLU(0.2)(x)
    x = Dropout(0.25)(x)
    x = Flatten()(x)
    validity = Dense(1, activation="sigmoid")(x)
    return Model(img, validity, name="Discriminator")

In [6]:
latent_dim = 100
dlatent_dim = 100
optimizer = Adam(0.0002, 0.5)

mapping = build_mapping(latent_dim, dlatent_dim)
generator = build_generator(latent_dim, dlatent_dim)
discriminator = build_discriminator()

discriminator.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])

In [7]:
z = Input(shape=(latent_dim,))
w = mapping(z)
img = generator(w)
discriminator.trainable = False
validity = discriminator(img)
gan = Model(z, validity)
gan.compile(loss="binary_crossentropy", optimizer=optimizer)

In [8]:
def train(epochs=10000, batch_size=64, save_interval=1000):
    d_losses, g_losses = [], []

    valid = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))

    for epoch in range(1, epochs+1):
        # ---------------------
        # Train Discriminator
        # ---------------------
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_imgs = X_train[idx]

        noise = np.random.normal(0,1,(batch_size, latent_dim))
        w = mapping.predict(noise, verbose=0)
        gen_imgs = generator.predict(w, verbose=0)

        d_loss_real = discriminator.train_on_batch(real_imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # ---------------------
        # Train Generator
        # ---------------------
        noise = np.random.normal(0,1,(batch_size, latent_dim))
        g_loss = gan.train_on_batch(noise, valid)

        # Save losses
        d_losses.append(d_loss[0])
        g_losses.append(g_loss)

        # Print progress
        if epoch % 100 == 0:
            print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc: {100*d_loss[1]:.2f}] [G loss: {g_loss:.4f}]")

        # Save images
        if epoch % save_interval == 0:
            save_imgs(epoch)

    return d_losses, g_losses

In [9]:
def save_imgs(epoch, examples=25):
    noise = np.random.normal(0,1,(examples, latent_dim))
    w = mapping.predict(noise, verbose=0)
    gen_imgs = generator.predict(w, verbose=0)

    gen_imgs = 0.5 * gen_imgs + 0.5

    plt.figure(figsize=(5,5))
    for i in range(examples):
        plt.subplot(5,5,i+1)
        plt.imshow(gen_imgs[i,:,:,0], cmap="gray")
        plt.axis("off")
    plt.suptitle(f"Epoch {epoch}")
    plt.show()


In [None]:
d_losses, g_losses = train(epochs=5000, batch_size=64, save_interval=1000)

# Plot losses
plt.figure(figsize=(8,6))
plt.plot(d_losses, label="Discriminator Loss")
plt.plot(g_losses, label="Generator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training Losses")
plt.show()



100 [D loss: 1.0395, acc: 12.07] [G loss: 0.3167]
200 [D loss: 1.0980, acc: 11.91] [G loss: 0.2734]
