In [17]:
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont

In [18]:
def build_generator():
    model = keras.models.Sequential()
    model.add(keras.layers.Dense(256, input_dim=100))
    model.add(keras.layers.LeakyReLU(alpha=0.2))
    model.add(keras.layers.BatchNormalization(momentum=0.8))
    model.add(keras.layers.Dense(512))
    model.add(keras.layers.LeakyReLU(alpha=0.2))
    model.add(keras.layers.BatchNormalization(momentum=0.8))
    model.add(keras.layers.Dense(1024))
    model.add(keras.layers.LeakyReLU(alpha=0.2))
    model.add(keras.layers.BatchNormalization(momentum=0.8))
    model.add(keras.layers.Dense(28 * 28 * 1, activation='tanh'))
    model.add(keras.layers.Reshape((28, 28, 1)))

    return model

In [19]:
def build_discriminator():
    model = keras.models.Sequential()
    model.add(keras.layers.Flatten(input_shape=(28,28,1)))
    model.add(keras.layers.Dense(512))
    model.add(keras.layers.LeakyReLU(alpha=0.2))
    model.add(keras.layers.Dense(256))
    model.add(keras.layers.LeakyReLU(alpha=0.2))
    model.add(keras.layers.Dense(1, activation='sigmoid'))
    return model

In [20]:
def build_gan(generator, discriminator):
    model = keras.models.Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

In [21]:
def sample_images(generator, epoch, rows=5, cols=5):
    noise = np.random.normal(0, 1 , (rows * cols, 100))
    gen_images = generator.predict(noise)
    gen_images = 0.5 * gen_images + 0.5 # Rescale to [0, 1]

    fig , axs = plt.subplots(rows, cols, figsize=(10,10), sharey=True, sharex=True)
    cnt = 0

    for i in range(rows):
        for j in range(cols):
            axs[i,j].imshow(gen_images[cnt, : , : , 0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1

    plt.show()

In [31]:
def train_gan(gan, generator, discriminator, x_train, epochs=10000, batch_size=128, sample_interval=1000):
    half_batch = int(batch_size / 2)

    for epoch in range(epochs):
        # Train the discriminator with real images
        idx = np.random.randint(0, x_train.shape[0], half_batch)
        real_images = x_train[idx]

        noise = np.random.normal(0, 1, (half_batch, 100))
        generated_images = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_images, np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train the generator via the GAN
        noise = np.random.normal(0, 1, (batch_size, 100))
        valid_y = np.array([1] * batch_size)

        g_loss = gan.train_on_batch(noise, valid_y)

        # Print the progress
        print(f"{epoch + 1}/{epochs} [D loss: {d_loss[0]} | D accuracy: {100 * d_loss[1]}] [G loss: {g_loss}]")

        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
            sample_images(generator, epoch)

In [None]:
def create_digit_image(digit, img_size=(28, 28)):
    img = Image.new('L', img_size, color=0)
    draw = ImageDraw.Draw(img)
    font = ImageFont.load_default()
    text = str(digit)
    text_width, text_height = draw.textsize(text, font)
    position = ((img_size[0] - text_width) // 2, (img_size[1] - text_height) // 2)
    draw.text(position, text, fill=255, font=font)
    return np.array(img)

def generate_digit_dataset():
    dataset = []
    labels = []
    for digit in range(10):
        for _ in range(100): 
            img = create_digit_image(digit)
            dataset.append(img)
            labels.append(digit)
    dataset = np.array(dataset).astype('float32') / 127.5 - 1.0  # Normalize to [-1, 1]
    dataset = np.expand_dims(dataset, axis=-1)
    return dataset

digit_dataset = generate_digit_dataset()


for i in range(10):
    plt.subplot(2, 5, i+1)
    plt.imshow(digit_dataset[i].reshape(28, 28), cmap='gray')
    plt.axis('off')
plt.show()

In [None]:
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer = keras.optimizers.Adam(0.0002, 0.5), metrics=['accuracy'])
discriminator.trainable = False

In [None]:
generator = build_generator()

gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=keras.optimizers.Adam(0.0002, 0.5))

In [None]:
train_gan(gan, generator, discriminator, digit_dataset)