In [1]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

#이미지가 저장될 폴더가 없다면 만듭니다.
import os
if not os.path.exists("./gan_images"):
    os.makedirs("./gan_images")

np.random.seed(3)
tf.random.set_seed(3)

In [2]:
#생성자 모델을 만듭니다.
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))

In [3]:
#판별자 모델을 만듭니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False

In [4]:
#생성자와 판별자 모델을 연결시키는 gan 모델을 만듭니다.
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         865281    
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 212865    
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________


In [None]:
#신경망을 실행시키는 함수를 만듭니다.
def gan_train(epoch, batch_size, saving_interval):

  # MNIST 데이터 불러오기

  (X_train, _), (_, _) = mnist.load_data()  # 앞서 불러온 적 있는 MNIST를 다시 이용합니다. 단, 테스트과정은 필요없고 이미지만 사용할 것이기 때문에 X_train만 불러왔습니다.
  X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
  X_train = (X_train - 127.5) / 127.5  # 픽셀값은 0에서 255사이의 값입니다. 이전에 255로 나누어 줄때는 이를 0~1사이의 값으로 바꾸었던 것인데, 여기서는 127.5를 빼준 뒤 127.5로 나누어 줌으로 인해 -1에서 1사이의 값으로 바뀌게 됩니다.
  #X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for i in range(epoch):
          # 실제 데이터를 판별자에 입력하는 부분입니다.
          idx = np.random.randint(0, X_train.shape[0], batch_size)
          imgs = X_train[idx]
          d_loss_real = discriminator.train_on_batch(imgs, true)

          #가상 이미지를 판별자에 입력하는 부분입니다.
          noise = np.random.normal(0, 1, (batch_size, 100))
          gen_imgs = generator.predict(noise)
          d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

          #판별자와 생성자의 오차를 계산합니다.
          d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
          g_loss = gan.train_on_batch(noise, true)

          print('epoch:%d' % i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)

        # 이부분은 중간 과정을 이미지로 저장해 주는 부분입니다. 본 장의 주요 내용과 관련이 없어
        # 소스코드만 첨부합니다. 만들어진 이미지들은 gan_images 폴더에 저장됩니다.
          if i % saving_interval == 0:
              #r, c = 5, 5
              noise = np.random.normal(0, 1, (25, 100))
              gen_imgs = generator.predict(noise)

              # Rescale images 0 - 1
              gen_imgs = 0.5 * gen_imgs + 0.5

              fig, axs = plt.subplots(5, 5)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("gan_images/gan_mnist_%d.png" % i)

gan_train(4001, 32, 200)  #4000번 반복되고(+1을 해 주는 것에 주의), 배치 사이즈는 32,  200번 마다 결과가 저장되게 하였습니다.

epoch:0  d_loss:0.7053  g_loss:0.6912
epoch:1  d_loss:0.4617  g_loss:0.3278
epoch:2  d_loss:0.5693  g_loss:0.1024
epoch:3  d_loss:0.6593  g_loss:0.0837
epoch:4  d_loss:0.5770  g_loss:0.1641
epoch:5  d_loss:0.5092  g_loss:0.4206
epoch:6  d_loss:0.4873  g_loss:0.6972
epoch:7  d_loss:0.5090  g_loss:0.8721
epoch:8  d_loss:0.4976  g_loss:0.8954
epoch:9  d_loss:0.4877  g_loss:0.8695
epoch:10  d_loss:0.4186  g_loss:0.7792
epoch:11  d_loss:0.4463  g_loss:0.8067
epoch:12  d_loss:0.4999  g_loss:0.8511
epoch:13  d_loss:0.4883  g_loss:0.8900
epoch:14  d_loss:0.5858  g_loss:0.9510
epoch:15  d_loss:0.4796  g_loss:1.0425
epoch:16  d_loss:0.4165  g_loss:1.0446
epoch:17  d_loss:0.6310  g_loss:0.8198
epoch:18  d_loss:0.5576  g_loss:0.6647
epoch:19  d_loss:0.4578  g_loss:0.7245
epoch:20  d_loss:0.4308  g_loss:0.7659
epoch:21  d_loss:0.5038  g_loss:0.7922
epoch:22  d_loss:0.5036  g_loss:0.6539
epoch:23  d_loss:0.4523  g_loss:0.6055
epoch:24  d_loss:0.4401  g_loss:0.6659
epoch:25  d_loss:0.4545  g_loss:0.6

epoch:208  d_loss:0.4659  g_loss:2.0016
epoch:209  d_loss:0.2614  g_loss:2.0860
epoch:210  d_loss:0.4314  g_loss:2.1297
epoch:211  d_loss:0.4489  g_loss:1.7915
epoch:212  d_loss:0.4381  g_loss:1.8466
epoch:213  d_loss:0.5164  g_loss:2.1236
epoch:214  d_loss:0.4625  g_loss:1.9765
epoch:215  d_loss:0.4179  g_loss:1.7598
epoch:216  d_loss:0.6044  g_loss:1.8791
epoch:217  d_loss:0.4373  g_loss:1.6120
epoch:218  d_loss:0.4465  g_loss:1.6686
epoch:219  d_loss:0.5025  g_loss:1.5804
epoch:220  d_loss:0.4757  g_loss:1.5960
epoch:221  d_loss:0.7635  g_loss:1.3110
epoch:222  d_loss:0.5012  g_loss:1.4184
epoch:223  d_loss:0.5141  g_loss:1.9045
epoch:224  d_loss:0.6441  g_loss:1.6723
epoch:225  d_loss:0.5635  g_loss:1.5432
epoch:226  d_loss:0.4989  g_loss:1.7377
epoch:227  d_loss:0.6532  g_loss:1.3154
epoch:228  d_loss:0.5450  g_loss:1.5431
epoch:229  d_loss:0.5180  g_loss:1.2782
epoch:230  d_loss:0.6460  g_loss:1.7234
epoch:231  d_loss:0.6732  g_loss:1.9067
epoch:232  d_loss:0.7757  g_loss:1.7305


epoch:413  d_loss:0.3704  g_loss:1.8495
epoch:414  d_loss:0.5720  g_loss:1.5808
epoch:415  d_loss:0.3747  g_loss:1.8245
epoch:416  d_loss:0.4241  g_loss:1.9611
epoch:417  d_loss:0.3082  g_loss:2.3366
epoch:418  d_loss:0.3547  g_loss:2.6584
epoch:419  d_loss:0.3455  g_loss:2.3312
epoch:420  d_loss:0.2645  g_loss:2.2580
epoch:421  d_loss:0.4537  g_loss:2.0253
epoch:422  d_loss:0.3413  g_loss:2.3164
epoch:423  d_loss:0.4410  g_loss:2.3725
epoch:424  d_loss:0.3902  g_loss:2.5719
epoch:425  d_loss:0.3239  g_loss:2.0420
epoch:426  d_loss:0.2851  g_loss:2.1722
epoch:427  d_loss:0.3483  g_loss:2.1144
epoch:428  d_loss:0.3561  g_loss:2.2115
epoch:429  d_loss:0.3350  g_loss:2.8098
epoch:430  d_loss:0.3113  g_loss:2.9677
epoch:431  d_loss:0.2834  g_loss:2.5672
epoch:432  d_loss:0.2543  g_loss:2.9672
epoch:433  d_loss:0.3603  g_loss:2.6056
epoch:434  d_loss:0.4121  g_loss:2.4312
epoch:435  d_loss:0.5847  g_loss:2.1842
epoch:436  d_loss:0.2034  g_loss:2.7992
epoch:437  d_loss:0.3888  g_loss:2.9874
