In [None]:
#-*- coding: utf-8 -*-

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

#이미지가 저장될 폴더가 없다면 만듭니다. (이미지를 생성 할 때 마다 저장하기)
import os
if not os.path.exists("./gan_images"):
    os.makedirs("./gan_images")

np.random.seed(3)
tf.random.set_seed(3)

#생성자 모델을 만듭니다.
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2)))
generator.add(BatchNormalization())
generator.add(Reshape((7, 7, 128)))
generator.add(UpSampling2D())
generator.add(Conv2D(64, kernel_size=5, padding='same'))
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))

#판별자 모델을 만듭니다.
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False

#생성자와 판별자 모델을 연결시키는 gan 모델을 만듭니다.
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput))
gan = Model(ginput, dis_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

#신경망을 실행시키는 함수를 만듭니다.
def gan_train(epoch, batch_size, saving_interval):

  # MNIST 데이터 불러오기

  (X_train, _), (_, _) = mnist.load_data()  # 앞서 불러온 적 있는 MNIST를 다시 이용합니다. 단, 테스트과정은 필요없고 이미지만 사용할 것이기 때문에 X_train만 불러왔습니다.
  X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
  X_train = (X_train - 127.5) / 127.5  # 픽셀값은 0에서 255사이의 값입니다. 이전에 255로 나누어 줄때는 이를 0~1사이의 값으로 바꾸었던 것인데, 여기서는 127.5를 빼준 뒤 127.5로 나누어 줌으로 인해 -1에서 1사이의 값으로 바뀌게 됩니다.
  #X_train.shape, Y_train.shape, X_test.shape, Y_test.shape

  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for i in range(epoch):
          # 실제 데이터를 판별자에 입력하는 부분입니다.
          idx = np.random.randint(0, X_train.shape[0], batch_size)
          imgs = X_train[idx]
          d_loss_real = discriminator.train_on_batch(imgs, true)

          #가상 이미지를 판별자에 입력하는 부분입니다.
          noise = np.random.normal(0, 1, (batch_size, 100))
          gen_imgs = generator.predict(noise)
          d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

          #판별자와 생성자의 오차를 계산합니다.
          d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
          g_loss = gan.train_on_batch(noise, true)

          print('epoch:%d' % i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)

        # 이부분은 중간 과정을 이미지로 저장해 주는 부분입니다. 본 장의 주요 내용과 관련이 없어
        # 소스코드만 첨부합니다. 만들어진 이미지들은 gan_images 폴더에 저장됩니다.
          if i % saving_interval == 0:
              #r, c = 5, 5
              noise = np.random.normal(0, 1, (25, 100))
              gen_imgs = generator.predict(noise)

              # Rescale images 0 - 1
              gen_imgs = 0.5 * gen_imgs + 0.5

              fig, axs = plt.subplots(5, 5)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("gan_images/gan_mnist_%d.png" % i)

gan_train(4001, 32, 200)  #4000번 반복되고(+1을 해 주는 것에 주의), 배치 사이즈는 32,  200번 마다 결과가 저장되게 하였습니다.


Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 100)]             0         
_________________________________________________________________
sequential (Sequential)      (None, 28, 28, 1)         865281    
_________________________________________________________________
sequential_1 (Sequential)    (None, 1)                 212865    
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________
epoch:0  d_loss:0.7101  g_loss:0.6986
epoch:1  d_loss:0.4348  g_loss:0.3323
epoch:2  d_loss:0.5332  g_loss:0.1001
epoch:3  d_loss:0.7526  g_loss:0.0611
epoch:4  d_loss:0.6900  g_loss:0.1292
epoch:5  d_loss:0.5406  g_loss:0.3777
epoch:6  d_loss:0.4938  g_loss:0.7602
epoch:7  d_loss:0.5136  g_loss:0.9300
epoch:8  d_loss:0.5171  g_loss:1.0347
epoch:9  d_loss:0.4714  g_loss:0.98

