In [6]:
import numpy as np

# from IPython.core.debugger import Tracer

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

# import matplotlib as plt
import matplotlib.pyplot as plt 
# plt.switch_backend('agg')   # allows code to run without a system DISPLAY

In [7]:
class GAN(object):
    def __init__(self, width=28, height=28, channels=1):

        self.width = width
        self.height = height
        self.channels = channels

        self.shape = (self.width, self.height, self.channels)

        self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)

        self.G = self.__generator()
        self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)

        self.D = self.__discriminator()
        self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])

        self.stacked_generator_discriminator = self.__stacked_generator_discriminator()

        self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)

    def __generator(self):
        model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))  # 使用 LeakyReLU 激活函數
        model.add(BatchNormalization(momentum=0.8))  # 使用 BatchNormalization 優化
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
        model.add(Reshape((self.width, self.height, self.channels)))
        model.summary()

        return model
    def __discriminator(self):
        model = Sequential()
        model.add(Flatten(input_shape=self.shape))
        model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(int((self.width * self.height * self.channels)/2)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        return model

    def __stacked_generator_discriminator(self):

        self.D.trainable = False

        model = Sequential()
        model.add(self.G)
        model.add(self.D)

        return model

    def train(self, X_train, epochs=30000, batch = 32, save_interval = 100):

        for cnt in range(epochs):

            ## train discriminator
            random_index = np.random.randint(0, len(X_train) - batch/2)
            legit_images = X_train[random_index : random_index + int(batch/2)].reshape(int(batch/2), self.width, self.height, self.channels)

            gen_noise = np.random.normal(0, 1, (int(batch/2), 100)) 
            syntetic_images = self.G.predict(gen_noise)

            x_combined_batch = np.concatenate((legit_images, syntetic_images))
            y_combined_batch = np.concatenate((np.ones((int(batch/2), 1)), np.zeros((int(batch/2), 1))))

            d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)


            # train generator

            noise = np.random.normal(0, 1, (batch, 100))  # 添加高斯噪聲
            y_mislabled = np.ones((batch, 1))

            g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
            if cnt % save_interval == 0:
                print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
                self.stacked_generator_discriminator.save("model.h5")
                self.plot_images(save2file=True, step=cnt)
    def plot_images(self, save2file=False, samples=16, step=0):
        filename = "mnist_%s.png" % str(step).zfill(5)
        noise = np.random.normal(0, 1, (samples, 100))

        images = self.G.predict(noise)

        # plt.figure(figsize=(10, 10))

        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.height, self.width])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()

        if save2file:
            plt.savefig(filename)
            plt.close('all')
        else:
            plt.show()

In [None]:
if __name__ == '__main__':
    (X_train, _), (_, _) = mnist.load_data()

    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)


    gan = GAN()
    gan.train(X_train)

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_7 (Dense)              (None, 256)               25856     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 256)               1024      
_________________________________________________________________
dense_8 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 512)               2048      
_________________________________________________________________
dense_9 (Dense)              (None, 1024)             

In [None]:
!pwd
!tar cvf pic.tar *.png

/content
mnist_00000.png
mnist_00100.png
mnist_00200.png
mnist_00300.png
mnist_00400.png
mnist_00500.png
mnist_00600.png
mnist_00700.png
mnist_00800.png
mnist_00900.png
mnist_01000.png
mnist_01100.png
mnist_01200.png
mnist_01300.png
mnist_01400.png
mnist_01500.png
mnist_01600.png
mnist_01700.png
mnist_01800.png
mnist_01900.png
mnist_02000.png
mnist_02100.png
mnist_02200.png
mnist_02300.png
mnist_02400.png
mnist_02500.png
mnist_02600.png
mnist_02700.png
mnist_02800.png
mnist_02900.png
mnist_03000.png
mnist_03100.png
mnist_03200.png
mnist_03300.png
mnist_03400.png
mnist_03500.png
mnist_03600.png
mnist_03700.png
mnist_03800.png
mnist_03900.png
mnist_04000.png
mnist_04100.png
mnist_04200.png
mnist_04300.png
mnist_04400.png
mnist_04500.png
mnist_04600.png
mnist_04700.png
mnist_04800.png
mnist_04900.png
mnist_05000.png
mnist_05100.png
mnist_05200.png
mnist_05300.png
mnist_05400.png
mnist_05500.png
mnist_05600.png
mnist_05700.png
mnist_05800.png
mnist_05900.png
mnist_06000.png
mnist_06100.png

In [None]:
!ls pic.tar

pic.tar
