In [46]:
import numpy as np
import pandas as pd
import tensorflow as tf
import cv2
import os
import glob
from tensorflow.layers import batch_normalization
from tensorflow.keras.layers import UpSampling2D
from sklearn.preprocessing import StandardScaler 
import random

In [37]:

class helper_functions:
    
    def get_weights(self, shape, name):
        with tf.variable_scope('weights', reuse = tf.AUTO_REUSE):
            wt_init = tf.random_normal_initializer()
            return tf.get_variable(name = name, shape = shape, initializer = wt_init)
    
    
    def get_bias(self, shape, name):
        with tf.variable_scope('biases', reuse = tf.AUTO_REUSE):
            init = tf.constant_initializer(0)
            return tf.get_variable(name = name, shape = shape, initializer = init)


    def conv_layer(self, data, weights, bias, name, strides, batch_normalize = False, discriminator = True):

        with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
            conv_res         = tf.nn.conv2d(data, filter = weights, strides = strides, padding = 'SAME')
            
            # we add bias to the conv_res only if it is discriminator
            if discriminator:    
                conv_res     = tf.nn.bias_add(conv_res, bias)

            if batch_normalize:
                conv_bn      = batch_normalization(conv_res, momentum = 0.5)
                conv_bn_relu = tf.nn.leaky_relu(conv_bn)
                return conv_bn_relu
            
            if not batch_normalize:
                conv_relu = tf.nn.leaky_relu(conv_res)
                return conv_relu
            
      
    def conv_block(self, data, weights, bias, block_name, strides, first_or_last_block = False, disc = True):
        
        # first block -> discriminator || last_block -> generator
        # we don't batch_normalize data for first conv block in discrminator and last block in generator
        batch_normalize = False if first_or_last_block else True
        conv_result     = self.conv_layer(data, weights, bias, block_name, strides, batch_normalize, disc)
        return conv_result
    
    
    
    def get_flatten_data_shape(self, data_shape):
        total_neurons = 1
        for val in data_shape:
            if val is not None:
                total_neurons = total_neurons * val
        return total_neurons
    
    
    def fully_connected_layer(self, data, shape_flatten, weights, bias, block_name):
#         total_neurons = data.get_shape().as_list()
        data   = tf.reshape(data, [-1, shape_flatten])
        data   = tf.matmul(data, weights)
        logits = tf.nn.bias_add(data, bias)
        return logits
    


In [38]:


class Discriminator:
    def __init__(self, image_shape, helper_functions):
        self.image_width    = image_shape[0]
        self.image_height   = image_shape[1]
        self.image_channels = image_shape[2]
        self.no_filters     = [64, 128, 256, 512]
        self.conv_strides   = [1, 2, 2, 1]
        self.helper         = helper_functions
    
    
    def d_propagate_forward(self, data):
        
        # data_shape = [-1, 28, 28, 1]
        data = tf.reshape(data, [-1, self.image_width, self.image_height, self.image_channels])
        
        # disc_conv_1 shape = [-1, 28, 28, 64]
        self.dwts_1   = self.helper.get_weights([3, 3, self.image_channels, self.no_filters[0]], 'd_w_1')
        self.dbias_1  = self.helper.get_bias([self.no_filters[0]], 'd_b_1')
        disc_conv_1  = self.helper.conv_block(data, self.dwts_1, self.dbias_1, 'd_conv_1', [1, 1, 1, 1], True, True)
        
        
        # disc_conv_2 shape = [-1, 14, 14, 64]
        self.dwts_2   = self.helper.get_weights([3, 3, self.no_filters[0], self.no_filters[0]], 'd_w_2')
        self.dbias_2  = self.helper.get_bias([self.no_filters[0]], 'd_b_2')
        disc_conv_2  = self.helper.conv_block(disc_conv_1, self.dwts_2, self.dbias_2, 'd_conv_2', self.conv_strides, False,
                                                                 True)

        
        # disc_conv_3 shape = [-1, 7, 7, 128]
        self.dwts_3   = self.helper.get_weights([3, 3, self.no_filters[0], self.no_filters[1]], 'd_w_3')
        self.dbias_3  = self.helper.get_bias([self.no_filters[1]], 'd_b_3')
        disc_conv_3  = self.helper.conv_block(disc_conv_2, self.dwts_3, self.dbias_3, 'd_conv_3', self.conv_strides, False, 
                                                                 True)
        
        
        # disc_conv_4 shape = [-1, 4, 4, 256]
        self.dwts_4   = self.helper.get_weights([3, 3, self.no_filters[1], self.no_filters[2]], 'd_w_4')
        self.dbias_4  = self.helper.get_bias([self.no_filters[2]], 'd_b_4')
        disc_conv_4  = self.helper.conv_block(disc_conv_3, self.dwts_4, self.dbias_4, 'd_conv_4', self.conv_strides, False, 
                                                                 True)
        
        
        # disc_conv_5 shape = [-1, 2, 2, 512]
        self.dwts_5   = self.helper.get_weights([3, 3, self.no_filters[2], self.no_filters[3]], 'd_w_5')
        self.dbias_5  = self.helper.get_bias([self.no_filters[3]], 'd_b_5')
        disc_conv_5  = self.helper.conv_block(disc_conv_4, self.dwts_5, self.dbias_5, 'd_conv_5', self.conv_strides, False,
                                                                 True)
        
        
        self.conv_5_shape = disc_conv_5.get_shape().as_list() # [-1, 4, 4, 512]
        conv_5_shape_flat = self.helper.get_flatten_data_shape(disc_conv_5.get_shape().as_list())
        self.shape_flatten = conv_5_shape_flat
        self.dfc_wts  = self.helper.get_weights([conv_5_shape_flat, 1], 'd_fc_wt')
        self.dfc_bias = self.helper.get_bias([1], 'd_fc_bias')
        d_logits     = self.helper.fully_connected_layer(disc_conv_5, conv_5_shape_flat, self.dfc_wts, self.dfc_bias, 'd_fc_layer')
        
        return d_logits
        

