In [None]:
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import math, os

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

In [None]:
tf.reset_default_graph()

# Define some handy network layers
def lrelu(x, rate=0.1):
    return tf.maximum(tf.minimum(x * rate, 0), x)

def conv2d_lrelu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d(inputs, num_outputs, kernel_size, stride, 
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    conv = lrelu(conv)
    return conv

def conv2d_t_relu(inputs, num_outputs, kernel_size, stride):
    conv = tf.contrib.layers.convolution2d_transpose(inputs, num_outputs, kernel_size, stride,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(),
                                                     activation_fn=tf.identity)
    conv = tf.nn.relu(conv)
    return conv

def fc_lrelu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    fc = lrelu(fc)
    return fc

def fc_relu(inputs, num_outputs):
    fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
                                           weights_initializer=tf.contrib.layers.xavier_initializer(),
                                           activation_fn=tf.identity)
    fc = tf.nn.relu(fc)
    return fc

In [None]:
# hyperparameters
batch_size = 200
z_dim = 10
# Lambda = 0.005 # kl
Lambda1 = 1 # mmd
Lambda2 = 0.01 # HSIC
prob_to_hsic = 0.75
Sigma = 1
steps = 5001
K = 'Gaussian'
# bandwidth = np.sqrt(z_dim) * Sigma # or -1
bandwidth = 0


In [None]:
def encoder(x, z_dim, reuse=False):
    with tf.device('/gpu:0'):
        with tf.variable_scope('encoder', reuse = tf.AUTO_REUSE) as en:
            if reuse:
                en.reuse_variables()

            conv1 = conv2d_lrelu(x, 64, 4, 2)
            conv2 = conv2d_lrelu(conv1, 128, 4, 2)
            conv2 = tf.reshape(conv2, [-1, np.prod(conv2.get_shape().as_list()[1:])])

            return tf.contrib.layers.fully_connected(conv2, z_dim * 2, activation_fn=tf.identity)

In [None]:
def decoder(z, reuse=False):
    with tf.device('/gpu:0'):
        with tf.variable_scope('decoder', reuse = tf.AUTO_REUSE) as vs:
            if reuse:
                vs.reuse_variables()

            fc2 = fc_relu(z, 7*7*128)
            fc2 = tf.reshape(fc2, tf.stack([tf.shape(fc2)[0], 7, 7, 128]))
            conv1 = conv2d_t_relu(fc2, 64, 4, 2)
            output = tf.contrib.layers.convolution2d_transpose(conv1, 1, 4, 2, activation_fn=tf.sigmoid)
            return output

In [None]:
# MMD implementation
def comp_med(X,kernel='Gaussian'):
    H = compute_diff(X,X)
    H = tf.where(tf.greater(H,0))
    if tf.shape(H)[0] == 0:
        return 1
    else:
        h = tf.contrib.distributions.percentile(H, 50.0)
        if kernel == 'Gaussian':
            return tf.sqrt(tf.cast(h/2,tf.float32)) 
            #return np.sqrt(0.5 * h) / np.log(X.shape[0]+1)
        else:
            return tf.sqrt(tf.cast(h,tf.float32))

def compute_diff(X,Y):
    XY = X[:,tf.newaxis,:] - Y
    out = tf.einsum('ijk,ijk->ij',XY,XY)
    return out
    
def compute_gram(X,Y, bandwidth, kernel='Gaussian'):
    H = compute_diff(X,Y)
    if kernel == 'Gaussian':
        return tf.exp(- H / 2 / tf.cast(bandwidth**2, tf.float32)) # (x_size, y_size)
    else: # inverse multiquadratics kernel
        # return c / (c + H) 
        return tf.pow(tf.cast(bandwidth**2,tf.float32) + H, -0.5)

# MMD

def MMD(X, Y, kernel='Gaussian', bandwidth=0):
    
    if bandwidth <= 0: # median heuristic
        bandwidth = comp_med(tf.concat([X,Y],0))

    XX = compute_gram(X,X,bandwidth,kernel)
    YY = compute_gram(Y,Y,bandwidth,kernel)
    XY = compute_gram(X,Y,bandwidth,kernel)
    XX = tf.reduce_mean(XX)
    YY = tf.reduce_mean(YY)
    XY = tf.reduce_mean(XY)
    
    return XX + YY - 2*XY # biased V-stats

In [None]:
# HSIC implemetation

