# MNIST - Arduino

This notebook has code for interfacing with the Arduino:
1. Generate a serialized version of the network
1. Generate test vectors using only numpy operations for verification
1. Sets up serial communications to an Aruino and runs test set

In [1]:
from __future__ import absolute_import, division, print_function
import os, sys, pdb, pickle
from itertools import product
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt

import serial
import binascii

import tensorflow as tf
import keras
from keras.datasets import mnist
from keras.models import Model, Sequential, load_model
from keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, AveragePooling2D, Lambda, Activation, Add, concatenate
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.engine.topology import Layer
from keras import regularizers, activations
from keras import backend as K

from quantization_layers import *
from network_parameterization import *

os.environ['CUDA_VISIBLE_DEVICES']=''

Using TensorFlow backend.


## Load data, model, and print model statistics

In [2]:
num_classes = 10

# Grab and massage the training and test data.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
img_rows, img_cols = x_train.shape[1:3]

x_train = x_train.astype('i')
x_test  = x_test.astype('i')
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test  = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

np.random.seed(0)
val_set = np.zeros(x_train.shape[0], dtype='bool')
val_set[np.random.choice(x_train.shape[0], 10000, replace=False)] = 1
x_val = x_train[val_set]
y_val = y_train[val_set]
x_train = x_train[~val_set]
y_train = y_train[~val_set]

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_val.shape[0], 'val samples')
print(x_test.shape[0], 'test samples')

x_train shape: (50000, 28, 28, 1)
50000 train samples
10000 val samples
10000 test samples


In [3]:
model_file = 'models/modelQL_0.h5'
model = load_model(model_file, custom_objects={'DenseQ':DenseQ, 'ConvQ':ConvQ, 'ResidQ':ResidQ, 'quantize':quantize, 'concatenate':concatenate})
print(' => '.join(map(lambda x: x.name, model.layers)))

input_11 => average_pooling2d_11 => lambda_11 => conv_q_31 => conv_q_32 => conv_q_33 => max_pooling2d_11 => dropout_11 => flatten_11 => dense_q_11 => activation_11


In [4]:
config = [('A', 2, 4), ('C', 5, 3, 3, 1, 1, 4, 8, 4), ('C', 8, 3, 3, 1, 1, 4, 8, 4), ('C', 11, 3, 3, 1, 1, 4, 8, 4), ('M', 2, 4), ('D', 0.1, 4), ('S', 10, 4, 8, 8)]
storage = compute_storage(config, verbose=True)
print(sum(storage), 'Bytes')

[14, 14, 1] 0 98.0
[12, 12, 5] 27.5 364.5
[10, 10, 8] 215.5 434.5
[8, 8, 11] 622.5 434.5
[4, 4, 11] 622.5 434.5
[4, 4, 11] 622.5 434.5
[10] 1512.5 434.5
1947.0 Bytes


## Generate the network serialization

In [9]:
def quant_int(x, bits, scale, signed):
    midrise = signed and (bits <= 2)
    qmax = 2**(bits-1) if signed else 2**bits
    s = x * qmax / scale
    rounded = np.floor(s)+0.5 if midrise else np.round(s)
    return np.clip(rounded, -qmax + midrise*0.5 if signed else 0, qmax - 1 + midrise*0.5).astype('i1')