In [53]:

class Generator:
    def __init__(self, image_shape, noise_size, helper_functions, discriminator):
        self.image_width = image_shape[0]
        self.image_height = image_shape[1]
        self.image_channels = image_shape[2]
        self.noise_size = noise_size
        self.no_gfilters = [512, 256, 128, 64, 3]
        self.ghelper = helper_functions
        self.disc_obj = discriminator
        self.conv_strides = [1, 1, 1, 1]
        
        
    def g_propagate_forward(self, data):
        '''
        gen_up_1   = [-1, 4, 4, 512]
        gen_conv_2 = [-1, 4, 4, 256]
        gen_up_2   = [-1, 8, 8, 256]
        gen_conv_3 = [-1, 8, 8, 128]
        gen_up_3   = [-1, 1, 32, 128]
        gen_conv_4 = [-1, 32, 32, 64]
        gen_up_5   = [-1, 64, 64, 64]
        gen_conv_5 = [-1, 64, 64, 3]
        '''
        
        width = self.disc_obj.conv_5_shape[1]
        height = self.disc_obj.conv_5_shape[2]
        channels = self.disc_obj.conv_5_shape[3]
        
        self.gwts_1 = self.ghelper.get_weights([100, self.disc_obj.shape_flatten], 'g_w_1')
#         print('gwts ' + str(np.shape(self.gwts_1)))
        data = tf.matmul(data, self.gwts_1)
        data = tf.nn.relu(data)
#         print('conv_5 ' + str(self.disc_obj.conv_5_shape))
#         print('data ' + data.get_shape().as_list())
        data = tf.reshape(data, [-1, width, height, channels]) 
        gen_up_1 = UpSampling2D()(data) 
        
        self.gwts_2 = self.ghelper.get_weights([5, 5, 512, self.no_gfilters[1]], 'g_w_2')
        gen_conv_2 = self.ghelper.conv_block(gen_up_1, self.gwts_2, None, 'g_conv_2', self.conv_strides, False, False)
        gen_up_2 = UpSampling2D()(gen_conv_2) 
        
        self.gwts_3 = self.ghelper.get_weights([5, 5, self.no_gfilters[1], self.no_gfilters[2]], 'g_w_3')
        gen_conv_3 = self.ghelper.conv_block(gen_up_2, self.gwts_3, None, 'g_conv_3', self.conv_strides, False, False)
        gen_up_3 = UpSampling2D()(gen_conv_3) 
        
        self.gwts_4 = self.ghelper.get_weights([5, 5, self.no_gfilters[2], self.no_gfilters[3]], 'g_w_4')
        gen_conv_4 = self.ghelper.conv_block(gen_up_3, self.gwts_4, None, 'g_conv_4', self.conv_strides, False, False)
        gen_up_4 = UpSampling2D()(gen_conv_4) 
        
        self.gwts_5 = self.ghelper.get_weights([5, 5, self.no_gfilters[3], self.no_gfilters[4]], 'g_w_5')
        gen_conv_5 = tf.nn.conv2d(gen_up_4, self.gwts_5, [1, 1, 1, 1], padding = 'SAME')
        return tf.nn.tanh(gen_conv_5)
        
        
        

In [48]:

class DCGAN:
    def __init__(self, image_shape, disc_lr, gen_lr, noise_size, batch_size, no_of_epochs,
                                        train_data_path):
        self.image_shape = image_shape
        self.noise_size = noise_size
        self.batch_size = batch_size
        self.no_of_epochs = no_of_epochs
        self.disc_lr = disc_lr
        self.gen_lr = gen_lr
        self.hf = helper_functions()
        self.Discriminator = Discriminator(self.image_shape, self.hf)
        self.Generator = Generator(self.image_shape, self.noise_size, self.hf, self.Discriminator)
