# Baseline MNIST

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow.keras.datasets.mnist as mnist

In [3]:
import vbranch

## Load Data

In [4]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [5]:
input_dim = 784
num_classes = 10

In [6]:
X_train_flat = X_train.reshape([-1, input_dim])
X_test_flat = X_test.reshape([-1, input_dim])

y_train_one_hot = tf.keras.utils.to_categorical(y_train, num_classes)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, num_classes)

## Build Model

In [7]:
BATCH_SIZE = 32
EPOCHS = 10
STEPS_PER_EPOCH = 100

In [8]:
tf.reset_default_graph()

train_data = (X_train_flat.astype('float32'), y_train_one_hot)
test_data = (X_test_flat.astype('float32'), y_test_one_hot)

batch_size = tf.placeholder('int64')

train_dataset = tf.data.Dataset.from_tensor_slices(train_data).\
    batch(batch_size).repeat().\
    shuffle(buffer_size=4*BATCH_SIZE)

test_dataset = tf.data.Dataset.from_tensor_slices(test_data).\
    batch(batch_size).repeat()

iter_ = tf.data.Iterator.from_structure(train_dataset.output_types, 
                                       train_dataset.output_shapes)
inputs, labels_one_hot = iter_.get_next()

train_init_op = iter_.make_initializer(train_dataset)
test_init_op = iter_.make_initializer(test_dataset)

In [9]:
outputs = vbranch.models.simple_fcnet(inputs, num_classes)

In [10]:
tf.trainable_variables()

[<tf.Variable 'fc1_w:0' shape=(784, 10) dtype=float32_ref>,
 <tf.Variable 'fc1_b:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'bn1_scale:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'bn1_beta:0' shape=(10,) dtype=float32_ref>]

In [11]:
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels_one_hot, 
                                                                 logits=outputs))
train_op = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

In [12]:
pred = tf.one_hot(tf.argmax(tf.nn.softmax(outputs), axis=-1), num_classes)
acc = tf.reduce_mean(tf.reduce_sum(tf.cast(labels_one_hot, "float32")*pred, [1]))

In [13]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        progbar = tf.keras.utils.Progbar(STEPS_PER_EPOCH)
        
        sess.run(train_init_op, feed_dict={batch_size: BATCH_SIZE})

        for i in range(STEPS_PER_EPOCH):
            _, loss_value, acc_value = sess.run([train_op, loss, acc])
            
            if i == STEPS_PER_EPOCH - 1:
                sess.run(test_init_op, feed_dict={batch_size: len(X_test_flat)})
                val_loss, val_acc = sess.run([loss, acc])
                progbar.update(i + 1, values=[("loss", loss_value), ("acc", acc_value), 
                                              ("val_loss", val_loss), ("val_acc", val_acc)])
            else:
                progbar.update(i + 1, values=[("loss", loss_value), ("acc", acc_value)])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
