## DCGAN Implementation

In [7]:
# import libraries
import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt

In [9]:
class DCGAN:
    def __init__(self, X = None, batch_size = 64, epochs = 10, optimizer = tf.keras.optimizers.Adam(0.001), noise_dim = 30):
        super(DCGAN, self).__init__()
        self.batch_size = batch_size
        self.epochs = epochs
        self.optimizer = optimizer
        self.noise_dim = noise_dim
        self.generator = None
        self.discriminator = None
        self.gan = None
        self.dataset = tf.data.Dataset.from_tensor_slices(X).shuffle(1000)
        self.dataset = self.dataset.batch(self.batch_size).prefetch(1)
    
    @staticmethod
    def print_images(images, n_cols = None):
        """

        Args:
            images ([type]): [description]
            n_cols ([type], optional): [description]. Defaults to None.
        """
        n_cols = n_cols or len(images)
        n_rows = (len(images) - 1) // n_cols + 1
        if images.shape[-1] == 1:
            images = np.squeeze(images, axis = 1)
        plt.figure(figsize=(n_cols, n_rows))
        for index, image in enumerate(images):
            plt.subplot(n_rows, n_cols)
            plt.imshow(image)
            plt.axis("off")
        plt.show()
    
    def generate(self, num_examples=None):
        """

        Args:
            num_examples ([type], optional): [description]. Defaults to None.

        Returns:
            [type]: [description]
        """
        num_examples = self.batch_size if num_examples is None else num_examples
        z = tf.random.normal(shape=[num_examples, self.noise_dim])
        X_generate = self.generatoar(z)
        return X_generate
    
    def build_generator(self):
        """
        
        """
        self.generator = tf.keras.models.Sequential([
            tf.keras.layers.Dense(7*7*256, activation=tf.keras.layers.LeakyReLU(0.2), input_shape=[self.noise_dim]),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Reshape(7, 7, 256),
            tf.keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLu(0.2),
            tf.keras.layers.Conv2DTranspose(32, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Conv2DTranspose(16, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Conv2DTranspose(16, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding="SAME", activation="tanh")
        ])
    
    def build_discriminator(self):
        """
        
        """
        self.discriminator = tf.keras.models.Sequential([
            tf.keras.layers.Conv2D(16, kernel_size=5, strides=2, padding="SAME", activation="tanh"),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Dropout(0.4),
            tf.keras.layers.Conv2D(16, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Dropout(0.4),
            tf.keras.layers.Conv2D(32, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Dropout(0.4),
            tf.keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="SAME", activation=tf.keras.layers.LeakyReLU(0.2)),
            tf.keras.layers.LeakyReLU(0.2),
            tf.keras.layers.Dropout(0.4),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(1, activation="sigmoid")
        ])
        
            
        