In [1]:
from __future__ import division, print_function, unicode_literals
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import tensorflow as tf

## Loading MNIST Data

In [2]:
tf.reset_default_graph()
np.random.seed(42)
tf.set_random_seed(42)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST Dataset")

Extracting MNIST Dataset\train-images-idx3-ubyte.gz
Extracting MNIST Dataset\train-labels-idx1-ubyte.gz
Extracting MNIST Dataset\t10k-images-idx3-ubyte.gz
Extracting MNIST Dataset\t10k-labels-idx1-ubyte.gz


In [3]:
batch_size = 64
n_noise = 64
X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name="X")
noise = tf.placeholder(dtype=tf.float32, shape=[None, n_noise])
y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")

keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')

def lrelu(x):
    return tf.maximum(x, tf.multiply(x, 0.2))

def binary_cross_entropy(x, z):
    eps = 1e-12
    return (-(x * tf.log(z + eps) + (1. - x) * tf.log(1. - z + eps)))

## Generator Module

In [4]:
def generator(z, keep_prob=keep_prob, is_training=is_training):
    activation = lrelu
    momentum = 0.99
    with tf.variable_scope("generator", reuse=None):
        x = z
        d1 = 4
        d2 = 1
        x = tf.layers.dense(x, units=d1 * d1 * d2, activation=activation)
        x = tf.layers.dropout(x, keep_prob)      
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)  
        x = tf.reshape(x, shape=[-1, d1, d1, d2])
        x = tf.image.resize_images(x, size=[7, 7])
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=2, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=2, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=64, strides=1, padding='same', activation=activation)
        x = tf.layers.dropout(x, keep_prob)
        x = tf.contrib.layers.batch_norm(x, is_training=is_training, decay=momentum)
        x = tf.layers.conv2d_transpose(x, kernel_size=5, filters=1, strides=1, padding='same', activation=tf.nn.sigmoid)
        return x

## Discriminator Module

In [5]:
# Primary Capsule parameters
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6  # 1152 primary capsules
caps1_n_dims = 8
    
# Digit Capsule parameters
caps2_n_caps = 10
caps2_n_dims = 16    
init_sigma = 0.01
    
# Margin Loss parameters
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5

# Decoder parameters
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28

In [6]:
# Squashing the vector between 0 and 1 to get the probabilities value
def squash(s, axis=-1, epsilon=1e-7, name=None):
    with tf.name_scope(name, default_name="squash"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
                                     keep_dims=True)
        safe_norm = tf.sqrt(squared_norm + epsilon)
        squash_factor = squared_norm / (1. + squared_norm)
        unit_vector = s / safe_norm
        return squash_factor * unit_vector

In [7]:
# Safe normalization so that we don't have divide by zero error
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):
    with tf.name_scope(name, default_name="safe_norm"):
        squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
                                     keep_dims=keep_dims)
        return tf.sqrt(squared_norm + epsilon)

In [8]:
# Code by Parag Mital (github.com/pkmital/CADL)
def montage(images):
    if isinstance(images, list):
        images = np.array(images)
    img_h = images.shape[1]
    img_w = images.shape[2]
    n_plots = int(np.ceil(np.sqrt(images.shape[0])))
    m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5
    for i in range(n_plots):
        for j in range(n_plots):
            this_filter = i * n_plots + j
            if this_filter < images.shape[0]:
                this_img = images[this_filter]
                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
    return m

In [9]:
# To compute the outputs we first use two convolutional layer
conv1_params = {
            "filters": 256,
            "kernel_size": 9,
            "strides": 1,
            "padding": "valid",
            "activation": tf.nn.relu,
        }

conv2_params = {
            "filters": caps1_n_maps * caps1_n_dims, # 256 convolutional filters
            "kernel_size": 9,
            "strides": 2,
            "padding": "valid",
            "activation": tf.nn.relu
        }
    

In [10]:
def marginLoss(y_pred, caps2_output, X, y):
          
        with tf.variable_scope("loss", reuse = True):
        
            # Margin Loss
            T = tf.one_hot(y, depth=caps2_n_caps)
    
            caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True)
        
            present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm))
            present_error = tf.reshape(present_error_raw, shape=(-1, 10))
    
            absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus))
            absent_error = tf.reshape(absent_error_raw, shape=(-1, 10))
    
            L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error)
    
            margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1))
    
            # Reconstruction 
            mask_with_labels = tf.placeholder_with_default(False, shape=())
    
            reconstruction_targets = tf.cond(mask_with_labels, lambda: y, lambda: y_pred)
    
            reconstruction_mask = tf.one_hot(reconstruction_targets,depth=caps2_n_caps)
    
            reconstruction_mask_reshaped = tf.reshape( reconstruction_mask, [-1, 1, caps2_n_caps, 1, 1])
    
            caps2_output_masked = tf.multiply( caps2_output, reconstruction_mask_reshaped)
    
            decoder_input = tf.reshape(caps2_output_masked, [-1, caps2_n_caps * caps2_n_dims])
    
            # Decoder
            with tf.variable_scope(tf.get_variable_scope(), reuse = True):
                hidden1 = tf.layers.dense(decoder_input, n_hidden1, activation=tf.nn.relu)
                hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu)
                decoder_output = tf.layers.dense(hidden2, n_output, activation=tf.nn.sigmoid)
    
            # Reconstruction Loss
            X_flat = tf.reshape(X, [-1, n_output])
            squared_difference = tf.square(X_flat - decoder_output)
            reconstruction_loss = tf.reduce_sum(squared_difference)
    
            # Final Loss
            alpha = 0.0005
            loss = tf.add(margin_loss, alpha * reconstruction_loss)
        
        return loss

