In [37]:
import numpy as np
from numpy import zeros
from numpy import ones
from numpy import expand_dims
from numpy import hstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.utils import to_categorical
from keras.models import Model
from keras import layers
import tensorflow_datasets as tfds
import tensorflow as tf
class InfoGAN:
    def __init__(self):
        
        #Define the discriminator model:
        # number of values for the categorical control code
        self.n_cat = 10
        # size of the latent space
        self.latent_dim = 62
        #Define size of generator input
        self.gen_input_size = self.latent_dim + self.n_cat
        self.batch_size=128
        
    def normalize_img(self,image,label):
        label=1
        return tf.cast(image, tf.float32) / 255.,label
         
    def load_MNIST(self):
        ds, ds_info = tfds.load('mnist', split='train', with_info=True,as_supervised=True)
        shape=ds_info.features['image'].shape
        ds = ds.map(self.normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.shuffle(ds_info.splits['train'].num_examples)
        ds=ds.batch(self.batch_size)
        return ds,ds_info, shape
    
    def load_SVHN(self):
        ds, ds_info = tfds.load('svhn_cropped', split='train', with_info=True,as_supervised=True)
        shape=ds_info.features['image'].shape
        ds = ds.map(self.normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.shuffle(ds_info.splits['train'].num_examples)
        ds=ds.batch(self.batch_size)
        return ds,ds_info, shape
    
    def load_CELEBA(self):
        #To Do, access via tensorflow datasets seems to be broken due to the ds being hosted on Drive
        pass

    def show_examples(self,ds, ds_info):
        fig = tfds.show_examples(ds, ds_info)

    #Define the generator model:
    def get_generator(self,shape):
        input_latent = layers.Input(shape=(self.gen_input_size))
        
        #The various datasets have various sizes and #channels so the dimension of the generator needs to be variable
        res=int(shape[0]/4)
        depth=int(shape[-1])
        
        n_nodes = 512 * res * res

        g = layers.Dense(n_nodes)(input_latent)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)
        g = layers.Reshape((res, res, 512))(g)

        g = layers.Conv2D(128, 4, padding='same',activation='relu')(g)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)

        g = layers.Conv2DTranspose(64, 4, strides=(2,2), padding='same')(g)
        g = layers.ReLU()(g)
        g = layers.BatchNormalization()(g)
        g = layers.Conv2DTranspose(depth, 4, strides=(2,2), padding='same')(g)
        # tanh output
        output_layer = layers.Activation('sigmoid')(g)
        # define the generator model
        gen_model = Model(input_latent, output_layer)
        return gen_model
    
    
    def get_discr_q(self,input_shape):
        input_image=layers.Input(shape=input_shape)

        l=layers.Conv2D(64, 4, strides=(2,2),padding='same')(input_image)
        l=layers.ReLU()(l)

        l=layers.Conv2D(128, 4, strides=(2,2), padding='same')(l)
        l=layers.ReLU()(l)
        l=layers.BatchNormalization()(l)

        l=layers.Conv2D(256, 4, strides=(2,2), padding='same')(l)
        l=layers.ReLU()(l)
        l=layers.BatchNormalization()(l)

        l=layers.Flatten()(l)
        #Classification head of the discriminator
        out=layers.Dense(1,activation='sigmoid')(l)
        discr_model = Model(input_image, out)
        discr_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.001, beta_1=0.9))
        # create q model layers
        q = layers.Dense(128)(l)
        q = layers.BatchNormalization()(q)
        q = layers.LeakyReLU(alpha=0.1)(q)
        # q model output
        out_codes = layers.Dense(self.n_cat, activation='softmax')(q)
        # define q model
        q_model = Model(input_image, out_codes)
        return discr_model, q_model
    
    def infogan(self,gen_model, discr_model, q_model):
        # make weights in the discriminator (some shared with the q model) as not trainable
        for layer in discr_model.layers:
            if not isinstance(layer, layers.BatchNormalization):
                layer.trainable = False
        # connect g outputs to d inputs
        discr_output = discr_model(gen_model.output)
        # connect g outputs to q inputs
        q_output = q_model(gen_model.output)
        # define composite model
        infogan_model = Model(gen_model.input, [discr_output, q_output])
        # compile model
        opt = Adam(lr=0.001, beta_1=0.9)
        infogan_model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
        return infogan_model
    

    # generate points in latent space as input for the generator
    def generate_noise(self, n_samples):
        # generate points in the latent space
        z_latent=randn(n_samples,self.latent_dim)
        # generate categorical one-hot codes
        cat_codes=np.eye(self.n_cat)[np.random.choice(self.n_cat, n_samples)]
        # concatenate latent points and control codes
        z_input = hstack((z_latent, cat_codes))
        return [z_input, cat_codes]
    
    # use the generator to generate n fake examples, with class labels
    def generate_noise_samples(generator, latent_dim, n_cat, n_samples):
        # generate points in latent space and control codes
        z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)
        # predict outputs
        images = generator.predict(z_input)
        # create class labels
        y = zeros((n_samples, 1))
        return images, y

In [38]:
gan=InfoGAN()
ds,ds_info, shape=gan.load_SVHN()
gen_test=gan.get_generator(shape)
discr_test,q_test=gan.get_discr_q(shape)

In [23]:
info_test=gan.infogan(gen_test,discr_test,q_test)

In [35]:
ds_train = ds.batch(128)


In [29]:
test=list(ds.as_numpy_iterator())


In [48]:
discr_test.fit(ds_train,steps_per_epoch=1)




<tensorflow.python.keras.callbacks.History at 0x7ffc72a3fed0>