epoch:189  d_loss:0.4637  g_loss:1.5764
epoch:190  d_loss:0.4224  g_loss:1.6979
epoch:191  d_loss:0.5082  g_loss:1.4673
epoch:192  d_loss:0.4191  g_loss:1.9144
epoch:193  d_loss:0.4829  g_loss:1.7104
epoch:194  d_loss:0.5730  g_loss:1.3834
epoch:195  d_loss:0.4788  g_loss:1.4562
epoch:196  d_loss:0.4230  g_loss:1.6893
epoch:197  d_loss:0.5470  g_loss:1.7690
epoch:198  d_loss:0.6082  g_loss:1.6508
epoch:199  d_loss:0.4954  g_loss:1.7604
epoch:200  d_loss:0.6700  g_loss:1.5292
epoch:201  d_loss:0.7016  g_loss:1.8097
epoch:202  d_loss:0.5170  g_loss:1.7146
epoch:203  d_loss:0.3877  g_loss:1.8057
epoch:204  d_loss:0.3659  g_loss:1.8382
epoch:205  d_loss:0.4335  g_loss:1.9323
epoch:206  d_loss:0.4016  g_loss:1.9957
epoch:207  d_loss:0.4534  g_loss:1.7417
epoch:208  d_loss:0.4842  g_loss:1.9819
epoch:209  d_loss:0.4537  g_loss:2.4148
epoch:210  d_loss:0.4209  g_loss:2.7010
epoch:211  d_loss:0.4199  g_loss:2.2708
epoch:212  d_loss:0.4943  g_loss:1.9082
epoch:213  d_loss:0.6050  g_loss:1.7624


epoch:394  d_loss:0.3253  g_loss:2.6947
epoch:395  d_loss:0.3020  g_loss:2.3101
epoch:396  d_loss:0.3444  g_loss:2.4914
epoch:397  d_loss:0.2350  g_loss:2.1339
epoch:398  d_loss:0.4001  g_loss:1.9881
epoch:399  d_loss:0.3565  g_loss:2.0522
epoch:400  d_loss:0.1958  g_loss:2.7600
epoch:401  d_loss:0.2936  g_loss:3.0516
epoch:402  d_loss:0.3371  g_loss:2.9259
epoch:403  d_loss:0.2753  g_loss:2.7366
epoch:404  d_loss:0.3394  g_loss:2.5247
epoch:405  d_loss:0.2851  g_loss:2.1581
epoch:406  d_loss:0.2967  g_loss:2.2373
epoch:407  d_loss:0.2988  g_loss:2.1534
epoch:408  d_loss:0.2096  g_loss:2.5564
epoch:409  d_loss:0.4301  g_loss:2.6206
epoch:410  d_loss:0.2855  g_loss:2.3584
epoch:411  d_loss:0.3827  g_loss:2.3540
epoch:412  d_loss:0.4157  g_loss:1.8347
epoch:413  d_loss:0.4139  g_loss:2.2096
epoch:414  d_loss:0.3987  g_loss:2.1810
epoch:415  d_loss:0.4739  g_loss:2.5484
epoch:416  d_loss:0.3712  g_loss:2.9468
epoch:417  d_loss:0.5261  g_loss:2.1291
epoch:418  d_loss:0.4104  g_loss:1.7595


epoch:599  d_loss:0.3823  g_loss:2.2019
epoch:600  d_loss:0.4252  g_loss:2.0761
epoch:601  d_loss:0.4706  g_loss:1.1524
epoch:602  d_loss:0.5723  g_loss:1.4273
epoch:603  d_loss:0.3842  g_loss:1.9712
epoch:604  d_loss:0.3747  g_loss:1.8383
epoch:605  d_loss:0.6018  g_loss:1.5016
epoch:606  d_loss:0.4642  g_loss:1.5166
epoch:607  d_loss:0.4373  g_loss:1.3455
epoch:608  d_loss:0.4681  g_loss:1.7327
epoch:609  d_loss:0.4287  g_loss:2.9110
epoch:610  d_loss:0.2708  g_loss:2.9007
epoch:611  d_loss:0.3373  g_loss:2.4249
epoch:612  d_loss:0.5315  g_loss:1.5955
epoch:613  d_loss:0.8984  g_loss:1.8191
epoch:614  d_loss:1.1900  g_loss:1.3283
epoch:615  d_loss:0.9430  g_loss:2.0688
epoch:616  d_loss:0.4376  g_loss:3.5174
epoch:617  d_loss:0.8031  g_loss:3.6323
epoch:618  d_loss:0.7345  g_loss:2.9927
epoch:619  d_loss:0.7787  g_loss:2.4889
epoch:620  d_loss:0.7621  g_loss:1.2560
epoch:621  d_loss:0.5095  g_loss:1.2415
epoch:622  d_loss:0.8562  g_loss:0.9358
epoch:623  d_loss:0.6231  g_loss:1.1426


