In [10]:
import numpy as np
from numpy import zeros
from numpy import ones
from numpy import expand_dims
from numpy import hstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.utils import to_categorical
from keras.models import Model
from keras import layers
import tensorflow_datasets as tfds
import tensorflow as tf
from matplotlib import pyplot

class InfoGAN:
    def __init__(self):
        
        # number of values for the categorical control code
        self.n_cat = 10
        # size of the latent space
        self.latent_dim = 62
        #Define size of generator input
        self.gen_input_size = self.latent_dim + self.n_cat
        self.batch_size=32
        self.kernel_init = RandomNormal(stddev=0.02)

    def normalize_img(self,image,label):
        label=1
        return tf.cast(image, tf.float32) / 255.,label
         
    def load_MNIST(self):
        ds, ds_info = tfds.load('mnist', split='train', with_info=True,as_supervised=True)
        shape=ds_info.features['image'].shape
        ds = ds.map(self.normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.cache()
        # For true randomness, we set the shuffle buffer to the full dataset size.
        ds = ds.shuffle(ds_info.splits['train'].num_examples)
        # Batch after shuffling to get unique batches at each epoch.
        ds = ds.batch(self.batch_size)
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
        return ds,ds_info, shape
    
    def load_SVHN(self):
        ds, ds_info = tfds.load('svhn_cropped', split='train', with_info=True,as_supervised=True)
        shape=ds_info.features['image'].shape
        ds = ds.map(self.normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.cache()
        # For true randomness, we set the shuffle buffer to the full dataset size.
        ds = ds.shuffle(ds_info.splits['train'].num_examples)
        # Batch after shuffling to get unique batches at each epoch.
        ds = ds.batch(self.batch_size)
        ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
        return ds,ds_info, shape
    
    def load_CELEBA(self):
        #To Do, access via tensorflow datasets seems to be broken due to the ds being hosted on Drive
        pass

    def show_real_examples(self,ds, ds_info):
        fig = tfds.show_examples(ds, ds_info)

    #Define the generator model:
    def get_generator(self,shape):
        input_latent = layers.Input(shape=(self.gen_input_size))
        
        #The various datasets have various sizes and #channels so the dimension of the generator needs to be variable
        res=int(shape[0]/4)
        depth=int(shape[-1])
        
        n_nodes = 128 * res * res

        g = layers.Dense(n_nodes)(input_latent)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)
        g = layers.Reshape((res, res, 128))(g)

        g = layers.Conv2D(64, 4, padding='same',activation='relu',kernel_initializer=self.kernel_init)(g)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)

        g = layers.Conv2DTranspose(32, 4, strides=(2,2), padding='same',kernel_initializer=self.kernel_init)(g)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)
        g = layers.Conv2DTranspose(depth, 4, strides=(2,2), padding='same',kernel_initializer=self.kernel_init)(g)
        # Sigmoid output
        output_layer = layers.Activation('sigmoid')(g)
        # define the generator model
        gen_model = Model(input_latent, output_layer)
        return gen_model
    
    
    def get_discr_q(self,input_shape):
        input_image=layers.Input(shape=input_shape)

        l=layers.Conv2D(32, 4, strides=(2,2),padding='same',kernel_initializer=self.kernel_init)(input_image)
        l=layers.ReLU()(l)

        l=layers.Conv2D(64, 4, strides=(2,2), padding='same',kernel_initializer=self.kernel_init)(l)
        l=layers.ReLU()(l)
        l=layers.BatchNormalization()(l)

        l=layers.Conv2D(128, 4, strides=(2,2), padding='same',kernel_initializer=self.kernel_init)(l)
        l=layers.ReLU()(l)
        l=layers.BatchNormalization()(l)

        l=layers.Flatten()(l)
        #Classification head of the discriminator
        out=layers.Dense(1,activation='sigmoid')(l)
        discr_model = Model(input_image, out)
        discr_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))
        # create q model layers
        q = layers.Dense(128)(l)
        q = layers.BatchNormalization()(q)
        q = layers.LeakyReLU(alpha=0.1)(q)
        # q model output
        out_codes = layers.Dense(self.n_cat, activation='softmax')(q)
        # define q model
        q_model = Model(input_image, out_codes)
        return discr_model, q_model
    
    def infogan(self,gen_model, discr_model, q_model):
        # make weights in the discriminator (some shared with the q model) as not trainable
        for layer in discr_model.layers:
            if not isinstance(layer, layers.BatchNormalization):
                layer.trainable = False
        # connect g outputs to d inputs
        discr_output = discr_model(gen_model.output)
        # connect g outputs to q inputs
        q_output = q_model(gen_model.output)
        # define composite model
        infogan_model = Model(gen_model.input, [discr_output, q_output])
        # compile model
        opt = Adam(lr=0.0002, beta_1=0.5)
        infogan_model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
        return infogan_model
    

    # generate points in latent space as input for the generator
    def generate_noise(self, n_samples):
        # generate points in the latent space
        z_latent=randn(n_samples,self.latent_dim)
        # generate categorical one-hot codes
        cat_codes=np.eye(self.n_cat)[np.random.choice(self.n_cat, n_samples)]
        # concatenate latent points and control codes
        z_input = hstack((z_latent, cat_codes))
        return [z_input, cat_codes]
    
    # use the generator to generate n fake examples, with class labels
    def generate_fake_samples(self,generator, n_samples):
        # generate points in latent space and control codes
        z_input, _ = self.generate_noise(n_samples)
        # predict outputs
        images = generator.predict(z_input)
        # create class labels
        y = zeros((n_samples, 1))
        return images, y
    
    def summarize_performance(self,step, gen_model, gan_model, n_samples=100):
        # prepare fake examples
        X, _ = self.generate_fake_samples(gen_model, n_samples)

        # plot images
        for i in range(n_samples):
            # define subplot
            pyplot.subplot(10, 10, 1 + i)
            # turn off axis
            pyplot.axis('off')
            # plot raw pixel data
            #pyplot.imshow(X[i, :, :, 0], cmap='gray_r')
            pyplot.imshow(X[i, :, :, :])
        # save plot to file
        filename1 = 'generated_plot_%04d.png' % (step+1)
        pyplot.savefig(filename1)
        pyplot.close()
        # save the generator model
        filename2 = 'model_%04d.h5' % (step+1)
        gen_model.save(filename2)
        # save the gan model
        filename3 = 'gan_model_%04d.h5' % (step+1)
        gan_model.save(filename3)
        print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))
        
    # train the generator and discriminator
    def train(self,gen_model, discr_model, infogan_model, dataset, n_epochs=100):
        # calculate the number of batches per training epoch
        bat_per_epo = len(ds)
        # calculate the number of training iterations
        n_steps = bat_per_epo * n_epochs
        print(n_steps)
        # manually enumerate epochs
        for i in range(n_steps):
            # update discriminator and q model weights
            d_loss1 = discr_model.fit(ds,steps_per_epoch=1,verbose=0)
            # generate 'fake' examples
            X_fake, y_fake = self.generate_fake_samples(gen_model, self.batch_size)
            # update discriminator model weights
            d_loss2 = discr_model.fit(X_fake, y_fake,verbose=0,batch_size=self.batch_size)
            # prepare points in latent space as input for the generator
            z_input, cat_codes = self.generate_noise(self.batch_size)
            # create inverted labels for the fake samples
            y_gan = ones((self.batch_size, 1))
            # update the g via the d and q error
            gan_loss = infogan_model.fit(z_input, [y_gan, cat_codes],verbose=0,batch_size=self.batch_size)
            # summarize loss on this batch
            print('>%d, d[%.3f,%.3f], g[%.3f] q[%.3f]' % (i+1, d_loss1.history['loss'][0], d_loss2.history['loss'][0], list(gan_loss.history.items())[1][1][0],list(gan_loss.history.items())[2][1][0]))
            # evaluate the model performance every 'epoch'
            if (i+1) % (bat_per_epo * 10) == 0:
                self.summarize_performance(i, gen_model, infogan_model)

