In [None]:
from keras.datasets import mnist
import matplotlib.pyplot as plt
from keras.layers import Dense, Dropout, Input, BatchNormalization, Flatten, Reshape
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
import numpy as np
from tqdm import tqdm_notebook

class GAN():

  def __init__(self):
    self.img_rows = 28
    self.img_cols = 28
    self.channels = 1
    self.img_shape = (self.img_rows, self.img_cols, self.channels) #Channels is optional
    self.latent_dim = 100
    self.D = self.create_discriminator()
    self.D.compile(loss='binary_crossentropy',
                  optimizer='adam')
    self.G = self.create_generator()
    z = Input(shape=(self.latent_dim,))
    counterfiet= self.G(z)
    self.D.trainable=False
    police= self.D(counterfiet)
    self.gan=Model(inputs=z, outputs=police)
    self.gan.compile(loss='binary_crossentropy', optimizer='adam')

  def create_generator(self):
    generator_model=Sequential()

    generator_model.add(Dense(256, activation = LeakyReLU(0.2),input_dim=self.latent_dim))
    generator_model.add(BatchNormalization(momentum=0.8))

    generator_model.add(Dense(512, activation = LeakyReLU(0.2)))
    generator_model.add(BatchNormalization(momentum=0.8))

    generator_model.add(Dense(1024, activation = LeakyReLU(0.2)))
    generator_model.add(BatchNormalization(momentum=0.8))

    generator_model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    generator_model.add(Reshape(self.img_shape))

    generator_model.summary()
    noise = Input(shape=(self.latent_dim,))
    gen_img = generator_model(noise)

    return Model(inputs=noise , outputs=gen_img)

  def create_discriminator(self):
    discriminator_model=Sequential()

    discriminator_model.add(Flatten(input_shape=self.img_shape))
    discriminator_model.add(Dense(1024, activation = LeakyReLU(.02)))
    discriminator_model.add(Dropout(.2))
    discriminator_model.add(Dense(512, activation = LeakyReLU(.02)))
    discriminator_model.add(Dropout(.2))
    discriminator_model.add(Dense(256, activation='relu'))
    discriminator_model.add(Dropout(.2))
    discriminator_model.add(Dense(1, activation='sigmoid'))

    discriminator_model.summary()
    img = Input(shape=self.img_shape)
    validity = discriminator_model(img)
    return Model(inputs=img, outputs=validity)

  def sample_images(self, epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, self.latent_dim))
    gen_imgs = self.G.predict(noise)
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
      for j in range(c):
        axs[i, j].imshow(gen_imgs[cnt, :, :,0], cmap='gray')
        axs[i, j].axis('off')
        cnt += 1
    plt.savefig('gan_generated_image %d.png' %epoch)
    plt.close()


  def training(self, epochs, batch_size=128, sample_interval=10000):
    (X_train , _), (_ , _) = mnist.load_data()
    X_train = X_train/127.5 -1
    X_train = np.expand_dims(X_train, axis=3) #In Order to add the channel Information
    # If you want to avoid expand_dims then remove self.channels from self.img_shape = (self.img_rows, self.img_cols, self.channels)

    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in tqdm_notebook(range(epochs+1)):

      idx = np.random.randint(0, X_train.shape[0], batch_size)
      imgs = X_train[idx]

      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

      self.gan.train_on_batch(noise, valid) #Fooling Tricks
      gen_imgs = self.G.predict(noise)
      
      self.D.train_on_batch(imgs, valid)
      self.D.train_on_batch(gen_imgs, fake)
      

      if epoch % sample_interval == 0:
        self.sample_images(epoch)

if __name__ == '__main__':
  gan = GAN()
  gan.training(epochs=80000, batch_size=128, sample_interval=10000)