In [1]:
# https://arxiv.org/pdf/1411.1784.pdf

In [1]:
from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, BatchNormalization, MaxPooling2D, Activation, ZeroPadding2D, multiply, Embedding
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam, SGD

import matplotlib.pyplot as plt

import sys

import numpy as np

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
class CGAN():
    def __init__(self):
        self.latent_dim = 100
        self.rows = 28
        self.cols = 28
        self.channels = 1
        self.num_classes = 10
        self.img_shape = (self.rows, self.cols, self.channels)
        
        self.sgd = SGD(lr=0.0005, decay=2e-5, momentum=0.9, nesterov=True)
        self.adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=5e-4)

        self.discriminator = self.discriminator_model()
        self.discriminator.compile(loss = 'binary_crossentropy', optimizer = self.adam, metrics = ["accuracy"])
        
        self.generator = self.generator_model()
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        
        generator_input = [noise, label]
        generator_output = self.generator(generator_input)
        self.discriminator.trainable = False
        discriminator_output = self.discriminator([generator_output, label])
        
        self.combined_model = Model([noise, label], discriminator_output)
        self.combined_model.compile(loss = 'binary_crossentropy', optimizer = self.adam, metrics = ["accuracy"])
        
    def generator_model(self):

        gen_input = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        label_input = multiply([label_embedding, gen_input])
        
        
        dense_1 = Dense(1024, activation="tanh")(label_input)
        dense_1 = Dense(128 * 7 * 7, activation="tanh")(dense_1)
        dense_1 = BatchNormalization()(dense_1)
        dense_reshape = Reshape((7, 7, 128))(dense_1)
        
        upsample_1 = UpSampling2D(size = 2)(dense_reshape)
        conv_1 = Conv2D(64, kernel_size=5, padding="same", activation='tanh')(upsample_1)
        conv_1 = BatchNormalization()(conv_1)
        
        upsample_2 = UpSampling2D(size = 2)(conv_1)
        out = Conv2D(self.channels, kernel_size=5, padding="same", activation='tanh')(upsample_2)
        
        model = Model([gen_input, label], out)

        print(model.summary())
        
        return model

    def discriminator_model(self):
        
        dis_input = Input(shape=(self.rows, self.cols, self.channels))
        label = Input(shape=(1,), dtype='int32')
        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        label_embedding = Reshape(self.img_shape)(label_embedding)
        
        concatenated = Concatenate(axis=-1)([dis_input, label_embedding])
        
        conv_1 = Conv2D(64, kernel_size=5, padding="same", activation="tanh")(concatenated)
        pool_1 = MaxPooling2D(pool_size=(2,2))(conv_1)
        
        conv_2 = Conv2D(128, kernel_size=5, padding="same", activation="tanh")(pool_1)
        pool_2 = MaxPooling2D(pool_size=(2,2))(conv_2)
        
        flatten = Flatten()(pool_2)
        flatten = Dense(1024, activation='tanh')(flatten)
        
        out = Dense(1, activation="sigmoid")(flatten)
        
        model = Model([dis_input, label], out)
        
        print(model.summary())
        return model
    
    def train(self, epochs, batch_size=128, save_interval=100):
        (X_train, y_train), (_, _) = mnist.load_data()
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
        y_train = y_train.reshape(-1, 1)
        print (X_train.shape)

        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
        
        for i in range(epochs):
            for iter_count in range(X_train.shape[0]//batch_size):
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs, labels = X_train[idx], y_train[idx]

                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
                gen_imgs = self.generator.predict([noise, labels])

                d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
                d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                
                sampled_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
                g_loss = self.combined_model.train_on_batch([noise, sampled_labels], valid)

                if iter_count % save_interval == 0:
                    print ("%d %d [D loss: %f, acc.: %.2f%%] [G loss: %f, acc.: %.2f%%]" % (i, iter_count, d_loss[0], 100*d_loss[1], g_loss[0], 100*g_loss[1]))
            
            self.save_imgs(i)                
            self.combined_model.save('combined_model.h5')
            
    def save_imgs(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        sampled_labels = np.random.randint(0, 10, r * c).reshape(-1, 1)
        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/mnist_%d.png" % epoch)
        plt.close()
        

In [None]:
if __name__ == '__main__':
    cond_gan = CGAN()
    cond_gan.train(epochs=200, batch_size=128, save_interval=50)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, 1, 784)       7840        input_2[0][0]                    
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 784)          0           embedding_1[0][0]                
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
reshape_1 

  'Discrepancy between trainable weights and collected trainable'


0 0 [D loss: 7.681523, acc.: 48.44%] [G loss: 0.542429, acc.: 72.66%]
0 50 [D loss: 0.000000, acc.: 100.00%] [G loss: 0.148961, acc.: 98.44%]
0 100 [D loss: 0.005460, acc.: 99.61%] [G loss: 0.022843, acc.: 99.22%]
0 150 [D loss: 0.000044, acc.: 100.00%] [G loss: 0.065049, acc.: 97.66%]
0 200 [D loss: 0.028384, acc.: 99.22%] [G loss: 0.000924, acc.: 100.00%]
0 250 [D loss: 0.000282, acc.: 100.00%] [G loss: 0.017913, acc.: 99.22%]
0 300 [D loss: 0.000100, acc.: 100.00%] [G loss: 0.027685, acc.: 99.22%]
0 350 [D loss: 0.000000, acc.: 100.00%] [G loss: 0.000016, acc.: 100.00%]
0 400 [D loss: 0.000003, acc.: 100.00%] [G loss: 0.004705, acc.: 100.00%]
0 450 [D loss: 0.000025, acc.: 100.00%] [G loss: 0.000078, acc.: 100.00%]
1 0 [D loss: 0.000424, acc.: 100.00%] [G loss: 0.013315, acc.: 99.22%]
1 50 [D loss: 0.005553, acc.: 99.61%] [G loss: 0.000383, acc.: 100.00%]
1 100 [D loss: 0.004364, acc.: 100.00%] [G loss: 0.015243, acc.: 100.00%]
1 150 [D loss: 0.001404, acc.: 100.00%] [G loss: 0.0150