def serialize(qw, input_size):
    nibbles = [ len(qw), input_size[2], input_size[0]//16, input_size[0]%16, input_size[1]//16, input_size[1]%16 ]
    for l in qw:
        layer_idx = {'A':0, 'C':1, 'D':2, 'M':3, 'R':4}[l['name'][0].upper()] # includes kernel size
        if l['params']:
            nibbles += [ layer_idx, l['b'].size ] + l['s']
        else:
            nibbles += [ layer_idx ]
    if len(nibbles) % 2: nibbles.append(0)
    for l in qw:
        if not l['params']: continue
        nibbles += list(l['W'].flatten())
        for e in l['b']: nibbles += [ e//16, e%16 ]
    if len(nibbles) % 2: nibbles.append(0)
    nib = map(lambda x: "%x"%(x if x >= 0 else 16+x), nibbles)
    nib = ''.join(nib)
    msg_len = len(nib) // 2
    nib = "%04x"%(msg_len) + nib
    #hexnib = ''.join([ '\\x' + nib[2*i:2*i+2] for i in range(len(nib)//2)])
    return nib

qw = []
tx = 0
for layer in model.layers:
    ws = layer.get_weights()
    if len(ws) < 2:
        if layer.name[:7] == 'average': qw.append({'name': layer.name, 'params': False})
        if layer.name[:3] == 'max': qw.append({'name': layer.name, 'params': False})
        continue
    w = ws[0]
    b = ws[1]
    tw = np.round(ws[2])
    tb = np.round(ws[3])
    ta = np.round(ws[4])
    
    tx = tx + tw
    w = quant_int(w, 4, 2**tw, True).astype('i1')
    if layer.name[:4] == 'conv': w = np.transpose(w, (3, 0, 1, 2))
    b = quant_int(b, 8, 2**tb, True).astype('i1')
    s1 = tx - tb
    s2 = 2 + ta - tb
    tx = ta
    qw.append({
        'name': layer.name,
        'params': True,
        'W': w,
        'b': b,
        's': [int(s1), int(s2)],
    })

dump = serialize(qw, (28,28,1))
print('Dump is %d nibbles long (%.1f B)'%(len(dump)-4, (len(dump)-4)/2))
print(dump)

Dump is 3050 nibbles long (1525.0 B)
05f5611c1c0150318141b1532a27304888b8bc8e67062038e88784217b578e0efd047480558181f06fe8114475add415fe81d527ec42a3ead2c862d28feb482fc6d4e7edd1aea57f685f7d8948f6841c6b33258fc5711cd0707446d404138fb231989e9b70981b0183cc38412578774407764ea141cf9b18a2e08e2e64de7562bf6d28b7df6eb38509483f11e91a3d001ca7db26e09d6088f7589c72715f1e7cf4c9d71f5685849580b016f2150e217812fb5d60d6f5cf46420917c4a4797cd83fd2871a087f0183112871fa8784600ce27f8d1f8ed31c302ee7bbf07ea57ec7f8073e7e479577318389b88df8381783282cef87d8e0838ff827f78cc1478e5be8d78bd8a79e86ed8742a1698872180d4c635470d03c1762e37c0da766287f8718e8c6889a89b88d0c02080e4ddfa3f73ba3a4267c0fd14e7f825042c259f1e85798cf58f188583ca788442c828608e78488f608df88a888488580875380774bf08edc8e7a908e8e72bd72e4218e74e448f39f1fd315c72948ece4f5eae8049d89fff871b722d83ac60e38d788791838867845a783f87287aec2df8082e7d18c80e41788cb8eafc2ab3f2872854ef1028cd717c078c1de2a2f708d58b648872fc331834ebca48772d1583f21d67871ec85b8074ee7dd83888b61c78dfd70df88227

## Implement the quantized neural network entirely with numpy

This allows us to make sure we know exactly what computations are being performed. We also run through the validation set to see where (if any) discrepancies are between this implementation and TensorFlow's.

In [9]:
def shift_round(x, s):
    '''
    This is the way TensorFlow rounds numbers (0.5 is rounded to the nearest even).
    '''
    exact50 = (np.bitwise_and(x, 2**(s+1) - 1) == 2**s).astype('i2')
    x = np.right_shift(x, s)
    geq50 = np.bitwise_and(x, 1)
    x = np.right_shift(x, 1)
    x += geq50
    odd = np.where(np.bitwise_and(x, 1))
    x[odd] -= exact50[odd]
    return x

In [10]:
def evaluate(qw, image):
    act = image
    for l in qw:
        if l['name'][:7] == 'average':
            dr, dc, df = act.shape
            act = act.reshape((dr//2,2,dc//2,2,df)).sum(axis=(1,3)) // 2**6
            tx = 0
        if l['name'][:3] == 'max':
            dr, dc, df = act.shape
            act = act.reshape((dr//2,2,dc//2,2,df)).max(axis=(1,3))
        if l['name'][:4] == 'conv':
            dr, dc, df = act.shape
            W, b = l['W'], l['b']
            s1, s2 = l['s']
            part = np.zeros((dr-2, dc-2, b.size), dtype='i2')
            for i in range(part.shape[0]):
                for j in range(part.shape[1]):
                    for o in range(part.shape[2]):
                        part[i,j,o] = np.sum(act[i:i+3,j:j+3,:] * W[o,:,:,:])
            part = np.maximum(np.left_shift(part, s1) + b[np.newaxis, np.newaxis, :], 0)
            act = np.minimum(shift_round(part, s2), 15)
        if l['name'][:5] == 'dense':
            di = np.prod(act.shape)
            W, b = l['W'], l['b']
            s1, s2 = l['s']
            part = np.dot(act.flatten(), W)
            part = np.left_shift(part, s1) + b
            act = part
        #print(l['name'], act.shape, np.min(act), np.max(act))
    return act

Here we run the validation set through to see if it perfectly matches TensorFlow. It mismatches in two places, but both are due to using 16-bit logits here instead of 8-bit logits as in the TensorFlow training. 16-bit should on average give better results, so we stick with this.

In [11]:
correct = 0
ybs = np.argmax(model.predict(x_val/256), -1)
print('Baseline accuracy: %.4f'%(sum(ybs == np.argmax(y_val, -1))/x_val.shape[0]))
for i in range(x_val.shape[0]):
    yp = np.argmax(evaluate(qw, x_val[i]))
    yt = np.argmax(y_val[i])
    if yp == yt: correct += 1
    print('\r%05d/%05d: %.4f'%(i+1, x_val.shape[0], correct/(i+1)), end='')
    if yp != ybs[i]: print('    [W] Mismatch on %d - true: %d - base: %d - quant: %d'%(i, yt, ybs[i], yp))

Baseline accuracy: 0.9909
01901/10000: 0.9905    [W] Mismatch on 1900 - true: 9 - base: 4 - quant: 9
03547/10000: 0.9910    [W] Mismatch on 3546 - true: 9 - base: 7 - quant: 9
10000/10000: 0.9911

## Generate test vectors

This allows us to compare results to Arduino's intermediate results to make sure everything matches 100%.

In [12]:
for i in range(-1, len(qw)):
    s = (qw[i]['name'].split('_')[0] if i >= 0 else 'input') + ' '
    if i < len(qw) - 1:
        s += ''.join(map(lambda x: '%x'%x, evaluate(qw[:i+1], x_val[0]).flatten()))
    else:
        s += ''.join(map(lambda x: '%04x'%((x+2**16)%(2**16)), evaluate(qw[:i+1], x_val[0]).flatten()))
    print(s)

input 000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000bcbe520000000000000000000001a2f2f1e5ffed7d00000000000000002d9ab9b9dffdfd85afffbc1300000000000000006efdfdfdf6a1e4fdfdfe5c000000000000000080f5fd9e8915030e9fde9800000000000000008bfedf190024aafef46a0000000000000000037d4fda1b1ab2fdec7100000000000000000079bfde450dffdfd6d000000000000000000008dfdfdfdfefd9a1d000000000000000000006efdfdfdfeb326000000000000000000003abfefefeb30000000000000000000000abfdfdfdfdb2000000000000000000001a7bfefdcb9cfdc8000000000000000000005dfdfe79d5dfd9e000000000000000000040effd4c820dbfd7e000000000000000000085febf056ceafe6a000000000000000000084fdbe555fdec9a0000000000000000000099fda9c0fdfd4d00000000000000000000070fdfdfeec8190000000000000000000001176f3bf710000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000
average 000000000000000000000000000000000000006000000000569

In [13]:
for i in range(len(qw)):
    print(''.join(map(lambda x: '%x'%x if i < len(qw)-1 else '%04x'%((x+2**16)%(2**16)), evaluate(qw[:i+1], x_val[0]).flatten())))

0000000000000000000000000000000000000060000000005696e000000007ea6f900000000aa09d1000000004fce20000000004fe30000000002edd0000000001d75c0000000004d1c80000000004ee700000000000540000000000000000000000
6220062200622006220062200622006220062200423022230044000622006220062200622006220042301124020420005200063010802009000622006220062200622003230202806083030c0010c0060b2080e0900904062200622006220062200026050580a0c0340a0460200b08024080e03204162200622006220062200116040100a030950300004000060b0100a0703106220062200622006220050302003080505a0f0050c071000b0703206120062200622006220052200015050270806015080a1030b07023061200622006220062200522000250404a060a0040504a040f30605061210622006220062200622004230102a08080060504200006040c00604062200622006220062200622002140202609040870600008000050c04204162200622006220062200622004130100306020690a048040a2400707031062200622006220062200622006110060000500305008060060703106220062200622006220062200
000602200004024004010460052001b0144000e040a001e031c000d0307000a05010202040000c3000

## Set up communications with the Arduino - debug a single image

In [7]:
import time
import serial.tools.list_ports
for x in serial.tools.list_ports.comports():
    print(x)

# USB serial port for Arduino communications
device = '/dev/ttyUSB0'

/dev/ttyUSB0 - USB2.0-Serial


In [10]:
# https://playground.arduino.cc/interfacing/python
# http://forum.arduino.cc/index.php?topic=38981.msg287027#msg287027
ser = serial.Serial(device, 115200, timeout=2)
ser.setDTR(False) # Reset the MCU (not necessary if ser.close())
time.sleep(0.022)
ser.setDTR(True)
while ser.read(): pass
ser.write(binascii.a2b_hex(dump))

in_str = ''.join(map(lambda x: '%02x'%x, x_val[0].flatten()))
ser.write(binascii.a2b_hex(in_str))
t0 = time.time()
while True:
    sr = ser.read()
    #if sr: print('%02x'%ord(sr), end='')
    if sr: print('%.3f: %02x'%(time.time()-t0, ord(sr)))
    else: break

ser.close()

0.682: 2a
0.682: 01
0.683: 5b
0.683: fd
0.683: d0
0.683: 03
0.683: 34
0.683: 03
0.683: 5d
0.683: fe
0.683: 30
0.683: fe
0.684: 97
0.684: fe
0.684: ae
0.684: fb
0.684: aa
0.684: 0d
0.684: 6a
0.684: 01
0.684: 08


## Set up communications with the Arduino - run entire validation and test sets

In [547]:
ser = serial.Serial(device, 115200, timeout=1)
ser.write(binascii.a2b_hex(dump))

num_correct = 0
num_match = 0
num_total = 0
for i in range(x_val.shape[0]):
    print('\rOn %05d/%05d'%(i+1, x_val.shape[0]), end='')
    cur = x_val[i]
    mr = evaluate(qw, cur)
    des = ''.join(map(lambda x: (lambda y: y[-2:]+y[:2])('%04x'%((x+2**16)%(2**16))), mr)) + '%02d'%(np.argmax(mr))
    
    in_str = ''.join(map(lambda x: '%02x'%x, cur.flatten()))
    ser.write(binascii.a2b_hex(in_str))
    collect = []
    while True:
        sr = ser.read()
        if sr: collect.append('%02x'%ord(sr))
        else: break
    result = ''.join(collect)
    #print(des, result, des == result)
    num_correct += int(('  '+result)[-1:]) == np.argmax(y_val[i])
    num_match += (result == des)
    num_total += 1
    print('\rOn %05d/%05d - Match: %.4f - Accuracy: %.4f - %s'%(
        i+1, x_val.shape[0], num_match/num_total, num_correct/num_total, result), end='')
print()
print('Overall Match: %.4f - Accuracy: %.4f'%(num_match/num_total, num_correct/num_total))

On 10000/10000 - Match: 1.0000 - Accuracy: 0.9911 - 86fccffc44ff840801fed00bfffe16fc6a03b60505
Overall Match: 1.0000 - Accuracy: 0.9911


In [680]:
ser = serial.Serial(device, 115200, timeout=1)
time.sleep(2)
ser.write(binascii.a2b_hex(dump))

num_correct = 0
num_match = 0
num_total = 0
for i in range(x_test.shape[0]):
    print('\rOn %05d/%05d'%(i+1, x_test.shape[0]), end='')
    cur = x_test[i]
    mr = evaluate(qw, cur)
    des = ''.join(map(lambda x: (lambda y: y[-2:]+y[:2])('%04x'%((x+2**16)%(2**16))), mr)) + '%02d'%(np.argmax(mr))
    
    in_str = ''.join(map(lambda x: '%02x'%x, cur.flatten()))
    ser.write(binascii.a2b_hex(in_str))
    collect = []
    while True:
        sr = ser.read()
        if sr: collect.append('%02x'%ord(sr))
        else: break
    result = ''.join(collect)
    #print(des, result, des == result)
    num_correct += int(('  '+result)[-1:]) == np.argmax(y_test[i])
    num_match += (result == des)
    num_total += 1
    print('\rOn %05d/%05d - Match: %.4f - Accuracy: %.4f - %s'%(
        i+1, x_val.shape[0], num_match/num_total, num_correct/num_total, result), end='')
print()
print('Overall Match: %.4f - Accuracy: %.4f'%(num_match/num_total, num_correct/num_total))

On 10000/10000 - Match: 1.0000 - Accuracy: 0.9915 - 720593fb6401ecff0100f0013f0c6ef93603cefd06
Overall Match: 1.0000 - Accuracy: 0.9915
