# Baseline MNIST

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow.keras.datasets.mnist as mnist

In [3]:
import vbranch

## Load Data

In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [5]:
input_dim = 784
num_classes = 10

In [6]:
X_train_flat = X_train.reshape([-1, input_dim])
X_test_flat = X_test.reshape([-1, input_dim])

y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)

## Build Model

In [7]:
tf.reset_default_graph()

# Wrapping all together -> Switch between train and test set using Initializable iterator
EPOCHS = 10
# create a placeholder to dynamically switch between batch sizes
batch_size = tf.placeholder(tf.int64)
x = tf.placeholder(tf.float32, shape=[None, input_dim])
y = tf.placeholder(tf.float32, shape=[None, num_classes])

dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.batch(batch_size).repeat().shuffle(buffer_size=400)

data_iter = dataset.make_initializable_iterator()
inputs, labels_one_hot = data_iter.get_next()

In [8]:
outputs = vbranch.models.simple_fcnet(inputs, input_dim, num_classes)

In [9]:
tf.trainable_variables()

[<tf.Variable 'fc1_w:0' shape=(784, 10) dtype=float32_ref>,
 <tf.Variable 'fc1_b:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'fc1_bn_scale:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'fc1_bn_beta:0' shape=(10,) dtype=float32_ref>]

In [10]:
loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_one_hot, logits=outputs)
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

In [11]:
BATCH_SIZE = 32
n_batches = 100

In [12]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # initialise iterator with train data
    sess.run(data_iter.initializer, feed_dict={x: X_train_flat, y: y_train_one_hot, 
                                               batch_size: BATCH_SIZE})
    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, n_batches))
        progbar = tf.keras.utils.Progbar(n_batches)
        ep_losses = []
        for i in range(n_batches):
            _, loss_value = sess.run([train_op, loss])
            ep_losses.append(loss_value[0])
            progbar.update(i + 1, values=[("loss", np.mean(ep_losses)),])
        
    # initialise iterator with test data
    sess.run(data_iter.initializer, feed_dict={ x: X_test_flat, y: y_test_one_hot, 
                                          batch_size: 1000})
    test_loss = sess.run(loss)[0]
    print('Test Loss: {:4f}'.format(test_loss))

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Test Loss: 0.581682
