In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import glob
import os
from tqdm import tqdm

import matplotlib.pyplot as plt
from skimage import io
from skimage import transform
import tensorflow as tf
import utils

%matplotlib inline

In [None]:
pics_train, labels_train, pics_test, labels_test = utils.load_mnist()

In [None]:
print("Training data:")
print(pics_train.shape)
print(labels_train.shape)
print()
print("Test data:")
print(pics_test.shape)
print(labels_test.shape)

In [None]:
utils.show_random_mnist(pics_train, labels_train)

In [None]:
print(np.unique(pics_train[0]))

## Defining model

In [None]:
N, H, W, _ = pics_train.shape
F = H * W
NUM_CLASSES = 10

In [None]:
def load_architecture():
    tf.reset_default_graph()
    
    x = tf.placeholder(tf.float32, shape=[None, H, W, 1], name="x")
    y = tf.placeholder(tf.uint8, shape=[None, NUM_CLASSES], name="y")
    
    init = tf.contrib.layers.xavier_initializer()
    
    out = tf.contrib.layers.flatten(x)

    out = tf.layers.dense(out, units=256, activation=tf.nn.relu, kernel_initializer=init)
    
    out = tf.layers.dense(out, units=256, activation=tf.nn.relu, kernel_initializer=init)
    
    out = tf.layers.dense(out, units=NUM_CLASSES, kernel_initializer=init)
    
    return x, y, out

In [None]:
def load_loss(y, out):
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=out, name="mean_loss")
    loss = tf.reduce_mean(loss, name="loss")
    return loss

In [None]:
def load_accuracy(y, out):
    pred = tf.argmax(out, axis=-1)
    gt = tf.argmax(y, axis=-1)
    
    matches = tf.equal(pred, gt)
    
    return tf.reduce_mean(tf.cast(matches, tf.float32), name="acc")

In [None]:
def load_trainer(loss):
    opt = tf.train.GradientDescentOptimizer(learning_rate=0.001)
    return opt.minimize(loss)

In [None]:
def register_scalars(m):
    for k, v in m.items():
        tf.summary.scalar(k, v)

In [None]:
def register_images(m):
    for k, v in m.items():
        tf.summary.image(k, v)

In [None]:
def trainable_parameters():
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    return total_parameters

In [None]:
def load_model():
    x, y, out = load_architecture()
    loss = load_loss(y, out)
    acc = load_accuracy(y, out)
    upd = load_trainer(loss)
    
    register_scalars({"info_loss": loss, "info_acc": acc})
    register_images({"input": x})

    info = tf.summary.merge_all()
    
    return x, y, out, loss, acc, upd, info

# Train model

In [None]:
def load_session():
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    return sess

In [None]:
def train(sess, model, pics_train, labels_train, pics_val, labels_val, epochs, batch_size, train_writer, val_writer):
    N, _, _, _ = pics_train.shape
    idxs = np.arange(N)
    
    x, y, out, loss, acc, upd, info = model
        
    i=0

    for ep in tqdm(range(epochs)):
        np.random.shuffle(idxs)
        pics_train = pics_train[idxs]
        labels_train = labels_train[idxs]

        for b in range(0, N, batch_size):
            X_batch = pics_train[b:b+batch_size]
            Y_batch = labels_train[b:b+batch_size]

            if X_batch.shape[0] < BATCH_SIZE:
                break

            graph_info, _ = sess.run([info, upd], feed_dict={x: X_batch, y: Y_batch})
            train_writer.add_summary(graph_info, i)
            
            graph_info, = sess.run([info], feed_dict={x: pics_val, y: labels_val})
            val_writer.add_summary(graph_info, i)
            
            i+=1

In [None]:
def predict(imgs, model):
    x, y, out, loss, acc, upd, info = model

    N, H, W, _ = imgs.shape
    fig=plt.figure(figsize=(10, 10))
    columns = 3
    rows = 3
    for i in range(1, columns*rows +1):
        idx = np.random.choice(range(N)) 
        img = imgs[idx].reshape((1, H, W, 1))
        graph_out, = sess.run([out], feed_dict={x: img})
        fig.add_subplot(rows, columns, i)
        plt.imshow(np.squeeze(img), cmap="gray")
        plt.title(np.argmax(np.squeeze(graph_out)))
    plt.show()

### Train on full data

In [None]:
model = load_model()
sess = load_session()
print("Trainable parameters: {}".format(trainable_parameters()))

In [None]:
EPOCHS = 70
BATCH_SIZE = 64
LOGS_DIR = "logs"

t_writer = tf.summary.FileWriter(os.path.join(LOGS_DIR, "all", "train"), graph=sess.graph)
v_writer = tf.summary.FileWriter(os.path.join(LOGS_DIR, "all", "val"), graph=sess.graph)

train(sess, model, pics_train, labels_train, pics_test, labels_test, EPOCHS, BATCH_SIZE, t_writer, v_writer)

In [None]:
predict(pics_train, model)

In [None]:
predict(pics_test, model)