#         self.is_training = tf.placeholder(tf.bool, shape = ()) 
        self.disc_X = tf.placeholder(tf.float32, [None, self.image_shape[0], self.image_shape[1], self.image_shape[2]])
        self.gen_X = tf.placeholder(tf.float32, [None, self.noise_size])
        self.train_data_path = train_data_path
#         self.test_data_path  = test_data_path
        self.train_image_names = os.listdir(self.train_data_path + '/')
        
    
    def get_input_batches(self, images_batch_names):
        images_arr = []
        for image_name in images_batch_names:
            
            image = cv2.resize(cv2.imread(self.train_data_path + "/" + image_name), (64, 64))
            image = image / 255.0
            images_arr.append(image)
        return np.asarray(images_arr)
        
    
        
    def calculate_loss(self, logits, labels):
        return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits))
    
    
    def get_disc_and_gen_loss(self):
        
        # disc_X -> MNIST data || gen_X -> noisy_images
        disc_logits_MNIST = self.Discriminator.d_propagate_forward(self.disc_X)
        self.gen_noisy_images = self.Generator.g_propagate_forward(self.gen_X)
        
        disc_logits_noise = self.Discriminator.d_propagate_forward(self.gen_noisy_images)
        
        
        disc_noise_labels = tf.zeros_like(disc_logits_noise)
        disc_MNIST_labels = tf.ones_like(disc_logits_MNIST)
        gen_noise_labels  = tf.ones_like(disc_logits_noise)
        
        disc_noise_loss = self.calculate_loss(disc_logits_noise, disc_noise_labels)
        disc_MNIST_loss = self.calculate_loss(disc_logits_MNIST, disc_MNIST_labels)
        
        self.gen_loss  = self.calculate_loss(disc_logits_noise, gen_noise_labels)
        self.disc_loss = tf.add(disc_noise_loss, disc_MNIST_loss)
        
    
    
    def generate_sample_images(self, epoch):
        noisy_data = np.random.uniform(-1, 1, (self.batch_size, self.noise_size))
        images = self.session.run(self.gen_noisy_images, feed_dict = {self.gen_X : noisy_data})
        images = images * 0.5 + 0.5
        # scale between 0, 1
        fig, axs = plt.subplots(c, r)
        cnt = 0
        for i in range(c):
            for j in range(r):
                axs[i, j].imshow(imgs[cnt, :, :, 0], cmap="gray")
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("samples/%d.png" % epoch)
        plt.close()
        
        
        
    
    def train_GAN(self):
        with tf.variable_scope('optim', reuse = tf.AUTO_REUSE):
            self.get_disc_and_gen_loss()    
            trainable_vars = tf.trainable_variables()
            discriminator_vars = [var for var in trainable_vars if 'd' in var.name]
            generator_vars = [var for var in trainable_vars if 'g' in var.name]

            self.train_discriminator = tf.train.AdamOptimizer(self.disc_lr,beta1 = 0.5).minimize(self.disc_loss, 
                                                                                var_list = discriminator_vars)
            self.train_generator = tf.train.AdamOptimizer(self.gen_lr, beta1 = 0.5).minimize(self.gen_loss, 
                                                                                var_list = generator_vars)

            self.session = tf.Session()
    #         with tf.Session() as session:
            self.session.run(tf.global_variables_initializer())
            total_no_of_samples = len(self.train_image_names)

            for epoch in range(0, self.no_of_epochs):
                random.shuffle(self.train_image_names)

                for index in range(0, total_no_of_samples, self.batch_size):
                    train_batch = get_input_batches(self.train_image_names[index : index + batch_size])
                    disc_noisy_batch = np.random.uniform(-1, 1, (self.batch_size, self.noise_size))

                    _, disc_loss_ = self.session.run([self.train_discriminator, self.disc_loss], 
                                                feed_dict = {self.disc_X : train_batch, 
                                                             self.gen_X : disc_noisy_batch
                                                             })


                    gen_noisy_batch = np.random.uniform(-1, 1, (self.batch_size, self.noise_size))
                    _, gen_loss_ = self.session.run([self.train_generator, self.gen_loss], 
                                                                feed_dict = {self.gen_X : gen_noisy_batch
                                                                             })


                if epoch % 10 == 0:
                    print(epoch, disc_loss_, gen_loss_)

                if epoch % 500 == 0:
                    self.generate_sample_images(epoch)
                    print(epoch, disc_loss_, gen_loss_)
            

In [None]:
if __name__ == '__main__':
    image_shape = [64, 64, 3]
    disc_lr = 0.002
    gen_lr  = 0.002
    noise_size = 100
    batch_size = 50
    no_of_epochs = 1
#     is_training = True
    train_data_path = '/Users/vijay/Downloads/Datasets/Simpsons'
    DC_GAN = DCGAN(image_shape, disc_lr, gen_lr, noise_size, batch_size, no_of_epochs, train_data_path)
    DC_GAN.train_GAN()