In [11]:
gan=InfoGAN()
ds,ds_info, shape=gan.load_SVHN()
gen_model=gan.get_generator(shape)
discr_model,q_model=gan.get_discr_q(shape)
infogan_model=gan.infogan(gen_model, discr_model, q_model)

In [13]:
gan.train(gen_model, discr_model, infogan_model, ds, n_epochs=1)

2290
>1, d[0.364,1.022], g[0.948] q[0.160]
>2, d[0.307,0.542], g[1.481] q[0.056]
>3, d[0.789,0.516], g[1.059] q[0.158]
>4, d[0.495,0.699], g[1.445] q[0.163]
>5, d[0.439,0.639], g[1.404] q[0.161]
>6, d[0.671,0.637], g[1.319] q[0.089]
>7, d[0.482,0.452], g[1.412] q[0.173]
>8, d[0.652,0.526], g[1.441] q[0.052]
>9, d[0.382,0.464], g[1.440] q[0.071]
>10, d[0.573,0.475], g[1.268] q[0.094]
>11, d[0.452,0.535], g[1.403] q[0.066]
>12, d[0.539,0.511], g[1.358] q[0.110]
>13, d[0.668,0.450], g[1.189] q[0.190]
>14, d[0.416,0.540], g[1.334] q[0.055]
>15, d[0.778,0.839], g[1.294] q[0.145]
>16, d[0.722,0.774], g[1.187] q[0.500]
>17, d[0.356,0.498], g[1.388] q[0.210]
>18, d[0.507,0.399], g[1.665] q[0.080]
>19, d[0.874,0.543], g[1.126] q[0.076]
>20, d[0.454,0.506], g[1.514] q[0.118]
>21, d[0.264,0.420], g[1.208] q[0.082]
>22, d[0.776,0.659], g[1.225] q[0.125]
>23, d[0.360,0.452], g[1.078] q[0.141]
>24, d[0.881,0.597], g[1.220] q[0.060]
>25, d[0.675,0.768], g[1.314] q[0.081]
>26, d[0.811,1.029], g[0.988]

KeyboardInterrupt: 