In [None]:
import random
import io
import math

import PIL, PIL.Image, PIL.ImageDraw, PIL.ImageFilter
import bqplot, bqplot.pyplot
import ipywidgets, ipyevents

import keras
import numpy as np
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data()
x_train = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
x_test = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

y_train = keras.utils.to_categorical(train_labels)
y_test = keras.utils.to_categorical(test_labels)

In [None]:
r, c, s = 6, 8, 64
digits = train_images[np.random.randint(train_images.shape[0], size=r*c), :, :]
canvas = PIL.Image.new('RGB', (c*s+2, r*s+2), color='white')
for i,d in enumerate(digits):
    dimg = PIL.Image.fromarray(255-d).resize((s-8, s-8))
    canvas.paste(dimg, box=(s*int(i/r), s*(i%r)))

buf = io.BytesIO()
canvas.save(buf, 'gif')
img = ipywidgets.Image(value=buf.getvalue())
display(img)

In [None]:
def init_plots():
    axes_loss = {'x': {'label': 'Epochs'}, 
                 'y': {'label': 'Losses', 
                       'label_offset': '50px',
                       'tick_style': {'font-size': 10}}}
    axes_acc = {'x': {'label': 'Epochs'}, 
                'y': {'label': 'Accuracy', 
                      'label_offset': '50px',
                       'tick_style': {'font-size': 10}}}
        
    loss_plt = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
    bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_loss)
    bqplot.pyplot.plot([0,1],[0.75,0.75], colors=['orange'])
    acc_plt  = bqplot.pyplot.figure(min_aspect_ratio=4/3, max_aspect_ratio=4/3)
    bqplot.pyplot.plot([0,1],[0.5,0.5], axes_options=axes_acc)
    bqplot.pyplot.plot([0,1],[0.75,0.75], colors=['orange'])        

    return (loss_plt, acc_plt)

class plot_history(keras.callbacks.Callback):
    def __init__(self, loss_plt, acc_plt):
        self.loss_plt = loss_plt
        self.acc_plt  = acc_plt
        self.history = {'loss':[], 'val_loss':[], 'acc': [], 'val_acc':[]}

    def on_epoch_end(self, epoch, logs={}):
        self.history['loss'].append(logs.get('loss'))
        self.history['val_loss'].append(logs.get('val_loss'))
        self.history['acc'].append(logs.get('acc'))
        self.history['val_acc'].append(logs.get('val_acc'))

        x_data = range(1, len(self.history['loss'])+1)
        self.loss_plt.marks[0].x = x_data
        self.loss_plt.marks[0].y = self.history['loss']
        self.loss_plt.marks[1].x = x_data
        self.loss_plt.marks[1].y = self.history['val_loss']
        
        x_data = range(1, len(self.history['acc'])+1)
        self.acc_plt.marks[0].x = x_data
        self.acc_plt.marks[0].y = self.history['acc']
        self.acc_plt.marks[1].x = x_data
        self.acc_plt.marks[1].y = self.history['val_acc']

In [None]:
def build_network_dnn():
    inp = keras.layers.Input(shape = (28, 28, 1), name='Input')
    x = keras.layers.Reshape((28*28,))(inp)
    x = keras.layers.Dense(32, activation='relu')(x)
    x = keras.layers.Dense(64, activation='relu')(x)
    out = keras.layers.Dense(10, activation='softmax', name='predictions')(x)
    network = keras.models.Model(inputs=inp, outputs=out)

    network.compile(optimizer=keras.optimizers.rmsprop(), loss='categorical_crossentropy', metrics=['accuracy'])

    return network
    
def build_network_cnn():
    inp = keras.layers.Input(shape = (28, 28, 1), name='Input')
    x = keras.layers.Conv2D(8, 3, padding='same', activation='elu', name='Conv_1')(inp)
    x = keras.layers.MaxPooling2D(2, name='Pool_1')(x)
    x = keras.layers.Conv2D(16, 3, padding='same', activation='elu', name='Conv_2')(x)
    x = keras.layers.MaxPooling2D(2, name='Pool_2')(x)
    x = keras.layers.Conv2D(32, 3, padding='same', activation='elu', name='Conv_3')(x)
    x = keras.layers.MaxPooling2D(2, name='Pool_3')(x)
    x = keras.layers.Conv2D(64, 3, padding='same', activation='elu', name='Conv_4')(x)
    x = keras.layers.MaxPooling2D(2, name='Pool_4')(x)
    x = keras.layers.Conv2D(128, 3, padding='same', activation='elu', name='Conv_5')(x)
    x = keras.layers.GlobalMaxPooling2D(name='Global_Pool')(x)
    out = keras.layers.Dense(10, activation='softmax', name='predictions')(x)
    network = keras.models.Model(inputs=inp, outputs=out, name='recognizer')

    network.compile(optimizer=keras.optimizers.rmsprop(), loss='categorical_crossentropy', metrics=['accuracy'])

    return network

In [None]:
network = build_network_cnn()
network.summary()

