In [1]:
from PIL import Image 
import numpy as np 
import math
import os 

import tensorflow as tf  
from keras.datasets import mnist, cifar10 
from keras import models, layers, optimizers
from keras.models import Model, Sequential
from keras.layers import Dense, Conv2D, BatchNormalization, \
                         Reshape, UpSampling2D, MaxPooling2D, Flatten
import keras.backend as K 
print(K.image_data_format())

channels_last


In [2]:
def mse_4d(y_true, y_pred) :
    return K.mean(K.square(y_pred - y_true), axis=(1, 2, 3))
def mse_4d_tf(y_true, y_pred) :
    # tf.reduce_mean 은 열단위로 연산
    return tf.reduce_mean(tf.square(y_pred - y_tre), axis = (1, 2, 3))

In [3]:
class GAN(Sequential) :
    def __init__(self, input_dim = 32):
        super().__init__()
        self.input_dim = input_dim 
        self.generator = self.GENERATOR()
        self.discriminator = self.DISCRIMINATOR()
        self.add(self.generator)
        self.discriminator.trainable = False 
        self.add(self.discriminator)

        self.compile_all() 
    
    def compile_all(self) :
        d_optim = optimizers.SGD(lr = 5e-4, momentum=0.9, nesterov= True)
        g_optim = optimizers.SGD(lr = 5e-4, momentum=0.9, nesterov= True)
        self.generator.compile(loss = mse_4d_tf, optimizer = 'SGD')
        self.compile(loss = 'binary_crossentropy', optimizer = g_optim)
        self.discriminator.trainable = True 
        self.discriminator.compile(loss = 'binary_crossentropy', optimizer = d_optim)

    def GENERATOR(self) :
        input_dim = self.input_dim 

        model = Sequential([
            Dense(1024, activation='relu', input_dim = input_dim),
            # CIFAR 10 은 32크기, MNIST는 28크기
            Dense(7*7*128, activation='tanh'),
            BatchNormalization(),
            Reshape((7, 7, 128), input_shape = (7 * 7 * 128, )),
            UpSampling2D(size = (2, 2)),
            Conv2D(128, (5,5), padding ='same', activation='tanh'),
            UpSampling2D(size = (2,2)),
            # 컬러이면 Conv2D(3,~) / 흑백이면 1
            Conv2D(1, (5, 5), padding = 'same', activation='tanh')
        ])
        return model 
    
    def DISCRIMINATOR(self) :
        model = Sequential([
            Conv2D(128, (5, 5), padding = 'same', activation='tanh',
                   input_shape = (28, 28, 1)),
            MaxPooling2D(pool_size=(2, 2)),
            Conv2D(256, (5, 5), activation='tanh'),
            MaxPooling2D(pool_size = (2, 2)),
            Flatten(),
            Dense(1024, activation='tanh'),
            Dense(1, activation='sigmoid')
        ])
        return model 
    
    def get_z(self, ln) :
        input_dim = self.input_dim 
        return np.random.uniform(-1, 1, (ln, input_dim))

    def train_both(self, x) :
        ln = x.shape[0]
        z = self.get_z(ln)
        w = self.generator.predict(z, verbose = 0)
        xw = np.concatenate((x, w))
        y2 = np.array([1] * ln + [0] * ln).reshape(-1, 1)
        d_loss = self.discriminator.train_on_batch(xw, y2)

        z = self.get_z(ln)
        self.discriminator.trainable = False 
        g_loss = self.train_on_batch(z, np.array([1] * ln).reshape(-1, 1))
        self.discriminator.trainable = TabError

        return d_loss, g_loss 
    

In [4]:
def combine_images(generated_images) :
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height * shape[0], width * shape[1]),
                     dtype = generated_images.dtype)
    for index, img in enumerate(generated_images) :
        i = int(index / width)
        j = index % width 
        image[i * shape[0]:(i+1)*shape[0],
              j*shape[1]:(j+1)*shape[1]] = img[: , : , 0]
    return image 

