In [None]:
import random
import io

import PIL
import bqplot, bqplot.pyplot
import ipywidgets

import keras
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
x_train = train_images.reshape((60000, 28 * 28)).astype('float32') / 255
x_test = test_images.reshape((10000, 28 * 28)).astype('float32') / 255

y_train = keras.utils.to_categorical(train_labels)
y_test = keras.utils.to_categorical(test_labels)

In [None]:
r, c, s = 6, 8, 96
digits = train_images[np.random.randint(train_images.shape[0], size=r*c), :, :]
canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
for i,d in enumerate(digits):
    dimg = PIL.Image.fromarray(255-d).resize((s-8, s-8))
    canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

buf = io.BytesIO()
canvas.save(buf, 'gif')
img = ipywidgets.Image(value=buf.getvalue())
display(img)

In [None]:
def init_plots():
    axes_loss = {'x': {'label': 'Epochs'}, 
                 'y': {'label': 'Losses', 
                       'label_offset': '50px',
                       'tick_style': {'font-size': 10}}}
    axes_acc = {'x': {'label': 'Epochs'}, 
                'y': {'label': 'Accuracy', 
                      'label_offset': '50px',
                       'tick_style': {'font-size': 10}}}
        
    loss_plt = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
    bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_loss)
    bqplot.pyplot.plot([0,1],[0.75,0.75], colors=['orange'])
    acc_plt  = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
    bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_acc)
    bqplot.pyplot.plot([0,1],[0.75,0.75], colors=['orange'])        

    return (loss_plt, acc_plt)

class plot_history(keras.callbacks.Callback):
    def __init__(self, loss_plt, acc_plt):
        self.loss_plt = loss_plt
        self.acc_plt  = acc_plt
        self.history = {'loss':[], 'val_loss':[], 'acc': [], 'val_acc':[]}

    def on_epoch_end(self, epoch, logs={}):
        self.history['loss'].append(logs.get('loss'))
        self.history['val_loss'].append(logs.get('val_loss'))
        self.history['acc'].append(logs.get('acc'))
        self.history['val_acc'].append(logs.get('val_acc'))

        x_data = range(1, len(self.history['loss'])+1)
        self.loss_plt.marks[0].x = x_data
        self.loss_plt.marks[0].y = self.history['loss']
        self.loss_plt.marks[1].x = x_data
        self.loss_plt.marks[1].y = self.history['val_loss']
        
        x_data = range(1, len(self.history['acc'])+1)
        self.acc_plt.marks[0].x = x_data
        self.acc_plt.marks[0].y = self.history['acc']
        self.acc_plt.marks[1].x = x_data
        self.acc_plt.marks[1].y = self.history['val_acc']

In [None]:
def build_network():
    network = keras.models.Sequential(name='MNIST_DNN')
    network.add(keras.layers.Dense(32, activation='relu', name='input', input_shape=(28*28*1,)))
    network.add(keras.layers.Dense(32, activation='relu', name='hidden'))
    network.add(keras.layers.Dense(10, activation='softmax', name='output'))

    network.compile(optimizer=keras.optimizers.sgd(lr=0.01, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])

    return network

In [None]:
network = build_network()
network.summary()

loss_plt, acc_plt = init_plots()
display(ipywidgets.HBox([loss_plt, acc_plt]))

In [None]:
epochs = 50
history = network.fit(x_train, y_train, epochs=epochs, 
                      batch_size=128, validation_data=(x_test, y_test), 
                      verbose=0, callbacks=[plot_history(loss_plt, acc_plt)])

In [None]:
entry=11

prediction = network.predict(x_test[entry:entry+1])
predicted_label = np.argmax(prediction, axis=1)

print("Ground truth: {} | Predicted: {}".format(np.argmax(y_test[entry]),predicted_label[0]))

img = test_images[entry].reshape((28,28))

pimg = ipywidgets.Image(width=256, height=256)

imbuf = io.BytesIO()
dimg = PIL.Image.fromarray((255-img).astype('uint8')).resize((256, 256))
dimg.save(imbuf, 'gif')
pimg.value = imbuf.getvalue()

columns = range(10)
axes_pred = {'x': {'label': '', 'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}, 
             'y': {'label': '', 
                   'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}}
opts_pred = {'y': {'min': 0.0, 'max':1.0}}        

pred_plt = bqplot.pyplot.figure(min_aspect_ratio=1.0, max_aspect_ratio=1.0)
bqplot.pyplot.bar(columns, prediction, options=opts_pred, axes_options=axes_pred)
pimg.layout.object_fit = 'contain'
pimg.layout.object_position = "center center"
pimg.layout.border = '1px solid black'
result_box = ipywidgets.HBox([pimg, pred_plt])
result_box.layout.align_items = 'center'
display(result_box)