In [39]:
from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_digits
from sklearn.preprocessing import scale, StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np

In [40]:
(digits_x, digits_y) = load_digits(return_X_y = True)
# Scale the input to the range <0, 1>
digits_x = digits_x / np.max(digits_x)
scaler = StandardScaler()
scaler.fit(digits_x)
preprocessed_x = scaler.transform(digits_x)
train_x, test_x, train_y, test_y = train_test_split(preprocessed_x, digits_y)

In [41]:
import matplotlib.pyplot as plt

In [42]:
mlp = MLPClassifier(hidden_layer_sizes=(16, 16, 16), activation="relu", solver="lbfgs")
mlp.fit(train_x, train_y)

MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
              beta_2=0.999, early_stopping=False, epsilon=1e-08,
              hidden_layer_sizes=(16, 16, 16), learning_rate='constant',
              learning_rate_init=0.001, max_fun=15000, max_iter=200,
              momentum=0.9, n_iter_no_change=10, nesterovs_momentum=True,
              power_t=0.5, random_state=None, shuffle=True, solver='lbfgs',
              tol=0.0001, validation_fraction=0.1, verbose=False,
              warm_start=False)

In [43]:
from sklearn.metrics import accuracy_score
print(accuracy_score(mlp.predict(test_x), test_y))

0.9577777777777777


In [44]:
mlp.coefs_[0].shape

(64, 16)

In [45]:
def to_fixed(n):
    return np.uint16(round(n * 2048))
import struct
def output_row(r):
    xs = [to_fixed(n) for n in r]
    xs = ["%04x" % x for x in xs]
    s = "".join([(xs[i+1] if i+1 < len(xs) else "0000") + xs[i] for i in range(0, len(xs), 2)])
    s = s + "0" * (16 * 4 - len(s))
    s = ", ".join(["0x" + s[i:i+8] for i in range(0, len(s), 8)])
    return "dat " + s


In [46]:
import math
for (layer, (coefs, bias)) in enumerate(zip(mlp.coefs_, mlp.intercepts_)):
    for part in range(int(math.floor(coefs.shape[0]) / 16)):
        print(f"l{layer}_{part}:")
        for row in range(coefs.shape[1]):
            print(output_row(coefs[(part * 16):(part * 16 + 16) , row]))
    print("b" + str(layer) + ":")
    print(output_row(bias))
            

l0_0:
dat 0x021100a2, 0x005f0415, 0x0273fc5d, 0x003605ea, 0x00f8ff22, 0xfe1bffe0, 0xfa10f8fd, 0x00d8fdd2
dat 0x040cfe5c, 0x068c0680, 0x024503c8, 0x022605d6, 0x00d2ff26, 0x04ee016e, 0xfd6704fd, 0x0022fd72
dat 0xfe8c01a0, 0x0456faf8, 0x01200487, 0x01d5fd5e, 0x0122ffee, 0xf990ff0d, 0x02ec042b, 0x02bb0167
dat 0xff5bfff8, 0xfecefdb1, 0xfeb800d0, 0xfe97fcda, 0xff00ff5a, 0xfc76f8a1, 0xff8904b4, 0xffddfcd3
dat 0xfe150171, 0x014002b0, 0x05cc0293, 0xff9c00b0, 0xff6301cb, 0x01d30234, 0x0083fde2, 0x0035ffbb
dat 0xff1afeaf, 0xffc1fe4d, 0xfd66fd37, 0xfe88fdd5, 0xfd5b0152, 0xfe2cfe3d, 0x02e1fd41, 0x0194002e
dat 0xfe11013a, 0xfeadfe1b, 0xfbcdff19, 0xfd43fbe8, 0xfe7cfe3c, 0x0037006d, 0xfde3fcb1, 0xfff4ff63
dat 0xff9f01bb, 0x0269011d, 0x04600302, 0x00be00c6, 0xff52008c, 0x008b0215, 0x005c0093, 0xfea00231
dat 0x00e6fed6, 0xfc57fe4a, 0x01aa04f6, 0xfdfafac9, 0xffdaffb3, 0xff40fe3e, 0xfd9dfe74, 0xffaafeff
dat 0x017d00e7, 0x01eb01d3, 0xfe6aff37, 0xff230512, 0xffd5ff9d, 0x02fb012a, 0xfc2c00fd, 0xfe63fe15
dat 

# --------------------------------------------------

In [47]:
def encode_digit(digit):
    assert(len(digit) == 64)
    words = []
    for i in range(0, len(digit), 4):
        b = ["%02x" % int(x * 255) for x in digit[i:i+4]]
        b.reverse()
        words.append("".join(b))
    return words

In [50]:
s = StandardScaler()
s.fit(digits_x)
px = scaler.transform(digits_x)
for (digit, p) in zip(digits_x, px):
    argv = ", ".join([str(x) for x in digit])
    print(argv)
    print("Expected %s" % str(mlp.predict([p])))
    print("---------------------------------")

