In [1]:
import tensorflow as tf
import numpy as np

In [2]:
num_classes = 10
img_rows, img_cols = 28, 28
input_shape = (img_rows, img_cols, 1)


(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
y_test = tf.keras.utils.to_categorical(y_test)
y_train = tf.keras.utils.to_categorical(y_train)

In [3]:
x = tf.placeholder(tf.float32, shape = (None, img_rows, img_cols, 1))
y = tf.placeholder(tf.float32, shape = (None, num_classes))

weight_initializer =  tf.glorot_uniform_initializer()

num_filters = 32
# First convolution layer
with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE):
    filter_shape = [3, 3, 1, num_filters]
    W_conv1 = tf.get_variable("W", shape=filter_shape, initializer=weight_initializer)
    b_conv1 = tf.get_variable("b", shape=[num_filters], initializer=tf.zeros_initializer)
    conv2d = tf.nn.conv2d(x, W_conv1, strides = [1, 1, 1, 1], padding = 'VALID')
    activation_conv1 = tf.nn.relu(conv2d + b_conv1)

# Max pooling
pool2d = tf.nn.max_pool(activation_conv1, ksize=[1, 2, 2, 1], strides = [1, 2, 2, 1], padding = 'VALID')

# Flatten
flatten = tf.layers.flatten(pool2d)

# Dense layer 1
num_hidden = 128
with tf.variable_scope("dense1", reuse=tf.AUTO_REUSE):
    matrix_shape = [32 * 13 * 13, num_hidden]
    W_dense1 = tf.get_variable("W_dense1", shape=matrix_shape, initializer=weight_initializer)
    b_dense1 = tf.get_variable("b_dense1", shape=[num_hidden], initializer=tf.zeros_initializer)

    dense1 = tf.nn.relu(tf.matmul(flatten, W_dense1) + b_dense1)

# Dense layer 2
with tf.variable_scope("dense2", reuse=tf.AUTO_REUSE):
    matrix_shape = [num_hidden, num_classes]
    W_dense2 = tf.get_variable("W_dense2", shape=matrix_shape, initializer=weight_initializer)
    b_dense2 = tf.get_variable("b_dense2", shape=[num_classes], initializer=tf.zeros_initializer)

    output = tf.nn.softmax(tf.matmul(dense1, W_dense2) + b_dense2)

In [4]:
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=output, logits=y))
# Same values as Keras defaults
train_step = tf.train.AdamOptimizer(1e-3).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


In [5]:
def batch_data(source, target, batch_size):
   # Shuffle data
    shuffle_indices = np.random.permutation(np.arange(len(target)))
    source = source[shuffle_indices]
    target = target[shuffle_indices]
    while True:
        for batch_i in range(0, len(source)//batch_size):
            start_i = batch_i * batch_size
            source_batch = source[start_i:start_i + batch_size]
            target_batch = target[start_i:start_i + batch_size]

            yield np.array(source_batch), np.array(target_batch)

In [6]:
batch_size = 20
num_steps = 10**4
batch_generator = batch_data(x_train, y_train, batch_size)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(num_steps):
        batch_x, batch_y = next(batch_generator)
        sess.run(train_step, feed_dict={x: batch_x, y: batch_y})
        
    train_accuracy = accuracy.eval(feed_dict={x: x_test, y: y_test})
    print("step %d, training accuracy %g"%(i, train_accuracy))

step 9999, training accuracy 0.9816


## Info
You'll notice that using the low-level APIs allows us to peak into the weights matrices:

In [10]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(W_dense2)
    print(sess.run(W_dense2))

<tf.Variable 'dense2/W_dense2:0' shape=(128, 10) dtype=float32_ref>
[[-0.0048102  -0.06678391 -0.15051834 ...  0.11293034  0.12195925
   0.03477891]
 [-0.09328046  0.03687395 -0.07295442 ...  0.1477557   0.10135518
  -0.06747966]
 [ 0.01071291  0.154963   -0.18757385 ...  0.08024196  0.15282069
   0.17554466]
 ...
 [ 0.02673787 -0.12205153  0.12581103 ...  0.14929445  0.04802634
   0.1382155 ]
 [ 0.07793574 -0.06643188  0.13112439 ...  0.12457545 -0.07694344
  -0.1506127 ]
 [-0.06204127 -0.15288     0.1785077  ... -0.1529056   0.13262792
  -0.10990766]]
