In [None]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import numpy as np
import matplotlib.pyplot as plt

import os

save_path = "./img_save/"
class GAN():
    
    def __init__(self):
        
        self.img_row = 28
        self.img_col = 28
        self.img_channel = 1
        self.img_shape = (self.img_row, self.img_col, self.img_channel)
        self.latent_dim = 100
        
        # input noisy layer
        z = Input(shape=(self.latent_dim,))
        
        # generator network
        self.genNet = self.generator_build()
        
        # generate image from generator network
        imgGenerate = self.genNet(z)
        
        # discriminator network
        self.disNet = self.discriminator_build()
        
        # only train the generator network for whole model training stage
        self.disNet.trainable = False
        
#         optimizer = Adam(lr=0.001, beta_1=0.9)
#         self.disNet.compile(loss='binary_crossentropy', 
#                       optimizer=optimizer,
#                      metrics=['accuracy'])
        
        print(imgGenerate.shape)
        
        # output layer
        label = self.disNet(imgGenerate)

        # the GAN model network
        self.modelNet = Model(z, label)
        
         # set ganNet compiling
        optimizer = Adam(lr=0.001, beta_1=0.9)
        self.modelNet.compile(loss='binary_crossentropy',
                            optimizer=optimizer,
                           metrics=['accuracy'])
        
    def generator_build(self):
        
        model = Sequential()
        
        model.add(Dense(256, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(512, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(1024, activation='relu'))
        model.add(BatchNormalization(momentum=0.8))
        
        model.add(Dense(np.prod(self.img_shape), activation='relu'))
        model.add(Reshape(self.img_shape))
        
#         model.add(Conv2())  ?
        
        # input & output layer
        noise = Input(shape=(self.latent_dim,))
        imgGenerate = model(noise)
        
#         model.summary()

        return Model(noise, imgGenerate)

                  
    def discriminator_build(self):
        
        model = Sequential()
        
        model.add(Conv2D(64, (3, 3), strides=(1, 1)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
#         model.add(BatchNormalization(momentum=0.8))
        
        model.add(Conv2D(128, (3, 3), strides=(1, 1)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
#         model.add(BatchNormalization(momentum=0.8))

        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Dropout(0.25))

        model.add(Flatten())
        model.add(Dense(64, activation='relu'))
        model.add(Dropout(0.25))
        model.add(Dense(1, activation='relu'))
        
        # input & output layer
        img = Input(shape=self.img_shape)
        label = model(img)
      
        disNet = Model(img, label)
        
        optimizer = Adam(lr=0.001, beta_1=0.9)
        disNet.compile(loss='binary_crossentropy', 
                      optimizer=optimizer,
                     metrics=['accuracy'])
          
#         model.summary()
        
        return disNet
        
    def save_sampling(self, epoch):
        
        row, col = 5, 5
        noise = np.random.random((row * col, self.latent_dim))
        img_gens = self.genNet.predict(noise)
        img_gens = img_gens.reshape((len(img_gens), self.img_row, self.img_col))
        
        fig, axs = plt.subplots(row, col)
        
        ctr = 0
        for r in range(row):
            for c in range(col):
#                 print(img_gens[ctr].shape)
#                 plt.imshow(img_gens[ctr], cmap='gray')
#                 input()
                axs[r, c].imshow(img_gens[ctr], cmap='gray')
                axs[r, c].axis('off')
                ctr += 1
                
        if not os.path.isdir(save_path):
            os.mkdir(save_path)
            print("make new", save_path, "path!")   
    
        fig.savefig(save_path+"%d.png" % epoch)
        plt.close()
            
    def graph_summary(self, loss_dis, acc_dis, loss_gen, acc_gen):
    
        len_dis = len(loss_dis)
        len_gen = len(loss_gen)
        a, b = 2, 2
        fig, axs = plt.subplots(a, b)
        axs[0, 0] = plt.scatter(np.arange(len_dis), loss_dis)
#         axs[0, 0].plot(loss_dis)
        axs[0, 1] = plt.scatter(np.arange(len_dis), acc_dis)
        axs[1, 0] = plt.scatter(np.arange(len_gen), loss_gen)
        axs[0, 1] = plt.scatter(np.arange(len_gen), acc_gen)
        
        plt.show()

        
    
    def train_model(self, epochs, batch_size=128, trainGen_interval=5, sample_interval=50):
        
        batch_size = batch_size // 2
        
        # load dataset
        (real_pool, _), (_, _) = mnist.load_data()
        
        # rescale 0 to 1 & increase dimision
        real_pool = real_pool / 255.
        real_pool = np.expand_dims(real_pool, axis=3)
        
        # build adversarial label
        label_positive = np.ones((batch_size, 1))
        label_negative = np.zeros((batch_size, 1))
        Y_train = np.concatenate((label_positive, label_negative))
        
        # build index to be shuffled for each epoch
        maxnn = max(len(real_pool), batch_size)
        index = [x for x in range(maxnn)]
        
        loss_dis = []
        acc_dis = []
        loss_gen = []
        acc_gen = []

        for epoch in range(epochs):
            
            # build noise
            noise = np.random.random((batch_size, self.latent_dim))
            
            # build fake img
            fake_img = self.genNet.predict(noise)
            
            # build train sample for this epoch
            index_epoch = index[:batch_size]
            np.random.shuffle(index_epoch)
            X_train_positive = real_pool[index_epoch]
            X_train_negative = fake_img
            
            X_train_epoch = np.concatenate((X_train_positive,X_train_negative))
            
            zipped = list(zip(X_train_epoch, Y_train))
            np.random.shuffle(zipped)
            X_train, Y_train = map(np.asarray, zip(*zipped))
            
            # train the discriminator net
            l, a = self.disNet.train_on_batch(X_train, Y_train)
            loss_dis.append(l)
            acc_dis.append(a)
            
            # train the whole net with discriminator network untrainable
            l, a = self.modelNet.train_on_batch(noise, Y_train[:batch_size])
            loss_gen.append(l)
            acc_gen.append(a)
            
            if epoch == sample_interval:
                self.save_sampling(epoch)
        
        self.graph_summary(loss_dis, acc_dis, loss_gen, acc_gen)

In [None]:
if __name__ == "__main__":
    
    gan = GAN()
    gan.train_model(2000, batch_size=32, sample_interval = 5)