0.0, 0.0, 0.3125, 0.8125, 0.5625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.8125, 0.9375, 0.625, 0.9375, 0.3125, 0.0, 0.0, 0.1875, 0.9375, 0.125, 0.0, 0.6875, 0.5, 0.0, 0.0, 0.25, 0.75, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.3125, 0.5, 0.0, 0.0, 0.5625, 0.5, 0.0, 0.0, 0.25, 0.6875, 0.0, 0.0625, 0.75, 0.4375, 0.0, 0.0, 0.125, 0.875, 0.3125, 0.625, 0.75, 0.0, 0.0, 0.0, 0.0, 0.375, 0.8125, 0.625, 0.0, 0.0, 0.0
Expected [0]
---------------------------------
0.0, 0.0, 0.0, 0.75, 0.8125, 0.3125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6875, 1.0, 0.5625, 0.0, 0.0, 0.0, 0.0, 0.1875, 0.9375, 1.0, 0.375, 0.0, 0.0, 0.0, 0.4375, 0.9375, 1.0, 1.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0625, 1.0, 1.0, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.0625, 1.0, 1.0, 0.375, 0.0, 0.0, 0.0, 0.0, 0.0625, 1.0, 1.0, 0.375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6875, 1.0, 0.625, 0.0, 0.0
Expected [1]
---------------------------------
0.0, 0.0, 0.0, 0.25, 0.9375, 0.75, 0.0, 0.0, 0.0, 0.0, 0.1875, 1.0, 0.9375, 0.875, 0.0, 0.0, 0.0, 0.0, 0.5, 0.8125, 0.5, 1.0, 0.0, 0.0, 0.0, 0.0,

Expected [8]
---------------------------------
0.0, 0.1875, 0.9375, 1.0, 0.8125, 0.0625, 0.0, 0.0, 0.0, 0.625, 0.8125, 0.5625, 1.0, 0.25, 0.0, 0.0, 0.0, 0.0625, 0.0625, 0.0, 1.0, 0.375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.625, 0.9375, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.625, 1.0, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.1875, 1.0, 0.4375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.8125, 0.75, 0.4375, 0.125, 0.0, 0.0, 0.125, 0.8125, 0.8125, 0.8125, 1.0, 0.9375, 0.0
Expected [2]
---------------------------------
0.0, 0.1875, 0.8125, 1.0, 0.5625, 0.0, 0.0, 0.0, 0.0, 0.625, 0.9375, 0.8125, 0.9375, 0.125, 0.0, 0.0, 0.0, 0.9375, 0.25, 0.25, 1.0, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0625, 0.875, 0.8125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.625, 1.0, 0.3125, 0.0, 0.0, 0.0, 0.0, 0.25, 1.0, 0.8125, 0.5, 0.625, 0.5625, 0.0625, 0.0, 0.125, 1.0, 1.0, 0.875, 0.75, 0.5625, 0.0625
Expected [2]
---------------------------------
0.0, 0.0, 0.4375, 0.6875, 0.75, 0.875, 0.125, 0.0, 0.0, 0.5, 1.0, 0.5625, 0.

Expected [9]
---------------------------------
0.0, 0.0, 0.3125, 0.75, 0.75, 0.5, 0.0625, 0.0, 0.0, 0.0, 0.625, 1.0, 1.0, 0.9375, 0.0, 0.0, 0.0, 0.0, 0.6875, 1.0, 1.0, 0.5, 0.0, 0.0, 0.0, 0.25, 1.0, 1.0, 1.0, 0.25, 0.0, 0.0, 0.0, 0.1875, 1.0, 1.0, 0.625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8125, 1.0, 1.0, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.8125, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.125, 0.625, 0.75, 0.0, 0.0, 0.0
Expected [1]
---------------------------------
0.0, 0.0, 0.0, 0.4375, 0.875, 1.0, 0.375, 0.0, 0.0, 0.0, 0.625, 1.0, 0.75, 0.9375, 0.5625, 0.0, 0.0, 0.0, 0.5, 0.1875, 0.125, 1.0, 0.4375, 0.0, 0.0, 0.0, 0.0625, 0.5, 0.8125, 1.0, 0.875, 0.0, 0.0, 0.125, 0.8125, 1.0, 1.0, 0.75, 0.0625, 0.0, 0.0, 0.375, 0.75, 0.375, 1.0, 0.1875, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 0.8125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5625, 0.375, 0.0, 0.0, 0.0
Expected [7]
---------------------------------
0.0, 0.0, 0.1875, 0.6875, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5625, 0.8125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.9375, 0.25, 0.0, 0.0

