In [2]:
import os
from math import log2, ceil, floor

import tensorflow as tf
import numpy as np

np.seterr(all='raise')

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}

In [3]:
HOME_DIR = os.getcwd()
MODEL_DIR = os.path.join(HOME_DIR, "models")

In [4]:
def quantize_nearest(x, scale, zero, qtype):
    if qtype not in {np.int8, np.uint8}:
        raise Exception("Only quantization to int8 or uint8 is supported")
    
    (min, max) = (-128, 127) if qtype == np.int8 else (0, 255)

    return np.clip(np.rint(x / scale) + zero, min, max).astype(qtype)

def fc_and_requantize(input_tensor, weights, bias, q_i, q_w, q_o):
    
    if input_tensor.dtype != np.int8:
        raise Exception("Input must be of type int8")
    
    if weights.dtype != np.int8:
        raise Exception("Weights must be of type int8")
    
    if bias.dtype != np.int32:
        raise Exception("Input and weights must be of type int32")
    
    (s_i, z_i), (s_w, z_w), (s_o, z_o) = q_i, q_w, q_o
    
    if z_w != 0:
        raise Exception("Expected zero point of weights to be 0")

    s = s_i * s_w / s_o

    # 1) shift input tensor
    input_tensor_32 = input_tensor.astype(np.int32) - z_i
    weights_32 = weights.astype(np.int32)

    # 2) compute the bmm
    bmm = np.matmul(input_tensor_32, weights_32.transpose()) + bias

    # 3) requantize
    rq = np.rint(s * bmm) + z_o

    # 4) saturating cast
    output = np.clip(rq, -228, 127).astype(np.int8)

    return output

In [5]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

In [6]:
interpreter = tf.lite.Interpreter(os.path.join(MODEL_DIR, "two_layer_perceptron_frozen.tflite"), experimental_preserve_all_tensors=True)
interpreter.allocate_tensors()

In [7]:
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]

print(input_details)
print(input_details["dtype"])

print(output_details)
print(output_details["dtype"])

