In [None]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD

import numpy as np
import matplotlib.pyplot as plt

import os

img_path = "./img_save/CNN_adam/"
plot_path = "./plot_save/"

class GAN():
    
    dis_type = 0
    debug = True
    

    def __init__(self):
        self.img_row = 28
        self.img_col = 28
        self.img_channel = 1
        self.img_shape = (self.img_row, self.img_col, self.img_channel)
        self.latent_dim = 100
        
        # input noisy layer
        z = Input(shape=(self.latent_dim,))
        
        # generator network
        self.genNet = self.generator_build()
        
        # generate image from generator network
        imgGenerate = self.genNet(z)
        
        # discriminator network
        if self.dis_type == 0:
            self.disNet = self.discriminator_build_fc()
        elif self.dis_type == 1:
            self.disNet = self.discriminator_build_cnn()
        
        # only train the generator network for whole model training stage
        self.disNet.trainable = False
        
        if self.debug == True:
            print(imgGenerate.shape)
        
        # output layer
        label = self.disNet(imgGenerate)

        # the GAN model network
        self.modelNet = Model(z, label)
        
         # set ganNet compiling
        optimizer = Adam(lr=0.0002, beta_1=0.5)
        self.modelNet.compile(loss='binary_crossentropy',
                            optimizer=optimizer,
                           metrics=['accuracy'])

    def generator_build(self):

        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)
                  
    def discriminator_build_cnn(self, optimizer_type='adam'):
        
        model = Sequential()
        
        model.add(Conv2D(32, (3, 3), strides=(1, 1)))
        model.add(LeakyReLU(alpha=0.2))
        
        model.add(Conv2D(64, (3, 3), strides=(1, 1)))
        model.add(LeakyReLU(alpha=0.2))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(64, activation='relu'))
        model.add(Dropout(0.25))
        model.add(Dense(1, activation='relu'))
        
        # input & output layer
        img = Input(shape=self.img_shape)
        label = model(img)
      
        disNet = Model(img, label)
        
