In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))
tf.device(device_name)

Found GPU at: /device:GPU:0


<tensorflow.python.eager.context._EagerDeviceContext at 0x7f00c1bc2278>

In [3]:
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Flatten, Conv2D, Dense, Dropout, Reshape, Conv2DTranspose
from keras.models import Model
from keras.utils.vis_utils import plot_model

In [4]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

In [5]:
import os
import tensorflow as tf
import numpy as np
import cv2
import random
import scipy.misc
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm.notebook import tqdm
import cv2

In [6]:
#import data
data_path = "/content/drive/My Drive/Study/DL/GAN/data/"
output_path = "/content/drive/My Drive/Study/DL/GAN/output/"
img_list = os.listdir(data_path)

In [9]:
def load_data():
  # load image as pixel array
  data = []
  for i in tqdm(img_list):
    img = cv2.imread(data_path + i)
    img = cv2.resize(img, (128, 128))
    img = list(img)
    data.append(img)

  return np.array(data)

X_train = load_data()

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [50]:
class GAN():
  def __init__(self, X_train):

    self.X_train = X_train
    self.img_rows, self.img_cols, self.channels = self.X_train.shape[1:]
    self.img_shape = self.X_train.shape[1:]
    self.latent_dim = 100

    self.noise_rows = 8
    self.noise_cols = 8
    self.noise_channels = 3

    optimizer = Adam(0.0002, 0.5)

    # Build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    # Build the generator
    self.generator = self.build_generator()

    # The generator takes noise as input and generates imgs
    z = Input(shape=(self.latent_dim,))
    img = self.generator(z)

    # For the combined model we will only train the generator
    self.discriminator.trainable = False

    # The discriminator takes generated images as input and determines validity
    validity = self.discriminator(img)

    # The combined model  (stacked generator and discriminator)
    # Trains the generator to fool the discriminator
    self.combined = Model(z, validity)
    self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    self.g_loss = []
    self.d_loss = []

  def build_generator(self):

    model = Sequential()

    model.add(Dense(self.noise_rows*self.noise_cols*self.noise_channels, activation='relu'))
    model.add(Reshape((self.noise_rows, self.noise_cols, self.noise_channels)))
    model.add(Conv2DTranspose(filters=512, kernel_size=[5, 5], strides=[2, 2], padding="same", activation='relu'))
    model.add(Conv2DTranspose(filters=256, kernel_size=[5, 5], strides=[2, 2], padding="same", activation='relu'))
    model.add(Conv2DTranspose(filters=128, kernel_size=[5, 5], strides=[2, 2], padding="same", activation='relu'))
    model.add(Conv2DTranspose(filters=self.channels, kernel_size=[5, 5], strides=[2, 2], padding="same",activation='relu'))

    noise = Input(shape=(self.latent_dim,))
    img = model(noise)

    return Model(noise, img)


  def build_discriminator(self):

    model = Sequential()

    model.add(Conv2D(filters=64, kernel_size=[5, 5], strides=[2, 2], padding="SAME", kernel_initializer='glorot_uniform' ,activation='relu'))
    model.add(Conv2D(filters=128, kernel_size=[5, 5], strides=[2, 2], padding="SAME", activation='relu'))
    model.add(Conv2D(filters=256, kernel_size=[5, 5], strides=[2, 2], padding="SAME", activation='relu'))
    model.add(Conv2D(filters=512, kernel_size=[5, 5], strides=[2, 2], padding="SAME", activation='relu'))
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(1, activation='sigmoid'))

    img = Input(shape=self.img_shape)
    validity = model(img)

    return Model(img, validity)



  def train(self, epochs, batch_size=128, sample_interval=50):

    # Load the dataset
    X_train = self.X_train

    # Rescale 0 to 1
    X_train = X_train / 255
    #X_train = np.expand_dims(X_train, axis=3)

    # Adversarial ground truths
    valid = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for epoch in tqdm(range(epochs)):

      # ---------------------
      #  Train Discriminator
      # ---------------------

      # Select a random batch of images
      idx = np.random.randint(0, X_train.shape[0], batch_size)
      imgs = X_train[idx]

      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

      # Generate a batch of new images
      gen_imgs = self.generator.predict(noise)

      # Train the discriminator
      d_loss_real = self.discriminator.train_on_batch(imgs, valid)
      d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
      d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
      self.d_loss.append(d_loss)

      # ---------------------
      #  Train Generator
      # ---------------------

      noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

      # Train the generator (to have the discriminator label samples as valid)
      g_loss = self.combined.train_on_batch(noise, valid)
      self.g_loss.append(g_loss)

      # If at save interval => save generated image samples and Plot the progress
      if epoch % sample_interval == 0:
        self.sample_images(epoch)
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

  

  def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim)) #generating r*c random noise samples
        gen_imgs = self.generator.predict(noise)


        #clipping the output to range [0, 1 ]
        np.clip(gen_imgs, 0, 1, out=gen_imgs)


        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i,j].axis('off')
                cnt += 1

        fig = plt.gcf()
        fig.set_size_inches(18, 18)
        fig.tight_layout()
        fig.savefig(output_path + "%d.png" % epoch)
        plt.close()

In [51]:
if __name__ == '__main__':
    gan = GAN(X_train)
    gan.train(epochs=30000, batch_size=32, sample_interval=200)

HBox(children=(FloatProgress(value=0.0, max=30000.0), HTML(value='')))

0 [D loss: 0.703272, acc.: 0.00%] [G loss: 0.692785]
200 [D loss: 0.000326, acc.: 100.00%] [G loss: 7.430267]
400 [D loss: 0.187056, acc.: 96.88%] [G loss: 2.702280]
600 [D loss: 0.556788, acc.: 67.19%] [G loss: 1.302315]
800 [D loss: 0.649593, acc.: 57.81%] [G loss: 2.734946]
1000 [D loss: 0.163179, acc.: 95.31%] [G loss: 2.905136]
1200 [D loss: 0.517689, acc.: 76.56%] [G loss: 1.330114]
1400 [D loss: 0.255898, acc.: 98.44%] [G loss: 1.928577]
1600 [D loss: 0.157979, acc.: 96.88%] [G loss: 2.761956]
1800 [D loss: 0.161223, acc.: 100.00%] [G loss: 2.038683]
2000 [D loss: 1.717680, acc.: 35.94%] [G loss: 3.707345]
2200 [D loss: 0.052008, acc.: 98.44%] [G loss: 3.615581]
2400 [D loss: 0.116154, acc.: 96.88%] [G loss: 3.331873]
2600 [D loss: 0.558298, acc.: 76.56%] [G loss: 1.495312]
2800 [D loss: 0.187989, acc.: 92.19%] [G loss: 5.380606]
3000 [D loss: 0.065275, acc.: 98.44%] [G loss: 4.241436]
3200 [D loss: 0.041566, acc.: 100.00%] [G loss: 3.874578]
3400 [D loss: 0.224641, acc.: 98.44%

KeyboardInterrupt: ignored