In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, Input
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import plot_model
import matplotlib.pyplot as plt
import os

# Load and preprocess MNIST
(x_train, y_train), (_, _) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5  # Normalize to [-1, 1]
x_train = np.expand_dims(x_train, axis=-1)
num_classes = 10
img_shape = x_train.shape[1:]

# Generator
def build_generator(latent_dim, num_classes):
    label = Input(shape=(1,), dtype='int32')
    noise = Input(shape=(latent_dim,))
    label_embedding = layers.Embedding(num_classes, latent_dim)(label)
    label_embedding = layers.Flatten()(label_embedding)
    model_input = layers.multiply([noise, label_embedding])
    x = layers.Dense(128)(model_input)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(256)(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(512)(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(np.prod(img_shape), activation='tanh')(x)
    img = layers.Reshape(img_shape)(x)
    return Model([noise, label], img, name="generator")

# Discriminator
def build_discriminator(img_shape, num_classes):
    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')
    label_embedding = layers.Embedding(num_classes, np.prod(img_shape))(label)
    label_embedding = layers.Flatten()(label_embedding)
    flat_img = layers.Flatten()(img)
    model_input = layers.multiply([flat_img, label_embedding])
    x = layers.Dense(512)(model_input)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(256)(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Dense(1, activation='sigmoid')(x)
    return Model([img, label], x, name="discriminator")

latent_dim = 100
generator = build_generator(latent_dim, num_classes)
discriminator = build_discriminator(img_shape, num_classes)
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Save model architectures
os.makedirs("visualizations", exist_ok=True)
plot_model(generator, to_file="visualizations/generator_architecture.png", show_shapes=True, show_layer_names=True)
plot_model(discriminator, to_file="visualizations/discriminator_architecture.png", show_shapes=True, show_layer_names=True)

# Combined model
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,), dtype='int32')
img = generator([noise, label])
discriminator.trainable = False
valid = discriminator([img, label])
combined = Model([noise, label], valid)
combined.compile(loss='binary_crossentropy', optimizer='adam')

# Visualization function for generated images
def save_generated_images(epoch, generator, latent_dim, num_classes):
    r, c = 2, 5  # 2 rows, 5 columns for 10 digits
    noise = np.random.normal(0, 1, (num_classes, latent_dim))
    labels = np.arange(0, num_classes).reshape(-1, 1)
    gen_imgs = generator.predict([noise, labels], verbose=0)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale to [0, 1]

    fig, axs = plt.subplots(r, c, figsize=(c*2, r*2))
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(f"Digit: {cnt}")
            axs[i, j].axis('off')
            cnt += 1
    fig.suptitle(f"Generated digits at epoch {epoch}")
    fig.tight_layout()
    plt.savefig(f"visualizations/generated_{epoch}.png")
    plt.close()

# Training loop
epochs = 10000
batch_size = 64
half_batch = batch_size // 2

for epoch in range(epochs):
    # Train Discriminator
    idx = np.random.randint(0, x_train.shape[0], half_batch)
    imgs, labels = x_train[idx], y_train[idx]
    noise = np.random.normal(0, 1, (half_batch, latent_dim))
    gen_labels = np.random.randint(0, num_classes, half_batch).reshape(-1, 1)
    gen_imgs = generator.predict([noise, gen_labels], verbose=0)
    d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    sampled_labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
    valid_y = np.ones((batch_size, 1))
    g_loss = combined.train_on_batch([noise, sampled_labels], valid_y)

    # Print and visualize progress
    if epoch % 1000 == 0:
        print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")
        save_generated_images(epoch, generator, latent_dim, num_classes)

You must install pydot (`pip install pydot`) for `plot_model` to work.
You must install pydot (`pip install pydot`) for `plot_model` to work.




0 [D loss: 0.6867, acc.: 70.31%] [G loss: 0.6931]
1000 [D loss: 0.7403, acc.: 43.16%] [G loss: 0.6006]
2000 [D loss: 0.7453, acc.: 43.09%] [G loss: 0.5928]
3000 [D loss: 0.7471, acc.: 43.11%] [G loss: 0.5900]
4000 [D loss: 0.7480, acc.: 43.12%] [G loss: 0.5885]
5000 [D loss: 0.7485, acc.: 43.12%] [G loss: 0.5876]
6000 [D loss: 0.7490, acc.: 43.14%] [G loss: 0.5870]
7000 [D loss: 0.7492, acc.: 43.13%] [G loss: 0.5865]
8000 [D loss: 0.7494, acc.: 43.10%] [G loss: 0.5862]
9000 [D loss: 0.7496, acc.: 43.09%] [G loss: 0.5859]
