In [61]:
from __future__ import absolute_import, division, print_function
import tensorflow as tf
tf.enable_eager_execution()

import os
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import time


In [62]:
from IPython.display import display, Image


### Load the data

In [10]:
(train_images, train_labels),(_,_)= tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0],28,28,1).astype('float32')
train_images  = (train_images-127.5)/127.5

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [63]:
BUFFER_SIZE = 60000
BATCH_SIZE = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


In [94]:
class Generator(tf.keras.Model):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=True)
        self.batchnorm1 = tf.keras.layers.BatchNormalization()
        self.conv1 = tf.keras.layers.Conv2DTranspose(64,(5,5), strides=(1,1),padding='same',use_bias=False)
        self.batchnorm2 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2DTranspose(32,(5,5), strides=(2,2), padding='same', use_bias=False)
        self.batchnorm3 = tf.keras.layers.BatchNormalization()
        self.conv3 = tf.keras.layers.Conv2DTranspose(1,(5,5), strides=(2,2), padding='same', use_bias=False)
        
    def call(self,x,training=True):
        x = self.fc1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.relu(x)
        
        x = tf.reshape(x, shape=(-1,7,7,64))
        
        x = self.conv1(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.conv2(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.relu(x)
        
        x = tf.nn.tanh(self.conv3(x))
        return x 


In [95]:
class Discriminator(tf.keras.Model):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64,(5,5), strides=(2,2), padding='same')
        self.conv2 = tf.keras.layers.Conv2D(128,(5,5), strides=(2,2), padding='same')
        self.dropout = tf.keras.layers.Dropout(0.3)
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(1)
        
    def call(self,x, training=True):
        x = tf.nn.leaky_relu(self.conv1(x))
        x = self.dropout(x, training=training)
        x = tf.nn.leaky_relu(self.conv2(x))
        x = self.dropout(x, training=training)
        x = self.flatten(x)
        x = self.fc1(x)
        return x

In [96]:
generator = Generator()
discriminator = Discriminator()

In [97]:
generator.call = tf.contrib.eager.defun(generator.call)
discriminator.call = tf.contrib.eager.defun(discriminator.call)

In [114]:
def discriminator_loss(real_output, generated_output):
    real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)
    generated_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.zeros_like(generated_output), logits=generated_output)
    
    total_loss = real_loss + generated_loss
    
    return total_loss

def generator_loss(generated_output):
    return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output),logits=generated_output)

discriminator_optimizer = tf.train.AdamOptimizer(1e-4)
generator_optimizer = tf.train.AdamOptimizer(1e-4)


In [115]:
checkpoint_dir='./training_checkpints'
checkpoint_prefix = os.path.join(checkpoint_dir,'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
                                discriminator_optimizer=discriminator_optimizer,
                                 generator = generator,
                                 discriminator = discriminator
                                )

In [116]:
EPOCHS = 150
noise_dim = 10
num_examples_to_generate = 16
random_vector_for_generation = tf.random_normal([num_examples_to_generate,
                                                 noise_dim])

In [117]:
def generate_and_save_images(model,epoch,test_input):
    predictions = model(test_output, training=False)
    
    fig = plt.figure(figsize=(4,4))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4,4,i+1)
        plt.imshow(predictions[i,:,:,0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
        
    plt.savefig('image_at_epoch_{:04d}'.format(epoch))
    plt.show()
        

In [122]:
def train(dataset, epochs, noise_dim):
    for epoch in range(epochs):
        start = time.time()
        
        for images in dataset:
            noise = tf.random_normal([BATCH_SIZE, noise_dim])
            
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                generated_images = generator(noise, training=True)
            
                real_output = discriminator(images, training=True)
                generated_output = discriminator(generated_images,training=True)
            
                gen_loss = generator_loss(generated_output)
                disc_loss = discriminator_loss(real_output, generated_output)
            
            gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)
            gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)
        
            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))
            discriminator_optimizer.apply_gradients(zip(gradients_of_generator,generator.variables))
        
        if epoch % 1 == 0:
            display.clear_output(wait=True)
            generate_and_save_image(generator,
                                   epoch + 1,
                                   random_vector_for_generation)
        # saving the checkpoint every 15 epochs
        if (epoch+1) % 15 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
        print ('Time taken for epoch {} is {} sec'.format(epoch+1,
                                                          time.time()))
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                            epochs,
                            random_vector_for_generation)
       

In [None]:
train(train_dataset, EPOCHS, noise_dim)

In [None]:
checkpoints.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
def display_image(epoch_no):
    return PIL.Image.open('image-at_epoch_{:04d}.png'.format(epoch_no))
display_images(EPOCHS)

In [None]:
with imageio.get_writter('dcgan.gif',mode='I') as writer:
    filenames=glob.glob('image.png')
    filenames = sorted(filenames)
    last = -1
    for i, filename in enumerate(filenames):
        frame = 2*(i**0.5)
        if round(frame) > round(last):
            last = frame
        else:
            continue
        image = imageio.imread(filename)
        writer.append_data(image)
    image = imageio.imread(filename)
    writer.append_data(image)
# hack to display the gif inside the notebook  

os.system('cp dcgan.gif dcgan.gif.png')
display.Image(filename='dcgan.gif.png')