{'name': 'serving_default_flatten_4_input:0', 'index': 0, 'shape': array([ 1, 28, 28], dtype=int32), 'shape_signature': array([-1, 28, 28], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.003921568859368563, 0), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([0], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>
{'name': 'StatefulPartitionedCall:0', 'index': 10, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([-1, 10], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.20425860583782196, 159), 'quantization_parameters': {'scales': array([0.2042586], dtype=float32), 'zero_points': array([159], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}
<class 'numpy.uint8'>


In [8]:
interpreter.get_tensor_details()

[{'name': 'serving_default_flatten_4_input:0',
  'index': 0,
  'shape': array([ 1, 28, 28], dtype=int32),
  'shape_signature': array([-1, 28, 28], dtype=int32),
  'dtype': numpy.uint8,
  'quantization': (0.003921568859368563, 0),
  'quantization_parameters': {'scales': array([0.00392157], dtype=float32),
   'zero_points': array([0], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_4/flatten_4/Const',
  'index': 1,
  'shape': array([2], dtype=int32),
  'shape_signature': array([2], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},
 {'name': 'sequential_4/dense_7/BiasAdd/ReadVariableOp',
  'index': 2,
  'shape': array([10], dtype=int32),
  'shape_signature': array([10], dtype=int32),
  'dtype': numpy.int32,
  'quantization': (0.0006746734725311399, 0),
  'quant

In [9]:
for i, t in enumerate(interpreter.get_tensor_details()):
    print(i, ": ", t["name"], sep="")

0: serving_default_flatten_4_input:0
1: sequential_4/flatten_4/Const
2: sequential_4/dense_7/BiasAdd/ReadVariableOp
3: sequential_4/dense_7/MatMul
4: sequential_4/dense_6/BiasAdd/ReadVariableOp
5: sequential_4/dense_6/MatMul
6: tfl.quantize
7: sequential_4/flatten_4/Reshape
8: sequential_4/dense_6/MatMul;sequential_4/re_lu_2/Relu;sequential_4/dense_6/BiasAdd
9: StatefulPartitionedCall:01
10: StatefulPartitionedCall:0


In [10]:
interpreter._get_ops_details()

[{'index': 0,
  'op_name': 'QUANTIZE',
  'inputs': array([0], dtype=int32),
  'outputs': array([6], dtype=int32)},
 {'index': 1,
  'op_name': 'RESHAPE',
  'inputs': array([6, 1], dtype=int32),
  'outputs': array([7], dtype=int32)},
 {'index': 2,
  'op_name': 'FULLY_CONNECTED',
  'inputs': array([7, 5, 4], dtype=int32),
  'outputs': array([8], dtype=int32)},
 {'index': 3,
  'op_name': 'FULLY_CONNECTED',
  'inputs': array([8, 3, 2], dtype=int32),
  'outputs': array([9], dtype=int32)},
 {'index': 4,
  'op_name': 'QUANTIZE',
  'inputs': array([9], dtype=int32),
  'outputs': array([10], dtype=int32)}]

In [11]:
input_1_idx = 7
weight_1_idx = 5
bias_1_idx = 4
output_1_idx = 8

w_1 = interpreter.get_tensor(weight_1_idx)
b_1 = interpreter.get_tensor(bias_1_idx)
q_1_i = interpreter.get_tensor_details()[input_1_idx]["quantization"]
q_1_w = interpreter.get_tensor_details()[weight_1_idx]["quantization"]
q_1_o = interpreter.get_tensor_details()[output_1_idx]["quantization"]

In [12]:
input_2_idx = 8
weight_2_idx = 3
bias_2_idx = 2
output_2_idx = 9

w_2 = interpreter.get_tensor(weight_2_idx)
b_2 = interpreter.get_tensor(bias_2_idx)
q_2_i = interpreter.get_tensor_details()[input_2_idx]["quantization"]
q_2_w = interpreter.get_tensor_details()[weight_2_idx]["quantization"]
q_2_o = interpreter.get_tensor_details()[output_2_idx]["quantization"]

In [13]:
peek = 0

In [29]:
def manual_model(input_tensor):

    global peek

    # 1) shift input tensor by -128 to sitch from input type (uint8) to TF Lite internal type (int8) 
    shifted_input = input_tensor.astype(np.int32)
    shifted_input = shifted_input - 128
    shifted_input = shifted_input.astype(np.int8)

    # 2) flatten input
    flattened_input = shifted_input.reshape(interpreter.get_tensor(1)) # [-1, 784]
    
    # 3) first fully-connected layer
    fc1 = fc_and_requantize(flattened_input, w_1, b_1, q_1_i, q_1_w, q_1_o)

    # 4) relu
    print(fc1)
    relu1 = np.maximum(fc1, 0)
    print(relu1)

    peek = fc1

    # 5) second fully-connected layer
    fc2 = fc_and_requantize(relu1, w_2, b_2, q_2_i, q_2_w, q_2_o)

    # 4) undo the shift to switch from TF Lite internal type (int8) to output type (uint8)
    output = fc2.astype(np.int32)
    output = output + 128
    output = output.astype(np.uint8)

    return output

## Execution

In [33]:
chosen_image = 150

In [34]:
test_image = x_test[chosen_image]

# Need to quantize the inputs outside the model!
input_scale, input_zero_point = input_details["quantization"]
input_tensor = quantize_nearest(test_image, input_scale, input_zero_point, np.uint8)
input_tensor = np.expand_dims(input_tensor, axis=0)

# Run the TF Lite model
interpreter.set_tensor(input_details["index"], input_tensor)
interpreter.invoke()
tflite_output = interpreter.get_tensor(output_details["index"])[0]

# Run the manual model
manual_output = manual_model(input_tensor)[0]

print("Test image {}\nTF Lite output:\t\t{}\nManual model output:\t{}\nCorrect label: {}".format(chosen_image, tflite_output, manual_output, y_test[chosen_image]))

[[ -23  -67  -43 -102 -109 -120  112 -108   71  -79  111 -120  -56  -82
  -116   -8  -78  -46  -85 -109  -88 -114  123  116   80  -69  -58 -110]]
[[  0   0   0   0   0   0 112   0  71   0 111   0   0   0   0   0   0   0
    0   0   0   0 123 116  80   0   0   0]]
Test image 150
TF Lite output:		[138 106 149 160 174 152 141 146 169 207]
Manual model output:	[160 132 139 162  68 145  85 200  16  28]
Correct label: 9


In [28]:
(manual_output == tflite_output).all()

False

In [33]:
# np.set_printoptions(precision=30, suppress=True)
# x_test[chosen_image]

In [44]:
def print_2d(t):
    for st in t[0]:
        print("[{}],".format(", ".join([str(e) for e in st])))

In [45]:
print_2d(input_tensor)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 131, 254, 254, 215, 163, 163, 143, 63, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 70, 192, 198, 198, 198, 234, 253, 237, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 86, 253, 253, 144, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 253, 253, 137, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 33, 228, 253, 227, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 12, 228, 255, 238, 91, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

# Displaying parameters for Rust

In [37]:
interpreter.get_input_details()[0]["quantization"]

(0.003921568859368563, 0)

In [40]:
print(q_1_i, q_1_w, q_1_o, sep="\n")
print(q_2_i, q_2_w, q_2_o, sep="\n")

(0.003921568859368563, -128)
(0.006542891729623079, 0)
(0.059290364384651184, -128)
(0.059290364384651184, -128)
(0.011379142291843891, 0)
(0.20425860583782196, 31)


In [65]:
for e in w_2.flatten():
    print("    ", e, ",", sep="")

    -17,
    23,
    23,
    -24,
    2,
    4,
    19,
    26,
    2,
    -34,
    36,
    32,
    -79,
    41,
    -41,
    25,
    40,
    -27,
    21,
    1,
    15,
    -57,
    56,
    -17,
    -40,
    -49,
    -32,
    -7,
    -10,
    -31,
    -87,
    24,
    20,
    -35,
    -17,
    -84,
    39,
    -28,
    -16,
    -43,
    41,
    -68,
    25,
    -56,
    42,
    39,
    -17,
    -10,
    74,
    29,
    30,
    37,
    22,
    -56,
    -24,
    24,
    40,
    -10,
    0,
    -109,
    30,
    2,
    33,
    45,
    -102,
    3,
    -15,
    7,
    -33,
    48,
    34,
    -39,
    28,
    -20,
    -20,
    -63,
    -22,
    11,
    8,
    52,
    -3,
    31,
    -3,
    4,
    32,
    -18,
    -60,
    50,
    -22,
    -31,
    -22,
    -31,
    -57,
    37,
    37,
    16,
    -24,
    16,
    -19,
    5,
    8,
    -34,
    -9,
    37,
    1,
    22,
    -16,
    36,
    15,
    35,
    34,
    -38,
    34,
    -31,
    50,
    24,
    -21,
    25,
    -46,
    20,


In [66]:
for e in w_2.flatten():
    print("    ", e, ",", sep="")

    -17,
    23,
    23,
    -24,
    2,
    4,
    19,
    26,
    2,
    -34,
    36,
    32,
    -79,
    41,
    -41,
    25,
    40,
    -27,
    21,
    1,
    15,
    -57,
    56,
    -17,
    -40,
    -49,
    -32,
    -7,
    -10,
    -31,
    -87,
    24,
    20,
    -35,
    -17,
    -84,
    39,
    -28,
    -16,
    -43,
    41,
    -68,
    25,
    -56,
    42,
    39,
    -17,
    -10,
    74,
    29,
    30,
    37,
    22,
    -56,
    -24,
    24,
    40,
    -10,
    0,
    -109,
    30,
    2,
    33,
    45,
    -102,
    3,
    -15,
    7,
    -33,
    48,
    34,
    -39,
    28,
    -20,
    -20,
    -63,
    -22,
    11,
    8,
    52,
    -3,
    31,
    -3,
    4,
    32,
    -18,
    -60,
    50,
    -22,
    -31,
    -22,
    -31,
    -57,
    37,
    37,
    16,
    -24,
    16,
    -19,
    5,
    8,
    -34,
    -9,
    37,
    1,
    22,
    -16,
    36,
    15,
    35,
    34,
    -38,
    34,
    -31,
    50,
    24,
    -21,
    25,
    -46,
    20,


In [68]:
for e in b_2.flatten():
    print("    ", e, ",", sep="")

    -300,
    288,
    384,
    -241,
    73,
    423,
    -301,
    -47,
    -173,
    -227,


# Back of the envelope

In [17]:
input_fc1_tflite = interpreter.get_tensor(7)
post_fc1_relu_tflite = interpreter.get_tensor(8)

In [20]:
interpreter.get_tensor(8)

array([[ -87, -100, -128, -106,  -73, -127, -101, -112,  -76,    5,  -92,
        -120, -128,  -81,  -81,  -54,  -82, -128,  -80,  -74,  -70,  -91,
        -128, -110,  -66, -128,  -52, -105]], dtype=int8)

In [19]:
peek

array([[ -87, -100,  119, -106,  -73, -127, -101, -112,  -76,    5,  -92,
        -120,  115,  -81,  -81,  -54,  -82, -128,  -80,  -74,  -70,  -91,
         117, -110,  -66,   82,  -52, -105]], dtype=int8)

In [21]:
for t in interpreter.get_tensor_details():
    print(t["name"])

serving_default_flatten_4_input:0
sequential_4/flatten_4/Const
sequential_4/dense_7/BiasAdd/ReadVariableOp
sequential_4/dense_7/MatMul
sequential_4/dense_6/BiasAdd/ReadVariableOp
sequential_4/dense_6/MatMul
tfl.quantize
sequential_4/flatten_4/Reshape
sequential_4/dense_6/MatMul;sequential_4/re_lu_2/Relu;sequential_4/dense_6/BiasAdd
StatefulPartitionedCall:01
StatefulPartitionedCall:0


In [32]:
interpreter.get_tensor(8)
# [[ -87 -100  119 -106  -73 -127 -101 -112  -76    5  -92 -120  115  -81
#    -81  -54  -82 -128  -80  -74  -70  -91  117 -110  -66   82  -52 -105]]

array([[ -87, -100, -128, -106,  -73, -127, -101, -112,  -76,    5,  -92,
        -120, -128,  -81,  -81,  -54,  -82, -128,  -80,  -74,  -70,  -91,
        -128, -110,  -66, -128,  -52, -105]], dtype=int8)

In [None]:
# mi [ -87, -100,  119, -106,  -73, -127, -101, -112,  -76,  5,  -92, -120,  115,  -81,  -81,  -54,  -82, -128,  -80,  -74,  -70,  -91,  117, -110,  -66,   82,  -52, -105]
# in [ -87, -100, -128, -106,  -73, -127, -101, -112,  -76,  5,  -92, -120, -128,  -81,  -81,  -54,  -82, -128,  -80,  -74,  -70,  -91, -128, -110,  -66, -128,  -52, -105]
#                  h                                                                                                                    h                 h                