In [None]:
from keras import backend, models
from keras.layers import Dense, Conv1D, Reshape, Flatten, Lambda
from keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# 연쇄 방식 Modeling - OOP

In [None]:
def add_decorate(x):
    m = backend.mean(x, axis=-1, keepdims=True)
    d = backend.square(x - m)
    return backend.concatenate([x, d], axis=-1)

def add_decorate_shape(input_shape):
    shape = list(input_shape)
    assert len(shape) == 2
    shape[1] *= 2
    return tuple(shape)

LEARNING_RATE = 2e-4
adam = Adam(learning_rate=LEARNING_RATE, beta_1=0.9, beta_2=0.999)

def model_compile(model):
    return model.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])

class GAN:
    def __init__(self, ni_d, nh_d, nh_g):
        self.ni_d = ni_d
        self.nh_d = nh_d
        self.nh_g = nh_g

        self.d = self.gen_d()
        self.g = self.gen_g()
        self.gd = self.make_gd()

    def gen_d(self):
        ni_d = self.ni_d
        nh_d = self.nh_g
        
        d = models.Sequential()
        d.add(Lambda(add_decorate, output_shape=add_decorate_shape, input_shape=(ni_d,)))
        d.add(Dense(nh_d, activation='relu'))
        d.add(Dense(nh_d, activation='relu'))
        d.add(Dense(1, activation='sigmoid'))

        model_compile(d)
        return d

    def gen_g(self):
        ni_d = self.ni_d
        nh_g = self.nh_g

        g = models.Sequential()
        g.add(Reshape((ni_d, 1), input_shape=(ni_d,)))
        g.add(Conv1D(nh_g, 1, activation='relu'))
        g.add(Conv1D(nh_g, 1, activation='sigmoid'))
        g.add(Conv1D(1, 1))
        g.add(Flatten())

        model_compile(g)
        return g

    def make_gd(self):
        g, d = self.g, self.d
        gd = models.Sequential()
        gd.add(g)
        gd.add(d)
        d.trainable = False
        model_compile(gd)
        d.trainable = True
        return gd

    def d_train_on_batch(self, real, gen):
        x = np.concatenate([real, gen], axis=0)
        y = np.array([1] * real.shape[0] + [0] * gen.shape[0])
        self.d.train_on_batch(x, y)

    def gd_train_on_batch(self, z):
        gd = self.gd
        y = np.array([1] * z.shape[0])
        gd.train_on_batch(z, y)

# Data Generator

In [None]:
class DataGenerator:
    def __init__(self, mu, sigma, ni_d):
        self.real_sample = lambda n_batch: np.random.normal(mu, sigma, (n_batch, ni_d))
        self.in_sample = lambda n_batch: np.random.rand(n_batch, ni_d)

# Training and Evaluation

In [None]:
class Machine:
    def __init__(self, n_batch=10, ni_d=100):
        data_mean = 4
        data_stddev = 1.25

        self.n_iter_d = 1
        self.n_iter_g = 5
        self.n_batch = n_batch

        self.data = DataGenerator(data_mean, data_stddev, ni_d)
        self.gan = GAN(ni_d=ni_d, nh_d=50, nh_g=50)

    @staticmethod
    def print_stat(real, gen):
        def stat(d):
            return (np.mean(d), np.std(d))
        print(f'Mean and Std of Real: {stat(real)}')
        print(f'Mean and Std of Gen: {stat(gen)}')

    def train_d(self):
        real = self.data.real_sample(self.n_batch)

        z = self.data.in_sample(self.n_batch)
        gen = self.gan.g.predict(z)

        self.gan.d_trainable = True
        self.gan.d_train_on_batch(real, gen)

    def train_gd(self):
        z = self.data.in_sample(self.n_batch)

        self.gan.d.trainable = False
        self.gan.gd_train_on_batch(z)

    def train_each(self):
        for it in range(self.n_iter_d):
            self.train_d()
        for it in range(self.n_iter_g):
            self.train_gd()

    def train(self, epochs):
        for epoch in range(epochs):
            self.train_each()

    def test(self, n_test):
        z = self.data.in_sample(n_test)
        gen = self.gan.g.predict(z)
        return gen, z

    def show_hist(self, real, gen, z):
        plt.hist(real.reshape(-1), histtype='step', label='Real')
        plt.hist(gen.reshape(-1), histtype='step', label='Generated')
        plt.hist(z.reshape(-1), histtype='step', label='Input')
        plt.legend(loc=0)

    def test_and_show(self, n_test):
        gen, z = self.test(n_test)
        real = self.data.real_sample(n_test)
        self.show_hist(real, gen, z)
        Machine.print_stat(real, gen)

    def run_epochs(self, epochs, n_test):
        self.train(epochs)
        self.test_and_show(n_test)

    def run(self, n_repeat=200, n_show=200, n_test=100):
        for ii in range(n_repeat):
            print(f'Stage {ii}, (Epoch: {ii * n_show})')
            self.run_epochs(n_show, n_test)
            plt.show()

# Usage

In [None]:
machine = Machine(n_batch=1, ni_d=100)
machine.run(n_repeat=200, n_show=200, n_test=100)