In [None]:
from __future__ import division                                                 
from __future__ import print_function                                           
from __future__ import absolute_import 

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from keras.datasets import mnist

from keras.models import Sequential, Model
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

In [None]:
class ACGAN():
    def __init__(self): 
        self.configuration()

        self.build_generator()
        self.build_discriminator()
        self.compile_models()
        
        self.build_adversarial()
        self.compile_adversarial()
        
    def configuration(self):
        self.verbose = 0
        self.image_width  = 28
        self.image_height = 28
        self.image_channels = 1
        self.num_classes = 10
        
        self.label_size = 1
        self.noise_size = 100
        self.image_shape = (self.image_height, self.image_width, self.image_channels)
        
        self.optimizer = Adam(0.0002, 0.5)
    
    def generator_network(self):
        net = Sequential()
        net.add(Dense(128*7*7, activation="relu", input_dim=self.noise_size))
        net.add(Reshape((7, 7, 128)))
        net.add(BatchNormalization(momentum=0.8))
        net.add(UpSampling2D())
        net.add(Conv2D(128, kernel_size=3, padding="same"))
        net.add(Activation("relu"))
        net.add(BatchNormalization(momentum=0.8))
        net.add(UpSampling2D())
        net.add(Conv2D(64, kernel_size=3, padding="same"))
        net.add(Activation("relu"))
        net.add(BatchNormalization(momentum=0.8))
        net.add(Conv2D(self.image_channels, kernel_size=3, padding='same'))
        net.add(Activation("tanh"))
        if (self.verbose == 1):
            net.summary()
        return net
    
    def generator_input(self):
        noise = Input(shape=(self.noise_size,))
        label = Input(shape=(self.label_size,), dtype='int32')
        return  [noise, label]

    def generator_output(self, model, gen_in):
        noise, label = gen_in[0], gen_in[1]
        label_embedding = Flatten()(Embedding(self.num_classes, self.noise_size)(label))
        image_seed = multiply([noise, label_embedding])
        image_generated = model(image_seed)
        return image_generated
    
    def build_generator(self):
        model = self.generator_network()

        gen_in = self.generator_input()   
        gen_out = self.generator_output(model, gen_in)
        self.generator = Model(gen_in, gen_out) 
        
    def discriminator_network(self):
        net = Sequential()
        net.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.image_shape, padding="same"))
        net.add(LeakyReLU(alpha=0.2))
        net.add(Dropout(0.25))
        net.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        net.add(ZeroPadding2D(padding=((0,1),(0,1))))
        net.add(LeakyReLU(alpha=0.2))
        net.add(Dropout(0.25))
        net.add(BatchNormalization(momentum=0.8))
        net.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        net.add(LeakyReLU(alpha=0.2))
        net.add(Dropout(0.25))
        net.add(BatchNormalization(momentum=0.8))
        net.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        net.add(LeakyReLU(alpha=0.2))
        net.add(Dropout(0.25))
        net.add(Flatten())
        if (self.verbose == 1):
            net.summary()
        return net

    def discriminator_input(self):
        return Input(shape=self.image_shape)

    def discriminator_output(self, model, dis_in):
        features = model(dis_in)
        validity = Dense(1, activation="sigmoid")(features)
        label = Dense(self.num_classes+1, activation="softmax")(features)
        return [ validity, label ]
    
    def build_discriminator(self):
        model = self.discriminator_network();
        
        if (self.verbose == 1):
            model.summary()       

        dis_in = self.discriminator_input()
        dis_out = self.discriminator_output(model, dis_in)
        self.discriminator = Model(dis_in, dis_out)

    def compile_models(self):
        gen_losses = ['binary_crossentropy']
        dis_losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']
        self.generator.compile(loss=gen_losses, optimizer=self.optimizer)
        self.discriminator.compile(loss=dis_losses, optimizer=self.optimizer, metrics=['accuracy'])
        
    def adversarial_input(self):
        noise = Input(shape=(self.noise_size,))
        label = Input(shape=(self.label_size,))
        image_gen = self.generator([noise, label])
        return [noise, label], image_gen

    def adversarial_output(self, image_gen):
        validity, identification = self.discriminator(image_gen)
        return [ validity, identification ]
    
    def compile_adversarial(self):
        losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']
        self.adversarial.compile(loss=losses, optimizer=self.optimizer)
        
    def build_adversarial(self):
        ad_in, image_gen = self.adversarial_input()
        self.discriminator.trainable = False
        ad_out = self.adversarial_output(image_gen)
        self.adversarial = Model(ad_in, ad_out)

    def train(self, epochs, batch_size=128, save_interval=1000):
            X_train, y_train, _, _ = self.load_data()
            half_batch = int(batch_size/2)

            cw1 = {0: 1, 1: 1}
            cw2 = {i: self.num_classes / half_batch for i in range(self.num_classes)}
            cw2[self.num_classes] = 1 / half_batch
            class_weights = [cw1, cw2]
            
            def run_epoch(epoch):
                indexes = np.random.randint(0, X_train.shape[0], half_batch)
                imgs = X_train[indexes]
                img_labels = y_train[indexes]
                noise = np.random.normal(0, 1, (half_batch, self.noise_size))
                labels_for_gen = np.random.randint(0, 10, half_batch).reshape(-1, 1)
                gen_imgs = self.generator.predict([noise, labels_for_gen])

                valid = np.ones((half_batch, 1))
                fake = np.zeros((half_batch, 1))

                fake_labels = 10 * np.ones(half_batch).reshape(-1, 1)   
                d_loss_real = self.discriminator.train_on_batch(imgs, [valid, img_labels], class_weight=class_weights)
                d_loss_fake = self.discriminator.train_on_batch(gen_imgs, [fake, fake_labels], class_weight=class_weights)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

                noise = np.random.normal(0, 1, (batch_size, self.noise_size))
                valid = np.ones((batch_size, 1))
                sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
                g_loss = self.adversarial.train_on_batch([noise, sampled_labels], [valid, sampled_labels], class_weight=class_weights)
                print ("%d [D loss: %f, acc.: %.2f%%, op_acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[3], 100*d_loss[4], g_loss[0]))

            for epoch in range(epochs):
                run_epoch(epoch)
            self.save_model()
            self.save_imgs(epoch)

    def load_data(self):
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)
        return X_train, y_train, X_test, y_test
        
    def save_imgs(self, epoch):
        r, c = 2, 5
        noise = np.random.normal(0, 1, (r * c, 100))
        sampled_labels = np.arange(0, 10).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        fig.suptitle("ACGAN: Generated digits", fontsize=12)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
                axs[i,j].set_title("Digit: %d" % sampled_labels[cnt])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("./mnist_%d.png" % epoch)
        plt.close()

    def save_model(self):
        def save(model, model_name):
            model_path = "./%s.json" % model_name
            weights_path = "./%s_weights.hdf5" % model_name
            options = {"file_arch": model_path, "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "mnist_acgan_generator")
        save(self.discriminator, "mnist_acgan_discriminator")
        save(self.adversarial, "mnist_acgan_adversarial")


In [None]:
gan = ACGAN()
gan.train(epochs=6000, batch_size=32, save_interval=50)