## Capsule GAN 

In [None]:
# math library
import numpy as np

# Tensorflow library
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist1 = input_data.read_data_sets('MNIST Dataset/')
tf.reset_default_graph()
from keras.layers.advanced_activations import LeakyReLU
from keras.layers import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers import Lambda, Reshape
import matplotlib.pyplot as plt

from keras.datasets import mnist, cifar10

# Device check
from tensorflow.python.client import device_lib
print('Devices: ', device_lib.list_local_devices())

# GPU Check 
if not tf.test.gpu_device_name():
    print('No GPU found')
else:
    print('Default GPU Device: {}' .format(tf.test.gpu_device_name()))
    
batch_size = 32
n_noise = 100

X_in = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[None, n_noise])

keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

def binary_cross_entropy(x, z):
    eps = 1e-12
    return (-(x * tf.log(z + eps) + (1. - x) * tf.log(1. - z + eps)))


## Loading The Data

In [None]:
def load_dataset(dataset, width, height, channels):
    
    if dataset == 'mnist':
        # Load the MNIST data
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        
        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis = 3)
        
    # Defining input dims
    img_rows = width 
    img_cols = height
    channels = channels
    img_shape = [img_rows, img_cols, channels]
    
    return X_train, img_shape 

dataset_, shape_ = load_dataset('mnist', 28, 28, 1)
print("Dataset shape: " +str(dataset_.shape) + " Image shape: " +str(shape_))

## Squash Function

In [None]:
def squash(vectors, axis = 1):
    epsilon = 1e-07
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keep_dims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm +epsilon)
    return scale * vectors

## Defining the model

In [None]:
# discriminator structure 
def discriminator(img, phase = False, reuse = None):
        
    with tf.variable_scope('discriminator', reuse = reuse):
        
        x = tf.reshape(img, shape=[-1, 28, 28, 1])    
        x = tf.layers.conv2d(x, kernel_size=9, filters=256, strides=1, padding='VALID')
        print(x.shape)
        x = LeakyReLU()(x)
        print(x.shape)
        
        x = tf.contrib.layers.batch_norm(x, center = True, scale = True, is_training = phase, scope = 'bn')
    
        # Primary Capsule 
        x = tf.contrib.layers.conv2d(x, num_outputs = 8 * 32, kernel_size = 9, stride = 2, padding = 'VALID')
        print(x.shape)
        x = Reshape(target_shape=[-1, 8], name='primarycap_reshape')(x)
        x = Lambda(squash, name = 'primary_squash')(x)
        
        x = tf.contrib.layers.batch_norm(x, center = True, scale = True, is_training = phase, scope = 'bn1')
        print(x.shape)

        # Flattening the layers
        x = tf.contrib.layers.flatten(x) 
        print(x.shape)
        uhat = tf.layers.dense(x, 160, bias_initializer=tf.zeros_initializer(), kernel_initializer = tf.keras.initializers.he_uniform(seed = 123))
        print(uhat.shape)

        # Routing by agreement Round 1
        c = tf.nn.softmax(uhat)
        c = tf.layers.dense(c, 160) #made a mistake it was x
        x = tf.multiply(uhat, c)
        s_j = LeakyReLU()(x)
        print(s_j.shape)

        # Routing by agreement Round 2
        c = tf.nn.softmax(s_j)
        c = tf.layers.dense(c, 160)
        x = tf.multiply(uhat, c)
        s_j = LeakyReLU()(x)
        print(s_j.shape)

        # Routing by agreement Round 3
        c = tf.nn.softmax(s_j)
        c = tf.layers.dense(c, 160)
        x = tf.multiply(uhat, c)
        s_j = LeakyReLU()(x)
        print(s_j.shape)

        pred = tf.layers.dense(s_j, 1, activation = tf.nn.sigmoid)
        print(pred.shape)
        return pred
    