epoch:804  d_loss:0.4436  g_loss:1.3748
epoch:805  d_loss:0.5602  g_loss:1.3375
epoch:806  d_loss:0.3797  g_loss:1.5228
epoch:807  d_loss:0.3954  g_loss:1.7928
epoch:808  d_loss:0.5096  g_loss:1.7117
epoch:809  d_loss:0.5043  g_loss:1.7685
epoch:810  d_loss:0.4319  g_loss:1.8627
epoch:811  d_loss:0.3595  g_loss:1.5751
epoch:812  d_loss:0.3971  g_loss:1.7138
epoch:813  d_loss:0.4426  g_loss:1.2824
epoch:814  d_loss:0.2865  g_loss:1.4140
epoch:815  d_loss:0.6511  g_loss:1.1255
epoch:816  d_loss:0.4986  g_loss:1.5061
epoch:817  d_loss:0.4229  g_loss:2.0103
epoch:818  d_loss:0.3622  g_loss:1.9417
epoch:819  d_loss:0.4283  g_loss:2.0144
epoch:820  d_loss:0.6665  g_loss:1.6832
epoch:821  d_loss:0.7525  g_loss:1.1647
epoch:822  d_loss:0.5217  g_loss:1.5240
epoch:823  d_loss:0.7001  g_loss:1.5896
epoch:824  d_loss:0.5321  g_loss:1.6300
epoch:825  d_loss:0.5699  g_loss:1.7717
epoch:826  d_loss:0.4373  g_loss:1.8551
epoch:827  d_loss:0.5855  g_loss:1.4746
epoch:828  d_loss:0.4836  g_loss:1.4230


epoch:1009  d_loss:0.2266  g_loss:2.7370
epoch:1010  d_loss:0.2962  g_loss:2.2410
epoch:1011  d_loss:0.4948  g_loss:1.8033
epoch:1012  d_loss:0.3142  g_loss:2.1941
epoch:1013  d_loss:0.2102  g_loss:2.4316
epoch:1014  d_loss:0.1837  g_loss:3.1537
epoch:1015  d_loss:0.2924  g_loss:2.2940
epoch:1016  d_loss:0.4004  g_loss:2.3916
epoch:1017  d_loss:0.2705  g_loss:2.3617
epoch:1018  d_loss:0.3943  g_loss:2.4322
epoch:1019  d_loss:0.3139  g_loss:2.0891
epoch:1020  d_loss:0.2492  g_loss:2.5429
epoch:1021  d_loss:0.2353  g_loss:2.6164
epoch:1022  d_loss:0.5050  g_loss:2.2631
epoch:1023  d_loss:0.4501  g_loss:1.9179
epoch:1024  d_loss:0.5178  g_loss:1.8362
epoch:1025  d_loss:0.5258  g_loss:2.1470
epoch:1026  d_loss:0.5985  g_loss:1.9094
epoch:1027  d_loss:0.3144  g_loss:2.2916
epoch:1028  d_loss:0.5138  g_loss:2.5313
epoch:1029  d_loss:0.4386  g_loss:2.2251
epoch:1030  d_loss:0.3217  g_loss:1.8217
epoch:1031  d_loss:0.4791  g_loss:1.7753
epoch:1032  d_loss:0.4332  g_loss:1.8236
epoch:1033  d_lo

epoch:1209  d_loss:0.5049  g_loss:1.5776
epoch:1210  d_loss:0.6392  g_loss:1.5348
epoch:1211  d_loss:0.5438  g_loss:1.6022
epoch:1212  d_loss:0.4860  g_loss:1.8699
epoch:1213  d_loss:0.5217  g_loss:1.8652
epoch:1214  d_loss:0.4663  g_loss:1.5815
epoch:1215  d_loss:0.3682  g_loss:1.5308
epoch:1216  d_loss:0.4385  g_loss:1.5041
epoch:1217  d_loss:0.4572  g_loss:1.5268
epoch:1218  d_loss:0.4133  g_loss:1.8472
epoch:1219  d_loss:0.2902  g_loss:2.0294
epoch:1220  d_loss:0.3002  g_loss:2.0677
epoch:1221  d_loss:0.3844  g_loss:1.7478
epoch:1222  d_loss:0.3008  g_loss:1.7936
epoch:1223  d_loss:0.4611  g_loss:1.6846
epoch:1224  d_loss:0.4195  g_loss:1.8165
epoch:1225  d_loss:0.3088  g_loss:2.1955
epoch:1226  d_loss:0.4084  g_loss:2.2698
epoch:1227  d_loss:0.3980  g_loss:2.2201
epoch:1228  d_loss:0.4659  g_loss:1.8287
epoch:1229  d_loss:0.5262  g_loss:1.3467
epoch:1230  d_loss:0.4003  g_loss:1.6707
epoch:1231  d_loss:0.4365  g_loss:2.3716
epoch:1232  d_loss:0.3510  g_loss:2.3765
epoch:1233  d_lo

