In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, BatchNormalization, LeakyReLU, Reshape, Conv2DTranspose, Dropout, Conv2D, Flatten
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

In [None]:
from tensorflow.keras.datasets import fashion_mnist

In [None]:
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

In [None]:
X_train.shape

In [None]:
plt.imshow(X_train[99], cmap='gray');

In [None]:
X_train = X_train / 255

In [None]:
X_train = X_train.reshape(-1, 28, 28, 1)

In [None]:
X_train = X_train * 2 - 1

In [None]:
print(X_train.min())
print(X_train.max())

In [None]:
buffer_size = 60000
batch_size = 128
latent_dim = 100
epochs = 50

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(buffer_size).batch(batch_size)

## Generator and Discriminator

In [None]:
generator = Sequential([
    Input(shape=(latent_dim,)),
    Dense(units=7*7*256,use_bias=False),
    BatchNormalization(),
    LeakyReLU(0.2),
    
    Reshape((7,7,256)),
    
    Conv2DTranspose(128, (5,5), strides=1, padding='same', use_bias=False),
    BatchNormalization(),
    LeakyReLU(0.2),

    Conv2DTranspose(64, (5,5), strides=2, padding='same', use_bias=False),
    BatchNormalization(),
    LeakyReLU(),

    Conv2DTranspose(1, (5,5), strides=2, padding='same', use_bias=False, activation='tanh')
])

In [None]:
discriminator = Sequential([
    Input(shape=(28,28,1)),
    Conv2D(64, (5,5), strides=2, padding='same'),
    LeakyReLU(0.2),
    Dropout(0.3),

    Conv2D(128, (5,5), strides=2, padding='same'),
    LeakyReLU(0.2),
    Dropout(0.3),

    Flatten(),
    Dense(1)
])

In [None]:
test_noise = tf.random.normal([1, 100])
generated_image = generator(test_noise, training=False)
print(f"Generator output shape: {generated_image.shape}")

test_image = tf.random.normal([1, 28, 28, 1])
decision = discriminator(test_image)
print(f"Discriminator output shape: {decision.shape}")  # Must be (1, 1)

In [None]:
cross_entropy = BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)

    return real_loss + fake_loss

In [None]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
generator_optimizer = Adam(0.0001)
discriminator_optimizer = Adam(0.0001)

## Training function

In [None]:
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

In [None]:
def generate_images(model, epoch, seed):
    predictions = model(seed, training=False)
    fig = plt.figure(figsize=(8,8))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i,:,:,0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')

In [None]:
for epoch in range(epochs):
    for image_batch in train_dataset:
        gen_loss, disc_loss = train_step(image_batch)

    if (epoch + 1) % 10 == 0:
        generate_images(generator, epoch+1, seed=tf.random.normal([16, latent_dim]))
        print(f'Epoch {epoch + 1}, Gen Loss: {gen_loss}, Disc Loss: {disc_loss}')