<a href="https://colab.research.google.com/github/NikuDubenco/code_replications/blob/master/GAN_coded_in_Keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## GAN coded in Keras

In [0]:
class GAN():
  def __init__(self):
    self.img_rows = 28
    self.img_cols = 28
    self.channels = 1
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    
    optimizer = Adam(.0002, .5)

    # build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy',
                               optimizer=optimizer,
                               metrics=['accuracy'])

    # build and compile the generator
    self.generator = self.build_generator()
    self.generator.compile(loss='binary_crossentropy',
                           optimizer=optimizer)

    # the generator takes noise as input and generated imgs
    z = Input(shape=(100,))
    img = self.generator(z)

    # for the combined model we will only train the generator
    self.discriminator.trainable = False

    # the valid takes generated images as input and determines validity
    valid = self.discriminator(img)

    # the combined model (stacked generator and discriminator) takes
    # noise as input => generates images => determines validity
    self.combined = Model(z, valid)
    self.combined.compile(loss='binary_crossentropy',
                          optimizer=optimizer)
  
  
  def build_generator(self):
    noise_shape = (100,)
    
    model = Sequential()
    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=.2))
    model.add(BatchNormalization(momentum=.8))
    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))
    
    model.summary()
    
    noise = Input(shape=noise_shape)
    img = model(noise)
    
    return Model(noise, img)
  
  
  def build_discriminator(self):
    img_shape = (self.img_rows, self.img_cols, self.channels)
    
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=.2))
    model.add(Dense(1, activation='sigmoid'))
    
    model.summary()
    
    img = Input(shape=img_shape)
    validity = model(img)
    
    return Model(img, validity)
  
  
  def train(self, epochs, batch_size=128, save_interval=50):
    # load the dataset
    (X_train, _), (_, _) = mnist.load_data()
    
    # rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)
    
    half_batch = int(batch_size / 2)
    
    for epoch in range(epochs):
      # train discriminator --------------------------------
      
      # select a random half batch of images
      idx = np.random.randint(0, X_train.shape[0], half_batch)
      imgs = X_train[idx]
      
      noise = np.random.normal(0, 1, (half_batch, 100))
      
      # generate a half batch of new images
      gen_imgs = self.generator.predict(noise)
      
      # train the discriminator
      d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
      d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
      d_loss = .5 * np.add(d_loss_real, d_loss_fake)
      
      # train generator --------------------------------------
      noise = np.random.normal(0, 1, (batch_size, 100))
      
      # the generator wants the discriminator to label the generated samples as valid (ones)
      valid_y = np.array([1] * batch_size)
      
      # train the generator
      g_loss = self.combined.train_on_batch(noise, valid_y)
      
      # plot the progress
      print('%d [D loss: %f, acc.: %.2f%%] [G loss: %f]' % (epoch, d_loss[0], 
                                                            100 * d_loss[1],
                                                            g_loss))
      
      
    def save_imgs(self, epoch):
      r, c = 5, 5
      noise = np.random.normal(0, 1, (r*c, 100))
      gen_imgs = self.imgs = self.generator.predict(noise)
      
      # rescale images 0 - 1
      gen_imgs = .5 * gen_imgs + .5
      
      fig, axs = plt.subplots(r, c)
      cnt = 0
      for i in range(r):
        for j in range(c):
          axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
          axs[i,j].axis('off')
          cnt += 1
      fig.savefig('gan/images/mnist_%d.png' % epoch)
      plt.close()
      
      
if __name__ == '__main__':
  gan = GAN()
  gan.train(epochs=30000, batch_size=32, save_interval=200)
     