epoch:1409  d_loss:0.3499  g_loss:1.7709
epoch:1410  d_loss:0.3978  g_loss:2.2200
epoch:1411  d_loss:0.4301  g_loss:1.9850
epoch:1412  d_loss:0.4681  g_loss:1.7129
epoch:1413  d_loss:0.4363  g_loss:1.7156
epoch:1414  d_loss:0.3134  g_loss:1.7839
epoch:1415  d_loss:0.4299  g_loss:1.8967
epoch:1416  d_loss:0.5847  g_loss:1.5771
epoch:1417  d_loss:0.3930  g_loss:1.5241
epoch:1418  d_loss:0.3452  g_loss:1.8161
epoch:1419  d_loss:0.5398  g_loss:1.7847
epoch:1420  d_loss:0.5720  g_loss:1.6520
epoch:1421  d_loss:0.5561  g_loss:1.6455
epoch:1422  d_loss:0.5453  g_loss:1.6730
epoch:1423  d_loss:0.4679  g_loss:1.5797
epoch:1424  d_loss:0.4793  g_loss:1.6036
epoch:1425  d_loss:0.4419  g_loss:1.5105
epoch:1426  d_loss:0.4272  g_loss:1.4716
epoch:1427  d_loss:0.5507  g_loss:1.4138
epoch:1428  d_loss:0.4040  g_loss:1.6970
epoch:1429  d_loss:0.4871  g_loss:1.3223
epoch:1430  d_loss:0.4772  g_loss:1.7493
epoch:1431  d_loss:0.5664  g_loss:1.4209
epoch:1432  d_loss:0.4281  g_loss:1.5434
epoch:1433  d_lo

epoch:1609  d_loss:0.3811  g_loss:1.7602
epoch:1610  d_loss:0.3968  g_loss:2.0627
epoch:1611  d_loss:0.4258  g_loss:1.8320
epoch:1612  d_loss:0.5030  g_loss:1.5498
epoch:1613  d_loss:0.5183  g_loss:1.6611
epoch:1614  d_loss:0.3818  g_loss:1.8019
epoch:1615  d_loss:0.4559  g_loss:1.6208
epoch:1616  d_loss:0.6024  g_loss:1.2699
epoch:1617  d_loss:0.4858  g_loss:1.7412
epoch:1618  d_loss:0.5532  g_loss:1.6089
epoch:1619  d_loss:0.3665  g_loss:2.0628
epoch:1620  d_loss:0.4437  g_loss:1.8588
epoch:1621  d_loss:0.4871  g_loss:1.6925
epoch:1622  d_loss:0.4047  g_loss:1.4784
epoch:1623  d_loss:0.4998  g_loss:1.4809
epoch:1624  d_loss:0.3896  g_loss:1.6503
epoch:1625  d_loss:0.4158  g_loss:1.5072
epoch:1626  d_loss:0.4408  g_loss:1.7270
epoch:1627  d_loss:0.3827  g_loss:2.2795
epoch:1628  d_loss:0.3937  g_loss:2.1273
epoch:1629  d_loss:0.4127  g_loss:2.1309
epoch:1630  d_loss:0.4087  g_loss:1.7754
epoch:1631  d_loss:0.3083  g_loss:1.7791
epoch:1632  d_loss:0.3699  g_loss:1.6536
epoch:1633  d_lo