#         optimizer = Adam(lr=0.00005, beta_1=0.9)
        optimizer = SGD(lr=0.0001, decay=1e-7, momentum=0.9, nesterov=True)
        disNet.compile(loss='binary_crossentropy', 
                      optimizer=optimizer,
                     metrics=['accuracy'])
          
        model.summary()
        
        return disNet

    def discriminator_build_fc(self):

        model = Sequential()

        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        label = model(img)
        
        disNet = Model(img, label)
        
        optimizer = Adam(lr=0.0002, beta_1=0.5)
        disNet.compile(loss='binary_crossentropy', 
                      optimizer=optimizer,
                     metrics=['accuracy'])
        
        disNet.summary()
        
        return disNet
    
    
    def save_sampling(self, epoch):
        
        row, col = 5, 5
        noise = np.random.random((row * col, self.latent_dim))
        img_gens = self.genNet.predict(noise)
        img_gens = img_gens.reshape((len(img_gens), self.img_row, self.img_col))
        
        fig, axs = plt.subplots(row, col)
        
        ctr = 0
        for r in range(row):
            for c in range(col):
                axs[r, c].imshow(img_gens[ctr], cmap='gray')
                axs[r, c].axis('off')
                ctr += 1
                
        if not os.path.isdir(img_path):
            os.mkdir(img_path)
            print("make new", img_path, "path!")
    
        fig.savefig(img_path+"%d.png" % epoch)
        plt.close()
            
    def graph_summary(self, loss_dis, acc_dis, loss_gen, acc_gen):
    
        a, b = 2, 1
        fig, axs = plt.subplots(2, 1, figsize=(20, 10))

        axs[0].plot(loss_dis, color='red', label='loss_dis')
        axs[0].plot(loss_gen, color='orange', label='loss_gen')
        axs[0].set(xlabel='Epoch', ylabel='Loss', title="Loss of Discriminator & Generator")
        axs[0].legend()

        axs[1].plot(acc_dis, color='blue', label='acc_dis')
        axs[1].plot(acc_gen, color='green', label='acc_gen')
        axs[1].set(xlabel='Epoch', ylabel='Acc', title="Accuracy of Discriminator & Generator")
        axs[1].legend()

        if not os.path.isdir(plot_path):
            os.mkdir(plot_path)
            print("make new", plot_path, "path!")
        fig.savefig(plot_path+"fig.png")


    def build_poolSample(self, batch_size):

        # load dataset
        (real_pool, _), (_, _) = mnist.load_data()
        
        # rescale 0 to 1 & increase dimision
        real_pool = real_pool / 255.
        
        # expand dimension to (num, row, col, 1)
        real_pool = np.expand_dims(real_pool, axis=3)
        
        # build adversarial label
        label_positive = np.ones((batch_size, 1))
        label_negative = np.zeros((batch_size, 1))
        label_pool = np.concatenate((label_positive, label_negative))

        # build index to be shuffled for each epoch
        maxn = max(len(real_pool), batch_size)
        index = [x for x in range(len(real_pool))]
        
        if self.debug == True:
            print("index.len:", len(index))
            
        return real_pool, label_pool, index
        
    
    def build_trainSample(self, real_pool, label_pool, index, batch_size=128, shuffle=True):
        
        # build noise
        noise = np.random.random((batch_size, self.latent_dim))

        # build fake img
        fake_img = self.genNet.predict(noise)

        # build random index of real_pool
        index_epoch = np.random.choice(index, size=batch_size, replace=False)

        if self.debug is True:
            print("index_epoch.shape:", index_epoch.shape)

        # build pos & neg train sample
        # X_train_positive = real_pool[index_epoch]
        X_train_positive = real_pool[[0]*batch_size]
        X_train_negative = fake_img

        X_train_epoch = np.concatenate((X_train_positive, X_train_negative))

        if self.debug is True:
            print("X_train_positive.shape:", X_train_positive.shape)
            print("X_train_negative.shape:", X_train_negative.shape)
            print("X_train_epoch.shape:", X_train_epoch.shape)

        # shuffle the train samples

        if shuffle is True:
            zipped = list(zip(X_train_epoch, label_pool))
            np.random.shuffle(zipped)
            X_train, Y_train = map(np.asarray, zip(*zipped))
        else:
            X_train, Y_train = X_train_epoch, label_pool

        if self.debug is True:
            print("X_train.shape:", X_train.shape)
            print("Y_train.shape:", Y_train.shape)
            self.debug = False   
            
        return noise, X_train, Y_train
        
 
    def train_model(self, epochs, batch_size=128, trainGen_interval=5, sample_interval=20, save_interval=200, shuffle=True):
        
        batch_size = batch_size // 2
        
        # build sample pool
        real_pool, label_pool, index = self.build_poolSample(batch_size)
        
        loss_dis = []
        acc_dis = []
        loss_mod = []
        acc_mod = []

        for epoch in range(epochs):
            
            # build training sample
            noise, X_train, Y_train = self.build_trainSample(real_pool, label_pool, index, batch_size, shuffle)
            
            # train the discriminator net
            l_dis, a_dis = self.disNet.train_on_batch(X_train, Y_train)

            # train the whole net with discriminator network untrainable
            l_mod, a_mod = self.modelNet.train_on_batch(noise, Y_train[:batch_size])
            
            # print current property
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, acc: %.2f%%]" % (epoch, l_dis, a_dis * 100., l_mod, a_mod * 100))
            
            if epoch % sample_interval == 0:
                loss_dis.append(l_dis)
                acc_dis.append(a_dis)
                loss_mod.append(l_mod)
                acc_mod.append(a_mod)
                
            # save image of training procedure
            if epoch % save_interval == 0:
                self.save_sampling(epoch)
        
        # show losses and accuracies figure
        self.graph_summary(loss_dis, acc_dis, loss_mod, acc_mod)

In [None]:
if __name__ == "__main__":
    gan = GAN()
    gan.train_model(2000, batch_size=64, sample_interval=20, save_interval=200)