<a href="https://colab.research.google.com/github/EunSu0/github/blob/main/mnist_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install keras numpy matplotlib tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# 학습한 모델의 오차를 줄이기 위해 경사 하강법 사용
# 모델 학습을 시각적으로 보여주는 tqdm 사용
import keras

from keras.models import Model, Sequential
from keras.layers import Dense, Input
# from keras.layers.advanced_activations import LeakyReLU
from keras.layers import LeakyReLU
from keras.optimizers import Adam
from keras.datasets import mnist
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# 데이터 불러오기
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_test = (x_test.astype(np.float32) - 127.5)/127.5
mnist_data = x_test.reshape(10000, 784)
print(mnist_data.shape)
len(mnist_data)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
(10000, 784)


10000

In [4]:
# 생성자 신경망 만들기
def create_generator():
  generator = Sequential()
  generator.add(Dense(units=256, input_dim=100))
  generator.add(LeakyReLU(0.2))
  generator.add(Dense(units=512))
  generator.add(LeakyReLU(0.2))
  generator.add(Dense(units=784, activation='tanh'))
  return generator
g = create_generator()
g.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 256)               25856     
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 256)               0         
                                                                 
 dense_1 (Dense)             (None, 512)               131584    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 512)               0         
                                                                 
 dense_2 (Dense)             (None, 784)               402192    
                                                                 
Total params: 559,632
Trainable params: 559,632
Non-trainable params: 0
_________________________________________________________________


In [5]:
# 판별자 신경망 만들기
def create_discriminator():
  discriminator = Sequential()
  discriminator.add(Dense(units=512, input_dim=784))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dense(units=256))
  discriminator.add(LeakyReLU(0.2))
  discriminator.add(Dense(units=1, activation='sigmoid'))
  discriminator.compile(loss='binary_crossentropy',
optimizer = Adam(learning_rate=0.002, beta_1=0.5))
  return discriminator
d = create_discriminator()
d.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_3 (Dense)             (None, 512)               401920    
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 512)               0         
                                                                 
 dense_4 (Dense)             (None, 256)               131328    
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 256)               0         
                                                                 
 dense_5 (Dense)             (None, 1)                 257       
                                                                 
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________


In [6]:
# gan 생성 함수 만들기
def create_gan(discriminator, generator):
  discriminator.trainable=False
  gan_input = Input(shape=(100,))
  x = generator(gan_input)
  gan_output = discriminator(x)
  gan = Model(inputs=gan_input, outputs=gan_output)
  gan.compile(loss = 'binary_crossentropy', optimizer='adam')
  return gan
gan = create_gan(d,g)
gan.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 100)]             0         
                                                                 
 sequential (Sequential)     (None, 784)               559632    
                                                                 
 sequential_1 (Sequential)   (None, 1)                 533505    
                                                                 
Total params: 1,093,137
Trainable params: 559,632
Non-trainable params: 533,505
_________________________________________________________________


In [7]:
# 결과 확인 함수 만들기
def plot_generated_images(generator):
  noise = np.random.normal(loc=0, scale=1, size=[100,100])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(100,28,28)
  plt.figure(figsize=(10,10))
  for i in range(generated_images.shape[0]):
    plt.subplot(10, 10, i+1)
    plt.imshow(generated_images[i], interpolation='nearest')
    plt.axis('off')
  plt.tight_layout()

