In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout, Reshape, Conv2D, Conv2DTranspose, Flatten, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam

import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
import sys
import os

# Load in data
fashion_mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

# Normalizing inputs
x_train, x_test = x_train/255.0*2-1, x_test/255.0*2-1
#print("x_train.shape", x_train.shape)

# Flattening Data
N, H, W = x_train.shape  # N-samples, H-height, W-width
D = H*W  # D- number of pixels in the image

#print("x_train shape: ", x_train.shape)

latent_dim = 100  # latent space dimension

# Generator model


def build_generator(latent_dim):
    i = Input(shape=(latent_dim,))
    x = Dense(7*7*128, activation=LeakyReLU(alpha=0.2))(i)
    x = Reshape([7, 7, 128])(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                        activation="relu")(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                        activation="tanh")(x)
    model = Model(i, x)
    return model

# Discriminator model


def build_discriminator(img_size):
    i = Input(shape=(img_size),)
    x = Conv2D(64, kernel_size=5, strides=2, padding="same",
               activation=LeakyReLU(0.3))(i)
    x = Dropout(0.5)(x)
    x = Conv2D(64, kernel_size=5, strides=2, padding="same",
               activation=LeakyReLU(0.3))(x)
    x = Dropout(0.5)(x)
    x = Flatten()(x)
    x = Dense(1, activation="sigmoid")(x)
    model = Model(i, x)
    return model


# building discriminator model
discriminator = build_discriminator([28, 28, 1])
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5),
    metrics=['accuracy'])

# building generator model
generator = build_generator(latent_dim)

# Creating noise sample and passing into generator
z = Input(shape=(latent_dim,))
print(z.shape)
img = generator(z)

# Ensure generator is being trained fully before discriminator is trying to decipher image
discriminator.trainable = False

# 1 means picture is fake
fake_pred = discriminator(img)

gan = Model(z, fake_pred)

# Compile Combined model to adjust weights
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# Training the GAN

batch_size = 32
epochs = 30000
sample_period = 200  # data is saved every time the sample period passes

# Creating Batch Labels
ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

d_losses = []  # discriminator loss
g_losses = []  # generator loss

# Creating folder to store images
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')


def sample_images(epoch):
    rows, cols = 5, 5
    # creating noise to input into the generator
    noise = np.random.randn(rows*cols, latent_dim)
    imgs = generator.predict(noise)

    # Rescale images to 0-1
    imgs = 0.5 * imgs + 0.5

    fig, axs = plt.subplots(rows, cols)
    idx = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(imgs[idx].reshape(H, W), cmap='gray')
            axs[i, j].axis('off')
            idx += 1
    fig.savefig("gan_images/%d.png" % epoch)
    plt.close()


for epoch in range(epochs):
    # Training discriminator
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_imgs = x_train[idx]
  
    noise = np.random.randn(batch_size, latent_dim)
    fake_imgs = generator.predict(noise)

    d_loss_real, d_acc_real = discriminator.train_on_batch(real_imgs, ones)
    d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
    d_loss = 0.5 * (d_loss_real + d_loss_fake)
    d_acc = 0.5 * (d_acc_real + d_acc_fake)

    # Training generator
    noise = np.random.randn(batch_size, latent_dim)
    g_loss = gan.train_on_batch(noise, ones)

    noise = np.random.randn(batch_size, latent_dim)
    g_loss = gan.train_on_batch(noise, ones)

    # Save the losses
    d_losses.append(d_loss)
    g_losses.append(g_loss)

    if epoch % 100 == 0:
        print(f"epoch: {epoch+1}/{epochs}, d_loss: {d_loss:.2f}, \
      d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")

    if epoch % sample_period == 0:
        sample_images(epoch)


(None, 100)
epoch: 1/30000, d_loss: 0.72,       d_acc: 0.36, g_loss: 0.69
epoch: 101/30000, d_loss: 0.00,       d_acc: 1.00, g_loss: 0.00
epoch: 201/30000, d_loss: 0.04,       d_acc: 1.00, g_loss: 0.00
epoch: 301/30000, d_loss: 0.80,       d_acc: 0.47, g_loss: 0.92
epoch: 401/30000, d_loss: 0.75,       d_acc: 0.34, g_loss: 0.65
epoch: 501/30000, d_loss: 0.71,       d_acc: 0.45, g_loss: 0.72
epoch: 601/30000, d_loss: 0.70,       d_acc: 0.50, g_loss: 0.72
epoch: 701/30000, d_loss: 0.66,       d_acc: 0.66, g_loss: 0.74
epoch: 801/30000, d_loss: 0.70,       d_acc: 0.55, g_loss: 0.70
epoch: 901/30000, d_loss: 0.67,       d_acc: 0.58, g_loss: 0.73
epoch: 1001/30000, d_loss: 0.69,       d_acc: 0.55, g_loss: 0.74
epoch: 1101/30000, d_loss: 0.69,       d_acc: 0.55, g_loss: 0.75
epoch: 1201/30000, d_loss: 0.66,       d_acc: 0.62, g_loss: 0.79
epoch: 1301/30000, d_loss: 0.64,       d_acc: 0.75, g_loss: 0.79
epoch: 1401/30000, d_loss: 0.69,       d_acc: 0.52, g_loss: 0.74
epoch: 1501/30000, d_loss

In [None]:

x_train.shape

(60000, 784)