def HSIC(X, Y, kernel='Gaussian', bandwidthX=0, bandwidthY=0, normalized=True):
    if tf.keras.backend.ndim(X) == 1:
        X = tf.expand_dims(X,axis=1)
    if tf.keras.backend.ndim(Y) == 1:
        Y = tf.expand_dims(Y,axis=1)
    n = tf.shape(X)[0]
    
    if bandwidthX == 0:
        bandwidthX = comp_med(X,kernel)
    if bandwidthY == 0:
        bandwidthY = comp_med(Y,kernel)
        
    K = compute_gram(X,X,bandwidthX,kernel)
    L = compute_gram(Y,Y,bandwidthY,kernel)
    H = tf.eye(n) - (tf.ones([n,n]) / tf.cast(n,tf.float32))
    Kc = tf.matmul(tf.matmul(H,K),H);
    trace = tf.reduce_sum(L*(tf.transpose(Kc)))
    if normalized:
        HKH = tf.norm(tf.matmul(tf.matmul(H,K),H),ord='fro',axis=[0,1])
        HLH = tf.norm(tf.matmul(tf.matmul(H,L),H),ord='fro',axis=[0,1])
        return trace / (HKH*HLH)
    else:
        return trace / tf.cast((n*n),tf.float32)

In [None]:
def reparameterize(mean, logvar, random=True):
    if random:
        eps = tf.random_normal(shape=tf.shape(mean))
        return eps * tf.exp(logvar * .5) + mean
    else: # deterministic
        return mean
    
def push_forward(encoder, train_x):
    mean, logvar = tf.split(encoder(train_x,z_dim), num_or_size_splits=2, axis=1)
    return reparameterize(mean,logvar)

In [None]:
# Build the computation graph for training
train_x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
train_y = tf.placeholder(tf.float32, shape=[None])

train_z = push_forward(encoder, train_x)
train_xr = decoder(train_z)

# Build the computation graph for generating samples# Build 
gen_z = tf.placeholder(tf.float32, shape=[None, z_dim])
gen_x = decoder(gen_z, reuse=True)

pretrained_mean, pretrained_var = tf.split(encoder(train_x, z_dim, reuse=True), num_or_size_splits=2, axis=1)
pretrained_z = reparameterize(pretrained_mean, pretrained_var)

In [None]:
# Compare the generated z with true samples from a standard Gaussian, and compute their MMD distance
true_samples = tf.random_normal([batch_size, z_dim],stddev=Sigma)
loss_mmd = Lambda1 *  MMD(true_samples, train_z, kernel=K, bandwidth=bandwidth)

loss_nll = tf.reduce_mean(tf.square(train_xr - train_x))

hsic_signal = tf.placeholder(tf.bool)  #placeholder for a single boolean value
hsic_trigger = tf.cond(tf.equal(hsic_signal, tf.constant(True)), lambda: tf.constant(1, tf.float32), lambda: tf.constant(0, tf.float32))

loss_hsic = Lambda2 * HSIC(train_z, train_y)
loss_hsic = hsic_trigger * loss_hsic

loss = loss_nll + loss_mmd - loss_hsic

trainer = tf.train.AdamOptimizer(1e-3).minimize(loss)

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('mnist_data')

# Convert a numpy array of shape [batch_size, height, width, 1] into a displayable array 
# of shape [height*sqrt(batch_size, width*sqrt(batch_size))] by tiling the images
def convert_to_display(samples):
    cnt, height, width = int(math.floor(math.sqrt(samples.shape[0]))), samples.shape[1], samples.shape[2]
    samples = np.transpose(samples, axes=[1, 0, 2, 3])
    samples = np.reshape(samples, [height, cnt, cnt, width])
    samples = np.transpose(samples, axes=[1, 0, 2, 3])
    samples = np.reshape(samples, [height*cnt, width*cnt])
    return samples

In [None]:
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

saver = tf.train.Saver()

In [None]:
import time
# Start training

mmd_list = []
o_loss_list = []
hisc_list = []
steps_list = []