def get_x(X_train, index, BATCH_SIZE) :
    return X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]

def save_images(generated_images, output_fold, epoch, index) :
    image = combine_images(generated_images)
    image = image * 127.5 + 127.5
    Image.fromarray(image.astype(np.uint8)).save(
        output_fold + '/' +
        str(epoch) + '_' + str(index) + '.png'
    )

def load_data(n_train) :
    (X_train, y_train), (_, _) = mnist.load_data() 
    return X_train[:n_train]

def train(args) :
    BATCH_SIZE = args.batch_size
    epochs = args.epochs
    output_fold = args.output_fold
    input_dim = args.input_dim 
    n_train = args.n_train

    os.makedirs(output_fold, exist_ok= True)
    print('Output_fold is', output_fold)
    
    X_train = load_data(n_train)

    X_train = (X_train.astype(np.float32) - 127.5) / 127.5 
    X_train = X_train.reshape(X_train.shape + (1,))

    gan = GAN(input_dim)

    d_loss_ll = []
    g_loss_ll = []
    
    for epoch in range(epochs) :
        if epoch % 10 == 0 :
            print('Epoch is', epoch)
            print('Number of batches', int(X_train.shape[0] / BATCH_SIZE))

        d_loss_l = []
        g_loss_l = []
        for index in range(int(X_train.shape[0] / BATCH_SIZE)) :
            x = get_x(X_train, index, BATCH_SIZE)

            d_loss, g_loss = gan.train_both(x)

            d_loss_l.append(d_loss)
            g_loss_l.append(g_loss)

        if epoch % 10 == 0 or epoch == epochs - 1 :
            z = gan.get_z(x.shape[0])
            w = gan.generator.predict(z, verbose = 0)
            save_images(w, output_fold, epoch, 0)

        d_loss_ll.append(d_loss_l)
        g_loss_ll.append(g_loss_l)

    gan.generator.save_weights(output_fold + '/' + 'generator', True)
    gan .discriminator.save_weights(output_fold + '/' + 'discriminator', True)

    np.savetxt(output_fold + '/' + 'd_loss' + d_loss_ll)
    np.savetxt(output_fold + '/' + 'g_loss' , g_loss_ll)

In [5]:
def main() :
    class ARGS :
        def __init__(args):
            args.batch_size = 64
            args.epochs = 4000
            args.output_fold = 'GAN_OUT'
            args.input_dim = 10
            args.n_train = 128
    args = ARGS()
    train(args)
main()

Output_fold is GAN_OUT


  super().__init__(name, **kwargs)


Epoch is 0
Number of batches 2
Epoch is 10
Number of batches 2
Epoch is 20
Number of batches 2
Epoch is 30
Number of batches 2
Epoch is 40
Number of batches 2
Epoch is 50
Number of batches 2
Epoch is 60
Number of batches 2
Epoch is 70
Number of batches 2
Epoch is 80
Number of batches 2
Epoch is 90
Number of batches 2
Epoch is 100
Number of batches 2
Epoch is 110
Number of batches 2
Epoch is 120
Number of batches 2
Epoch is 130
Number of batches 2
Epoch is 140
Number of batches 2
Epoch is 150
Number of batches 2
Epoch is 160
Number of batches 2
Epoch is 170
Number of batches 2
Epoch is 180
Number of batches 2
Epoch is 190
Number of batches 2
Epoch is 200
Number of batches 2
Epoch is 210
Number of batches 2
Epoch is 220
Number of batches 2
Epoch is 230
Number of batches 2
Epoch is 240
Number of batches 2
Epoch is 250
Number of batches 2
Epoch is 260
Number of batches 2
Epoch is 270
Number of batches 2
Epoch is 280
Number of batches 2
Epoch is 290
Number of batches 2
Epoch is 300
Number o

KeyboardInterrupt: 