In [11]:
def discriminator(X, reuse = None):
    
    if(reuse):
        tf.get_variable_scope().reuse_variables()
    
    conv1 = tf.layers.conv2d(X, **conv1_params)
    conv2 = tf.layers.conv2d(conv1, **conv2_params)
    
    with tf.variable_scope('PrimaryCaps_layer'):
        caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims])
    
        caps1_output = squash(caps1_raw)
    
    with tf.variable_scope('DigitCaps_layer'):
        W_init = tf.random_normal(shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_n_dims), stddev=init_sigma, dtype=tf.float32)
        W = tf.Variable(W_init)
    
        batch_size = tf.shape(X)[0]
        W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1])
    
        caps1_output_expanded = tf.expand_dims(caps1_output, -1)
        caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2)
        caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1])
    
        caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled)
    
    with tf.variable_scope('Routing_by_Agreement'):
        raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],dtype=np.float32)
    
        routing_weights = tf.nn.softmax(raw_weights, dim=2)
        weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
        weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keep_dims=True)
        
        caps2_output_round_1 = squash(weighted_sum, axis=-2)
    
        caps2_output_round_1_tiled = tf.tile( caps2_output_round_1, [1, caps1_n_caps, 1, 1, 1])
    
        agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled, transpose_a=True)
    
        raw_weights_round_2 = tf.add(raw_weights, agreement)
    
        routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2, dim=2)
        weighted_predictions_round_2 = tf.multiply(routing_weights_round_2, caps2_predicted)
        weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2, axis=1, keep_dims=True)
        caps2_output_round_2 = squash(weighted_sum_round_2, axis=-2)
    
        caps2_output = caps2_output_round_2
    
    with tf.variable_scope('First_fully_connected_layer'):
        d_w3 = tf.get_variable('d_w3', [16*10, 1024], initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b3 = tf.get_variable('d_b3', [1024], initializer=tf.constant_initializer(0))
        d3 = tf.reshape(caps2_output, [-1, 16*10])
        d3 = tf.matmul(d3, d_w3)
        d3 = d3 + d_b3
        d3 = tf.nn.relu(d3)

    #The last fully-connected layer holds the output, such as the class scores.
    # Second fully connected layer
    with tf.variable_scope('second_fully_connected_layer'):
        d_w4 = tf.get_variable('d_w4', [1024, 1], initializer=tf.truncated_normal_initializer(stddev=0.02))
        d_b4 = tf.get_variable('d_b4', [1], initializer=tf.constant_initializer(0))

    #At the end of the network, we do a final matrix multiply and 
    #return the activation value. 
    #For those of you comfortable with CNNs, this is just a simple binary classifier. Nothing fancy.
    # Final layer
        d4 = tf.matmul(d3, d_w4) + d_b4
    return d4

In [12]:
g = generator(noise, keep_prob, is_training)
d_real = tf.cast(discriminator(X), tf.float32)
d_fake = tf.cast(discriminator(g, reuse=True), tf.float32)
print(d_real)

vars_g = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]

d_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_d)
g_reg = tf.contrib.layers.apply_regularization(tf.contrib.layers.l2_regularizer(1e-6), vars_g)

loss_d_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real, labels=tf.ones_like(d_real))
loss_d_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.zeros_like(d_fake))
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = d_fake, labels = tf.ones_like(d_fake)))
loss_d = tf.reduce_mean(0.5 * (loss_d_real + loss_d_fake))

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer_d = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_d + d_reg, var_list=vars_d)
    optimizer_g = tf.train.RMSPropOptimizer(learning_rate=0.00015).minimize(loss_g + g_reg, var_list=vars_g)
        
sess = tf.Session()
sess.run(tf.global_variables_initializer())

ValueError: Variable conv2d_2/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

In [None]:
for i in range(60000):
    train_d = True
    train_g = True
    keep_prob_train = 0.6 # 0.5
    
    print("Iteration: ", i)
    n = np.random.uniform(0.0, 1.0, [batch_size, n_noise]).astype(np.float32)   
    batch, _ = mnist.train.next_batch(batch_size=batch_size)
    batch = batch.reshape([-1, 28, 28, 1])
    
    d_real_ls, d_fake_ls, g_ls, d_ls = sess.run([loss_d_real, loss_d_fake, loss_g, loss_d], feed_dict={X: batch, noise: n, keep_prob: keep_prob_train, is_training:True})
    
    d_real_ls = np.mean(d_real_ls)
    d_fake_ls = np.mean(d_fake_ls)
    g_ls = g_ls
    d_ls = d_ls
    
    if g_ls * 1.5 < d_ls:
        train_g = False
        pass
    if d_ls * 2 < g_ls:
        train_d = False
        pass
    
    if train_d:
        sess.run(optimizer_d, feed_dict={noise: n, X: batch, keep_prob: keep_prob_train, is_training:True})
        
        
    if train_g:
        sess.run(optimizer_g, feed_dict={noise: n, keep_prob: keep_prob_train, is_training:True})
        
    if not i % 50:
        print (i, d_ls, g_ls, d_real_ls, d_fake_ls)
        if not train_g:
            print("not training generator")
        if not train_d:
            print("not training discriminator")
        gen_img = sess.run(g, feed_dict = {noise: n, keep_prob: 1.0, is_training:False})
        imgs = [img[:,:,0] for img in gen_img]
        m = montage(imgs)
        gen_img = m
        plt.axis('off')
        plt.imshow(gen_img, cmap='gray')
        plt.show()