# using median heuristic bandwidth and HSIC
start_time = time.time()
for i in range(steps):
    batch_x, batch_y = mnist.train.next_batch(batch_size)
    batch_x = batch_x.reshape(-1, 28, 28, 1)
    use_hsic = False
    if np.random.random() >= prob_to_hsic:
        use_hsic = True
    _, o_loss, nll, mmd, hsic = sess.run([trainer, loss, loss_nll, loss_mmd, loss_hsic], feed_dict={train_x: batch_x, train_y: batch_y, 
                                                                                                     hsic_signal: use_hsic})
    if i % 100 == 0:
        print("epoch: {}, Overall loss is {}, Negative log likelihood is {}, mmd loss is {}, hsic is {}".format(
            i, o_loss, nll, mmd, hsic))
        elapsed_time = time.time() - start_time
        print("time elapsed: {0:.2f}s".format(elapsed_time))
        start_time = time.time()
        # storing data for plot 
        mmd_list += [mmd]
        o_loss_list += [o_loss]
        steps_list += [i]
        hisc_list += [hsic]
        
    if i % 500 == 0:
        # feed in test image to get generated mmd loss
        test_x, test_y = mnist.test.next_batch(batch_size)
        test_x = test_x.reshape(-1, 28, 28, 1)
        samples, gen_mmd, my_z = sess.run([gen_x, loss_mmd, pretrained_z],  feed_dict={gen_z: np.random.normal(size=(100, z_dim)), train_x: test_x, train_y: test_y})
        plt.imshow(convert_to_display(samples), cmap='Greys_r')
        plt.show()
        print("generated mmd loss: {}, my_z: {}".format(gen_mmd, my_z[0]))

saver.save(sess, "./hsic_test.ckpt")

In [None]:
# mlp classifer for mnist using trained latent space as inputs

from __future__ import print_function

# Import MNIST data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)


# Parameters
learning_rate = 0.001
training_epochs = 25
batch_size = 100
display_step = 1

# Network Parameters
n_hidden_1 = 256 # 1st layer number of neurons
n_hidden_2 = 256 # 2nd layer number of neurons
# n_input = 784 # MNIST data input (img shape: 28*28)
n_input = z_dim
n_classes = 10 # MNIST total classes (0-9 digits)

# tf Graph input
X = tf.placeholder("float", [None, n_input])
Y = tf.placeholder("float", [None, n_classes])

# Store layers weight & bias
weights = {
    'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
    'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
biases = {
    'b1': tf.Variable(tf.random_normal([n_hidden_1])),
    'b2': tf.Variable(tf.random_normal([n_hidden_2])),
    'out': tf.Variable(tf.random_normal([n_classes]))
}


# Create model
def multilayer_perceptron(x):
    # Hidden fully connected layer with 256 neurons
    layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
    # Hidden fully connected layer with 256 neurons
    layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
    # Output fully connected layer with a neuron for each class
    out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
    return out_layer

# Construct model
logits = multilayer_perceptron(X)

# Define loss and optimizer
loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=Y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)
# Initializing the variables
init = tf.global_variables_initializer()

with tf.Session() as sess2:
    sess2.run(init)

    # Training cycle
    for epoch in range(training_epochs):
        avg_cost = 0.
        total_batch = int(mnist.train.num_examples/batch_size)
        # Loop over all batches
        for i in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            batch_x = batch_x.reshape(-1, 28, 28, 1)
            batch_x = sess.run(pretrained_z,  feed_dict={train_x: batch_x})
            # Run optimization op (backprop) and cost op (to get loss value)
            _, c = sess2.run([train_op, loss_op], feed_dict={X: batch_x,
                                                            Y: batch_y})
            # Compute average loss
            avg_cost += c / total_batch
        # Display logs per epoch step
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch+1), "cost={:.9f}".format(avg_cost))
    print("Optimization Finished!")

    # Test model
    pred = tf.nn.softmax(logits)  # Apply softmax to logits
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(Y, 1))
    # Calculate accuracy
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    test_images = mnist.test.images.reshape(-1, 28, 28 ,1)
    test_z = sess.run(pretrained_z,  feed_dict={train_x: test_images})
    acc = accuracy.eval({X: test_z, Y: mnist.test.labels})
    print("Accuracy:", acc)

In [None]:
# saving plots
import matplotlib.patches as mpatches

# plotting learning curve
plt.plot(steps_list, o_loss_list, color='red')
plt.plot(steps_list, mmd_list, color='green')
plt.plot(steps_list, hisc_list, color='blue')
# adding legends 
red_patch = mpatches.Patch(color='red', label='overall loss')
green_patch = mpatches.Patch(color='green', label='mmd loss')
blue_patch = mpatches.Patch(color='blue', label='hsic loss')

plt.legend(handles=[red_patch, blue_patch, green_patch], loc=1)

plt.xlabel("steps")
plt.ylabel("loss")
plt.title("overall, hsic and mmd loss vs steps")
# file name format
# Lambda1_Lambda2_ratio_accuracy.png
plt.savefig("./test_data/{0}_{1}_{2}_{3:.3f}.png".format(Lambda1, Lambda2, 1-prob_to_hsic, acc))