In [None]:
# Generator model
def generator(noise, keep_prob = keep_prob, is_training = is_training):
    
    x = noise 
    momentum = 0.8    
    
    with tf.variable_scope("generator", reuse=None):           
        print(noise.shape)
        x = tf.layers.dense(x, units=128 * 7 * 7, activation=tf.nn.relu)  
        print(x.shape)
        x = Reshape((7, 7, 128))(x)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)          
        print(x.shape)
                
        x = tf.layers.conv2d_transpose(x, kernel_size=3, filters=128, strides=2, padding='same', activation=tf.nn.relu)
        print(x.shape)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        print(x.shape)
        
        x = tf.layers.conv2d_transpose(x, kernel_size=3, filters=64, strides=2, padding='same', activation=tf.nn.relu)                
        print(x.shape)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        print(x.shape)
        
        x = tf.layers.conv2d_transpose(x, kernel_size=3, filters=1, strides=1, padding='same', activation=tf.nn.tanh)        
        print(x.shape)
                
        return x 

In [None]:
g = generator(noise, keep_prob, is_training)
d_real = discriminator(X_in)
d_fake = discriminator(g, reuse=True)

vars_g = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]

d_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_d)
g_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_g)

loss_d_real = binary_cross_entropy(tf.ones_like(d_real), d_real)
loss_d_fake = binary_cross_entropy(tf.zeros_like(d_fake), d_fake)
loss_g = tf.reduce_mean(binary_cross_entropy(tf.ones_like(d_fake), d_fake))
loss_d = tf.reduce_mean(0.5 * (loss_d_real + loss_d_fake))
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    #optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_d + d_reg, var_list=vars_d)
    #optimizer_g = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_g + g_reg, var_list=vars_g)
    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002).minimize(loss_d + d_reg, var_list=vars_d)
    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002).minimize(loss_g + g_reg, var_list=vars_g)
    
    
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
# Code by Parag Mital (github.com/pkmital/CADL)
def montage(images):
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

In [None]:
vals = 0

for i in range(30000):
    train_d = True
    train_g = True
    keep_prob_train = 0.6 # 0.5

    print("Iteration: ", i)
    n = np.random.uniform(0.0, 1.0, [batch_size, n_noise]).astype(np.float32)   
    batch = [np.reshape(b, [28, 28, 1]) for b in mnist1.train.next_batch(batch_size=batch_size)[0]] 
    #x = tf.reshape(batch, shape=[-1, 28, 28, 1])

    vals += 32
    #print("hello")
    d_real_ls, d_fake_ls, g_ls, d_ls = sess.run([loss_d_real, loss_d_fake, loss_g, loss_d], 
                                                feed_dict={X_in: batch, noise: n, keep_prob: keep_prob_train, is_training:True})

    d_real_ls = np.mean(d_real_ls)
    d_fake_ls = np.mean(d_fake_ls)
    g_ls = g_ls
    d_ls = d_ls

    if g_ls * 1.5 < d_ls:
        train_g = False
        pass
    if d_ls * 2 < g_ls:
        train_d = False
        pass

    if train_d:
        sess.run(optimizer_d, feed_dict={noise: n, X_in: batch, keep_prob: keep_prob_train, is_training:True})


    if train_g:
        sess.run(optimizer_g, feed_dict={noise: n, keep_prob: keep_prob_train, is_training:True})

    if not i % 50:
        print (i, d_ls, g_ls, d_real_ls, d_fake_ls)
        if not train_g:
            print("not training generator")
        if not train_d:
            print("not training discriminator")
        gen_img = sess.run(g, feed_dict = {noise: n, keep_prob: 1.0, is_training:False})
        imgs = [img[:,:,0] for img in gen_img]
        m = montage(imgs)
        gen_img = m
        plt.axis('off')
        plt.imshow(gen_img, cmap='gray')
        plt.show()

In [None]:
temp = dataset_[0:32]
print(temp.shape)