Expected [7]
---------------------------------
0.0, 0.0, 0.0, 0.8125, 0.375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6875, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8125, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.875, 0.9375, 1.0, 0.875, 0.3125, 0.0, 0.0, 0.0, 0.8125, 0.4375, 0.0, 0.0, 0.8125, 0.0625, 0.0, 0.0, 0.625, 0.375, 0.0, 0.3125, 0.875, 0.0, 0.0, 0.0, 0.125, 0.8125, 0.75, 0.9375, 0.25, 0.0
Expected [6]
---------------------------------
0.0, 0.0, 0.625, 0.9375, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.25, 0.6875, 0.1875, 0.3125, 0.0, 0.0, 0.0, 0.875, 0.3125, 0.4375, 0.625, 0.4375, 0.0, 0.0, 0.0, 0.25, 0.8125, 0.75, 0.6875, 0.0, 0.0, 0.0, 0.0, 0.125, 0.875, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0625, 0.875, 0.4375, 0.75, 0.25, 0.0, 0.0, 0.0, 0.4375, 0.625, 0.0, 0.1875, 0.75, 0.0, 0.0, 0.0, 0.0625, 0.625, 0.6875, 0.75, 0.625, 0.0, 0.0
Expected [8]
---------------------------------
0.0, 0.0, 0.0, 0.5, 0.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.1875, 0.0, 0.125, 0.0, 0.0, 

---------------------------------
0.0, 0.0, 0.125, 0.6875, 0.625, 0.0625, 0.0, 0.0, 0.0, 0.0, 0.625, 0.8125, 0.875, 0.8125, 0.0, 0.0, 0.0, 0.0, 0.8125, 0.0, 0.0, 0.875, 0.3125, 0.0, 0.0, 0.1875, 0.5625, 0.0, 0.0, 0.5625, 0.375, 0.0, 0.0, 0.3125, 0.5625, 0.0, 0.0, 0.3125, 0.5, 0.0, 0.0, 0.375, 0.75, 0.0, 0.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.875, 0.6875, 0.3125, 0.875, 0.0625, 0.0, 0.0, 0.0, 0.1875, 0.8125, 0.875, 0.3125, 0.0, 0.0
Expected [0]
---------------------------------
0.0, 0.0, 0.0, 0.0, 0.375, 1.0, 0.4375, 0.0, 0.0, 0.0625, 0.3125, 0.6875, 1.0, 1.0, 0.5, 0.0, 0.0, 0.6875, 1.0, 1.0, 0.8125, 1.0, 0.5, 0.0, 0.0, 0.1875, 0.4375, 0.0625, 0.25, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3125, 1.0, 0.4375, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4375, 1.0, 0.5625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.5, 0.0
Expected [1]
---------------------------------
0.0, 0.0, 0.125, 0.8125, 1.0, 0.5625, 0.0, 0.0, 0.0, 0.0, 0.75, 0.75, 0.4375, 1.0, 0.1875, 0.0, 0.0, 0.0625, 0.87

# --------------------------------------------------

In [56]:
test_vec = digits_x[7]
print(encode_digit(test_vec))
buff = "dat 0x" + ", 0x".join(encode_digit(test_vec))
print(buff)

['7f6f0000', '0fefffcf', '6f6f0000', '00bfaf3f', '00000000', '000fcf7f', '7f7f3f00', '005fefef', 'efaf1f00', '00003fef', 'ff000000', '0000004f', 'ef8f0000', '0000000f', '4fcf0000', '00000000']
dat 0x7f6f0000, 0x0fefffcf, 0x6f6f0000, 0x00bfaf3f, 0x00000000, 0x000fcf7f, 0x7f7f3f00, 0x005fefef, 0xefaf1f00, 0x00003fef, 0xff000000, 0x0000004f, 0xef8f0000, 0x0000000f, 0x4fcf0000, 0x00000000


In [None]:
digits_y[1]

In [None]:
np.max(test_x)

In [None]:
test_y[0]

In [None]:
mlp.predict(test_vec.reshape(1, -1))

In [None]:
def apply_layer(mat, vec, bias):
    x = vec @ mat
    assert len(x) == len(bias)
    for i in range(len(x)):
        x[i] += bias[i]
        if x[i] < 0:
            x[i] = 0
        if x[i] > 16:
            x[i] = 16
    return x

def predict(x, verbose=True):
    i = 0
    for (coef, bias) in zip(mlp.coefs_, mlp.intercepts_):
        if verbose:
            print("LAYER " + str(i) + ": " + str(x))
        i += 1
        x = apply_layer(coef, x, bias)
    if verbose:
        print("LAYER " + str(i) + ": " + str(x))
    return x

In [None]:
predict(test_vec)

In [None]:
min(np.min(coefs) for coefs in mlp.coefs_)

In [None]:
max(np.max(coefs) for coefs in mlp.coefs_)

In [None]:
plt.matshow(scaler.mean_.reshape(8, 8))

In [None]:
plt.matshow(scaler.scale_.reshape(8, 8))

In [None]:
for i in range(0, len(scaler.mean_), 16):
    print(f'm_{int(i/16)}:')
    print(output_row([-x for x in scaler.mean_[i:i+16]]))

In [None]:
for i in range(0, len(scaler.scale_), 16):
    print(f's_{int(i/16)}:')
    print(output_row([1.0 / x for x in scaler.scale_[i:i+16]]))