In [None]:
# 적재적 생성 신경망 훈련시키기
batch_size = 128
epochs = 5000
for e in tqdm(range(epochs)):
  noise = np.random.normal(0, 1, [batch_size, 100])
  generated_images = g.predict(noise)
  image_batch = mnist_data[np.random.randint(low=0, high=mnist_data.shape[0], size=batch_size)]
  X = np.concatenate([image_batch, generated_images])
  y_dis = np.zeros(2*batch_size)
  y_dis[:batch_size] = 1
  d.trainable = True
  d.train_on_batch(X, y_dis)
  noise = np.random.normal(0, 1, [batch_size, 100])
  y_gen = np.ones(batch_size)
  d.trainable = False
  gan.train_on_batch(noise, y_gen)
  if e == 0 or e % 1000 ==0:
    plot_generated_images(g)

  0%|          | 0/5000 [00:00<?, ?it/s]



  0%|          | 1/5000 [00:05<7:36:34,  5.48s/it]



  0%|          | 2/5000 [00:05<3:16:38,  2.36s/it]



  0%|          | 3/5000 [00:05<1:53:29,  1.36s/it]



  0%|          | 4/5000 [00:05<1:13:20,  1.14it/s]



  0%|          | 5/5000 [00:06<51:01,  1.63it/s]  



  0%|          | 6/5000 [00:06<37:36,  2.21it/s]



  0%|          | 7/5000 [00:06<29:15,  2.84it/s]



  0%|          | 8/5000 [00:06<25:03,  3.32it/s]



  0%|          | 9/5000 [00:06<21:07,  3.94it/s]



  0%|          | 10/5000 [00:06<19:54,  4.18it/s]



  0%|          | 11/5000 [00:07<18:00,  4.62it/s]



  0%|          | 12/5000 [00:07<17:32,  4.74it/s]



  0%|          | 13/5000 [00:07<16:27,  5.05it/s]



  0%|          | 14/5000 [00:07<14:56,  5.56it/s]



  0%|          | 15/5000 [00:07<14:52,  5.59it/s]



  0%|          | 16/5000 [00:07<14:50,  5.60it/s]



  0%|          | 17/5000 [00:08<14:37,  5.68it/s]



  0%|          | 18/5000 [00:08<14:46,  5.62it/s]



  0%|          | 19/5000 [00:08<14:06,  5.89it/s]



  0%|          | 20/5000 [00:08<14:42,  5.64it/s]



  0%|          | 21/5000 [00:08<15:04,  5.51it/s]



  0%|          | 22/5000 [00:09<14:36,  5.68it/s]



  0%|          | 23/5000 [00:09<14:16,  5.81it/s]



  0%|          | 24/5000 [00:09<14:01,  5.92it/s]



  0%|          | 25/5000 [00:09<14:03,  5.90it/s]



  1%|          | 26/5000 [00:09<14:21,  5.78it/s]



  1%|          | 27/5000 [00:09<14:37,  5.67it/s]



  1%|          | 28/5000 [00:10<14:50,  5.59it/s]



  1%|          | 29/5000 [00:10<14:44,  5.62it/s]



  1%|          | 30/5000 [00:10<13:46,  6.01it/s]



  1%|          | 31/5000 [00:10<14:18,  5.79it/s]



  1%|          | 32/5000 [00:10<14:00,  5.91it/s]



  1%|          | 33/5000 [00:10<14:37,  5.66it/s]



  1%|          | 34/5000 [00:11<15:00,  5.51it/s]



  1%|          | 35/5000 [00:11<14:27,  5.72it/s]



  1%|          | 36/5000 [00:11<15:32,  5.32it/s]



  1%|          | 37/5000 [00:11<14:59,  5.52it/s]



  1%|          | 38/5000 [00:11<16:36,  4.98it/s]



  1%|          | 39/5000 [00:12<17:30,  4.72it/s]



  1%|          | 40/5000 [00:12<19:16,  4.29it/s]



  1%|          | 41/5000 [00:12<19:14,  4.30it/s]



  1%|          | 42/5000 [00:12<19:40,  4.20it/s]



  1%|          | 43/5000 [00:13<19:44,  4.19it/s]



  1%|          | 44/5000 [00:13<20:12,  4.09it/s]



  1%|          | 45/5000 [00:13<20:29,  4.03it/s]



  1%|          | 46/5000 [00:13<20:58,  3.94it/s]



  1%|          | 47/5000 [00:14<21:14,  3.89it/s]



  1%|          | 48/5000 [00:14<20:29,  4.03it/s]



  1%|          | 49/5000 [00:14<19:13,  4.29it/s]



  1%|          | 50/5000 [00:14<18:19,  4.50it/s]



  1%|          | 51/5000 [00:15<17:40,  4.67it/s]



  1%|          | 52/5000 [00:15<16:27,  5.01it/s]



  1%|          | 53/5000 [00:15<15:29,  5.32it/s]



  1%|          | 54/5000 [00:15<14:59,  5.50it/s]



  1%|          | 55/5000 [00:15<14:36,  5.64it/s]



  1%|          | 56/5000 [00:15<15:18,  5.38it/s]



  1%|          | 57/5000 [00:16<15:47,  5.21it/s]



  1%|          | 58/5000 [00:16<15:00,  5.49it/s]



  1%|          | 59/5000 [00:16<14:33,  5.66it/s]



  1%|          | 60/5000 [00:16<15:09,  5.43it/s]



  1%|          | 61/5000 [00:16<14:33,  5.65it/s]



  1%|          | 62/5000 [00:16<15:09,  5.43it/s]



  1%|▏         | 63/5000 [00:17<15:31,  5.30it/s]



  1%|▏         | 64/5000 [00:17<15:37,  5.26it/s]



  1%|▏         | 65/5000 [00:17<16:26,  5.00it/s]



  1%|▏         | 66/5000 [00:17<16:26,  5.00it/s]



  1%|▏         | 67/5000 [00:17<15:04,  5.46it/s]



  1%|▏         | 68/5000 [00:18<13:54,  5.91it/s]



  1%|▏         | 69/5000 [00:18<13:08,  6.26it/s]



  1%|▏         | 70/5000 [00:18<13:26,  6.12it/s]



  1%|▏         | 71/5000 [00:18<13:42,  5.99it/s]



  1%|▏         | 72/5000 [00:18<13:07,  6.25it/s]



  1%|▏         | 73/5000 [00:18<13:08,  6.25it/s]



  1%|▏         | 74/5000 [00:19<13:03,  6.29it/s]



  2%|▏         | 75/5000 [00:19<13:26,  6.11it/s]



  2%|▏         | 76/5000 [00:19<13:59,  5.87it/s]



  2%|▏         | 77/5000 [00:19<14:12,  5.78it/s]



  2%|▏         | 78/5000 [00:19<14:22,  5.71it/s]



  2%|▏         | 79/5000 [00:19<14:47,  5.54it/s]



  2%|▏         | 80/5000 [00:20<14:06,  5.81it/s]



  2%|▏         | 81/5000 [00:20<14:30,  5.65it/s]



  2%|▏         | 82/5000 [00:20<14:06,  5.81it/s]



  2%|▏         | 83/5000 [00:20<14:30,  5.65it/s]



  2%|▏         | 84/5000 [00:20<15:17,  5.36it/s]



  2%|▏         | 85/5000 [00:21<15:39,  5.23it/s]



  2%|▏         | 86/5000 [00:21<14:49,  5.53it/s]



  2%|▏         | 87/5000 [00:21<15:03,  5.44it/s]



  2%|▏         | 88/5000 [00:21<14:25,  5.68it/s]



  2%|▏         | 89/5000 [00:21<14:01,  5.84it/s]



  2%|▏         | 90/5000 [00:21<14:37,  5.60it/s]



  2%|▏         | 91/5000 [00:22<14:04,  5.81it/s]



  2%|▏         | 92/5000 [00:22<14:34,  5.61it/s]



  2%|▏         | 93/5000 [00:22<15:15,  5.36it/s]



  2%|▏         | 94/5000 [00:22<14:40,  5.57it/s]



  2%|▏         | 95/5000 [00:22<15:27,  5.29it/s]



  2%|▏         | 96/5000 [00:23<15:21,  5.32it/s]



  2%|▏         | 97/5000 [00:23<15:00,  5.45it/s]



  2%|▏         | 98/5000 [00:23<14:45,  5.53it/s]



  2%|▏         | 99/5000 [00:23<14:31,  5.62it/s]



  2%|▏         | 100/5000 [00:23<13:35,  6.01it/s]



  2%|▏         | 101/5000 [00:23<13:55,  5.87it/s]



  2%|▏         | 102/5000 [00:24<14:03,  5.81it/s]



  2%|▏         | 103/5000 [00:24<13:19,  6.12it/s]



  2%|▏         | 104/5000 [00:24<13:35,  6.01it/s]



  2%|▏         | 105/5000 [00:24<14:17,  5.71it/s]



  2%|▏         | 106/5000 [00:24<15:55,  5.12it/s]



  2%|▏         | 107/5000 [00:25<17:06,  4.77it/s]



  2%|▏         | 108/5000 [00:25<18:19,  4.45it/s]



  2%|▏         | 109/5000 [00:25<18:38,  4.37it/s]



  2%|▏         | 110/5000 [00:25<19:33,  4.17it/s]



  2%|▏         | 111/5000 [00:26<19:55,  4.09it/s]



  2%|▏         | 112/5000 [00:26<21:00,  3.88it/s]



  2%|▏         | 113/5000 [00:26<21:24,  3.80it/s]



  2%|▏         | 114/5000 [00:26<21:22,  3.81it/s]



  2%|▏         | 115/5000 [00:27<20:47,  3.92it/s]



  2%|▏         | 116/5000 [00:27<18:30,  4.40it/s]



  2%|▏         | 117/5000 [00:27<17:52,  4.55it/s]



  2%|▏         | 118/5000 [00:27<17:11,  4.73it/s]



  2%|▏         | 119/5000 [00:27<16:10,  5.03it/s]



  2%|▏         | 120/5000 [00:28<16:03,  5.06it/s]



  2%|▏         | 121/5000 [00:28<14:59,  5.42it/s]



  2%|▏         | 122/5000 [00:28<15:42,  5.17it/s]



  2%|▏         | 123/5000 [00:28<14:55,  5.45it/s]



  2%|▏         | 124/5000 [00:28<14:23,  5.64it/s]



  2%|▎         | 125/5000 [00:28<14:43,  5.52it/s]



  3%|▎         | 126/5000 [00:29<15:19,  5.30it/s]



  3%|▎         | 127/5000 [00:29<15:29,  5.24it/s]



  3%|▎         | 128/5000 [00:29<15:33,  5.22it/s]



  3%|▎         | 129/5000 [00:29<15:36,  5.20it/s]



  3%|▎         | 130/5000 [00:29<15:23,  5.28it/s]



  3%|▎         | 131/5000 [00:30<15:12,  5.33it/s]



  3%|▎         | 132/5000 [00:30<14:41,  5.52it/s]



  3%|▎         | 133/5000 [00:30<14:40,  5.53it/s]



  3%|▎         | 134/5000 [00:30<14:53,  5.44it/s]



  3%|▎         | 135/5000 [00:30<14:25,  5.62it/s]



  3%|▎         | 136/5000 [00:30<14:04,  5.76it/s]



  3%|▎         | 137/5000 [00:31<14:06,  5.75it/s]