epoch:1809  d_loss:0.4312  g_loss:1.5888
epoch:1810  d_loss:0.4173  g_loss:1.8721
epoch:1811  d_loss:0.4842  g_loss:1.9239
epoch:1812  d_loss:0.5501  g_loss:1.9668
epoch:1813  d_loss:0.4196  g_loss:1.4099
epoch:1814  d_loss:0.5736  g_loss:1.2355
epoch:1815  d_loss:0.4113  g_loss:1.2989
epoch:1816  d_loss:0.4290  g_loss:1.5104
epoch:1817  d_loss:0.5229  g_loss:1.9939
epoch:1818  d_loss:0.5374  g_loss:1.8793
epoch:1819  d_loss:0.5551  g_loss:1.7893
epoch:1820  d_loss:0.4149  g_loss:1.7911
epoch:1821  d_loss:0.4375  g_loss:1.6638
epoch:1822  d_loss:0.5832  g_loss:1.3747
epoch:1823  d_loss:0.4881  g_loss:1.4990
epoch:1824  d_loss:0.4010  g_loss:1.8659
epoch:1825  d_loss:0.5256  g_loss:1.8441
epoch:1826  d_loss:0.4901  g_loss:1.6802
epoch:1827  d_loss:0.5514  g_loss:1.9275
epoch:1828  d_loss:0.3828  g_loss:1.8208
epoch:1829  d_loss:0.5916  g_loss:1.3647
epoch:1830  d_loss:0.4953  g_loss:1.5119
epoch:1831  d_loss:0.6422  g_loss:1.3384
epoch:1832  d_loss:0.3978  g_loss:1.3047
epoch:1833  d_lo

epoch:2009  d_loss:0.4360  g_loss:1.5833
epoch:2010  d_loss:0.3797  g_loss:1.7310
epoch:2011  d_loss:0.3725  g_loss:1.5902
epoch:2012  d_loss:0.4498  g_loss:1.7197
epoch:2013  d_loss:0.4014  g_loss:1.7861
epoch:2014  d_loss:0.4923  g_loss:1.5376
epoch:2015  d_loss:0.4105  g_loss:1.8058
epoch:2016  d_loss:0.5013  g_loss:1.7398
epoch:2017  d_loss:0.4696  g_loss:1.6332
epoch:2018  d_loss:0.4661  g_loss:1.3656
epoch:2019  d_loss:0.5648  g_loss:1.5846
epoch:2020  d_loss:0.5147  g_loss:1.5346
epoch:2021  d_loss:0.4551  g_loss:1.8179
epoch:2022  d_loss:0.3809  g_loss:1.6154
epoch:2023  d_loss:0.5285  g_loss:1.8084
epoch:2024  d_loss:0.5921  g_loss:1.7611
epoch:2025  d_loss:0.7111  g_loss:1.5845
epoch:2026  d_loss:0.5735  g_loss:1.3457
epoch:2027  d_loss:0.4177  g_loss:1.4061
epoch:2028  d_loss:0.4741  g_loss:1.6135
epoch:2029  d_loss:0.4881  g_loss:1.7062
epoch:2030  d_loss:0.5596  g_loss:1.7117
epoch:2031  d_loss:0.5139  g_loss:1.5060
epoch:2032  d_loss:0.7203  g_loss:1.5569
epoch:2033  d_lo

epoch:2209  d_loss:0.5426  g_loss:1.6890
epoch:2210  d_loss:0.5706  g_loss:1.5702
epoch:2211  d_loss:0.5946  g_loss:1.3549
epoch:2212  d_loss:0.5148  g_loss:1.2580
epoch:2213  d_loss:0.5833  g_loss:1.1725
epoch:2214  d_loss:0.3819  g_loss:1.4146
epoch:2215  d_loss:0.5702  g_loss:1.7377
epoch:2216  d_loss:0.4565  g_loss:1.9135
epoch:2217  d_loss:0.5258  g_loss:1.9052
epoch:2218  d_loss:0.5139  g_loss:1.8921
epoch:2219  d_loss:0.5324  g_loss:1.6074
epoch:2220  d_loss:0.4157  g_loss:1.5325
epoch:2221  d_loss:0.5843  g_loss:1.3990
epoch:2222  d_loss:0.5605  g_loss:1.4305
epoch:2223  d_loss:0.5790  g_loss:1.5314
epoch:2224  d_loss:0.5507  g_loss:1.9949
epoch:2225  d_loss:0.6430  g_loss:1.6290
epoch:2226  d_loss:0.4879  g_loss:1.4503
epoch:2227  d_loss:0.4465  g_loss:1.3685
epoch:2228  d_loss:0.5722  g_loss:1.2471
epoch:2229  d_loss:0.5358  g_loss:1.3094
epoch:2230  d_loss:0.5625  g_loss:1.2659
epoch:2231  d_loss:0.5487  g_loss:1.3476
epoch:2232  d_loss:0.5043  g_loss:1.8977
epoch:2233  d_lo