In [None]:
%%time
traingen = keras.preprocessing.image.ImageDataGenerator(
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    rotation_range = 10,
    zoom_range = 0.1,
    fill_mode='nearest')

valgen = keras.preprocessing.image.ImageDataGenerator()

train_generator = traingen.flow(x_train, y_train, batch_size=256)
val_generator = valgen.flow(x_test, y_test, batch_size=256)

loss_plt, acc_plt = init_plots()
display(ipywidgets.HBox([loss_plt, acc_plt]))

epochs = 12
history = network.fit_generator(train_generator, epochs=epochs, 
                      validation_data=val_generator, 
                      use_multiprocessing=True, workers=3,
                      verbose=1, callbacks=[plot_history(loss_plt, acc_plt)])

In [None]:
canvas = PIL.Image.new('L', (256, 256), color='white')
draw = PIL.ImageDraw.Draw(canvas)
buf = io.BytesIO()

img = ipywidgets.Image()
img.layout.border = '1px solid black'
img.layout.object_fit = 'contain'
img.layout.object_position = "center center"

im_events = ipyevents.Event()
im_events.source = img
im_events.watched_events = ['mousemove', 'mousedown', 'mouseup']
no_drag = ipyevents.Event(source=img, watched_events=['dragstart'], prevent_default_action = True)

columns = list(range(10))
axes_pred = {'x': {'label': '', 'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}, 
             'y': {'label': '', 
                   'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}}
opts_pred = {'y': {'min': 0.0, 'max':1.0}}        
pred_plt = bqplot.pyplot.figure(min_aspect_ratio=1.0, max_aspect_ratio=1.0)
bar_plt = bqplot.pyplot.bar(columns, [0.0]*10, options=opts_pred, axes_options=axes_pred)

button_clear = ipywidgets.Button(description = 'Clear')
button_evaluate = ipywidgets.Button(description = 'Evaluate')
result_field = ipywidgets.HTML('')

def paint_frame():
    canvas.save(buf, 'png')
    buf.seek(0)
    img.value = buf.getvalue()
    
paint_frame()

do_draw = False
ps = 8
def print_coords(event):
    global do_draw
    if event['type'] == 'mousedown':
        do_draw = True      
    elif event['type'] == 'mouseup':
        do_draw = False
    elif event['type'] == 'mousemove' and do_draw:
        x,y = event['dataX'], event['dataY']
        draw.ellipse([x-ps,y-ps,x+ps,y+ps], outline='black', fill='black')
        paint_frame()
im_events.on_dom_event(print_coords)

def clear_canvas(event):
    draw.rectangle([0,0,255,255], fill='white')
    paint_frame()
    result_field.value = ''
    bar_plt.y = [0.0]*10
    
button_clear.on_click(clear_canvas)

def evaluate_canvas(event):
    frame = np.asarray(canvas.filter(PIL.ImageFilter.GaussianBlur(2)).resize((28,28)))
    mx, mi = np.max(frame), np.min(frame)
    frame = (frame - mi) / (mx-mi)
    #canvas.paste(PIL.Image.fromarray(256*frame).resize((256,256)))
    #paint_frame()
    frame = 1.0 - frame.reshape((28, 28, 1)).astype('float32')
    prediction = network.predict(np.expand_dims(frame, axis=0))
    predicted_label = np.argmax(prediction, axis=1)
    result_field.value = "<h2><b>Prediction: {}</b></h2>".format(predicted_label)
    bar_plt.y = [math.log2(1+p) for p in prediction[0]]
    
button_evaluate.on_click(evaluate_canvas)

result_box = ipywidgets.HBox([img, pred_plt])
result_box.layout.align_items = 'center'
display(result_box)
display(result_field)
display(button_clear)
display(button_evaluate)


In [None]:
entry=11

prediction = network.predict(x_test[entry:entry+1])
predicted_label = np.argmax(prediction, axis=1)

print("Ground truth: {} | Predicted: {}".format(np.argmax(y_test[entry]),predicted_label[0]))

img = test_images[entry].reshape((28,28))

pimg = ipywidgets.Image(width=256, height=256)

imbuf = io.BytesIO()
dimg = PIL.Image.fromarray((255-img).astype('uint8')).resize((256, 256))
dimg.save(imbuf, 'gif')
pimg.value = imbuf.getvalue()

columns = range(10)
axes_pred = {'x': {'label': '', 'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}, 
             'y': {'label': '', 
                   'tick_style': {'font-weight': 'bold', 'font-size': "16px"}}}
opts_pred = {'y': {'min': 0.0, 'max':1.0}}        

pred_plt = bqplot.pyplot.figure(min_aspect_ratio=1.0, max_aspect_ratio=1.0)
bqplot.pyplot.bar(columns, prediction, options=opts_pred, axes_options=axes_pred)
pimg.layout.object_fit = 'contain'
pimg.layout.object_position = "center center"
pimg.layout.border = '1px solid black'
result_box = ipywidgets.HBox([pimg, pred_plt])
result_box.layout.align_items = 'center'
display(result_box)