In [None]:
#-*- coding: utf-8 -*-

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

#이미지가 저장될 폴더가 없다면 만듭니다.
import os
if not os.path.exists("./gan_images"):
    os.makedirs("./gan_images")

np.random.seed(3)
tf.random.set_seed(3)

#생성자 모델을 만듭니다.
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))

#판별자 모델을 만듭니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False

#생성자와 판별자 모델을 연결시키는 gan 모델을 만듭니다.
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

#신경망을 실행시키는 함수를 만듭니다.
def gan_train(epoch, batch_size, saving_interval):

  # MNIST 데이터 불러오기

  (X_train, _), (_, _) = mnist.load_data()  # 앞서 불러온 적 있는 MNIST를 다시 이용합니다. 단, 테스트과정은 필요없고 이미지만 사용할 것이기 때문에 X_train만 불러왔습니다.
  X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
  X_train = (X_train - 127.5) / 127.5  # 픽셀값은 0에서 255사이의 값입니다. 이전에 255로 나누어 줄때는 이를 0~1사이의 값으로 바꾸었던 것인데, 여기서는 127.5를 빼준 뒤 127.5로 나누어 줌으로 인해 -1에서 1사이의 값으로 바뀌게 됩니다.
  #X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for i in range(epoch):
          # 실제 데이터를 판별자에 입력하는 부분입니다.
          idx = np.random.randint(0, X_train.shape[0], batch_size)
          imgs = X_train[idx]
          d_loss_real = discriminator.train_on_batch(imgs, true)

          #가상 이미지를 판별자에 입력하는 부분입니다.
          noise = np.random.normal(0, 1, (batch_size, 100))
          gen_imgs = generator.predict(noise)
          d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

          #판별자와 생성자의 오차를 계산합니다.
          d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
          g_loss = gan.train_on_batch(noise, true)

          print('epoch:%d' % i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)

        # 이부분은 중간 과정을 이미지로 저장해 주는 부분입니다. 본 장의 주요 내용과 관련이 없어
        # 소스코드만 첨부합니다. 만들어진 이미지들은 gan_images 폴더에 저장됩니다.
          if i % saving_interval == 0:
              #r, c = 5, 5
              noise = np.random.normal(0, 1, (25, 100))
              gen_imgs = generator.predict(noise)

              # Rescale images 0 - 1
              gen_imgs = 0.5 * gen_imgs + 0.5

              fig, axs = plt.subplots(5, 5)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("gan_images/gan_mnist_%d.png" % i)

gan_train(4001, 32, 200)  #4000번 반복되고(+1을 해 주는 것에 주의), 배치 사이즈는 32,  200번 마다 결과가 저장되게 하였습니다.


In C:\Users\A\Anaconda3\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The text.latex.unicode rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2.
In C:\Users\A\Anaconda3\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The savefig.frameon rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In C:\Users\A\Anaconda3\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The pgf.debug rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2.
In C:\Users\A\Anaconda3\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The verbose.level rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.
In C:\Users\A\Anaconda3\lib\site-packages\matplotlib\mpl-data\stylelib\_classic_test.mplstyle: 
The verbose.fileo rcparam was deprecated in Matplotlib 3.1 and will be removed in 3.3.


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         865281    
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 212865    
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________
epoch:0  d_loss:0.7053  g_loss:0.6912
epoch:1  d_loss:0.4624  g_loss:0.3282
epoch:2  d_loss:0.5686  g_loss:0.1025
epoch:3  d_loss:0.6585  g_loss:0.0836
epoch:4  d_loss:0.5769  g_loss:0.1651
epoch:5  d_loss:0.5088  g_loss:0.4211
epoch:6  d_loss:0.4875  g_loss:0.6973
epoch:7  d_loss:0.5091  g_loss:0.8700
epoch:8  d_loss:0.4967  g_loss:0.8962
epoch:9  d_loss:0.4865  g_loss:0.87

epoch:189  d_loss:0.4662  g_loss:1.2892
epoch:190  d_loss:0.4029  g_loss:1.8643
epoch:191  d_loss:0.5575  g_loss:1.7463
epoch:192  d_loss:0.5555  g_loss:1.6062
epoch:193  d_loss:0.5942  g_loss:1.3192
epoch:194  d_loss:0.5378  g_loss:1.2569
epoch:195  d_loss:0.5349  g_loss:1.1366
epoch:196  d_loss:0.6193  g_loss:1.1818
epoch:197  d_loss:0.3990  g_loss:1.6045
epoch:198  d_loss:0.5498  g_loss:1.7734
epoch:199  d_loss:0.6170  g_loss:1.5667
epoch:200  d_loss:0.5853  g_loss:1.3204
epoch:201  d_loss:0.5179  g_loss:1.1256
epoch:202  d_loss:0.6931  g_loss:1.1956
epoch:203  d_loss:0.3751  g_loss:1.9254
epoch:204  d_loss:0.5199  g_loss:1.6857
epoch:205  d_loss:0.5035  g_loss:1.6165
epoch:206  d_loss:0.5111  g_loss:1.8119
epoch:207  d_loss:0.6063  g_loss:1.4648
epoch:208  d_loss:0.6251  g_loss:1.6088
epoch:209  d_loss:0.4465  g_loss:1.4735
epoch:210  d_loss:0.5211  g_loss:1.6889
epoch:211  d_loss:0.4916  g_loss:1.9064
epoch:212  d_loss:0.4926  g_loss:1.7887
epoch:213  d_loss:0.6706  g_loss:1.5994


