In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [2]:
img_rows = 28
img_cols = 28
channels =1

img_shape = (img_rows,img_cols,channels)
z_dim =100

In [9]:
class generator(tf.keras.Model):
    def __init__(self,z_dim):
        super(generator,self).__init__()
        self.fc1 = tf.keras.layers.Dense(128, input_dim=(100,))
        self.fc2 = tf.keras.layers.Dense(28*28*1)
        self.reshape = tf.keras.layers.Reshape(img_shape)
    
    def call(self, input_tensor):
        x = self.fc1(input_tensor)
        x = tf.nn.leaky_relu(x ,alpha=0.01)
        x = self.fc2(x)
        x = tf.nn.tanh(x)
        return x        

In [10]:
class discriminator(tf.keras.Model):
    def __init__(self,img_shape):
        super(discriminator,self).__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(128)
        self.fc2 = tf.keras.layers.Dense(1)

    def call(self, input_tensor):
        x = self.flatten(input_tensor)
        x = self.fc1(x)
        x =  tf.nn.leaky_relu(x, alpha=0.01)
        x = self.fc2(x)
        x = tf.nn.sigmoid(x)
        return x      

In [11]:
g = generator(100)
d = discriminator(img_shape)

In [12]:
generator_optimizer = tf.keras.optimizers.Adam()
discriminator_optimizer = tf.keras.optimizers.Adam()

def generator_loss(fake_ouput):
    return tf.keras.losses.BinaryCrossentropy()(tf.ones_like(fake_ouput.shape),fake_ouput)

def discriminator_loss(real_output, fake_output):
    real_loss = tf.keras.losses.BinaryCrossentropy()(tf.ones_like(real_output),real_output)
    fake_loss = tf.keras.losses.BinaryCrossentropy()(tf.zeros_like(fake_ouput),fake_ouput)
    return 0.5 * (real_loss + fake_loss)

In [13]:
def sample_images(generator, image_grid_rows=4, image_grid_cols=4):
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_cols, z_dim))
    gen_imgs = generator(z)
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(image_grid_rows, image_grid_cols, figsize=(4,4))

    cnt =0
    for i in range(image_grid_rows):
        for j in range(image_grid_cols):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt+=1

In [14]:
def train(iterations, batch_size, sample_interval):
    (x_train,_),(_,_) = tf.keras.datasets.mnist.load_data()

    x_train = x_train / 127.5 -1.0
    x_train = x_train[...,tf.newaxis]

    real  = np.ones((batch_size, 1))
    fake  = np.zeros((batch_size, 1))

    for iteration in range(1, iterations+1):
        total_d_loss = 0.0
        total_g_loss = 0.0

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:

            idx = np.random.randint(0, x_train.shape[0], batch_size)
            imgs = x_train[idx]

            z = np.random.normal(0, 1, (batch_size,100))
            
            generated_img = g(z)
            real_output = d(imgs)
            fake_output = d(generated_img)

            gen_loss = generator_loss(fake_output)
            disc_loss = discriminator_loss(real_output, fake_output)

            total_g_loss += gen_loss
            total_d_loss += disc_loss

        gen_gradient = gen_tape.gradient(gen_loss, g.trainable_variables)
        disc_gradient = disc_tape.gradient(disc_loss, d.trainable_variables)

        generator_optimizer.apply_gradients(zip(gen_gradient, g.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(disc_gradient, d.trainable_variables))

        if iteration % sample_interval == 0:
            print(f'iteration : {iteration} G loss = {total_g_loss/iteration:.4f} D loss = {total_d_loss/iteration:.4f}')
            sample_images(g)

train(20000, 256, 1000)

In [None]:
z = np.random.normal(0, 1, (1,100))
g(z)