<a href="https://colab.research.google.com/github/KudohAtsuo/MNISTGAN/blob/main/MNISTGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **1) Importing Python Packages for GAN**


In [None]:
from keras.datasets import mnist

from keras.models import Sequential
from keras.layers import BatchNormalization
from keras.layers import Dense, Reshape, Flatten
from keras.layers.advanced_activations import LeakyReLU
from tensorflow.keras.optimizers import Adam

import numpy as np

!mkdir generated_images

## **2) Variables for Neural Networks & Data**

In [None]:
(x_train,_), (_,_) = mnist.load_data()

img_width = 28
img_heighty = 28
channels = 1
img_shape = (img_heighty, img_width, channels)

latent_dim = 100

adam = Adam(learning_rate=0.0001)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## **3) Building Generator**





In [None]:
def build_generator():
    '''
    ジェネレータのmodelを作る
        - 基本のmodelはSequential
        - activationにLeakyReLUを使う
        - parameterを変える最後の層ではtanhを使う
        - 本物の画像のshapeに一致させるよう整形して出力する
    '''
    model = Sequential()

    model.add(Dense(256, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
  

    model.summary()
    return model

generator = build_generator()



    

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 256)               25856     
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 256)               0         
                                                                 
 batch_normalization (BatchN  (None, 256)              1024      
 ormalization)                                                   
                                                                 
 dense_1 (Dense)             (None, 512)               131584    
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 512)               0         
                                                                 
 batch_normalization_1 (Batc  (None, 512)              2048      
 hNormalization)                                        

## **4) Building Discriminator**

In [None]:
def build_discriminator():
    '''
    識別器のmodelを作成する
        - 基本modelはSequential
        - 入力は画像
        - ジェネレーターで生成された画像のshapeを一致させてFlattenから
        - activationはLeakyReLU
        - 最後に一つのnodeをoutputする
        - 2値分類なのでsigmoid
    '''
    model = Sequential()

    model.add(Flatten(input_shape=img_shape))  

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Dense(1, activation='sigmoid'))

    model.summary()
    return model

discriminator = build_discriminator()


    



Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense_4 (Dense)             (None, 512)               401920    
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 512)               0         
                                                                 
 dense_5 (Dense)             (None, 256)               131328    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 256)               0         
                                                                 
 dense_6 (Dense)             (None, 1)                 257       
                                                                 
Total params: 533,505
Trainable params: 533,505
Non-tr

## **5) Connecting Neural Networks to build GAN**

In [None]:
discriminator.compile(optimizer='adam', loss='binary_crossentropy')

GAN = Sequential()

discriminator.trainable = False
GAN.add(generator)
GAN.add(discriminator)

GAN.compile(optimizer='adam', loss='binary_crossentropy')


## **6) Outputting Images**


In [None]:
#@title
## **7) Outputting Images**
import matplotlib.pyplot as plt
import glob
import imageio
import PIL

save_name = 0.00000000

def save_imgs(epoch):
    #generate 25 images to fit on a 5 x 5 grid for our animation!
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, latent_dim))
    gen_imgs = generator.predict(noise)
    global save_name
    save_name += 0.00000001
    print("%.8f" % save_name)

    # Rescale images 0 - 1
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            # print(gen_imgs[cnt].shape, gen_imgs.shape, gen_imgs[cnt, :,:,0].shape)
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            # axs[i,j].imshow(gen_imgs[cnt], cmap='inferno')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("generated_images/%.8f.png" % save_name)
    print('saved')
    plt.close()

## **7) Training GAN**

In [None]:
def train(epochs, batch_size=64, save_interval=200):
    '''
    GANをtrainする
        - 実際の画像とnoiseを同数用意する
        - 正解ラベル１と不正解ラベル０も同数用意する
        - まずdiscriminatorをtrainしてclassifyする
        - discriminatorのtrainを止める
        - y_labelをひっくり返してnoiseでgeneratorのみをtrainする
    '''
    (X_train, _), (_, _) = mnist.load_data()
    X_train = X_train/127.5 -1.

    valid = np.ones((batch_size, 1))
    fakes = np.zeros((batch_size, 1))

    for epoch in range(epochs):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]      
            
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        gen_imgs = generator.predict(noise)

        # discriminatorの学習(classify)
        d_loss_real = discriminator.train_on_batch(imgs, valid)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fakes)
        d_loss = np.add(d_loss_real, d_loss_fake)/2
 
        # GANの学習（classify, discriminatorを騙せるよう教育する)
        noise = np.random.normal(0, 1, (batch_size, latent_dim))

        g_loss = GAN.train_on_batch(noise, valid)

        print(f'***********{epoch} [D loss: {d_loss_real}], [G loss: {g_loss}]')

        if epoch%200 == 0:
            save_imgs(epoch)

train(30000, batch_size=64, save_interval=200)

    
#noise = np.random.normal(0, 1, (64, 100))
#gen_imgs = generator.predict(noise)

#type(gen_imgs)
#gen_imgs.shape
# print(gen_imgs[:5])






***********0 [D loss: 0.02138662338256836], [G loss: 7.206105709075928]
0.00000001
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (28, 28)
(28, 28, 1) (25, 28, 28, 1) (

KeyboardInterrupt: ignored

### **8) Making GIF**

In [None]:
anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('generated_images/*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)