epoch:394  d_loss:0.5178  g_loss:1.4697
epoch:395  d_loss:0.4546  g_loss:2.3090
epoch:396  d_loss:0.5045  g_loss:2.0336
epoch:397  d_loss:0.5008  g_loss:2.3218
epoch:398  d_loss:0.6389  g_loss:1.3246
epoch:399  d_loss:0.6280  g_loss:1.2662
epoch:400  d_loss:0.4974  g_loss:1.2156
epoch:401  d_loss:0.3518  g_loss:1.6166
epoch:402  d_loss:0.4543  g_loss:1.7706
epoch:403  d_loss:0.4132  g_loss:2.0090
epoch:404  d_loss:0.4532  g_loss:1.7882
epoch:405  d_loss:0.4102  g_loss:1.6244
epoch:406  d_loss:0.3922  g_loss:1.3669
epoch:407  d_loss:0.4105  g_loss:1.7943
epoch:408  d_loss:0.3884  g_loss:2.1429
epoch:409  d_loss:0.3125  g_loss:1.9239
epoch:410  d_loss:0.3952  g_loss:2.3543
epoch:411  d_loss:0.3586  g_loss:2.1128
epoch:412  d_loss:0.3397  g_loss:1.9871
epoch:413  d_loss:0.4289  g_loss:1.7101
epoch:414  d_loss:0.4648  g_loss:1.3635
epoch:415  d_loss:0.4161  g_loss:1.9510
epoch:416  d_loss:0.4219  g_loss:1.9064
epoch:417  d_loss:0.2840  g_loss:2.2133
epoch:418  d_loss:0.3534  g_loss:1.9230


epoch:599  d_loss:0.5542  g_loss:1.3819
epoch:600  d_loss:0.6859  g_loss:2.0115
epoch:601  d_loss:0.5974  g_loss:2.6352
epoch:602  d_loss:0.9172  g_loss:2.0357
epoch:603  d_loss:0.5759  g_loss:1.9354
epoch:604  d_loss:0.6454  g_loss:1.6180
epoch:605  d_loss:0.3615  g_loss:1.8853
epoch:606  d_loss:0.5067  g_loss:1.4537
epoch:607  d_loss:0.3836  g_loss:2.3132
epoch:608  d_loss:0.4193  g_loss:2.2239
epoch:609  d_loss:0.3683  g_loss:2.6817
epoch:610  d_loss:0.4066  g_loss:2.1245
epoch:611  d_loss:0.2786  g_loss:2.2737
epoch:612  d_loss:0.2191  g_loss:2.5533
epoch:613  d_loss:0.3216  g_loss:2.0510
epoch:614  d_loss:0.2929  g_loss:2.2297
epoch:615  d_loss:0.2166  g_loss:2.2309
epoch:616  d_loss:0.1716  g_loss:2.1411
epoch:617  d_loss:0.4128  g_loss:2.3478
epoch:618  d_loss:0.3391  g_loss:2.4204
epoch:619  d_loss:0.2164  g_loss:3.2180
epoch:620  d_loss:0.2164  g_loss:3.3598
epoch:621  d_loss:0.3331  g_loss:2.9860
epoch:622  d_loss:0.3674  g_loss:2.6091
epoch:623  d_loss:0.3383  g_loss:2.1232


epoch:804  d_loss:0.3791  g_loss:2.4253
epoch:805  d_loss:0.5776  g_loss:2.1087
epoch:806  d_loss:0.5339  g_loss:2.2544
epoch:807  d_loss:0.4128  g_loss:1.8349
epoch:808  d_loss:0.5209  g_loss:1.4052
epoch:809  d_loss:0.5592  g_loss:1.8980
epoch:810  d_loss:0.5336  g_loss:1.6750
epoch:811  d_loss:0.6404  g_loss:1.9071
epoch:812  d_loss:0.6848  g_loss:1.6961
epoch:813  d_loss:0.5687  g_loss:1.4651
epoch:814  d_loss:0.6161  g_loss:1.2747
epoch:815  d_loss:0.6751  g_loss:1.4227
epoch:816  d_loss:0.5181  g_loss:1.8526
epoch:817  d_loss:0.3872  g_loss:1.8631
epoch:818  d_loss:0.4359  g_loss:2.1548
epoch:819  d_loss:0.5572  g_loss:1.9283
epoch:820  d_loss:0.4831  g_loss:1.4516
epoch:821  d_loss:0.4797  g_loss:1.4665
epoch:822  d_loss:0.4428  g_loss:1.4166
epoch:823  d_loss:0.5540  g_loss:1.5976
epoch:824  d_loss:0.4037  g_loss:2.0264
epoch:825  d_loss:0.3334  g_loss:2.1953
epoch:826  d_loss:0.4201  g_loss:2.1403
epoch:827  d_loss:0.3934  g_loss:2.0322
epoch:828  d_loss:0.3816  g_loss:2.2354
