In [2]:
import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist.input_data as input_data

depth   = 3                 
num_leaf  = 2 ** (depth + 1)  
num_label = 10                
num_tree  = 5                
num_batch = 128               
num_epochs = 10

def initialize_w(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))


def initialize_prob_w(shape, minval=-5, maxval=5):
    return tf.Variable(tf.random_uniform(shape, minval, maxval))


def model(X, w, w2, w3, w4_e, w_d_e, w_last_e, prob_keep, prob_keep_hidden):

    layer1_0 = tf.nn.relu(tf.nn.conv2d(X, w, [1, 1, 1, 1], 'SAME'))
    layer1 = tf.nn.max_pool(layer1_0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    layer1 = tf.nn.dropout(layer1, prob_keep)

    layer2_0 = tf.nn.relu(tf.nn.conv2d(layer1, w2, [1, 1, 1, 1], 'SAME'))
    layer2 = tf.nn.max_pool(layer2_0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    layer2 = tf.nn.dropout(layer2, prob_keep)

    layer3_0 = tf.nn.relu(tf.nn.conv2d(layer2, w3, [1, 1, 1, 1], 'SAME'))
    layer3 = tf.nn.max_pool(layer3_0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

    layer3 = tf.reshape(layer3, [-1, w4_e[0].get_shape().as_list()[0]])
    layer3 = tf.nn.dropout(layer3, prob_keep)

    dn_e = []
    pr_leaf_e = []
    for w4, w_d, w_last in zip(w4_e, w_d_e, w_last_e):
        l4 = tf.nn.relu(tf.matmul(layer3, w4))
        l4 = tf.nn.dropout(l4, prob_keep_hidden)

        decision_p = tf.nn.sigmoid(tf.matmul(l4, w_d))
        leaf_p = tf.nn.softmax(w_last)

        dn_e.append(decision_p)
        pr_leaf_e.append(leaf_p)

    return dn_e, pr_leaf_e #the vector 

mnist = input_data.read_data_sets("MNIST/", one_hot=True)
X_train, y_train = mnist.train.images, mnist.train.labels
X_test, y_test = mnist.test.images, mnist.test.labels
X_train = X_train.reshape(-1, 28, 28, 1)
X_test = X_test.reshape(-1, 28, 28, 1)

# Input X, output Y
X = tf.placeholder("float", [num_batch, 28, 28, 1])
Y = tf.placeholder("float", [num_batch, num_label])

w = initialize_w([3, 3, 1, 32])
w2 = initialize_w([3, 3, 32, 64])
w3 = initialize_w([3, 3, 64, 128])
w4_ensemble = []
w_d_ensemble = []
w_last_ensemble = []

for i in range(num_tree):
    w4_ensemble.append(initialize_w([128 * 4 * 4, 625]))
    w_d_ensemble.append(initialize_prob_w([625, num_leaf], -1, 1))
    w_last_ensemble.append(initialize_prob_w([num_leaf, num_label], -2, 2))

prob_keep = tf.placeholder("float")
prob_keep_hidden = tf.placeholder("float")

dn_e, pr_leaf_e = model(X, w, w2, w3, w4_ensemble, w_d_ensemble, w_last_ensemble, prob_keep, prob_keep_hidden)

flat_dn_e = []

for decision_p in dn_e:
    decision_p_comp = tf.subtract(tf.ones_like(decision_p), decision_p)
    decision_p_pack = tf.stack([decision_p, decision_p_comp])
    flat_decision_p = tf.reshape(decision_p_pack, [-1])
    flat_dn_e.append(flat_decision_p)

batch_ind0 = tf.tile(tf.expand_dims(tf.range(0, num_batch * num_leaf, num_leaf), 1),
            [1, num_leaf])


in_repeat = num_leaf // 2
out_repeat = num_batch

batch_neg_ind = np.array([[0] * int(in_repeat), [num_batch * num_leaf] * int(in_repeat)] * 
                                    int(out_repeat)).reshape(num_batch, num_leaf)

mu_e = []

for i, flat_decision_p in enumerate(flat_dn_e):
    mu = tf.gather(flat_decision_p, tf.add(batch_ind0, batch_neg_ind))
    mu_e.append(mu)

for d in range(1, depth + 1):
    indices = tf.range(2 ** d, 2 ** (d + 1)) - 1
    tile_ind = tf.reshape(tf.tile(tf.expand_dims(indices, 1), [1, 2 ** (depth - d + 1)]), [1, -1])
    batch_ind = tf.add(batch_ind0, tf.tile(tile_ind, [num_batch, 1]))

    in_repeat = in_repeat // 2
    out_repeat = out_repeat * 2
    batch_neg_ind = np.array([[0] * in_repeat, [num_batch * num_leaf] * in_repeat]
                 * out_repeat).reshape(num_batch, num_leaf)

    mu_e_new = []
    for mu, flat_decision_p in zip(mu_e, flat_dn_e):
        mu = tf.multiply(mu, tf.gather(flat_decision_p, tf.add(batch_ind, batch_neg_ind)))
        mu_e_new.append(mu)

    mu_e = mu_e_new

py_x_e = []
for mu, leaf_p in zip(mu_e, pr_leaf_e):
    py_x_tree = tf.reduce_mean(tf.multiply(tf.tile(tf.expand_dims(mu, 2), [1, 1, num_label]),
               tf.tile(tf.expand_dims(leaf_p, 0), [num_batch, 1, 1])), 1)
    py_x_e.append(py_x_tree)

py_x_e = tf.stack(py_x_e)
py_x = tf.reduce_mean(py_x_e, 0)

cost = tf.reduce_mean(-tf.multiply(tf.log(py_x), Y))

train_step = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
predict = tf.argmax(py_x, 1)

sess = tf.Session()
sess.run(tf.initialize_all_variables())

for epoch in range(num_epochs):
    for start, end in zip(range(0, len(X_train), num_batch), range(num_batch, len(X_train), num_batch)):
        sess.run(train_step, feed_dict={X: X_train[start:end], Y: y_train[start:end],
                                        prob_keep: 0.8, prob_keep_hidden: 0.5})
    results = []
    for start, end in zip(range(0, len(X_test), num_batch), range(num_batch, len(X_test), num_batch)):
        results.extend(np.argmax(y_test[start:end], axis=1) ==
            sess.run(predict, feed_dict={X: X_test[start:end], prob_keep: 1.0,
                                         prob_keep_hidden: 1.0}))

    print('Epoch: %d, Test Accuracy: %f' % (epoch + 1, np.mean(results)))

Extracting MNIST/train-images-idx3-ubyte.gz
Extracting MNIST/train-labels-idx1-ubyte.gz
Extracting MNIST/t10k-images-idx3-ubyte.gz
Extracting MNIST/t10k-labels-idx1-ubyte.gz
Epoch: 1, Test Accuracy: 0.951623


KeyboardInterrupt: 