In [7]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

In [8]:
# Load MNIST data
mnist = tf.keras.datasets.mnist
(train_images, _), (_, _) = mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], -1).astype('float32')
train_images = (train_images) / 255  # Normalize the images to [-1, 1]

In [9]:
# Xavier Initialization
def xavier_init(shape, dtype=tf.dtypes.float32):
    in_dim = shape[0]
    xavier_stddev = 1. / tf.math.sqrt(in_dim / 2.)
    return tf.random.normal(shape=shape, stddev=xavier_stddev, dtype=dtype)



# Generator
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    return G_prob

# Discriminator
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    return D_prob, D_logit

# Discriminator Loss  - minmax loss
def discriminator_loss(real_output, fake_output):
    return -tf.reduce_mean(tf.math.log(real_output + 1e-8) + tf.math.log(1. - fake_output + 1e-8))

# Generator Loss  minmax loss
def generator_loss(fake_output):
    return -tf.reduce_mean(tf.math.log(fake_output + 1e-8))

# Optimizers
D_optimizer = tf.optimizers.Adam(1e-4)
G_optimizer = tf.optimizers.Adam(1e-4)

# Training step
@tf.function
def train_step(images):
    noise = sample_Z(mb_size, Z_dim)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        G_sample = generator(noise)
        D_real, _ = discriminator(images)
        D_fake, _ = discriminator(G_sample)

        D_loss = discriminator_loss(D_real, D_fake)
        G_loss = generator_loss(D_fake)

    gradients_of_discriminator = disc_tape.gradient(D_loss, theta_D)
    gradients_of_generator = gen_tape.gradient(G_loss, theta_G)

    D_optimizer.apply_gradients(zip(gradients_of_discriminator, theta_D))
    G_optimizer.apply_gradients(zip(gradients_of_generator, theta_G))

def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n]).astype(np.float32)  # Cast to float32

# Plotting function
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

# Training step
@tf.function
def train_step(images):
    noise = sample_Z(mb_size, Z_dim)
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        G_sample = generator(noise)
        D_real, _ = discriminator(images)
        D_fake, _ = discriminator(G_sample)

        D_loss = discriminator_loss(D_real, D_fake)
        G_loss = generator_loss(D_fake)

    gradients_of_discriminator = disc_tape.gradient(D_loss, theta_D)
    gradients_of_generator = gen_tape.gradient(G_loss, theta_G)

    D_optimizer.apply_gradients(zip(gradients_of_discriminator, theta_D))
    G_optimizer.apply_gradients(zip(gradients_of_generator, theta_G))
    return D_loss, G_loss

In [10]:
# Model Parameters
mb_size = 128
Z_dim = 100

# Discriminator Weights and Biases
D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))
D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1, D_b1, D_W2, D_b2]

# Generator Weights and Biases
G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))
G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))
theta_G = [G_W1, G_b1, G_W2, G_b2]

In [11]:
# Training loop
if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0
for it in range(30000):
    batch_index = np.random.randint(0, train_images.shape[0], mb_size)
    X_mb = train_images[batch_index]

    D_loss_curr, G_loss_curr = train_step(X_mb)  # Capture the loss values here

    if it % 1000 == 0:
        G_sample = generator(sample_Z(16, Z_dim))
        samples = G_sample.numpy()
        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss_curr.numpy(), G_loss_curr.numpy()))


Iter-0; D_loss: 1.3915553092956543; G_loss: 1.5876109600067139
Iter-1000; D_loss: 0.36768266558647156; G_loss: 2.3177099227905273
Iter-2000; D_loss: 0.3169485926628113; G_loss: 1.947006344795227
Iter-3000; D_loss: 0.4201398193836212; G_loss: 1.4058104753494263
Iter-4000; D_loss: 0.406434029340744; G_loss: 1.597029447555542
Iter-5000; D_loss: 0.4634042978286743; G_loss: 1.9123272895812988
Iter-6000; D_loss: 0.5394510626792908; G_loss: 1.6409804821014404
Iter-7000; D_loss: 0.46160179376602173; G_loss: 1.8026102781295776
Iter-8000; D_loss: 0.5959134101867676; G_loss: 1.6496095657348633
Iter-9000; D_loss: 0.7266929149627686; G_loss: 1.5127217769622803
Iter-10000; D_loss: 0.6457779407501221; G_loss: 1.572351336479187
Iter-11000; D_loss: 0.6271353960037231; G_loss: 1.529844880104065
Iter-12000; D_loss: 0.6523762941360474; G_loss: 1.5388013124465942
Iter-13000; D_loss: 0.6657979488372803; G_loss: 1.5472614765167236
Iter-14000; D_loss: 0.6550084352493286; G_loss: 1.4824557304382324
Iter-15000;