In [None]:
import os
import math
import argparse
from PIL import Image
from keras import backend, models, layers, optimizers
from keras.datasets import mnist
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 연쇄 방식 Modeling - OOP

In [None]:
def mse_4d(y_true, y_pred):
    return backend.mean(backend.square(y_pred - y_true), axis=(1, 2, 3))
    
def mse_4d_tf(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_pred - y_true), axis=(1, 2, 3))

class GAN_Seq_OOP(models.Sequential):
    def __init__(self, input_dim=64):
        super().__init__()
        self.input_dim = input_dim
        self.generator = self._generator()
        self.discriminator = self._discriminator()
        self.discriminator.trainable = False

        self.add(self.generator)
        self.add(self.discriminator)
        self.compile_all()

    def compile_all(self):
        d_optim = optimizers.SGD(learning_rate=0.0005, momentum=0.9, nesterov=True)
        g_optim = optimizers.SGD(learning_rate=0.0005, momentum=0.9, nesterov=True)

        self.generator.compile(loss=mse_4d_tf, optimizer="SGD")
        self.compile(loss='binary_crossentropy', optimizer=g_optim)
        self.discriminator.trainable = True
        self.discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)

    def _generator(self):
        model = models.Sequential()
        model.add(layers.Dense(1024, activation='tanh', input_dim=self.input_dim))
        model.add(layers.Dense(7 * 7* 128, activation='tanh'))
        model.add(layers.BatchNormalization())
        model.add(layers.Reshape((7, 7, 128), input_shape=(7 * 7 * 128,)))
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh'))
        model.add(layers.UpSampling2D(size=(2, 2)))
        model.add(layers.Conv2D(1, (5, 5), padding='same', activation='tanh'))
        return model

    def _discriminator(self):
        model = models.Sequential()
        model.add(layers.Conv2D(64, (5, 5), padding='same', activation='tanh', input_shape=(28, 28, 1)))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Conv2D(128, (5, 5), activation='tanh'))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Flatten())
        model.add(layers.Dense(1024, activation='tanh'))
        model.add(layers.Dense(1, activation='sigmoid'))
        return model

    def get_z(self, ln):
        return np.random.uniform(-1, 1, (ln, self.input_dim))

    def train_both(self, x):
        ln = x.shape[0]
        z = self.get_z(ln)
        w = self.generator.predict(z, verbose=0)
        xw = np.concatenate((x, w))
        y2 = np.array([1] * ln + [0] * ln).reshape(-1, 1)
        d_loss = self.discriminator.train_on_batch(xw, y2)

        z = self.get_z(ln)
        self.discriminator.trainable = False
        g_loss = self.train_on_batch(z, np.array([1] * ln).reshape(-1, 1))
        self.discriminator.trainable = True

        return d_loss, g_loss

# Training and Evaluation

In [None]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num) / width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height * shape[0], width * shape[1]), dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index / width)
        j = index % width
        image[i * shape[0]:(i + 1) * shape[0], j * shape[1]:(j + 1) * shape[1]] = img[:, :, 0]
    return image

def get_x(x_train, index, batch_size):
    return x_train[index * batch_size:(index + 1) * batch_size]

def save_images(generated_images, output_fold, epoch, index):
    image = combine_images(generated_images)
    image = image * 127.5 + 127.5
    Image.fromarray(image.astype(np.uint8)).save(output_fold + '/' + str(epoch) + "_" + str(index) + ".png")

def load_data(n_train):
    (x_train, y_train), (_, _) = mnist.load_data()
    return x_train[:n_train]

def train(args):
    batch_size = args.batch_size
    epochs = args.epochs
    output_fold = args.output_fold
    input_dim = args.input_dim
    n_train = args.n_train

    os.makedirs(output_fold, exist_ok=True)

    x_train = load_data(n_train)
    x_train = (x_train.astype(np.float32) - 127.5) / 127.5
    x_train = x_train.reshape(x_train.shape + (1,))

    gan = GAN_Seq_OOP(input_dim)

    d_loss_ll = []
    g_loss_ll = []
    for epoch in range(epochs):
        if epoch % 10 == 0:
            print(f'Epoch is {epoch}')
            print(f'Number of batches {int(x_train.shape[0] / batch_size)}')
        
        d_loss_l = []
        g_loss_l = []
        for index in range(int(x_train.shape[0] / batch_size)):
            x = get_x(x_train, index, batch_size)
            d_loss, g_loss = gan.train_both(x)
            d_loss_l.append(d_loss)
            g_loss_l.append(g_loss)
        if epoch % 10 == 0 or epoch == epochs - 1:
            z = gan.get_z(x.shape[0])
            w = gan.generator.predict(z, verbose=0)
            save_images(w, output_fold, epoch, 0)
        d_loss_ll.append(d_loss_l)
        g_loss_ll.append(g_loss_l)
    
    gan.generator.save_weights(output_fold + '/' + 'generator', True)
    gan.discriminator.save_weights(output_fold + '/' + 'discriminator', True)

    np.savetxt(output_fold + '/' + 'd_loss', d_loss_ll)
    np.savetxt(output_fold + '/' + 'g_loss', g_loss_ll)
    

# Usage

In [None]:
class ARGS:
    def __init__(args):
        args.batch_size = 16
        args.epochs = 1000
        args.output_fold = 'GAN_OUT'
        args.input_dim = 10
        args.n_train = 32

args = ARGS()
train(args)