In [1]:
#GAN implementation using keras and tensorflow

#imports
import numpy as np
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, BatchNormalization
from keras.layers import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import plot_model

In [2]:
class Discriminator(object):
  def __init__(self, width=28, height=28, channels=1, latent_size=100):
    #initialize variables
    self.CAPACITY = width*height*channels
    self.SHAPE = (width, height, channels)
    self.OPTIMIZER = Adam(lr=0.0002, decay=8e-9)
    self.Discriminator = self.model()

    self.Discriminator.compile(loss='binary_crossentropy', optimizer=self.OPTIMIZER, metrics=['accuracy'])
    self.Discriminator.summary

  def model(self):
    #build the binary classifier and return it
    model = Sequential()
    model.add(Flatten(input_shape=self.SHAPE))
    model.add(Dense(self.CAPACITY, input_shape=self.SHAPE))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(int(self.CAPACITY/2)))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(1, activation='sigmoid'))
    return model

  def summary(self):
    #Prints the model summary to the screen
    return self.Discriminator.summary()

  def save_mode(self):
    #saves the model structure to a file in the data folder
    plot_model(self.Discriminator.model, to_file='/content/GAN_100/Discriminator_model.png')


In [3]:
class Generator(object):
  def __init__(self,  width=28, height=28, channels=1, latent_size = 100):
    #initialize variables
    self.W = width
    self.H = height
    self.C = channels
    self.OPTIMIZER = Adam(lr=0.0002, decay= 8e-9)

    self.LATENT_SPACE_SIZE = latent_size
    self.latent_space = np.random.normal(0,1,(self.LATENT_SPACE_SIZE,))

    self.Generator = self.model()
    self.Generator.compile(loss='binary_crossentropy',optimizer=self.OPTIMIZER)

    self.Generator.summary()

  def model(self, block_starting_size=128, num_blocks=4):
    #Build the generator model and returns it
    model = Sequential()

    block_size = block_starting_size

    model.add(Dense(block_size, input_shape=(self.LATENT_SPACE_SIZE,)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    for i in range(num_blocks-1):
      block_size = block_size*2
      model.add(Dense(block_size))
      model.add(LeakyReLU(alpha=0.2))
      model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(self.W*self.H*self.C, activation='tanh'))
    model.add(Reshape((self.W, self.H, self.C)))

    return model

  def summary(self):
    #prints the summary of the model to the screen
    return self.Generator.summary()

  def save_model(self):
    #saves the model structure to a file in the data folder
    plot_model(self.Discriminator.model, to_file='/content/GAN_100/Generator_model.png')


In [4]:
class GAN(object):
  def __init__(self, discriminator, generator):
    #initialize variables
    self.OPTIMIZER = Adam(lr=0.0002, decay=8e-9)
    self.Generator = generator
    self.Discriminator = discriminator
    self.Discriminator.trainable = False
    self.gan_model = self.model()
    self.gan_model.compile(loss='binary_crossentropy',optimizer=self.OPTIMIZER)
    self.gan_model.summary()

  def model(self):
    #build the adversarial model and return it
    model = Sequential()

    model.add(self.Generator)
    model.add(self.Discriminator)

    return model

  def summary(self):
    # prints the model summary to the screen
    return self.gan_model.summary()

  def save_model(self):
    #saves the model structure to a file
    plot_model(self.gan_model, to_file='/content/GAN_100/GAN_model.png')


In [5]:
%matplotlib inline
from keras.datasets import mnist
import matplotlib.pyplot as plt

In [11]:
class Trainer:
  def __init__(self, width=28, height=28, channels=1, latent_size=100, epochs=50000, batch=32, checkpoint=50, model_type=1):
    self.W = width
    self.H = height
    self.C = channels

    self.EPOCHS = epochs
    self.BATCH = batch
    self.CHECKPOINT =checkpoint
    self.model_type = model_type

    self.LATENT_SPACE_SIZE = latent_size

    self.generator = Generator(height=self.H, width=self.W, channels=self.C, latent_size=self.LATENT_SPACE_SIZE)
    self.discriminator = Discriminator(height=self.H, width=self.W, channels=self.C)
    self.gan = GAN(generator=self.generator.Generator, discriminator=self.discriminator.Discriminator)

    self.load_MNIST()

  def load_MNIST(self, model_type=3):
    allowed_types = [-1,0,1,2,3,4,5,6,7,8,9]
    if self.model_type not in allowed_types:
      print('Error: Only Integer values from -1 to 9 are allowed')

    (self.X_train, self.Y_train), (_, _) = mnist.load_data()
    if self.model_type!=-1:
      self.X_train = self.X_train[np.where(self.Y_train==int(self.model_type))[0]]

    self.X_train = (np.float32(self.X_train) - 127.5) / 127.5
    self.X_train = np.expand_dims(self.X_train, axis=3)
    return

  def train(self):
    for e in range(self.EPOCHS):
      #grab a batch
      count_real_images = int(self.BATCH/2)
      starting_index = np.random.randint(0, (len(self.X_train)-count_real_images))
      real_images_raw = self.X_train[starting_index:(starting_index+count_real_images)]
      x_real_images = real_images_raw.reshape(count_real_images, self.W, self.H, self.C)
      y_real_labels = np.ones([count_real_images,1])

      #grab a generated images for this training batch
      latent_space_samples = self.sample_latent_space(count_real_images)
      x_generated_images = self.generator.Generator.predict(latent_space_samples)
      y_generated_labels = np.zeros([self.BATCH-count_real_images,1])

      #combine to train on the discriminator
      x_batch = np.concatenate([x_real_images, x_generated_images])
      y_batch = np.concatenate([y_real_labels, y_generated_labels])

      #Now, train the discriminator with this batch
      discriminator_loss = self.discriminator.Discriminator.train_on_batch(x_batch,y_batch)[0]

      #generate noise
      x_latent_space_samples = self.sample_latent_space(self.BATCH)
      y_generated_labels = np.ones([self.BATCH,1])
      generator_loss = self.gan.gan_model.train_on_batch(x_latent_space_samples, y_generated_labels)

      print('Epoch: '+str(int(e))+', [Discriminator :: Loss: '+str(discriminator_loss)+'], [ Generator :: Loss: '+str(generator_loss)+']')
      if e % self.CHECKPOINT == 0:
        self.plot_checkpoint(e)
    return

  def sample_latent_space(self, instances):
    return np.random.normal(0,1,(instances,self.LATENT_SPACE_SIZE))

  def plot_checkpoint(self,e):
    filename = '/content/GAN_100/sample_'+str(e)+'.png'
    noise = self.sample_latent_space(16)
    images = self.generator.Generator.predict(noise)

    plt.figure(figsize=(10,10))
    for i in range(images.shape[0]):
      plt.subplot(4,4,i+1)
      image = images[i,:,:,:]
      image = np.reshape(image, [self.H, self.W])
      plt.imshow(image, cmap='gray')
      plt.axis('off')
      plt.tight_layout()
    plt.savefig(filename)
    plt.close('all')
    return

In [14]:
HEIGHT = 28
WIDTH = 28
CHANNEL = 1
LATENT_SPACE_SIZE = 100
EPOCHS = 40000
BATCH = 32
CHECKPOINT = 500
MODEL_TYPE = -1

In [15]:
trainer = Trainer(height=HEIGHT, width=WIDTH, channels=CHANNEL, latent_size=LATENT_SPACE_SIZE, epochs=EPOCHS, batch=BATCH,
                  checkpoint=CHECKPOINT, model_type=MODEL_TYPE)

Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_16 (Dense)            (None, 128)               12928     
                                                                 
 leaky_re_lu_12 (LeakyReLU)  (None, 128)               0         
                                                                 
 batch_normalization_8 (Batc  (None, 128)              512       
 hNormalization)                                                 
                                                                 
 dense_17 (Dense)            (None, 256)               33024     
                                                                 
 leaky_re_lu_13 (LeakyReLU)  (None, 256)               0         
                                                                 
 batch_normalization_9 (Batc  (None, 256)              1024      
 hNormalization)                                      

In [16]:
trainer.train()

Epoch: 32144, [Discriminator :: Loss: 0.3806874454021454], [ Generator :: Loss: 1.899539589881897]
Epoch: 32145, [Discriminator :: Loss: 0.20215913653373718], [ Generator :: Loss: 2.0367908477783203]
Epoch: 32146, [Discriminator :: Loss: 0.5031000375747681], [ Generator :: Loss: 1.3483779430389404]
Epoch: 32147, [Discriminator :: Loss: 0.5605377554893494], [ Generator :: Loss: 1.395275592803955]
Epoch: 32148, [Discriminator :: Loss: 0.4069874882698059], [ Generator :: Loss: 1.3901937007904053]
Epoch: 32149, [Discriminator :: Loss: 0.5948561429977417], [ Generator :: Loss: 1.5265759229660034]
Epoch: 32150, [Discriminator :: Loss: 0.42391619086265564], [ Generator :: Loss: 1.8759994506835938]
Epoch: 32151, [Discriminator :: Loss: 0.3222929835319519], [ Generator :: Loss: 2.0250778198242188]
Epoch: 32152, [Discriminator :: Loss: 0.32820451259613037], [ Generator :: Loss: 2.3267173767089844]
Epoch: 32153, [Discriminator :: Loss: 0.28307604789733887], [ Generator :: Loss: 2.0048184394836426

KeyboardInterrupt: ignored