In [1]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, LeakyReLU, UpSampling2D, Conv2D
from keras.models import Sequential, Model

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt

#결과 이미지 저장할 폴더, 일정시간 이후에 삭제됨 T.T
import os
if not os.path.exists("./gan_images"):
    os.makedirs("./gan_images")

In [2]:
#생성 모델
generator = Sequential()
generator.add(Dense(128*7*7, input_dim=100, activation=LeakyReLU(0.2))) 
#128*7*7 : 출력 뉴런의 수, input_dim : 입력 뉴런의 수(입력의 차원), LeakyReLU() : x(x>=0), 0.2*x(x<0)
generator.add(BatchNormalization()) #배치 정규화
generator.add(Reshape((7, 7, 128))) #출력 뉴런의 수를 맞춰서 UpSampling후에 conv2D에 전달
generator.add(UpSampling2D()) #입력 이미지의 크기 2배 확장
generator.add(Conv2D(64, kernel_size=5, padding='same')) #64 필터의 갯수
generator.add(BatchNormalization())
generator.add(Activation(LeakyReLU(0.2)))
generator.add(UpSampling2D())
generator.add(Conv2D(1, kernel_size=5, padding='same', activation='tanh'))
#generator.compilie....XX...
generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 6272)              633472    
                                                                 
 batch_normalization (BatchN  (None, 6272)             25088     
 ormalization)                                                   
                                                                 
 reshape (Reshape)           (None, 7, 7, 128)         0         
                                                                 
 up_sampling2d (UpSampling2D  (None, 14, 14, 128)      0         
 )                                                               
                                                                 
 conv2d (Conv2D)             (None, 14, 14, 64)        204864    
                                                                 
 batch_normalization_1 (Batc  (None, 14, 14, 64)       2

In [4]:
#판별 모델(실제 생성 모델을 위한 학습은 실시하지 않음)
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, input_shape=(28,28,1), padding="same")) #28*28 1channel
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same"))
discriminator.add(Activation(LeakyReLU(0.2)))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer='adam') #trainable = True.
discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_2 (Conv2D)           (None, 14, 14, 64)        1664      
                                                                 
 activation_1 (Activation)   (None, 14, 14, 64)        0         
                                                                 
 dropout (Dropout)           (None, 14, 14, 64)        0         
                                                                 
 conv2d_3 (Conv2D)           (None, 7, 7, 128)         204928    
                                                                 
 activation_2 (Activation)   (None, 7, 7, 128)         0         
                                                                 
 dropout_1 (Dropout)         (None, 7, 7, 128)         0         
                                                                 
 flatten (Flatten)           (None, 6272)             

In [5]:
discriminator.trainable = False #판별 후 자신이 학습하지 않도록 학습 기능을 OFF시킴. 
#생성 모듈과 판별 모델의 GAN 생성
ginput = Input(shape=(100,))
dis_output = discriminator(generator(ginput)) 
gan = Model(ginput, dis_output) #ginput과 dis_output을 이용한 GAN생성
gan.compile(loss='binary_crossentropy', optimizer='adam')
gan.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 100)]             0         
                                                                 
 sequential (Sequential)     (None, 28, 28, 1)         865281    
                                                                 
 sequential_1 (Sequential)   (None, 1)                 212865    
                                                                 
Total params: 1,078,146
Trainable params: 852,609
Non-trainable params: 225,537
_________________________________________________________________


In [7]:
def gan_train(epoch, batch_size, saving_interval):

  # MNIST 데이터 불러오기

  (X_train, _), (_, _) = mnist.load_data()  # 테스트과정 불필요, 입력 이미지만 사용
  X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
  X_train = (X_train - 127.5) / 127.5  # 픽셀값 0 ~ 255 => -1 ~ 1 값 변경

  true = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for i in range(epoch):
          # 실제 데이터 판별 모듈에 입력
          idx = np.random.randint(0, X_train.shape[0], batch_size) # 0 ~ X_train.shape[0], batch_size만큼 반복
          imgs = X_train[idx]
          d_loss_real = discriminator.train_on_batch(imgs, true) #판별 시작, 학습 실시

          #가상 이미지 판별 모듈에 입력
          noise = np.random.normal(0, 1, (batch_size, 100)) #batch_size만큼 100열 선출
          gen_imgs = generator.predict(noise)
          d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)

          #판별 모듈와 생성 모듈의 오차를 계산
          d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) #오차 계산
          g_loss = gan.train_on_batch(noise, true) #noise를 1로 라벨링

          print('epoch:%d' % i, ' d_loss:%.4f' % d_loss, ' g_loss:%.4f' % g_loss)

        # 이부분은 중간 과정을 이미지로 저장해 주는 부분입니다. 
          if i % saving_interval == 0:
              noise = np.random.normal(0, 1, (25, 100))
              gen_imgs = generator.predict(noise)

              gen_imgs = 0.5 * gen_imgs + 0.5

              fig, axs = plt.subplots(5, 5)
              count = 0
              for j in range(5):
                  for k in range(5):
                      axs[j, k].imshow(gen_imgs[count, :, :, 0], cmap='gray')
                      axs[j, k].axis('off')
                      count += 1
              fig.savefig("gan_images/gan_mnist_%d.png" % i)